├── .flake8 ├── .github └── workflows │ ├── mannu_build.yml │ ├── python-package.yml │ ├── release.yml │ └── scripts │ ├── cuda-install.sh │ ├── github_create_release.js │ └── pytorch-install.sh ├── .gitignore ├── .style.yapf ├── LICENSE.txt ├── MANIFEST.in ├── README.md ├── assets └── fb201d9c-f889-4504-9ef5-ac77ec1cd8e2.jpg ├── csrc ├── awq_cuda │ ├── attention │ │ ├── cuda_bf16_fallbacks.cuh │ │ ├── cuda_bf16_wrapper.h │ │ ├── decoder_masked_multihead_attention.cu │ │ ├── decoder_masked_multihead_attention.h │ │ ├── decoder_masked_multihead_attention_template.hpp │ │ ├── decoder_masked_multihead_attention_utils.h │ │ ├── ft_attention.cpp │ │ └── ft_attention.h │ ├── layernorm │ │ ├── layernorm.cu │ │ ├── layernorm.h │ │ └── reduction.cuh │ ├── position_embedding │ │ ├── pos_encoding.h │ │ └── pos_encoding_kernels.cu │ ├── pybind_awq.cpp │ ├── pybind_ft.cpp │ └── quantization │ │ ├── dequantize.cuh │ │ ├── gemm_cuda.h │ │ ├── gemm_cuda_gen.cu │ │ ├── gemv_cuda.cu │ │ ├── gemv_cuda.h │ │ ├── marlin_cuda.cpp │ │ ├── marlin_cuda_kernel.cu │ │ └── marlin_cuda_kernel.cuh └── ort_cuda │ ├── common.cuh │ ├── dq.cu │ ├── dq_gemv.cu │ └── ort_ops.cc ├── pyproject.toml ├── qllm ├── __init__.py ├── __main__.py ├── args_config.py ├── auto_datasets │ └── __init__.py ├── auto_model_quantization.py ├── custom │ ├── __init__.py │ ├── __main__.py │ ├── m_mpt.py │ └── run.py ├── modeling │ ├── __init__.py │ ├── base.py │ ├── config.py │ └── q_layers │ │ ├── __init__.py │ │ ├── compress_weight.py │ │ ├── custom_autotune.py │ │ ├── ext_package_checker.py │ │ ├── quant_linear_awq.py │ │ ├── quant_linear_gptq.py │ │ ├── quant_linear_hqq.py │ │ ├── quant_linear_marlin.py │ │ ├── quant_linear_onnxruntime.py │ │ ├── quant_linear_triton.py │ │ ├── quant_linear_vptq.py │ │ └── triton_norm.py ├── plugin │ ├── __init__.py │ ├── chatcli │ │ ├── README.md │ │ ├── __init__.py │ │ ├── chatio.py │ │ ├── conversation.py │ │ ├── generation.py │ │ └── inference.py │ ├── conversation.py │ └── perplexity_utils.py ├── quantization │ ├── __init__.py │ ├── awq │ │ ├── __init__.py │ │ ├── _awq_quantizer.py │ │ ├── quant_awq.py │ │ └── sequential_layes_awq_config.py │ ├── config_builder.py │ ├── gptq │ │ ├── __init__.py │ │ ├── _gptq_quantizer.py │ │ ├── gptq.py │ │ ├── quant_gptq.py │ │ └── sequential_layes_gptq_config.py │ ├── hqq │ │ ├── __init__.py │ │ ├── _hqq_quantizer.py │ │ └── quant_hqq.py │ ├── method.py │ ├── quant_frame_base.py │ └── vptq │ │ ├── __init__.py │ │ ├── _vptq_quantizer.py │ │ ├── hessian_collector.py │ │ ├── inv_hessian.py │ │ ├── merge_hessian.py │ │ ├── qllm_hessian.py │ │ └── quant_vptq.py ├── run.py └── utils │ ├── __init__.py │ ├── comm_utils.py │ ├── datautils.py │ ├── logger.py │ ├── modelutils.py │ └── onnx │ ├── __init__.py │ ├── exporter.py │ └── merge_encoder_decoder.py ├── qllm_colab.ipynb ├── requirements.txt └── setup.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E501,E701,E731,W503 3 | exclude = .git,__pycache__,build,dist, 4 | ./qllm/plugin 5 | max-line-length = 120 6 | 7 | 8 | -------------------------------------------------------------------------------- /.github/workflows/mannu_build.yml: -------------------------------------------------------------------------------- 1 | name: build_wheels mannuly 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | name: 7 | description: "need OptionalCUDAGuard?" 8 | default: 'false' 9 | 10 | jobs: 11 | hello: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - name: Hello Step 16 | shell: bash 17 | run: echo "Hello ${{ github.event.inputs.name }}" 18 | build_wheels: 19 | name: Build qllm 20 | runs-on: ${{ matrix.os }} 21 | 22 | strategy: 23 | matrix: 24 | os: [ubuntu-20.04, windows-latest] 25 | pyver: ["3.10"] 26 | cuda: ["12.6.0"] 27 | defaults: 28 | run: 29 | shell: pwsh 30 | env: 31 | PYPI_CUDA_VERSION: "11.8.0" 32 | CUDA_VERSION: ${{ matrix.cuda }} 33 | 34 | steps: 35 | - name: Free Disk Space 36 | uses: jlumbroso/free-disk-space@v1.3.0 37 | if: runner.os == 'Linux' 38 | with: 39 | tool-cache: false 40 | android: true 41 | dotnet: true 42 | haskell: true 43 | large-packages: false 44 | docker-images: true 45 | swap-storage: false 46 | 47 | - uses: actions/checkout@v3 48 | 49 | - uses: actions/setup-python@v3 50 | with: 51 | python-version: ${{ matrix.pyver }} 52 | 53 | - name: Setup Mamba 54 | uses: conda-incubator/setup-miniconda@v2.2.0 55 | with: 56 | activate-environment: "build" 57 | python-version: ${{ matrix.pyver }} 58 | miniforge-variant: Miniforge3 59 | miniforge-version: latest 60 | use-mamba: true 61 | add-pip-as-python-dependency: true 62 | auto-activate-base: false 63 | 64 | - name: Install Dependencies 65 | run: | 66 | # Install CUDA toolkit 67 | mamba install -y 'cuda' -c "nvidia/label/cuda-${env:CUDA_VERSION}" 68 | # Env variables 69 | $env:CUDA_PATH = $env:CONDA_PREFIX 70 | $env:CUDA_HOME = $env:CONDA_PREFIX 71 | 72 | # Install torch 73 | $cudaVersion = $env:CUDA_VERSION.Replace('.', '') 74 | $cudaVersionPytorch = $cudaVersion.Substring(0, $cudaVersion.Length - 1) 75 | if ([int]$cudaVersionPytorch -gt 118) { $pytorchVersion = "torch==2.2.2" } else {$pytorchVersion = "torch==2.2.2"} 76 | python -m pip install --upgrade --no-cache-dir $pytorchVersion+cu$cudaVersionPytorch --index-url https://download.pytorch.org/whl/cu$cudaVersionPytorch 77 | python -m pip install build setuptools==69.5.1 wheel ninja 78 | # Print version information 79 | python --version 80 | python -c "import torch; print('PyTorch:', torch.__version__)" 81 | python -c "import torch; print('CUDA:', torch.version.cuda)" 82 | python -c "import os; print('CUDA_HOME:', os.getenv('CUDA_HOME', None))" 83 | python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" 84 | - name: Build Wheel 85 | run: | 86 | $env:CUDA_PATH = $env:CONDA_PREFIX 87 | $env:CUDA_HOME = $env:CONDA_PREFIX 88 | # Only add +cu118 to wheel if not releasing on PyPi 89 | if ( $env:CUDA_VERSION -eq $env:PYPI_CUDA_VERSION ){ 90 | $env:PYPI_BUILD = 1 91 | } 92 | # echo "{CUDA_VERSION}=$env:CUDA_VERSION" >> $GITHUB_ENV 93 | if (${{ github.event.inputs.name }} -eq 'true') { 94 | echo "set GENERAL_TORCH to true !!!!!!!!!!!!!!!" 95 | $env:GENERAL_TORCH = 1 96 | } 97 | python setup.py sdist bdist_wheel -k $env:PLAT_ARG.split() 98 | env: 99 | PLAT_ARG: ${{ contains(runner.os, 'Linux') && '--plat-name manylinux2014_x86_64' || '--plat-name win_amd64' }} 100 | 101 | - uses: actions/upload-artifact@v4 102 | with: 103 | name: wheel-${{ matrix.os }} 104 | path: ./dist/*.whl 105 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Build Wheels with CUDA/CPU 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | build_wheels: 14 | name: Build QLLM 15 | runs-on: ${{ matrix.os }} 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | os: [ubuntu-latest] 20 | python-version: ["3.10","3.12"] 21 | cuda-version: ["12.4"] 22 | 23 | steps: 24 | - name: Free Disk Space 25 | uses: jlumbroso/free-disk-space@v1.3.0 26 | if: runner.os == 'Linux' && github.event.pull_request.merged == true 27 | with: 28 | tool-cache: false 29 | android: true 30 | dotnet: true 31 | haskell: true 32 | large-packages: false 33 | docker-images: true 34 | swap-storage: false 35 | 36 | - uses: actions/checkout@v3 37 | 38 | - name: Set up Python ${{ matrix.python-version }} 39 | uses: actions/setup-python@v3 40 | with: 41 | python-version: ${{ matrix.python-version }} 42 | 43 | - name: Install CUDA ${{ matrix.cuda-version }} 44 | if: github.event.pull_request.merged == true 45 | run: | 46 | bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ubuntu-20.04 47 | 48 | - name: Install PyTorch 2.2.2 with CUDA ${{ matrix.cuda-version }} 49 | if: github.event.pull_request.merged == true 50 | run: | 51 | pip config set global.cache-dir "/tmp/.cache/pip" 52 | bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} 2.2.2 ${{ matrix.cuda-version }} 53 | 54 | - name: Install dependencies 55 | run: | 56 | echo "${{ github.event.pull_request.merged }}" 57 | python -m pip install --upgrade pip 58 | python -m pip install flake8 pytest build wheel packaging 59 | #if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 60 | - name: Lint with flake8 61 | run: | 62 | # stop the build if there are Python syntax errors or undefined names 63 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 64 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 65 | flake8 ./qllm/modeling/q_layers --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 66 | flake8 ./qllm/quantization --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 67 | flake8 ./qllm/utils --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 68 | - name: build pypi package 69 | if: github.event.pull_request.merged == true 70 | run: | 71 | set -x 72 | python setup.py bdist_wheel --dist-dir=dist 73 | 74 | - name: build pypi package-cpu 75 | if: github.event.pull_request.merged == false 76 | run: | 77 | set -x 78 | MAX_JOBS=2 python -m build 79 | - name: upload wheel artifacts 80 | uses: actions/upload-artifact@v4 81 | with: 82 | name: wheel-${{ matrix.python-version }} 83 | path: dist/*.whl 84 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Build release Wheels 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v*" 7 | 8 | jobs: 9 | build_wheels: 10 | name: Build qllm 11 | runs-on: ${{ matrix.os }} 12 | # needs: release 13 | 14 | strategy: 15 | matrix: 16 | os: [ubuntu-20.04, windows-latest] 17 | pyver: ["3.10", "3.11", "3.12"] 18 | cuda: ["12.4.1"] 19 | defaults: 20 | run: 21 | shell: pwsh 22 | env: 23 | PYPI_CUDA_VERSION: "12.1.1" 24 | CUDA_VERSION: ${{ matrix.cuda }} 25 | 26 | steps: 27 | - name: Free Disk Space 28 | uses: jlumbroso/free-disk-space@v1.3.0 29 | if: runner.os == 'Linux' 30 | with: 31 | tool-cache: false 32 | android: true 33 | dotnet: true 34 | haskell: true 35 | large-packages: false 36 | docker-images: true 37 | swap-storage: false 38 | 39 | - uses: actions/checkout@v3 40 | 41 | - uses: actions/setup-python@v3 42 | with: 43 | python-version: ${{ matrix.pyver }} 44 | 45 | - name: Setup Mamba 46 | uses: conda-incubator/setup-miniconda@v2.2.0 47 | with: 48 | activate-environment: "build" 49 | python-version: ${{ matrix.pyver }} 50 | miniforge-variant: Miniforge3 51 | miniforge-version: latest 52 | use-mamba: true 53 | add-pip-as-python-dependency: true 54 | auto-activate-base: false 55 | 56 | - name: Install Dependencies 57 | run: | 58 | # Install CUDA toolkit 59 | mamba install -y 'cuda' -c "nvidia/label/cuda-${env:CUDA_VERSION}" 60 | # Env variables 61 | $env:CUDA_PATH = $env:CONDA_PREFIX 62 | $env:CUDA_HOME = $env:CONDA_PREFIX 63 | 64 | # Install torch 65 | $cudaVersion = $env:CUDA_VERSION.Replace('.', '') 66 | $cudaVersionPytorch = $cudaVersion.Substring(0, $cudaVersion.Length - 1) 67 | if ([int]$cudaVersionPytorch -gt 121) { $pytorchVersion = "torch==2.5.1" } else {$pytorchVersion = "torch==2.4.1"} 68 | echo "pytorchVersion=$pytorchVersion" 69 | echo "cudaVersion=<$cudaVersion>" 70 | echo "cudaVersionPytorch=$cudaVersionPytorch" 71 | python -m pip install --upgrade --no-cache-dir $pytorchVersion+cu$cudaVersionPytorch --index-url https://download.pytorch.org/whl/cu$cudaVersionPytorch 72 | python -m pip install build setuptools==69.5.1 wheel ninja 73 | # Print version information 74 | python --version 75 | python -c "import torch; print('PyTorch:', torch.__version__)" 76 | python -c "import torch; print('CUDA:', torch.version.cuda)" 77 | python -c "import os; print('CUDA_HOME:', os.getenv('CUDA_HOME', None))" 78 | python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" 79 | - name: Build Wheel 80 | run: | 81 | $env:CUDA_PATH = $env:CONDA_PREFIX 82 | $env:CUDA_HOME = $env:CONDA_PREFIX 83 | # Only add +cu118 to wheel if not releasing on PyPi 84 | if ( $env:CUDA_VERSION -eq $env:PYPI_CUDA_VERSION ){ 85 | $env:PYPI_BUILD = 1 86 | } 87 | # echo "{CUDA_VERSION}=$env:CUDA_VERSION" >> $GITHUB_ENV 88 | $env:GENERAL_TORCH = 1 # OptionalCUDAGuard 89 | python setup.py sdist bdist_wheel -k $env:PLAT_ARG.split() 90 | ls dist/*.whl 91 | env: 92 | PLAT_ARG: ${{ contains(runner.os, 'Linux') && '--plat-name manylinux2014_x86_64' || '--plat-name win_amd64' }} 93 | 94 | - uses: actions/upload-artifact@v4 95 | with: 96 | name: 'wheels_${{runner.os}}_${{matrix.pyver}}_${{matrix.cuda}}' 97 | path: ./dist/*.whl 98 | overwrite: true 99 | 100 | release: 101 | # Retrieve tag and create release 102 | name: Create Release 103 | runs-on: ubuntu-latest 104 | needs: build_wheels 105 | outputs: 106 | upload_url: ${{ steps.create_release.outputs.upload_url }} 107 | steps: 108 | - name: Checkout 109 | uses: actions/checkout@v3 110 | 111 | - name: Extract branch info 112 | shell: bash 113 | run: | 114 | echo "release_tag=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV 115 | - name: Create Release 116 | id: create_release 117 | uses: "actions/github-script@v6" 118 | env: 119 | RELEASE_TAG: ${{ env.release_tag }} 120 | with: 121 | github-token: "${{ secrets.GITHUB_TOKEN }}" 122 | script: | 123 | const script = require('.github/workflows/scripts/github_create_release.js') 124 | await script(github, context, core) 125 | - uses: actions/download-artifact@v4 126 | with: 127 | pattern: wheels_* 128 | merge-multiple: true 129 | path: ./dist/ 130 | - name: Upload Assets 131 | uses: shogo82148/actions-upload-release-asset@v1 132 | with: 133 | upload_url: ${{ steps.create_release.outputs.upload_url }} 134 | asset_path: ./dist/*.whl -------------------------------------------------------------------------------- /.github/workflows/scripts/cuda-install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir /tmp/cuda-$1 4 | sudo ln -s /tmp/cuda-$1 /usr/local/cuda-$1 5 | # Replace '.' with '-' ex: 11.8 -> 11-8 6 | cuda_version=$(echo $1 | tr "." "-") 7 | # Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004 8 | OS=$(echo $2 | tr -d ".\-") 9 | 10 | # Installs CUDA 11 | wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb 12 | sudo dpkg -i cuda-keyring_1.1-1_all.deb 13 | rm cuda-keyring_1.1-1_all.deb 14 | sudo apt -qq update 15 | sudo apt -y install cuda-${cuda_version} cuda-nvcc-${cuda_version} cuda-libraries-dev-${cuda_version} 16 | sudo apt clean 17 | 18 | # Test nvcc 19 | PATH=/usr/local/cuda-$1/bin:${PATH} 20 | nvcc --version 21 | 22 | # Log gcc, g++, c++ versions 23 | gcc --version 24 | g++ --version 25 | c++ --version 26 | -------------------------------------------------------------------------------- /.github/workflows/scripts/github_create_release.js: -------------------------------------------------------------------------------- 1 | module.exports = async (github, context, core) => { 2 | try { 3 | const response = await github.rest.repos.createRelease({ 4 | draft: false, 5 | generate_release_notes: true, 6 | name: process.env.RELEASE_TAG, 7 | owner: context.repo.owner, 8 | prerelease: false, 9 | repo: context.repo.repo, 10 | tag_name: process.env.RELEASE_TAG, 11 | }); 12 | 13 | core.setOutput('upload_url', response.data.upload_url); 14 | } catch (error) { 15 | core.setFailed(error.message); 16 | } 17 | } -------------------------------------------------------------------------------- /.github/workflows/scripts/pytorch-install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python_executable=python$1 4 | pytorch_version=$2 5 | cuda_version=$3 6 | 7 | # Install torch 8 | $python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools==69.5.1 && conda clean -ya 9 | $python_executable -m pip install torch==${pytorch_version}+cu${cuda_version//./} --extra-index-url https://download.pytorch.org/whl/cu${cuda_version//./} 10 | 11 | # Print version information 12 | $python_executable --version 13 | $python_executable -c "import torch; print('PyTorch:', torch.__version__)" 14 | $python_executable -c "import torch; print('CUDA:', torch.version.cuda)" 15 | $python_executable -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | build 3 | *.egg-info 4 | *.so 5 | Llama* 6 | *.json 7 | dist 8 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = pep8 3 | column_limit = 200 4 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE.txt 2 | include requirements.txt 3 | 4 | recursive-include src * 5 | -------------------------------------------------------------------------------- /assets/fb201d9c-f889-4504-9ef5-ac77ec1cd8e2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wejoncy/QLLM/df20c15920bfabfd0581d7fcccbe87e5c96cd5c7/assets/fb201d9c-f889-4504-9ef5-ac77ec1cd8e2.jpg -------------------------------------------------------------------------------- /csrc/awq_cuda/attention/cuda_bf16_wrapper.h: -------------------------------------------------------------------------------- 1 | // Downloaded from from FasterTransformer v5.2.1 2 | // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h 3 | /* 4 | * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | 19 | #pragma once 20 | 21 | #ifdef ENABLE_BF16 22 | #include 23 | #endif 24 | -------------------------------------------------------------------------------- /csrc/awq_cuda/attention/decoder_masked_multihead_attention.cu: -------------------------------------------------------------------------------- 1 | // Adapted from from FasterTransformer v5.2.1 2 | // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu 3 | /* 4 | * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | 19 | #include "decoder_masked_multihead_attention.h" 20 | #include "decoder_masked_multihead_attention_utils.h" 21 | #include "cuda_bf16_wrapper.h" 22 | #include 23 | #include 24 | #include 25 | 26 | #include "decoder_masked_multihead_attention_template.hpp" 27 | 28 | //////////////////////////////////////////////////////////////////////////////////////////////////// 29 | 30 | #define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \ 31 | size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ 32 | auto kernel = mmha::masked_multihead_attention_kernel; \ 34 | if (smem_sz >= 48 * 1024) { \ 35 | cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ 36 | } \ 37 | dim3 grid(params.num_heads, params.batch_size); \ 38 | kernel<<>>(params) 39 | 40 | //////////////////////////////////////////////////////////////////////////////////////////////////// 41 | 42 | // !!! Specialize the launcher for Cross attention 43 | template 44 | void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) 45 | { 46 | constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; 47 | constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; 48 | int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; 49 | // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION); 50 | if (tlength < 32) { 51 | MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream); 52 | } 53 | else if (tlength < 2048) { 54 | MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream); 55 | } 56 | else { 57 | MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream); 58 | } 59 | } 60 | 61 | //////////////////////////////////////////////////////////////////////////////////////////////////// 62 | 63 | #undef MMHA_LAUNCH_KERNEL 64 | 65 | template 66 | void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) 67 | { 68 | switch (params.hidden_size_per_head) { 69 | case 32: 70 | mmha_launch_kernel(params, stream); 71 | break; 72 | case 48: 73 | mmha_launch_kernel(params, stream); 74 | break; 75 | case 64: 76 | mmha_launch_kernel(params, stream); 77 | break; 78 | case 80: 79 | mmha_launch_kernel(params, stream); 80 | break; 81 | case 96: 82 | mmha_launch_kernel(params, stream); 83 | break; 84 | case 112: 85 | mmha_launch_kernel(params, stream); 86 | break; 87 | case 128: 88 | mmha_launch_kernel(params, stream); 89 | break; 90 | case 160: 91 | mmha_launch_kernel(params, stream); 92 | break; 93 | case 192: 94 | mmha_launch_kernel(params, stream); 95 | break; 96 | case 224: 97 | mmha_launch_kernel(params, stream); 98 | break; 99 | case 256: 100 | mmha_launch_kernel(params, stream); 101 | break; 102 | default: 103 | assert(false); 104 | } 105 | } 106 | 107 | //////////////////////////////////////////////////////////////////////////////////////////////////// 108 | 109 | void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) 110 | { 111 | multihead_attention_>(params, stream); 112 | } 113 | 114 | //////////////////////////////////////////////////////////////////////////////////////////////////// 115 | 116 | void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) 117 | { 118 | multihead_attention_>(params, stream); 119 | } 120 | 121 | //////////////////////////////////////////////////////////////////////////////////////////////////// 122 | 123 | #ifdef ENABLE_BF16 124 | void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, 125 | const cudaStream_t& stream) 126 | { 127 | multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream); 128 | } 129 | #endif 130 | //////////////////////////////////////////////////////////////////////////////////////////////////// 131 | 132 | void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) 133 | { 134 | multihead_attention_>(params, stream); 135 | } 136 | 137 | //////////////////////////////////////////////////////////////////////////////////////////////////// 138 | 139 | void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) 140 | { 141 | multihead_attention_>(params, stream); 142 | } 143 | 144 | //////////////////////////////////////////////////////////////////////////////////////////////////// 145 | 146 | #ifdef ENABLE_BF16 147 | void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, 148 | const cudaStream_t& stream) 149 | { 150 | multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream); 151 | } 152 | #endif 153 | 154 | //////////////////////////////////////////////////////////////////////////////////////////////////// 155 | -------------------------------------------------------------------------------- /csrc/awq_cuda/attention/decoder_masked_multihead_attention.h: -------------------------------------------------------------------------------- 1 | // Downloaded from from FasterTransformer v5.2.1 2 | // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h 3 | /* 4 | * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | 19 | #pragma once 20 | 21 | #include "cuda_bf16_wrapper.h" 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | //////////////////////////////////////////////////////////////////////////////////////////////////// 29 | 30 | #define CHECK_CUDA(call) \ 31 | do { \ 32 | cudaError_t status_ = call; \ 33 | if (status_ != cudaSuccess) { \ 34 | fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ 35 | exit(1); \ 36 | } \ 37 | } while (0) 38 | 39 | //////////////////////////////////////////////////////////////////////////////////////////////////// 40 | 41 | // The structure of parameters for the masked multihead attention kernel. 42 | // 43 | // We use the following terminology to describe the different dimensions. 44 | // 45 | // B: Batch size (number of sequences), 46 | // L: Sequence length, 47 | // D: Hidden dimension, 48 | // H: Number of heads, 49 | // Dh: Hidden dimension per head - Dh = D / H. 50 | 51 | template 52 | struct Multihead_attention_params_base { 53 | 54 | // The output buffer. Dimensions B x D. 55 | T* out = nullptr; 56 | 57 | // The input Qs and the associated bias. Dimensions B x D and D, resp. 58 | const T *q = nullptr, *q_bias = nullptr; 59 | // The input Ks and the associated bias. Dimensions B x D and D, resp. 60 | const T *k = nullptr, *k_bias = nullptr; 61 | // The input Vs and the associated bias. Dimensions B x D and D, resp. 62 | const T *v = nullptr, *v_bias = nullptr; 63 | 64 | // The cache for the Ks. The size must be at least B x L x D. 65 | T* k_cache = nullptr; 66 | // The cache for the Vs. The size must be at least B x L x D. 67 | T* v_cache = nullptr; 68 | // The indirections to use for cache when beam sampling. 69 | const int* cache_indir = nullptr; 70 | 71 | // Stride to handle the case when KQV is a single buffer 72 | int stride = 0; 73 | 74 | // The batch size. 75 | int batch_size = 0; 76 | // The beam width 77 | int beam_width = 0; 78 | // The sequence length. 79 | int memory_max_len = 0; 80 | // The number of heads (H). 81 | int num_heads = 0; 82 | // The number of heads for KV cache. 83 | int num_kv_heads = 0; 84 | // The hidden dimension per head (Dh). 85 | int hidden_size_per_head = 0; 86 | // The per-head latent space reserved for rotary embeddings. 87 | int rotary_embedding_dim = 0; 88 | bool neox_rotary_style = false; 89 | float rotary_base = 0.0f; 90 | // The maximum length of input sentences. 91 | int max_input_length = 0; 92 | // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention? 93 | int timestep = 0; 94 | // The current timestep of each sentences (support different timestep for different sentences) 95 | 96 | // The 1.f / sqrt(Dh). Computed on the host. 97 | float inv_sqrt_dh = 0.0f; 98 | 99 | // Used when we have some input context like gpt 100 | const int* total_padding_tokens = nullptr; 101 | 102 | const bool* masked_tokens = nullptr; 103 | const int* prefix_prompt_lengths = nullptr; 104 | int max_prefix_prompt_length = 0; 105 | 106 | const T* relative_attention_bias = nullptr; 107 | int relative_attention_bias_stride = 0; 108 | // The slope per head of linear position bias to attention score (H). 109 | const float* linear_bias_slopes = nullptr; 110 | 111 | const T* ia3_key_weights = nullptr; 112 | const T* ia3_value_weights = nullptr; 113 | const int* ia3_tasks = nullptr; 114 | 115 | const float* qkv_scale_out = nullptr; 116 | const float* attention_out_scale = nullptr; 117 | int int8_mode = 0; 118 | }; 119 | 120 | template 121 | struct Multihead_attention_params: public Multihead_attention_params_base { 122 | // output cross attentions 123 | float* cross_attention_out = nullptr; 124 | int max_decoder_seq_len = 0; 125 | bool is_return_cross_attentions = false; 126 | 127 | // allows to exist attention eary 128 | bool* finished = nullptr; 129 | 130 | // required in case of cross attention 131 | // will need it here till if constexpr in c++17 132 | int* memory_length_per_sample = nullptr; 133 | 134 | // required in case of masked attention with different length 135 | const int* length_per_sample = nullptr; 136 | }; 137 | 138 | template 139 | struct Multihead_attention_params: public Multihead_attention_params_base { 140 | // output cross attentions 141 | float* cross_attention_out = nullptr; 142 | int max_decoder_seq_len = 0; 143 | bool is_return_cross_attentions = false; 144 | 145 | // allows to exist attention eary 146 | bool* finished = nullptr; 147 | 148 | // required in case of cross attention 149 | int* memory_length_per_sample = nullptr; 150 | 151 | // required in case of masked attention with different length 152 | const int* length_per_sample = nullptr; 153 | }; 154 | 155 | template 156 | using Masked_multihead_attention_params = Multihead_attention_params; 157 | 158 | template 159 | using Cross_multihead_attention_params = Multihead_attention_params; 160 | 161 | template 162 | struct outputCrossAttentionParam { 163 | // max decoder output length 164 | int max_decoder_seq_len = 0; 165 | T* cross_attention_out = nullptr; 166 | bool is_return_cross_attentions = false; 167 | }; 168 | 169 | //////////////////////////////////////////////////////////////////////////////////////////////////// 170 | 171 | void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); 172 | void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); 173 | #ifdef ENABLE_BF16 174 | void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, 175 | const cudaStream_t& stream); 176 | #endif 177 | void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); 178 | void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); 179 | #ifdef ENABLE_BF16 180 | void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, 181 | const cudaStream_t& stream); 182 | #endif 183 | 184 | //////////////////////////////////////////////////////////////////////////////////////////////////// 185 | -------------------------------------------------------------------------------- /csrc/awq_cuda/attention/ft_attention.cpp: -------------------------------------------------------------------------------- 1 | // Adapted from NVIDIA/FasterTransformer and FlashAttention 2 | 3 | #include 4 | #include "ATen/cuda/CUDAContext.h" 5 | #include 6 | 7 | #include "ft_attention.h" 8 | #include "decoder_masked_multihead_attention.h" 9 | 10 | #define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA") 11 | #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") 12 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | 14 | #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \ 15 | if (TYPE == at::ScalarType::Half) { \ 16 | using scalar_t = at::Half; \ 17 | __VA_ARGS__(); \ 18 | } else if (TYPE == at::ScalarType::BFloat16) { \ 19 | using scalar_t = at::BFloat16; \ 20 | __VA_ARGS__(); \ 21 | } else if (TYPE == at::ScalarType::Float) { \ 22 | using scalar_t = float; \ 23 | __VA_ARGS__(); \ 24 | } else { \ 25 | AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \ 26 | } 27 | 28 | template 29 | void masked_multihead_attention(const Masked_multihead_attention_params& params, 30 | const cudaStream_t& stream); 31 | 32 | template 33 | void cross_multihead_attention(const Masked_multihead_attention_params& params, 34 | const cudaStream_t& stream); 35 | 36 | template 37 | struct SATypeConverter { 38 | using Type = T; 39 | }; 40 | 41 | template<> 42 | struct SATypeConverter { 43 | using Type = uint16_t; 44 | }; 45 | 46 | template<> 47 | struct SATypeConverter { 48 | using Type = __nv_bfloat16; 49 | }; 50 | 51 | template 52 | void set_params(Masked_multihead_attention_params ¶ms, 53 | const size_t batch_size, 54 | const size_t nheads, 55 | const size_t nheads_kv, 56 | const size_t memory_max_seqlen, 57 | const size_t headdim, 58 | const int timestep, 59 | const int rotary_embedding_dim, 60 | const float rotary_base, 61 | const bool neox_rotary_style, 62 | const int qkv_batch_stride, 63 | T *q_ptr, 64 | T *k_ptr, 65 | T *v_ptr, 66 | T *k_cache_ptr, 67 | T *v_cache_ptr, 68 | int *length_per_sample, 69 | float *alibi_slopes_ptr, 70 | T *out_ptr) { 71 | // Reset the parameters 72 | memset(¶ms, 0, sizeof(params)); 73 | params.q = q_ptr; 74 | params.k = k_ptr; 75 | params.v = v_ptr; 76 | params.q_bias = nullptr; 77 | params.k_bias = nullptr; 78 | params.v_bias = nullptr; 79 | params.k_cache = k_cache_ptr; 80 | params.v_cache = v_cache_ptr; 81 | params.linear_bias_slopes = alibi_slopes_ptr; 82 | params.out = out_ptr; 83 | params.cache_indir = nullptr; 84 | params.stride = qkv_batch_stride; 85 | params.batch_size = batch_size; 86 | params.beam_width = 1; 87 | params.memory_max_len = memory_max_seqlen; 88 | params.num_heads = nheads; 89 | params.num_kv_heads = nheads_kv; 90 | params.hidden_size_per_head = headdim; 91 | params.rotary_embedding_dim = rotary_embedding_dim; 92 | params.rotary_base = rotary_base; 93 | params.neox_rotary_style = neox_rotary_style; 94 | params.timestep = timestep; 95 | params.inv_sqrt_dh = 1.f / sqrt(float(headdim)); 96 | params.total_padding_tokens = nullptr; 97 | params.masked_tokens = nullptr; 98 | params.prefix_prompt_lengths = nullptr; 99 | params.max_prefix_prompt_length = 0; 100 | params.relative_attention_bias = nullptr; 101 | params.relative_attention_bias_stride = 0; 102 | params.cross_attention_out = nullptr; 103 | params.max_decoder_seq_len = 0; 104 | params.is_return_cross_attentions = false; 105 | params.finished = nullptr; 106 | params.memory_length_per_sample = nullptr; 107 | params.length_per_sample = length_per_sample; 108 | } 109 | 110 | torch::Tensor single_query_attention(const torch::Tensor q, 111 | const torch::Tensor k, 112 | const torch::Tensor v, 113 | torch::Tensor k_cache, 114 | torch::Tensor v_cache, 115 | c10::optional length_per_sample_, 116 | c10::optional alibi_slopes_, 117 | const int timestep, 118 | const int rotary_embedding_dim, 119 | const float rotary_base, 120 | // neox_rotary_style = not interleaved 121 | const bool neox_rotary_style) { 122 | CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache); 123 | int batch_size = v_cache.size(0); 124 | int nheads = q.size(1); 125 | int nheads_kv = v_cache.size(1); 126 | int memory_max_seqlen = v_cache.size(2); 127 | int headdim = v_cache.size(3); 128 | CHECK_SHAPE(q, batch_size, nheads, headdim); 129 | CHECK_SHAPE(k, batch_size, nheads_kv, headdim); 130 | CHECK_SHAPE(v, batch_size, nheads_kv, headdim); 131 | CHECK_SHAPE(v_cache, batch_size, nheads_kv, memory_max_seqlen, headdim); 132 | // k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32 133 | int packsize = k_cache.dtype() == torch::kFloat32 ? 4 : 8; 134 | CHECK_SHAPE(k_cache, batch_size, nheads_kv, headdim / packsize, memory_max_seqlen, packsize); 135 | TORCH_CHECK(q.stride(2) == 1 && q.stride(1) == headdim); 136 | TORCH_CHECK(k.stride(2) == 1 && k.stride(1) == headdim); 137 | TORCH_CHECK(v.stride(2) == 1 && v.stride(1) == headdim); 138 | // TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0)); 139 | CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache); 140 | 141 | if (length_per_sample_.has_value()) { 142 | auto length_per_sample = length_per_sample_.value(); 143 | CHECK_DEVICE(length_per_sample); 144 | CHECK_SHAPE(length_per_sample, batch_size); 145 | CHECK_CONTIGUOUS(length_per_sample); 146 | TORCH_CHECK(length_per_sample.dtype() == torch::kInt32); 147 | } 148 | 149 | if (alibi_slopes_.has_value()) { 150 | auto alibi_slopes = alibi_slopes_.value(); 151 | CHECK_DEVICE(alibi_slopes); 152 | CHECK_SHAPE(alibi_slopes, nheads); 153 | CHECK_CONTIGUOUS(alibi_slopes); 154 | TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32); 155 | } 156 | 157 | // Otherwise the kernel will be launched from cuda:0 device 158 | // Cast to char to avoid compiler warning about narrowing 159 | at::cuda::CUDAGuard device_guard{(char)q.get_device()}; 160 | 161 | torch::Tensor out = torch::empty_like(q); 162 | 163 | DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), "single_query_attention", [&] { 164 | using DataType = typename SATypeConverter::Type; 165 | Masked_multihead_attention_params params; 166 | set_params(params, batch_size, nheads, nheads_kv, memory_max_seqlen, headdim, 167 | timestep, rotary_embedding_dim, rotary_base, neox_rotary_style, q.stride(0), 168 | reinterpret_cast(q.data_ptr()), 169 | reinterpret_cast(k.data_ptr()), 170 | reinterpret_cast(v.data_ptr()), 171 | reinterpret_cast(k_cache.data_ptr()), 172 | reinterpret_cast(v_cache.data_ptr()), 173 | length_per_sample_.has_value() 174 | ? length_per_sample_.value().data_ptr() : nullptr, 175 | alibi_slopes_.has_value() 176 | ? alibi_slopes_.value().data_ptr(): nullptr, 177 | reinterpret_cast(out.data_ptr())); 178 | auto stream = at::cuda::getCurrentCUDAStream(); 179 | masked_multihead_attention(params, stream); 180 | }); 181 | return out; 182 | } -------------------------------------------------------------------------------- /csrc/awq_cuda/attention/ft_attention.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | 5 | torch::Tensor single_query_attention(const torch::Tensor q, 6 | const torch::Tensor k, 7 | const torch::Tensor v, 8 | torch::Tensor k_cache, 9 | torch::Tensor v_cache, 10 | c10::optional length_per_sample_, 11 | c10::optional alibi_slopes_, 12 | const int timestep, 13 | const int rotary_embedding_dim = 0, 14 | const float rotary_base = 10000.0f, 15 | const bool neox_rotary_style=true); -------------------------------------------------------------------------------- /csrc/awq_cuda/layernorm/layernorm.cu: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Adapted from NVIDIA FasterTransformer: 4 | https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/layernorm_kernels.cu 5 | 6 | */ 7 | 8 | #include 9 | #include 10 | #include "reduction.cuh" 11 | #include "layernorm.h" 12 | #include 13 | #include 14 | 15 | static inline __device__ float to_float(half src) 16 | { 17 | return __half2float(src); 18 | } 19 | 20 | static inline __device__ float to_float(float src) 21 | { 22 | return src; 23 | } 24 | 25 | template 26 | __global__ void generalT5LayerNorm( 27 | const T* __restrict input, const T* __restrict gamma, T* output, const float layernorm_eps, int m, int n) 28 | { 29 | // layernorm module in the T5 style No bias and no subtraction of mean. 30 | const int tid = threadIdx.x; 31 | 32 | __shared__ float s_variance; 33 | float variance = 0.0f; 34 | 35 | float local_var_sum = 0.0f; 36 | for (int i = tid; i < n; i += blockDim.x) { 37 | float diff = to_float(__ldg(&input[blockIdx.x * n + i])); 38 | local_var_sum += diff * diff; 39 | } 40 | variance = blockReduceSum(local_var_sum); 41 | 42 | if (threadIdx.x == 0) { 43 | s_variance = rsqrtf(variance / (float)n + layernorm_eps); 44 | } 45 | __syncthreads(); 46 | 47 | for (int i = tid; i < n; i += blockDim.x) { 48 | output[blockIdx.x * n + i] = 49 | clamp_inf_for_half((to_float(input[blockIdx.x * n + i]) * s_variance) * to_float(__ldg(&gamma[i]))); 50 | } 51 | } 52 | 53 | 54 | template 55 | void invokeGeneralT5LayerNorm(T* out, 56 | const T* input, 57 | const T* gamma, 58 | // const T* beta, 59 | const float layernorm_eps, 60 | const int m, 61 | const int n) 62 | { 63 | dim3 grid(m); 64 | dim3 block(min(n, 1024)); 65 | 66 | /* For general cases, n is equal to hidden_units, e.g., 512/1024. 67 | Since we have warp shuffle inside the code, block.x % 32 should be 0. 68 | */ 69 | if (n % 32 != 0) { 70 | block.x = 1024; 71 | } 72 | 73 | block.x = block.x / (4 / sizeof(T)); // if using half, only need half of block.x 74 | 75 | /* should pay attention to the rsqrt precision*/ 76 | generalT5LayerNorm<<>>(input, gamma, out, layernorm_eps, m, n); // For gpt-3 77 | } 78 | 79 | template void invokeGeneralT5LayerNorm(half* out, 80 | const half* input, 81 | const half* gamma, 82 | // const half* beta, 83 | const float layernorm_eps, 84 | const int m, 85 | const int n); 86 | 87 | template void invokeGeneralT5LayerNorm(float* out, 88 | const float* input, 89 | const float* gamma, 90 | // const half* beta, 91 | const float layernorm_eps, 92 | const int m, 93 | const int n); 94 | 95 | 96 | 97 | // input b, n, c 98 | void layernorm_forward_cuda( 99 | torch::Tensor _input, 100 | torch::Tensor _gamma, 101 | torch::Tensor _out, 102 | float eps) 103 | { 104 | int m = _input.size(0) * _input.size(1); 105 | int n = _input.size(2); 106 | const at::cuda::OptionalCUDAGuard device_guard(device_of(_input)); 107 | 108 | auto input = reinterpret_cast(_input.data_ptr()); 109 | auto gamma = reinterpret_cast(_gamma.data_ptr()); 110 | auto out = reinterpret_cast(_out.data_ptr()); 111 | 112 | invokeGeneralT5LayerNorm(out, input, gamma, eps, m, n); 113 | } 114 | -------------------------------------------------------------------------------- /csrc/awq_cuda/layernorm/layernorm.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void layernorm_forward_cuda(torch::Tensor _input, torch::Tensor _gamma, torch::Tensor _out, float eps); 4 | -------------------------------------------------------------------------------- /csrc/awq_cuda/layernorm/reduction.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Adapted from NVIDIA FasterTransformer: 4 | https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/reduce_kernel_utils.cuh 5 | */ 6 | 7 | #pragma once 8 | #include 9 | #if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) 10 | #include 11 | #else 12 | #include 13 | #endif 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #define HALF_FLT_MAX 65504.F 20 | #define FINAL_MASK 0xffffffff 21 | 22 | 23 | template 24 | inline __device__ T add(T a, T b) { 25 | return a + b; 26 | } 27 | 28 | template<> 29 | inline __device__ half2 add(half2 a, half2 b) { 30 | return __hadd2(a, b); 31 | } 32 | 33 | template<> 34 | inline __device__ half add(half a, half b) { 35 | return __hadd(a, b); 36 | } 37 | 38 | template 39 | __inline__ __device__ T warpReduceSum(T val) 40 | { 41 | #pragma unroll 42 | for (int mask = 16; mask > 0; mask >>= 1) 43 | val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80 44 | return val; 45 | } 46 | 47 | /* Calculate the sum of all elements in a block */ 48 | template 49 | __inline__ __device__ T blockReduceSum(T val) 50 | { 51 | static __shared__ T shared[32]; 52 | int lane = threadIdx.x & 0x1f; 53 | int wid = threadIdx.x >> 5; 54 | 55 | val = warpReduceSum(val); 56 | 57 | if (lane == 0) 58 | shared[wid] = val; 59 | 60 | __syncthreads(); 61 | 62 | // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent 63 | // blockDim.x is not divided by 32 64 | val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); 65 | val = warpReduceSum(val); 66 | 67 | return val; 68 | } 69 | 70 | 71 | template 72 | __device__ __forceinline__ T clamp_inf_for_half(const float input) 73 | { 74 | return input; 75 | } 76 | 77 | template<> 78 | __device__ __forceinline__ half clamp_inf_for_half(const float input) 79 | { 80 | // clamp inf values to enable fp16 training 81 | return input > 0.0f ? __float2half(min(input, HALF_FLT_MAX - 1000)) : __float2half(max(input, -HALF_FLT_MAX + 1000)); 82 | } 83 | -------------------------------------------------------------------------------- /csrc/awq_cuda/position_embedding/pos_encoding.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | void rotary_embedding_neox( 5 | torch::Tensor& positions, 6 | torch::Tensor& query, 7 | torch::Tensor& key, 8 | int head_size, 9 | torch::Tensor& cos_sin_cache); -------------------------------------------------------------------------------- /csrc/awq_cuda/position_embedding/pos_encoding_kernels.cu: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Adapted from the VLLM project: 4 | https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu 5 | 6 | */ 7 | 8 | #include 9 | #include 10 | #include "pos_encoding.h" 11 | 12 | template 13 | __global__ void rotary_embedding_neox_kernel( 14 | const int64_t* __restrict__ positions, // [num_tokens] 15 | scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] 16 | scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] 17 | const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] 18 | const int rot_dim, 19 | const int stride, 20 | const int num_heads, 21 | const int head_size) { 22 | // Each thread block is responsible for one token. 23 | const int token_idx = blockIdx.x; 24 | int64_t pos = positions[token_idx]; 25 | const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; 26 | 27 | const int embed_dim = rot_dim / 2; 28 | const int n = num_heads * embed_dim; 29 | for (int i = threadIdx.x; i < n; i += blockDim.x) { 30 | const int head_idx = i / embed_dim; 31 | const int token_head = token_idx * stride + head_idx * head_size; 32 | 33 | const int rot_offset = i % embed_dim; 34 | const int x_index = rot_offset; 35 | const int y_index = embed_dim + rot_offset; 36 | 37 | const int out_x = token_idx * stride + head_idx * head_size + x_index; 38 | const int out_y = token_idx * stride + head_idx * head_size + y_index; 39 | 40 | const scalar_t cos = __ldg(cache_ptr + x_index); 41 | const scalar_t sin = __ldg(cache_ptr + y_index); 42 | 43 | const scalar_t q_x = query[token_head + x_index]; 44 | const scalar_t q_y = query[token_head + y_index]; 45 | query[out_x] = q_x * cos - q_y * sin; 46 | query[out_y] = q_y * cos + q_x * sin; 47 | 48 | const scalar_t k_x = key[token_head + x_index]; 49 | const scalar_t k_y = key[token_head + y_index]; 50 | key[out_x] = k_x * cos - k_y * sin; 51 | key[out_y] = k_y * cos + k_x * sin; 52 | } 53 | } 54 | 55 | void rotary_embedding_neox( 56 | torch::Tensor& positions, // [b, num_tokens] 57 | torch::Tensor& query, // [b, num_tokens, 1, num_heads, head_size] 58 | torch::Tensor& key, // [b, num_tokens, 1, num_heads, head_size] 59 | int head_size, 60 | torch::Tensor& cos_sin_cache) // [max_position, rot_dim] 61 | { 62 | int num_tokens = query.size(0) * query.size(1); 63 | int rot_dim = cos_sin_cache.size(1); 64 | int num_heads = query.size(-2); 65 | int stride = num_heads * head_size; 66 | // TORCH_CHECK(stride == key.stride(0)); 67 | 68 | dim3 grid(num_tokens); 69 | dim3 block(std::min(num_heads * rot_dim / 2, 512)); 70 | const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 71 | AT_DISPATCH_FLOATING_TYPES_AND2( 72 | at::ScalarType::Half, 73 | at::ScalarType::BFloat16, 74 | query.scalar_type(), 75 | "rotary_embedding_neox", 76 | [&] { 77 | rotary_embedding_neox_kernel<<>>( 78 | positions.data_ptr(), 79 | query.data_ptr(), 80 | key.data_ptr(), 81 | cos_sin_cache.data_ptr(), 82 | rot_dim, 83 | stride, 84 | num_heads, 85 | head_size); 86 | }); 87 | } 88 | 89 | -------------------------------------------------------------------------------- /csrc/awq_cuda/pybind_awq.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "layernorm/layernorm.h" 4 | #include "quantization/gemm_cuda.h" 5 | #include "quantization/gemv_cuda.h" 6 | #include "position_embedding/pos_encoding.h" 7 | 8 | extern void mul(const torch::Tensor &A, const torch::Tensor &B, 9 | torch::Tensor &C, const torch::Tensor &s, 10 | torch::Tensor &workspace, int thread_k = -1, int thread_n = -1, 11 | int sms = -1, int max_par = 8); 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 14 | { 15 | //m.def("layernorm_forward_cuda", &layernorm_forward_cuda, "FasterTransformer layernorm kernel"); 16 | m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel."); 17 | m.def("gemmv2_forward_cuda", &gemmv2_forward_cuda, "Quantized v2 GEMM kernel."); 18 | m.def("gemv_forward_cuda", &gemv_forward_cuda, "Quantized GEMV kernel."); 19 | //m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key"); 20 | m.def("mul", &mul, "Marlin FP16xINT4 matmul."); 21 | } -------------------------------------------------------------------------------- /csrc/awq_cuda/pybind_ft.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "attention/ft_attention.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 6 | { 7 | m.def("single_query_attention", &single_query_attention, "Attention with a single query", 8 | py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"), 9 | py::arg("length_per_sample_"), py::arg("alibi_slopes_"), py::arg("timestep"), py::arg("rotary_embedding_dim")=0, 10 | py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true); 11 | } -------------------------------------------------------------------------------- /csrc/awq_cuda/quantization/dequantize.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h 3 | 4 | @article{lin2023awq, 5 | title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, 6 | author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, 7 | journal={arXiv}, 8 | year={2023} 9 | } 10 | */ 11 | 12 | #pragma once 13 | 14 | #include 15 | __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) 16 | { 17 | uint4 result; 18 | 19 | uint32_t* h = reinterpret_cast(&result); 20 | uint32_t const i4s = reinterpret_cast(source); 21 | 22 | // First, we extract the i4s and construct an intermediate fp16 number. 23 | static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; 24 | static constexpr uint32_t BOTTOM_MASK = 0x000f000f; 25 | static constexpr uint32_t TOP_MASK = 0x00f000f0; 26 | static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; 27 | 28 | // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing 29 | // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. 30 | // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and 31 | // elt_67 to fp16 without having to shift them to the bottom bits before hand. 32 | 33 | // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue 34 | // immediately before required. 35 | const uint32_t top_i4s = i4s >> 8; 36 | // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 37 | asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" 38 | : "=r"(h[0]) 39 | : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); 40 | // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 41 | asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" 42 | : "=r"(h[1]) 43 | : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); 44 | // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 45 | asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" 46 | : "=r"(h[2]) 47 | : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); 48 | // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 49 | asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" 50 | : "=r"(h[3]) 51 | : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); 52 | 53 | // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the 54 | // half2 ctor. In this case, I chose performance reliability over code readability. 55 | 56 | // This is the half2 {1032, 1032} represented as an integer. 57 | // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; 58 | // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] 59 | static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; 60 | // This is the half2 {1 / 16, 1 / 16} represented as an integer. 61 | static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; 62 | // This is the half2 {-72, -72} represented as an integer. 63 | // static constexpr uint32_t NEG_72 = 0xd480d480; 64 | // Haotian: Let's use {-64, -64}. 65 | static constexpr uint32_t NEG_64 = 0xd400d400; 66 | 67 | // Finally, we construct the output numbers. 68 | // Convert elt_01 69 | asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); 70 | // Convert elt_23 71 | asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); 72 | // Convert elt_45 73 | asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); 74 | // Convert elt_67 75 | asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); 76 | 77 | return result; 78 | } 79 | 80 | -------------------------------------------------------------------------------- /csrc/awq_cuda/quantization/gemm_cuda.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | torch::Tensor gemm_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel, 4 | torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters); 5 | 6 | torch::Tensor gemmv2_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel, 7 | torch::Tensor _scaling_factors, torch::Tensor _zeros, int group_size, int split_k_iters); -------------------------------------------------------------------------------- /csrc/awq_cuda/quantization/gemv_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | torch::Tensor gemv_forward_cuda( 5 | torch::Tensor _in_feats, 6 | torch::Tensor _kernel, 7 | torch::Tensor _scaling_factors, 8 | torch::Tensor _zeros, 9 | int group_size); 10 | -------------------------------------------------------------------------------- /csrc/awq_cuda/quantization/marlin_cuda.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | #include "marlin_cuda_kernel.cuh" 25 | 26 | const int ERR_PROB_SHAPE = 1; 27 | const int ERR_KERN_SHAPE = 2; 28 | 29 | void mul( 30 | const torch::Tensor& A, 31 | const torch::Tensor& B, 32 | torch::Tensor& C, 33 | const torch::Tensor& s, 34 | torch::Tensor& workspace, 35 | int thread_k = -1, 36 | int thread_n = -1, 37 | int sms = -1, 38 | int max_par = 8 39 | ) { 40 | int prob_m = A.size(0); 41 | int prob_n = C.size(1); 42 | int prob_k = A.size(1); 43 | int groupsize = (s.size(0) == 1) ? -1 : prob_k / s.size(0); 44 | if (groupsize != -1 && groupsize * s.size(0) != prob_k) 45 | AT_ERROR("k=", prob_k, " not compatible with ", s.size(0), " groups."); 46 | if (workspace.numel() < prob_n / 128 * max_par) 47 | AT_ERROR("workspace must be of size at least ", prob_n / 128 * max_par, "."); 48 | int dev = A.get_device(); 49 | int err = marlin_cuda( 50 | A.data_ptr(), 51 | B.data_ptr(), 52 | C.data_ptr(), 53 | s.data_ptr(), 54 | prob_m, prob_n, prob_k, 55 | workspace.data_ptr(), 56 | groupsize, 57 | dev, 58 | at::cuda::getCurrentCUDAStream(dev), 59 | thread_k, 60 | thread_n, 61 | sms, 62 | max_par 63 | ); 64 | if (err == ERR_PROB_SHAPE) { 65 | AT_ERROR( 66 | "Problem (m=", prob_m, ", n=", prob_n, ", k=", prob_k, ")", 67 | " not compatible with thread_k=", thread_k, ", thread_n=", thread_n, "." 68 | ); 69 | } else if (err == ERR_KERN_SHAPE) { 70 | AT_ERROR( 71 | "No kernel implementation for thread_k=", thread_k, ", thread_n=", thread_n, ", groupsize=", groupsize, "." 72 | ); 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /csrc/awq_cuda/quantization/marlin_cuda_kernel.cuh: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | int marlin_cuda( 5 | const void* A, 6 | const void* B, 7 | void* C, 8 | void* s, 9 | int prob_m, 10 | int prob_n, 11 | int prob_k, 12 | void* workspace, 13 | int groupsize, 14 | int dev, 15 | cudaStream_t stream, 16 | int thread_k, 17 | int thread_n, 18 | int sms, 19 | int max_par 20 | ); 21 | -------------------------------------------------------------------------------- /csrc/ort_cuda/common.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #include 6 | 7 | // AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) 8 | #define XBITOPS_DISPATCH_CASE_FLOATING_TYPES(...) \ 9 | AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ 10 | AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) 11 | 12 | #define XBITOPS_DISPATCH_CASE_FLOATING_TYPES_HALF(...) \ 13 | AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ 14 | AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) 15 | 16 | #define XBITOPS_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \ 17 | AT_DISPATCH_SWITCH( \ 18 | TYPE, NAME, XBITOPS_DISPATCH_CASE_FLOATING_TYPES_HALF(__VA_ARGS__)) 19 | 20 | #define XBITOPS_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ 21 | AT_DISPATCH_SWITCH( \ 22 | TYPE, NAME, XBITOPS_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) 23 | 24 | 25 | #if __CUDA_ARCH__ >= 800 26 | #define XBITOPS_DISPATCH_TYPES XBITOPS_DISPATCH_FLOATING_TYPES 27 | #else 28 | #define XBITOPS_DISPATCH_TYPES XBITOPS_DISPATCH_FLOATING_TYPES_AND_HALF 29 | #endif 30 | 31 | 32 | #if __CUDA_ARCH__ >= 800 33 | #define __has_bfloat16 1 34 | #else 35 | #define __has_bfloat16 0 36 | #endif 37 | 38 | namespace onnxruntime_gptq { 39 | namespace cuda_quant { 40 | template 41 | struct TYPE_VEC2 { 42 | }; 43 | 44 | template <> 45 | struct TYPE_VEC2 { 46 | using Type = half2; 47 | }; 48 | 49 | template <> 50 | struct TYPE_VEC2 { 51 | using Type = __half2; 52 | }; 53 | 54 | template <> 55 | struct TYPE_VEC2 { 56 | using Type = float2; 57 | }; 58 | 59 | template <> 60 | struct TYPE_VEC2<__nv_bfloat16> { 61 | using Type = __nv_bfloat162; 62 | }; 63 | 64 | template <> 65 | struct TYPE_VEC2 { 66 | using Type = __nv_bfloat162; 67 | }; 68 | 69 | template 70 | __device__ toT ConvertFromShort(const short a, toT v={}) { 71 | if constexpr (std::is_same::value) { 72 | return __short2half_rn(a); 73 | } else if constexpr (std::is_same::value) { 74 | return __short2half_rn(a); 75 | } 76 | #if __has_bfloat16 77 | if constexpr (std::is_same::value) { 78 | return __short2bfloat16_rn(a); 79 | } 80 | #endif 81 | else { 82 | //static_assert(false, "Not supported type"); 83 | return __short2half_rn(a); 84 | } 85 | } 86 | 87 | template 88 | __device__ toT ConvertFromInt(const int a) { 89 | if constexpr (std::is_same::value || 90 | std::is_same::value || 91 | std::is_same::value) { 92 | return __int2half_rn(a); 93 | } else if constexpr (std::is_same::value) { 94 | return __int2half_rn(a); 95 | } 96 | #if __has_bfloat16 97 | if constexpr (std::is_same::value) { 98 | return __int2bfloat16_rn(a); 99 | } 100 | #endif 101 | else { 102 | //static_assert(false, "Not supported type"); 103 | return __int2half_rn(a); 104 | } 105 | } 106 | 107 | template 108 | __device__ toT ConvertFromFloat(const float a) { 109 | if constexpr (std::is_same::value) { 110 | return __float2half_rn(a); 111 | } else if constexpr (std::is_same::value) { 112 | return __float2half_rn(a); 113 | } 114 | #if __has_bfloat16 115 | if constexpr (std::is_same::value) { 116 | return __float2bfloat16_rn(a); 117 | } 118 | #endif 119 | else { 120 | //static_assert(false, "Not supported type"); 121 | return __float2half_rn(a); 122 | } 123 | } 124 | 125 | template 126 | __device__ fromT ConvertToFloat(const float a) { 127 | if constexpr (std::is_same::value) { 128 | return __half2float(a); 129 | } else if constexpr (std::is_same::value) { 130 | return __half2float(a); 131 | } 132 | #if __has_bfloat16 133 | if constexpr (std::is_same::value) { 134 | return __bfloat162float(a); 135 | } 136 | #endif 137 | else { 138 | return __half2float(a); 139 | //static_assert(false, "Not supported type"); 140 | } 141 | } 142 | 143 | template 144 | __device__ auto MakeVec2(T a, T b) { 145 | if constexpr (std::is_same::value || 146 | std::is_same::value || 147 | std::is_same::value || 148 | std::is_same::value) { 149 | return __halves2half2(a, b); 150 | } 151 | #if __has_bfloat16 152 | if constexpr (std::is_same::value) { 153 | return __halves2bfloat162(a, b); 154 | } 155 | #endif 156 | else { 157 | return __halves2half2(a, b); 158 | //static_assert(false, "Not supported type"); 159 | } 160 | } 161 | 162 | template 163 | __device__ auto Short22Vec2(const short a, const short b) { 164 | return MakeVec2(ConvertFromShort(a), ConvertFromShort(b)); 165 | } 166 | 167 | template 168 | __device__ auto Int22Vec2(const int a, const int b) { 169 | return MakeVec2(ConvertFromInt(a), ConvertFromInt(b)); 170 | } 171 | 172 | template 173 | __device__ auto Element2Vec2(const T a) { 174 | if constexpr (std::is_same::value || 175 | std::is_same::value || 176 | std::is_same::value) { 177 | return __half2half2(a); 178 | } 179 | #if __has_bfloat16 180 | if constexpr (std::is_same::value) { 181 | return __bfloat162bfloat162(a); 182 | } 183 | #endif 184 | else { 185 | // static_assert(false, "Not supported type"); 186 | } 187 | } 188 | } // namespace cuda_quant 189 | } 190 | 191 | template 192 | struct C10Type2Type { 193 | }; 194 | 195 | template <> 196 | struct C10Type2Type { 197 | using Type = half; 198 | }; 199 | 200 | template <> 201 | struct C10Type2Type { 202 | using Type = __nv_bfloat16; 203 | }; -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools==69.5.1", "wheel", "packaging", "ninja>=1.11.1", "torch==2.2.2"] 3 | 4 | build-backend = "setuptools.build_meta" 5 | 6 | 7 | [tool.yapf] 8 | based_on_style = "pep8" 9 | column_limit = 120 10 | disable_split_list_with_comment = true 11 | each_dict_entry_on_separate_line=false 12 | split_before_named_assigns = false 13 | split_complex_comprehension = true 14 | 15 | [tool.yapfignore] 16 | ignore_patterns = [ 17 | ] 18 | 19 | [tool.ruff] 20 | line-length = 120 21 | src = ["qllm"] 22 | exclude = ["qllm/utils/onnx/merge_encoder_decoder.py", 23 | "qllm/plugin", 24 | "qllm/modeling/q_layers/fused_mlp.py", 25 | "qllm/modeling/q_layers/fused_attn.py", 26 | "qllm/modeling/q_layers/quant_linear_triton.py", 27 | "qllm/custom", 28 | ] 29 | 30 | [tool.ruff.lint] 31 | ignore = ["E501", "E701", "E731", "E741",] 32 | select = [ 33 | # pycodestyle 34 | "E", 35 | # Pyflakes 36 | "F", 37 | # pyupgrade 38 | # "UP", 39 | # flake8-bugbear 40 | "B", 41 | # flake8-simplify 42 | "SIM", 43 | # isort 44 | # "I", 45 | ] 46 | 47 | -------------------------------------------------------------------------------- /qllm/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | if os.getenv("PROXY_PORT", None): 3 | proxy_port = os.getenv("PROXY_PORT", None) 4 | os.environ["HTTP_PROXY"] = f"http://127.0.0.1:{proxy_port}" 5 | os.environ["HTTPS_PROXY"] = f"http://127.0.0.1:{proxy_port}" 6 | 7 | __version__ = '0.2.2.1' 8 | -------------------------------------------------------------------------------- /qllm/__main__.py: -------------------------------------------------------------------------------- 1 | from .run import main 2 | import sys 3 | 4 | if len(sys.argv) == 1: 5 | sys.argv = sys.argv+["-h"] 6 | main() 7 | -------------------------------------------------------------------------------- /qllm/args_config.py: -------------------------------------------------------------------------------- 1 | class FakeArgs: 2 | def __init__(self, **entries): 3 | self.quant_method = "gptq" 4 | self.dataset = "wikitext2" 5 | self.seed = 0 6 | self.nsamples = 128 7 | self.percdamp = 0.01 8 | self.wbits = 4 9 | self.groupsize = 128 10 | self.pack_mode = "AUTO" 11 | self.act_order = False 12 | self.true_sequential = False 13 | self.sym = False 14 | self.allow_mix_bits = False 15 | self.static_groups = False 16 | self.__dict__.update(entries) 17 | 18 | # def __getattr__(self, name): 19 | # if name not in self.__dict__: 20 | # return None 21 | # return self.__dict__[name] 22 | -------------------------------------------------------------------------------- /qllm/auto_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import tempfile 3 | import re 4 | import torch 5 | from .. import utils 6 | import os 7 | 8 | 9 | def cur_user(): 10 | try: 11 | return os.getlogin() 12 | except: # noqa: E722 13 | return "root" # in docker 14 | 15 | 16 | def get_sample_datas_for_quantization(tokenizer_or_path, dataset, nsamples, seed, seqlen=2048): 17 | logger = utils.logger.get_logger() 18 | tokenizer_name_or_path = tokenizer_or_path.name_or_path if not isinstance(tokenizer_or_path, str) else tokenizer_or_path 19 | normlized_tokenizer = re.sub(r"[^0-9a-zA-Z_-]", "", tokenizer_name_or_path) 20 | named_hash = f"{normlized_tokenizer}_{dataset}_{nsamples}_{seqlen}_{seed}" 21 | cache_dir = Path(f"{tempfile.gettempdir()}/qllm_v{cur_user()}/_{named_hash}_dataloader.pt") 22 | cache_dir.parent.mkdir(parents=True, exist_ok=True) 23 | logger.info(f"loading dataset from {dataset}") 24 | if cache_dir.exists(): 25 | logger.info(f"found cached dataloader in {cache_dir}") 26 | dataloader = torch.load(cache_dir, weights_only=True) 27 | else: 28 | dataloader, _ = utils.get_loaders(dataset, nsamples=nsamples, seed=seed, tokenizer=tokenizer_or_path, seqlen=seqlen) 29 | torch.save(dataloader, str(cache_dir)) 30 | assert len(dataloader) > 0, f"dataset {dataset} is empty" 31 | return dataloader 32 | -------------------------------------------------------------------------------- /qllm/custom/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wejoncy/QLLM/df20c15920bfabfd0581d7fcccbe87e5c96cd5c7/qllm/custom/__init__.py -------------------------------------------------------------------------------- /qllm/custom/__main__.py: -------------------------------------------------------------------------------- 1 | from .run import main 2 | 3 | main() 4 | -------------------------------------------------------------------------------- /qllm/custom/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | if "CUDA_VISIBLE_DEVICES" not in os.environ: # NOQA 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" # NOQA 4 | 5 | import warnings 6 | import torch 7 | from pathlib import Path 8 | 9 | from ..utils.logger import get_logger 10 | from ..auto_model_quantization import AutoModelQuantization 11 | 12 | logger = get_logger() 13 | 14 | import sys 15 | sys.path.append(os.path.dirname(__file__)) # NOQA 16 | sys.path.append(os.getcwd()) # NOQA 17 | 18 | import loralib as lora 19 | 20 | NEED_CHECK_PACK = False 21 | 22 | 23 | class CustomModel(AutoModelQuantization): 24 | def __init__(self): 25 | super().__init__() 26 | self.argv_user = None 27 | self.quant_layers = [torch.nn.Linear, lora.MergedLinear, lora.Linear] 28 | self.datsets = None 29 | 30 | def get_torch_model(self, args): 31 | argv_user = self.argv_user 32 | if 'ckpt/mpt-' not in argv_user[argv_user.index('--model_name_or_path')+1]: 33 | lora_ind = argv_user.index('--use_lora') 34 | argv_user[lora_ind+1] = 'False' 35 | try: 36 | import examples_ads 37 | from examples_ads import run_mpt_prompt 38 | except: 39 | logger.error(f"Do you forget to run the command in the root directory of the project? `examples_ads` is not find in {os.getcwd()},\ 40 | please switch to the right directory and try again") 41 | raise 42 | argv_user.insert(0, run_mpt_prompt.__file__) 43 | argv_back = sys.argv 44 | sys.argv = argv_user 45 | 46 | os.environ['init_device'] = "cpu" 47 | model, data_sets = run_mpt_prompt.main(True) 48 | new_data = [] 49 | for idx, indata in enumerate(data_sets): 50 | if idx >= args.nsamples: 51 | break 52 | input_ = (torch.tensor([indata["input_ids"]]), 53 | torch.tensor([indata["attention_mask"]])) 54 | new_data.append(input_) 55 | self.datsets = new_data 56 | return model.half() 57 | 58 | def get_datasets(self, tokenizer_path, dataset, nsamples, seed): 59 | cache_dir = Path(f"/tmp/qllm_v1/{tokenizer_path.replace(' ','_')}_{dataset}_dataloader.pt") 60 | cache_dir.parent.mkdir(parents=True, exist_ok=True) 61 | logger.info(f"loading dataset from {dataset}") 62 | 63 | if self.datsets is not None: 64 | torch.save(self.datsets, str(cache_dir)) 65 | return self.datsets 66 | 67 | if cache_dir.exists(): 68 | logger.info(f"found cached dataloader in {cache_dir}") 69 | dataloader = torch.load(cache_dir) 70 | 71 | return dataloader 72 | 73 | @torch.no_grad() 74 | def eval_model(self, model, dev, args): 75 | logger.info('Evaluating ...') 76 | sys.argv = self.argv_user 77 | import examples_ads 78 | from examples_ads import run_llama_prompt 79 | run_llama_prompt.main(quant_model=model.to(dev)) 80 | 81 | def process_forward_args(self, args): 82 | argv_user = args.forward_args 83 | import re 84 | key_with_space = re.findall(r'(".*"|\'.*\')', argv_user) 85 | argv_map = {} 86 | for idx, v in enumerate(key_with_space): 87 | argv_user = re.sub(v, f'____{idx}___', argv_user) 88 | argv_map[f'____{idx}___'] = v.strip('"') 89 | argv_user = argv_user.split(' ') 90 | argv_user = list(filter(None, argv_user)) 91 | idx = 0 92 | for i in range(len(argv_user)): 93 | if argv_user[i] == f'____{idx}___': 94 | argv_user[i] = argv_map[f'____{idx}___'] 95 | idx += 1 96 | self.argv_user = argv_user 97 | 98 | def export_onnx(self, model: torch.nn.Module, onnx_path_str: str, sample_inputs: tuple, with_past: bool = False, args=None): 99 | try: 100 | import onnxruntime 101 | from packaging import version 102 | assert version.parse(onnxruntime.__version__) >= version.parse('1.17.0') 103 | assert version.parse(torch.__version__) >= version.parse('2.0.0') 104 | return super().export_onnx(model, onnx_path_str, sample_inputs, with_past, args) 105 | except: 106 | warnings.warn('this exporter will be deprecated, please upgrade to torch 2.1.0+ and onnxruntime 1.17+', 107 | DeprecationWarning, stacklevel=2) 108 | # model = self.pipeline_to_multiple_gpu(model, [torch.device(i) 109 | # for i in range(torch.cuda.device_count())], sample_inputs) 110 | # model = model.cpu().float() 111 | model = model.cuda() 112 | os.environ["export_onnx"] = "1" 113 | from pathlib import Path 114 | import shutil 115 | onnx_path = Path(onnx_path_str).absolute() 116 | assert onnx_path.suffix == '.onnx' 117 | inputs = {'input_ids': sample_inputs[0].to( 118 | model.device), "attention_mask": sample_inputs[1].to(model.device)} 119 | onnx_filepath_export_multi_files_tmp = onnx_path.parent/'tmp/tmp.onnx' 120 | onnx_filepath_export_multi_files_tmp.parent.exists() and shutil.rmtree( 121 | onnx_filepath_export_multi_files_tmp.parent) 122 | os.makedirs(onnx_filepath_export_multi_files_tmp.parent) 123 | 124 | input_ids = inputs['input_ids'] 125 | attention_mask = inputs['attention_mask'] 126 | past_key_values = None 127 | onnx_inputs = (input_ids, past_key_values, attention_mask, 128 | None, None, None, True, False, False, False) 129 | onnx_inp_names = ("input_ids", "attention_mask") 130 | onnx_out_names = ("logits",) 131 | onnx_dynamic_axes = {"input_ids": {0: 'batch_size', 1: "seq_len"}, 132 | "attention_mask": {0: 'batch_size', 1: "seq_len"}} 133 | torch.onnx.export(model=model, args=onnx_inputs, f=str(onnx_filepath_export_multi_files_tmp), verbose=False, opset_version=16, 134 | input_names=onnx_inp_names, output_names=onnx_out_names, dynamic_axes=onnx_dynamic_axes) 135 | import onnx 136 | onnx_model = onnx.load(str(onnx_filepath_export_multi_files_tmp)) 137 | 138 | onnx_path.exists() and onnx_path.unlink() 139 | (onnx_path.parent/'model_ext.data').exists() and (onnx_path.parent / 140 | 'model_ext.data').unlink() 141 | onnx.save_model(onnx_model, str(onnx_path), save_as_external_data=True, all_tensors_to_one_file=True, 142 | location="model_ext.data", size_threshold=1024, convert_attribute=False) 143 | 144 | 145 | def main(): 146 | from .. import run 147 | parser = run.define_basic_args() 148 | parser.add_argument('--forward_args', type=str,default=None, help='args for run_prompts_mpt.py') 149 | sys.argv = sys.argv + ["--model=./a"] 150 | args = parser.parse_args() 151 | 152 | mpt_quanter = CustomModel() 153 | mpt_quanter.process_forward_args(args) 154 | if args.load: 155 | mpt_quanter.argv_user[mpt_quanter.argv_user.index('--model_name_or_path')+1] = os.path.abspath(args.load) 156 | 157 | mpt_quanter.run(args) 158 | 159 | if __name__ == '__main__': 160 | main() 161 | -------------------------------------------------------------------------------- /qllm/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: F401 2 | 3 | from .base import AutoQuantizedModelForCausalLM 4 | 5 | all = ["AutoQuantizedModelForCausalLM"] 6 | -------------------------------------------------------------------------------- /qllm/modeling/config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import json 3 | from transformers.utils.hub import cached_file 4 | import os 5 | from .. import utils 6 | from ..quantization.config_builder import MetaConfig 7 | logger = utils.logger.get_logger() 8 | 9 | 10 | class BaseQuantizeConfig: 11 | def __init__(self): 12 | self.quant_config = {} 13 | self.quant_config_by_op = {} 14 | self.quant_method = None 15 | self.COMPATIBLE_WITH_AUTOGPTQ = False 16 | 17 | def groupsize(self, layer_name: str = None): 18 | if layer_name is not None and layer_name in self.quant_config_by_op: 19 | return self.quant_config_by_op[layer_name]["groupsize"] 20 | return self.quant_config.get('group_size', None) or self.quant_config.get('q_group_size', None) 21 | 22 | def bits(self, layer_name: str = None): 23 | if layer_name is not None and layer_name in self.quant_config_by_op: 24 | return self.quant_config_by_op[layer_name]["wbits"] 25 | return self.quant_config.get('bits', None) or self.quant_config.get('w_bit', None) 26 | 27 | @property 28 | def to_meta(self): 29 | return MetaConfig(self.bits(), self.groupsize(), self.quant_method) 30 | 31 | @property 32 | def version(self): 33 | return self.quant_config["version"].upper() 34 | 35 | def dict(self): 36 | return self.quant_config 37 | 38 | def to_dict(self): 39 | return self.quant_config 40 | 41 | @version.setter 42 | def version(self, value): 43 | self.quant_config["version"] = value 44 | 45 | @staticmethod 46 | def get_resolved_base_dir(model_name_or_path, quantize_config_filename) -> Path: 47 | if os.path.isdir(model_name_or_path): # Local 48 | resolved_config_file = Path(model_name_or_path) / quantize_config_filename 49 | if not resolved_config_file.exists(): 50 | resolved_config_file = None 51 | else: # Remote 52 | user_agent = {"file_type": "config", "from_auto_class": True} 53 | try: 54 | resolved_config_file = cached_file( 55 | model_name_or_path, 56 | quantize_config_filename, 57 | cache_dir=None, 58 | user_agent=user_agent, 59 | ) 60 | resolved_config_file = Path(resolved_config_file) 61 | except: # noqa : E722 62 | resolved_config_file = None 63 | return resolved_config_file 64 | 65 | def try_make_default_quant_op_config(self): 66 | if self.quant_config_by_op: return 67 | # backward compatability, we just make a genaral config 68 | self.quant_config_by_op = { 69 | "groupsize": self.groupsize(), "wbits": self.bits()} 70 | 71 | def load_quant_op_config(self, model_name_or_path): 72 | if not (Path(model_name_or_path) / "quant_config_by_layer.json").exists(): 73 | return self.try_make_default_quant_op_config() 74 | # load quant info 75 | with open(Path(model_name_or_path) / "quant_config_by_layer.json") as fp: 76 | qunat_info = json.load(fp) 77 | self.quant_config_by_op = qunat_info 78 | if self.quant_method == "vptq": 79 | self.quant_config_by_op = self.quant_config.get('config_for_layers', None) 80 | 81 | def load_quant_config(self, model_name_or_path): 82 | while True: 83 | config_file = self.get_resolved_base_dir(model_name_or_path, "quant_config.json") 84 | quant_config = None 85 | if config_file is None: 86 | # GPTQ-for-llama/AutoGPTQ 87 | config_file = self.get_resolved_base_dir(model_name_or_path, "quantize_config.json") 88 | if config_file is not None: 89 | with open(config_file) as fp: 90 | quant_config = json.load(fp) 91 | break 92 | if config_file is None: 93 | config_file = self.get_resolved_base_dir(model_name_or_path, "config.json") 94 | if config_file is not None: 95 | with open(config_file) as fp: 96 | quant_config = json.load(fp) 97 | quant_config = quant_config.get("quantization_config", None) 98 | assert quant_config.get('use_exllama', False) is False, "use_exllama is not supported yet" 99 | break 100 | 101 | assert quant_config is not None, ("quant_config.json/quantize_config.json not found in checkpoint directory") 102 | 103 | wbits = quant_config.get("w_bit", quant_config.get("bits", None)) 104 | groupsize = quant_config.get("q_group_size", quant_config.get("group_size", None)) 105 | quant_method = quant_config.get("quant_method", None) 106 | if quant_method != "vptq": 107 | assert wbits is not None and groupsize is not None 108 | 109 | if quant_config.get('COMPATIBLE_WITH_AUTOGPTQ', None): 110 | self.COMPATIBLE_WITH_AUTOGPTQ = True 111 | if "version" not in quant_config: 112 | self.quant_method = "gptq" 113 | quant_config["version"] = "GPTQ" 114 | self.COMPATIBLE_WITH_AUTOGPTQ = True 115 | import os 116 | os.environ["COMPATIBLE_WITH_AUTOGPTQ"] = '1' # FixMe: hacky 117 | else: # FIXME is it correct? 118 | self.quant_method = quant_config.get("quant_method", "awq") 119 | self.quant_config = quant_config 120 | 121 | @classmethod 122 | def from_pretrained(cls, model_name_or_path): 123 | obj = cls() 124 | obj.load_quant_config(model_name_or_path) 125 | obj.load_quant_op_config(model_name_or_path) 126 | return obj 127 | -------------------------------------------------------------------------------- /qllm/modeling/q_layers/__init__.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: F401 2 | 3 | # from .fused_attn import QuantLlamaAttention, make_quant_attn 4 | # from .fused_mlp import QuantLlamaMLP, make_fused_mlp, autotune_warmup_fused 5 | from .quant_linear_gptq import (QuantLinearGPTQ) 6 | # from .quant_linear_triton import autotune_warmup_linear 7 | # from .triton_norm import TritonLlamaRMSNorm, make_quant_norm 8 | -------------------------------------------------------------------------------- /qllm/modeling/q_layers/ext_package_checker.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch 3 | import functools 4 | from ...utils.logger import get_logger 5 | logger = get_logger() 6 | 7 | @functools.lru_cache() 8 | def has_package(package_name): 9 | try: 10 | if importlib.util.find_spec(package_name) is not None: 11 | importlib.import_module(package_name) 12 | return True 13 | except: # noqa: E722 14 | logger.warning(f"Failed to import {package_name}") 15 | return False 16 | 17 | 18 | def has_awq_inference_engine(): 19 | return (torch.cuda.get_device_properties(0).major * 10 + torch.cuda.get_device_properties(0).minor >= 75 20 | and has_package("qllm.awq_inference_engine")) 21 | 22 | 23 | def is_the_machine_support_awq_engine(nbits): 24 | return has_awq_inference_engine() and nbits == 4 25 | 26 | 27 | def has_ort_ops(): 28 | return has_package("qllm.ort_ops") 29 | -------------------------------------------------------------------------------- /qllm/modeling/q_layers/quant_linear_gptq.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | from .ext_package_checker import has_ort_ops 7 | from .compress_weight import (CompressWeight, general_pack_on_row, 8 | general_unpack_on_row) 9 | if has_ort_ops(): 10 | from qllm import ort_ops 11 | 12 | 13 | def DequantizeLinearBlockWise(qweight, scales, qzeros, groupsize, bits, in_features, g_idx): 14 | COMPATIBLE_WITH_AUTOGPTQ = int( 15 | os.environ.get("COMPATIBLE_WITH_AUTOGPTQ", "0")) 16 | scales = scales.reshape(-1, 1, scales.shape[-1]) 17 | if bits in [2, 4, 8]: 18 | wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32, device=qweight.device).unsqueeze(0) 19 | # expand is removed as torch will auto broadcast to relavant dimension 20 | zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2), wf.unsqueeze(0) 21 | ).to(torch.int16 if bits == 8 else torch.int8) 22 | zeros = zeros + COMPATIBLE_WITH_AUTOGPTQ 23 | zeros = torch.bitwise_and(zeros, (2 ** bits) - 1) 24 | zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2]) 25 | # expand is removed as torch will auto broadcast to relavant dimension 26 | weight = torch.bitwise_right_shift(torch.unsqueeze( 27 | qweight, 1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8) 28 | torch.bitwise_and(weight, (2 ** bits) - 1, out=weight) 29 | else: 30 | weight = torch.zeros((in_features, qweight.shape[1]), dtype=torch.int32, device=qweight.device) 31 | general_unpack_on_row(qweight, weight, bits) 32 | zeros = torch.zeros((qzeros.shape[0], qweight.shape[1]), dtype=torch.int32, device=qweight.device) 33 | general_unpack_on_row(qzeros, zeros, bits) 34 | zeros = zeros.reshape(-1, 1, zeros.shape[1]) 35 | zeros = zeros + COMPATIBLE_WITH_AUTOGPTQ 36 | zeros = torch.bitwise_and(zeros, (2 ** bits) - 1) 37 | 38 | if g_idx is not None: 39 | zeros.squeeze_(1) 40 | scales.squeeze_(1) 41 | weight = weight.view(-1, weight.shape[-1]) 42 | scale_zeros = zeros * scales 43 | weight = (scales[g_idx] * weight - scale_zeros[g_idx]) 44 | weight = weight.view(-1, groupsize, weight.shape[-1]) 45 | else: 46 | scale_zeros = zeros * scales 47 | weight = weight.reshape(-1, groupsize, weight.shape[-1]) 48 | weight = (scales * weight - scale_zeros.to(scales.dtype)) 49 | 50 | # weight = (scales * (weight - zeros)) 51 | weight = weight.reshape(-1, weight.shape[2]) 52 | return weight 53 | 54 | 55 | class QuantLinearTorchFunction(torch.autograd.Function): 56 | @staticmethod 57 | def symbolic(g, inputs, qweight, scales, qzeros, groupsize, bits, in_features, g_idx): 58 | # bias = g.op("Constant", value_t=torch.tensor([], dtype=torch.float16)) 59 | g_idx = g.op("Constant", value_t=torch.tensor([], dtype=torch.int32)) if g_idx is None else g_idx 60 | use_gemm_op = True 61 | if use_gemm_op: 62 | out_features = qweight.type().sizes()[-1] 63 | return g.op("com.microsoft::MatMulNBits", inputs, qweight, scales, qzeros, g_idx, 64 | outputs=1, K_i=in_features, N_i=out_features, bits_i=bits, block_size_i=groupsize, packing_s="gptq") 65 | else: 66 | fp_weight = g.op("com.microsoft::DequantizeLinearBlockWise", qweight, scales, qzeros, g_idx, 67 | outputs=1, groupsize_i=groupsize, bits_i=bits, in_features_i=in_features) 68 | return g.op("MatMul", inputs, fp_weight) 69 | 70 | @staticmethod 71 | def forward(ctx, inputs, qweight, scales, qzeros, groupsize, bits, in_features, g_idx): 72 | if torch.onnx.is_in_onnx_export(): 73 | return torch.zeros(inputs.shape[:-1] + (qweight.size(1), ), dtype=inputs.dtype, device=inputs.device) 74 | 75 | COMPATIBLE_WITH_AUTOGPTQ = int(os.environ.get("COMPATIBLE_WITH_AUTOGPTQ", "0")) 76 | if (not torch.onnx.is_in_onnx_export() 77 | and inputs.numel() // inputs.shape[-1] <= 8 78 | and bits == 4 79 | and has_ort_ops()): 80 | return ort_ops.gemv(inputs, qweight, scales, qzeros, g_idx, groupsize, bits, in_features, COMPATIBLE_WITH_AUTOGPTQ) 81 | if qweight.is_cuda and has_ort_ops(): 82 | weight = ort_ops.dequant(qweight, scales, qzeros, g_idx, groupsize, bits, in_features, COMPATIBLE_WITH_AUTOGPTQ) 83 | else: 84 | weight = DequantizeLinearBlockWise(qweight, scales, qzeros, groupsize, bits, in_features, g_idx) 85 | return torch.matmul(inputs, weight) 86 | 87 | 88 | def QuantLinearTorchFunction_forward(input, qweight, scales, qzeros, g_idx, bits, groupsize, in_features): 89 | return QuantLinearTorchFunction().apply(input, qweight, scales, qzeros, groupsize, bits, in_features, g_idx) 90 | 91 | 92 | class QuantLinearGPTQ(nn.Module, CompressWeight): 93 | 94 | def __init__(self, bits, groupsize, infeatures, outfeatures, bias, dtype=None): 95 | super().__init__() 96 | if bits not in [2, 3, 4, 5, 6, 7, 8]: 97 | raise NotImplementedError("Only 2,4,5,6,7,8 bits are supported.") 98 | self.dtype = torch.get_default_dtype() if dtype is None else dtype 99 | self.infeatures = infeatures 100 | self.outfeatures = outfeatures 101 | self.bits = bits 102 | self.act_order = None 103 | self.orig_fp_weight = None 104 | self.maxq = 2**self.bits - 1 105 | self.groupsize = groupsize if groupsize != -1 else infeatures 106 | self.pack_mode = "GPTQ" 107 | 108 | self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) 109 | self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), 110 | outfeatures // 32 * self.bits), dtype=torch.int32)) 111 | self.register_buffer('scales', torch.zeros( 112 | (math.ceil(infeatures / self.groupsize), outfeatures), dtype=self.dtype)) 113 | self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) 114 | if bias: 115 | self.register_buffer("bias", torch.zeros((outfeatures), dtype=self.dtype)) 116 | else: 117 | self.bias = None 118 | 119 | def handle_qzeros_for_autogptq(self): 120 | if self.qzeros.numel() == 0: 121 | return 122 | device = "cuda" if torch.cuda.is_available() else "cpu" 123 | qzeros = self.qzeros.to(device) 124 | zeros = torch.zeros((self.outfeatures, self.infeatures // self.groupsize), 125 | dtype=torch.int32, device=qzeros.device).T.contiguous() 126 | 127 | general_unpack_on_row(qzeros, zeros, self.bits) 128 | 129 | zeros += 1 130 | torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros) 131 | 132 | general_pack_on_row(qzeros, zeros, self.bits) 133 | 134 | self.qzeros = qzeros.to("cpu", non_blocking=True) 135 | 136 | def forward(self, x): 137 | if self.act_order is None: 138 | self.act_order = self.g_idx[:self.groupsize].sum() != 0 139 | g_idx = self.g_idx if self.act_order else None 140 | out = QuantLinearTorchFunction_forward(x, self.qweight, self.scales, 141 | self.qzeros, g_idx, self.bits, self.groupsize, self.infeatures) 142 | out = out + self.bias if self.bias is not None else out 143 | return out 144 | -------------------------------------------------------------------------------- /qllm/modeling/q_layers/quant_linear_hqq.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from .compress_weight import (CompressWeight, general_unpack_on_row) 6 | 7 | 8 | class DequantAndUnpack(torch.autograd.Function): 9 | @staticmethod 10 | def forward(ctx, qweight, scales, qzeros, groupsize, bits, in_features): 11 | scales = scales.reshape(-1, 1, scales.shape[-1]) 12 | qzeros = qzeros.reshape(-1, 1, qzeros.shape[-1]) 13 | if bits in [2, 4, 8]: 14 | wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32, device=qweight.device).unsqueeze(0) 15 | weight = torch.bitwise_right_shift(torch.unsqueeze( 16 | qweight, 1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8) 17 | torch.bitwise_and(weight, (2 ** bits) - 1, out=weight) 18 | else: 19 | weight = torch.zeros((in_features, qweight.shape[1]), dtype=torch.int32, device=qweight.device) 20 | general_unpack_on_row(qweight, weight, bits) 21 | 22 | scale_zeros = qzeros * scales 23 | weight = weight.reshape(-1, groupsize, weight.shape[-1]) 24 | weight = (scales * weight - scale_zeros.to(scales.dtype)) 25 | 26 | # weight = (scales * (weight - zeros)) 27 | weight = weight.reshape(-1, weight.shape[2]) 28 | return weight 29 | 30 | 31 | class QuantLinearTorchFunction(torch.autograd.Function): 32 | @staticmethod 33 | def forward(ctx, inputs, qweight, scales, qzeros, groupsize, bits, in_features): 34 | if torch.onnx.is_in_onnx_export(): 35 | return torch.zeros(inputs.shape[:-1] + (qweight.size(1), ), dtype=inputs.dtype, device=inputs.device) 36 | 37 | weight = DequantAndUnpack.apply(qweight, scales, qzeros, groupsize, bits, in_features) 38 | return torch.matmul(inputs, weight) 39 | 40 | @staticmethod 41 | def symbolic(g, inputs, qweight, scales, qzeros, groupsize, bits, in_features): 42 | out_features = qweight.type().sizes()[-1] 43 | return g.op("com.microsoft::MatMulNBits", inputs, qweight, scales, qzeros, 44 | outputs=1, K_i=in_features, N_i=out_features, bits_i=bits, block_size_i=groupsize, packing_s="hqq") 45 | 46 | 47 | class QuantLinearHQQ(nn.Module, CompressWeight): 48 | 49 | def __init__(self, bits, groupsize, infeatures, outfeatures, bias, dtype=None): 50 | super().__init__() 51 | self.dtype = torch.get_default_dtype() if dtype is None else dtype 52 | if bits not in [2, 3, 4, 5, 6, 7, 8]: 53 | raise NotImplementedError("Only 2,4,5,6,7,8 bits are supported.") 54 | self.infeatures = infeatures 55 | self.outfeatures = outfeatures 56 | self.bits = bits 57 | self.groupsize = groupsize if groupsize != -1 else infeatures 58 | self.pack_mode = "HQQ" 59 | self.orig_fp_weight = None 60 | self.g_idx = torch.tensor([i // groupsize for i in range(infeatures)], dtype=torch.int32) 61 | 62 | self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) 63 | self.register_buffer("qzeros", torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=self.dtype)) 64 | self.register_buffer("scales", torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=self.dtype)) 65 | if bias: 66 | self.register_buffer("bias", torch.zeros((outfeatures), dtype=self.dtype)) 67 | else: 68 | self.bias = None 69 | 70 | def unpack_qzeros(self, device): 71 | return self.qzeros.to(device) 72 | 73 | def pack_qzeros(self, intzeros, device): 74 | self.qzeros = intzeros.contiguous().to("cpu", non_blocking=True) 75 | 76 | def forward(self, x): 77 | out = QuantLinearTorchFunction.apply(x, self.qweight, self.scales, self.qzeros, 78 | self.groupsize, self.bits, self.infeatures) 79 | out = out + self.bias if self.bias is not None else out 80 | return out 81 | -------------------------------------------------------------------------------- /qllm/modeling/q_layers/quant_linear_marlin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from .compress_weight import CompressWeight 5 | from .ext_package_checker import has_awq_inference_engine 6 | from ...utils import comm_utils 7 | 8 | if has_awq_inference_engine(): 9 | from qllm import awq_inference_engine as marlin_cuda 10 | else: 11 | print("marlin_cuda is not installed.") 12 | 13 | DEBUG_ = False 14 | 15 | # Precompute permutations for Marlin weight and scale shuffling 16 | 17 | 18 | def _get_perms(): 19 | perm = [] 20 | for i in range(32): 21 | perm1 = [] 22 | col = i // 4 23 | for block in [0, 1]: 24 | for row in [2 * (i % 4), 2 * (i % 4) + 1, 2 * (i % 4 + 4), 2 * (i % 4 + 4) + 1]: 25 | perm1.append(16 * row + col + 8 * block) 26 | for j in range(4): 27 | perm.extend([p + 256 * j for p in perm1]) 28 | 29 | perm = np.array(perm) 30 | interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) 31 | perm = perm.reshape((-1, 8))[:, interleave].ravel() 32 | perm = torch.from_numpy(perm) 33 | scale_perm = [] 34 | for i in range(8): 35 | scale_perm.extend([i + 8 * j for j in range(8)]) 36 | scale_perm_single = [] 37 | for i in range(4): 38 | scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) 39 | return perm, scale_perm, scale_perm_single 40 | 41 | 42 | _perm, _scale_perm, _scale_perm_single = _get_perms() 43 | 44 | 45 | def mul(A, B, C, s, workspace, thread_k=-1, thread_n=-1, sms=-1, max_par=16): 46 | """Marlin FP16xINT4 multiply; can be used within `torch.compile`. 47 | @A: `torch.half` input matrix of shape `(m, k)` in standard row-major layout 48 | @B: `torch.int` weight matrix of original shape `(k, n)` in Marlin format; see `Layer.pack()` 49 | @C: `torch.half` out matrix of shape `(m, n)` in standard row-major layout 50 | @s: `torch.half` scales of shape `(m / group_size, n)` 51 | @workspace: `torch.int` tensor with at least as many entries as there a GPU SMs (256 is usually safe) 52 | @thread_k: `k` size of a thread_tile in `B` (can usually be left as auto -1) 53 | @thread_n: `n` size of a thread_tile in `B` (can usually be left as auto -1) 54 | @sms: number of SMs to use for the kernel (can usually be left as auto -1) 55 | @max_par: maximum number of batch 64 problems to solve in parallel for large input sizes 56 | """ 57 | marlin_cuda.mul(A, B, C, s, workspace, thread_k, thread_n, sms, max_par) 58 | 59 | 60 | class QuantLinearMarlin(nn.Module, CompressWeight): 61 | def __init__(self, bits, group_size, infeatures, outfeatures, bias): 62 | super().__init__() 63 | if bits not in [2, 3, 4, 5, 6, 7, 8]: 64 | raise NotImplementedError("Only 2,4,5,6,7,8 bits are supported.") 65 | self.infeatures = infeatures 66 | self.outfeatures = outfeatures 67 | self.bits = bits 68 | self.orig_fp_weight = None 69 | self.maxq = 2**self.bits - 1 70 | self.group_size = group_size if group_size != -1 else infeatures 71 | self.act_order = None 72 | self.pack_mode = "MARLIN" 73 | if infeatures % 128 != 0 or outfeatures != 256 == 0: 74 | raise ValueError("`infeatures` must be divisible by 128 and `outfeatures` by 256.") 75 | if bits not in [4]: 76 | raise NotImplementedError("Only 4 bits are supported.") 77 | if group_size not in [-1, 128] and group_size != infeatures: 78 | raise ValueError("Only group_size -1 and 128 are supported.") 79 | if infeatures % group_size != 0: 80 | raise ValueError("`infeatures` must be divisible by `group_size`.") 81 | self.register_buffer( 82 | "qweight", 83 | torch.zeros((self.infeatures // 16, self.outfeatures * 16 // 8), dtype=torch.int), 84 | ) 85 | self.register_buffer( 86 | "scales", torch.zeros((self.infeatures // group_size, self.outfeatures), dtype=torch.float16) 87 | ) 88 | self.register_buffer("workspace", torch.zeros(self.outfeatures // 128 * 16, dtype=torch.int), persistent=False) 89 | self.g_idx = None 90 | if bias: 91 | self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16)) 92 | else: 93 | self.bias = None 94 | 95 | def pack(self, linear, scales, zeros, g_idx=None): 96 | if g_idx is not None: 97 | self.act_order = g_idx[: self.group_size // self.bits].sum().item() != 0 98 | assert self.act_order is False 99 | assert zeros is None or torch.all(zeros == 8), ("only support symmetry quantization, \ 100 | please add `--sym` and try again ") 101 | if linear.weight.dtype != torch.half: 102 | raise ValueError("Only `torch.half` weights are supported.") 103 | device = torch.device("cuda") 104 | tile = 16 105 | maxq = 2**4 - 1 106 | s = scales.to(device).t() 107 | w = linear.weight.data.to(device).t() 108 | if self.group_size != self.infeatures: 109 | w = w.reshape((-1, self.group_size, self.outfeatures)) 110 | w = w.permute(1, 0, 2) 111 | w = w.reshape((self.group_size, -1)) 112 | s = s.reshape((1, -1)) 113 | w = torch.round(w / s).int() 114 | w += (maxq + 1) // 2 115 | w = torch.clamp(w, 0, maxq) 116 | if self.group_size != self.infeatures: 117 | w = w.reshape((self.group_size, -1, self.outfeatures)) 118 | w = w.permute(1, 0, 2) 119 | w = w.reshape((self.infeatures, self.outfeatures)).contiguous() 120 | s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm] 121 | else: 122 | s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single] 123 | s = s.reshape((-1, self.outfeatures)).contiguous() 124 | w = w.reshape((self.infeatures // tile, tile, self.outfeatures // tile, tile)) 125 | w = w.permute((0, 2, 1, 3)) 126 | w = w.reshape((self.infeatures // tile, self.outfeatures * tile)) 127 | res = w 128 | res = res.reshape((-1, _perm.numel()))[:, _perm].reshape(res.shape) 129 | q = torch.zeros((res.shape[0], res.shape[1] // 8), dtype=torch.int32, device=device) 130 | 131 | for i in range(8): 132 | q |= res[:, i::8] << (4 * i) 133 | q = q.to(torch.int32) 134 | self.qweight[:, :] = q.to(self.qweight.device) 135 | self.scales[:, :] = s.to(self.scales.device) 136 | if self.bias is not None: 137 | self.bias[:] = linear.bias.data.to(self.bias.device) 138 | 139 | def unpack(self): 140 | raise NotImplementedError("Marlin unpacking is not supported.") 141 | 142 | def forward(self, x): 143 | C = torch.empty(x.shape[:-1] + (self.scales.shape[1],), dtype=x.dtype, device=x.device) 144 | mul(x.view((-1, x.shape[-1])), self.qweight, C.view((-1, C.shape[-1])), self.scales, self.workspace) 145 | C = C + self.bias if self.bias is not None else C 146 | return C 147 | -------------------------------------------------------------------------------- /qllm/modeling/q_layers/quant_linear_onnxruntime.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | from .compress_weight import CompressWeight 6 | from .ext_package_checker import has_ort_ops 7 | 8 | if has_ort_ops(): 9 | from qllm import ort_ops 10 | else: 11 | print("ort_ops is not installed. Will fallback to Torch Backend") 12 | 13 | DEBUG_ = False 14 | 15 | 16 | class QuantLinearTorchFunction(torch.autograd.Function): 17 | @staticmethod 18 | def symbolic(g, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, groupsize, in_features, out_features): 19 | input_tuple = (x, qself_qweight, qself_scales, qself_qzeros) 20 | input_tuple += (g_idx,) if g_idx is not None else () 21 | return g.op( 22 | "com.microsoft::MatMulNBits", 23 | *input_tuple, 24 | outputs=1, 25 | K_i=in_features, 26 | N_i=out_features, 27 | bits_i=bits, 28 | block_size_i=groupsize, 29 | ) 30 | 31 | @staticmethod 32 | def forward(ctx, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, groupsize, in_features, out_features): 33 | if torch.onnx.is_in_onnx_export(): 34 | return torch.zeros(x.shape[:-1] + (out_features,), dtype=x.dtype, device=x.device) 35 | if not has_ort_ops(): 36 | fp_weight = dequantize_blockwise_4bits( 37 | qself_qweight, qself_scales, qself_qzeros, g_idx, in_features, out_features 38 | )[0] 39 | else: 40 | fp_weight = ort_ops.Dequantize4Bits( 41 | qself_qweight, qself_scales, qself_qzeros, g_idx, groupsize, in_features, out_features 42 | ) 43 | return torch.matmul(x, fp_weight.T) 44 | 45 | 46 | def QuantLinearTorchFunction_forward(inputs, qweight, scales, qzeros, g_idx, bits, groupsize, in_features, out_features): 47 | assert bits == 4, "Only 4 bits are supported." 48 | out = QuantLinearTorchFunction().apply(inputs, qweight, scales, qzeros, g_idx, bits, groupsize, in_features, out_features) 49 | return out 50 | 51 | 52 | def dequantize_blockwise_4bits(quant_values, scale, zero_point, g_idx, rows, cols): 53 | expand_quant_value = ( 54 | quant_values.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32, device=quant_values.device) 55 | ) & 0x0F 56 | expand_quant_value = expand_quant_value.reshape(*quant_values.shape[:-1], -1) 57 | aligned_scale = scale.reshape(*quant_values.shape[:-1], 1) 58 | if zero_point.dtype == scale.dtype: 59 | expand_zero_point = zero_point.reshape(*quant_values.shape[:-1], -1) 60 | else: 61 | expand_zero_point = ( 62 | zero_point.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32, device=quant_values.device) 63 | ) & 0x0F 64 | expand_zero_point = expand_zero_point.reshape(*quant_values.shape[:-1], -1) 65 | expand_zero_point = expand_zero_point[..., : aligned_scale.shape[-1]] 66 | if g_idx is not None and g_idx[:32].sum().item() != 0: 67 | float_values = ( 68 | (expand_quant_value.reshape(expand_quant_value.shape[0], -1) - expand_zero_point[:, g_idx, 0]) 69 | * aligned_scale[:, g_idx, 0] 70 | ).to(scale.dtype) 71 | else: 72 | float_values = ((expand_quant_value - expand_zero_point) * aligned_scale).to(scale.dtype) 73 | float_values = float_values.reshape(cols, -1) 74 | if rows != float_values.shape[-1]: 75 | float_values = float_values[:, :rows] 76 | expand_zero_point = expand_zero_point[:, :rows] 77 | if expand_zero_point.ndim == 3: 78 | expand_zero_point = expand_zero_point.squeeze(-1) 79 | if aligned_scale.ndim == 3: 80 | aligned_scale = aligned_scale.squeeze(-1) 81 | 82 | return float_values, expand_zero_point, aligned_scale 83 | 84 | 85 | class QuantLinearORT(nn.Module, CompressWeight): 86 | def __init__(self, bits, groupsize, infeatures, outfeatures, bias, dtype=None): 87 | super().__init__() 88 | self.dtype = torch.get_default_dtype() if dtype is None else dtype 89 | if bits not in [2, 3, 4, 5, 6, 7, 8]: 90 | raise NotImplementedError("Only 2,4,5,6,7,8 bits are supported.") 91 | self.infeatures = infeatures 92 | self.outfeatures = outfeatures 93 | self.bits = bits 94 | self.orig_fp_weight = None 95 | self.maxq = 2**self.bits - 1 96 | self.groupsize = groupsize if groupsize != -1 else infeatures 97 | self.act_order = None 98 | self.pack_mode = "ORT" 99 | q_rows = infeatures // self.groupsize 100 | self.register_buffer( 101 | "qweight", 102 | torch.zeros((outfeatures, q_rows, self.groupsize // (8 // bits)), dtype=torch.uint8), 103 | ) 104 | self.register_buffer( 105 | "qzeros", 106 | torch.zeros((q_rows+(q_rows&1)) * (outfeatures // 8 * self.bits), dtype=torch.uint8), 107 | ) 108 | self.register_buffer("scales", torch.zeros((math.ceil(infeatures / self.groupsize) * outfeatures), dtype=self.dtype)) 109 | self.register_buffer("g_idx", torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) 110 | if bias: 111 | self.register_buffer("bias", torch.zeros((outfeatures), dtype=self.dtype)) 112 | else: 113 | self.bias = None 114 | 115 | def pack_on_device(self, intweight_gpu, intzeros_T): 116 | self.act_order = self.g_idx[: self.groupsize // self.bits].sum().item() != 0 117 | assert self.bits == 4, "only 4bit is supported by ONNXRUNTIME for now." 118 | intzeros_pt = intzeros_T.T if intzeros_T.dtype == self.scales.dtype else intzeros_T.T.byte() 119 | scales_pt = self.scales.T.to(intweight_gpu.device) 120 | intweight_pt = intweight_gpu.byte() 121 | block_size = self.groupsize 122 | 123 | rows, cols = intweight_pt.shape 124 | blob_size = block_size // 2 125 | k_blocks = (rows + block_size - 1) // block_size 126 | padded_rows = k_blocks * block_size 127 | pad_len = padded_rows - rows 128 | if pad_len > 0: 129 | intweight_pt = torch.nn.functional.pad(intweight_pt, (0, 0, 0, pad_len), "constant", 0) 130 | intzeros_pt = torch.nn.functional.pad(intzeros_pt, (0, intzeros_pt.shape[-1]&1, 0, 0), "constant", 0) 131 | 132 | if intzeros_T.dtype != self.scales.dtype: 133 | intzeros_pt = (intzeros_pt[:, 0::2]) | (intzeros_pt[:, 1::2] << 4) 134 | intzeros_pt = intzeros_pt.reshape(-1) 135 | 136 | intweight_pt_T = intweight_gpu.T 137 | intweight_pt_T = (intweight_pt_T[:, 0::2]) | (intweight_pt_T[:, 1::2] << 4) 138 | intweight_pt_T = intweight_pt_T.reshape(cols, k_blocks, blob_size) 139 | 140 | scales_pt = scales_pt.reshape(-1) 141 | 142 | assert self.qweight.shape == intweight_pt_T.shape 143 | assert self.qzeros.shape == intzeros_pt.shape or self.qzeros.dtype != intzeros_pt.dtype 144 | 145 | self.scales = scales_pt.contiguous() 146 | self.qweight = intweight_pt_T.contiguous().byte() 147 | if intzeros_T.dtype != self.scales.dtype: 148 | self.qzeros = intzeros_pt.contiguous().byte() 149 | else: 150 | self.qzeros = intzeros_pt.contiguous() 151 | 152 | if DEBUG_: 153 | mat_float, _, _ = dequantize_blockwise_4bits(intweight_pt_T, scales_pt, intzeros_pt, self.g_idx.cuda(), rows, cols) 154 | print("mat_float", mat_float.shape, mat_float.dtype) 155 | 156 | def unpack(self): 157 | float_values, zero_point, scale = dequantize_blockwise_4bits( 158 | self.qweight, self.scales, self.qzeros, self.g_idx, self.infeatures, self.outfeatures 159 | ) 160 | float_values = float_values.contiguous() 161 | zero_point = zero_point.T.contiguous() 162 | scale = scale.T.contiguous() 163 | return ( 164 | float_values.to("cpu", non_blocking=True), 165 | scale.to("cpu", non_blocking=True), 166 | zero_point.to("cpu", non_blocking=True), 167 | ) 168 | 169 | def forward(self, x): 170 | out = QuantLinearTorchFunction_forward( 171 | x, self.qweight, self.scales, self.qzeros, self.g_idx if self.act_order else None, self.bits, self.groupsize, self.infeatures, self.outfeatures 172 | ) 173 | out = out + self.bias if self.bias is not None else out 174 | return out 175 | -------------------------------------------------------------------------------- /qllm/modeling/q_layers/quant_linear_vptq.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | from .ext_package_checker import has_ort_ops 7 | from .compress_weight import (CompressWeight, general_pack_on_row, 8 | general_unpack_on_row) 9 | if has_ort_ops(): 10 | from qllm import ort_ops 11 | 12 | 13 | # fake_layer 14 | class VQuantLinear(nn.Module, CompressWeight): 15 | 16 | def __init__(self, bits, groupsize, infeatures, outfeatures, bias, dtype=None): 17 | super().__init__() 18 | self.dtype = torch.get_default_dtype() if dtype is None else dtype 19 | if bits not in [2, 3, 4, 5, 6, 7, 8]: 20 | raise NotImplementedError("Only 2,4,5,6,7,8 bits are supported.") 21 | self.infeatures = infeatures 22 | self.outfeatures = outfeatures 23 | self.bits = bits 24 | self.act_order = None 25 | self.orig_fp_weight = None 26 | self.maxq = 2**self.bits - 1 27 | self.groupsize = groupsize if groupsize != -1 else infeatures 28 | self.pack_mode = "GPTQ" 29 | 30 | self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) 31 | self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), 32 | outfeatures // 32 * self.bits), dtype=torch.int32)) 33 | self.register_buffer("scales", torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=self.dtype)) 34 | self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) 35 | if bias: 36 | self.register_buffer("bias", torch.zeros((outfeatures), dtype=self.dtype)) 37 | else: 38 | self.bias = None 39 | 40 | 41 | def forward(self, x): 42 | return x 43 | -------------------------------------------------------------------------------- /qllm/modeling/q_layers/triton_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import triton 4 | import triton.language as tl 5 | from transformers.models.llama.modeling_llama import LlamaRMSNorm 6 | 7 | 8 | @triton.jit 9 | def rms_norm_fwd_fused( 10 | X, # pointer to the input 11 | Y, # pointer to the output 12 | W, # pointer to the weights 13 | stride, # how much to increase the pointer when moving by 1 row 14 | N, # number of columns in X 15 | eps, # epsilon to avoid division by zero 16 | BLOCK_SIZE: tl.constexpr, 17 | ): 18 | # Map the program id to the row of X and Y it should compute. 19 | row = tl.program_id(0) 20 | Y += row * stride 21 | X += row * stride 22 | # Compute variance 23 | _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 24 | for off in range(0, N, BLOCK_SIZE): 25 | cols = off + tl.arange(0, BLOCK_SIZE) 26 | x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) 27 | x = tl.where(cols < N, x, 0.) 28 | _var += x * x 29 | var = tl.sum(_var, axis=0) / N 30 | rstd = 1 / tl.sqrt(var + eps) 31 | # Normalize and apply linear transformation 32 | for off in range(0, N, BLOCK_SIZE): 33 | cols = off + tl.arange(0, BLOCK_SIZE) 34 | mask = cols < N 35 | w = tl.load(W + cols, mask=mask) 36 | x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) 37 | x_hat = x * rstd 38 | y = x_hat * w 39 | # Write output 40 | tl.store(Y + cols, y, mask=mask) 41 | 42 | 43 | class TritonLlamaRMSNorm(nn.Module): 44 | def __init__(self, weight, eps=1e-6): 45 | """ 46 | LlamaRMSNorm is equivalent to T5LayerNorm 47 | """ 48 | super().__init__() 49 | self.weight = weight 50 | self.variance_epsilon = eps 51 | 52 | def forward(self, x): 53 | y = torch.empty_like(x) 54 | # reshape input data into 2D tensor 55 | x_arg = x.reshape(-1, x.shape[-1]) 56 | M, N = x_arg.shape 57 | # Less than 64KB per feature: enqueue fused kernel 58 | MAX_FUSED_SIZE = 65536 // x.element_size() 59 | BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) 60 | if N > BLOCK_SIZE: 61 | raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") 62 | # heuristics for number of warps 63 | num_warps = min(max(BLOCK_SIZE // 256, 1), 8) 64 | # enqueue kernel 65 | rms_norm_fwd_fused[(M,)](x_arg, y, self.weight, 66 | x_arg.stride(0), N, self.variance_epsilon, 67 | BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) 68 | return y 69 | 70 | 71 | def make_quant_norm(model): 72 | """ 73 | Replace all LlamaRMSNorm modules with TritonLlamaRMSNorm modules 74 | """ 75 | 76 | for name, m in model.named_modules(): 77 | if not isinstance(m, LlamaRMSNorm): 78 | continue 79 | 80 | norm = TritonLlamaRMSNorm(m.weight, m.variance_epsilon) 81 | 82 | if '.' in name: 83 | parent_name = name.rsplit('.', 1)[0] 84 | child_name = name[len(parent_name) + 1:] 85 | parent = model.get_submodule(parent_name) 86 | else: 87 | parent_name = '' 88 | parent = model 89 | child_name = name 90 | 91 | # print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") 92 | 93 | setattr(parent, child_name, norm) 94 | -------------------------------------------------------------------------------- /qllm/plugin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wejoncy/QLLM/df20c15920bfabfd0581d7fcccbe87e5c96cd5c7/qllm/plugin/__init__.py -------------------------------------------------------------------------------- /qllm/plugin/chatcli/README.md: -------------------------------------------------------------------------------- 1 | # Chat CLI 2 | This code is inspired by FastChat, and modified for easy to use in CLI mode. 3 | 4 | ## Usage 5 | The main function is `chat_loop`, an example of using this ChatCLI is: 6 | 7 | ``` 8 | from chatcli import chat_loop 9 | 10 | # type !!exit to exit the chat loop 11 | # type !!reset to reset the conversation 12 | chat_loop(llama_baseline, tokenizer, echo=True) 13 | ``` 14 | 15 | There are two command: 16 | * `!!exit` means to exit chat loop 17 | * `!!reset` means to reset the chat context and start a new conversation. 18 | 19 | ## Chat loop 20 | Please refer to the `chat_loop` function: 21 | 22 | ``` 23 | def chat_loop( 24 | model, 25 | tokenizer, 26 | max_new_tokens: int = 512, 27 | generate_stream_func = generate_stream, 28 | generate_func = None, 29 | chatio: ChatIO = None, 30 | debug: bool = False, 31 | echo: bool = False, 32 | ): 33 | ``` 34 | 35 | Mostly, you only need to input `model` and `tokenizer`, leave the other args as default. 36 | 37 | The `generete_stream_func` is used to stream output tokens, while the `generete_func` will wait all tokens are generated and then output them. These two functions are responsible of doing generation work, if you need to modify the generation process, please refer to `generation.py` to modify these functions. 38 | 39 | The `echo` option is used for jupyter, because jupyter is not echo user input to the screen. 40 | 41 | When set `echo=True`, it will print user input to the screen, then it's better to see your input history. 42 | 43 | ## Conversation 44 | We use a class `Conversation` to save the chat context of one conversation. 45 | 46 | For different model, the conversation may be different, for example using different role name. 47 | 48 | Currently we only include `Llama2` conversation. 49 | -------------------------------------------------------------------------------- /qllm/plugin/chatcli/__init__.py: -------------------------------------------------------------------------------- 1 | from .generation import generate, generate_stream 2 | from .inference import chat_loop 3 | from .chatio import DistChatIO, SimpleChatIO 4 | -------------------------------------------------------------------------------- /qllm/plugin/chatcli/chatio.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class ChatIO(abc.ABC): 5 | @abc.abstractmethod 6 | def prompt_for_input(self, role: str) -> str: 7 | """Prompt for input from a role.""" 8 | 9 | @abc.abstractmethod 10 | def prompt_for_output(self, role: str): 11 | """Prompt for output from a role.""" 12 | 13 | @abc.abstractmethod 14 | def stream_output(self, output_stream): 15 | """Stream output.""" 16 | 17 | class SimpleChatIO(ChatIO): 18 | def __init__(self, multiline: bool = False, echo=False): 19 | self.multiline = multiline 20 | self.echo = echo 21 | 22 | def prompt_for_input(self, role) -> str: 23 | if not self.multiline: 24 | inputs = input(f"{role}: ") 25 | else: 26 | prompt_data = [] 27 | line = input(f"{role} [ctrl-d/z on empty line to end]: ") 28 | while True: 29 | prompt_data.append(line.strip()) 30 | try: 31 | line = input() 32 | except EOFError as e: 33 | break 34 | inputs = "\n".join(prompt_data) 35 | if self.echo: 36 | print(f'{role}: {inputs}') 37 | return inputs 38 | 39 | def prompt_for_output(self, role: str): 40 | print(f"{role}: ", end="", flush=True) 41 | 42 | def stream_output(self, output_stream): 43 | pre = 0 44 | for outputs in output_stream: 45 | output_text = outputs["text"] 46 | output_text = output_text.strip().split(" ") 47 | now = len(output_text) - 1 48 | if now > pre: 49 | print(" ".join(output_text[pre:now]), end=" ", flush=True) 50 | pre = now 51 | print(" ".join(output_text[pre:]), flush=True) 52 | return " ".join(output_text) 53 | 54 | def output(self, message): 55 | output = message['text'] 56 | print(output, flush=True) 57 | return output 58 | 59 | 60 | class DistChatIO(ChatIO): 61 | def __init__(self, multiline: bool = False): 62 | self.multiline = multiline 63 | from mpi4py import MPI 64 | self.comm = MPI.COMM_WORLD 65 | self.rank = self.comm.Get_rank() 66 | 67 | def _prompt_for_input(self, role) -> str: 68 | if not self.multiline: 69 | return input(f"{role}: ") 70 | 71 | prompt_data = [] 72 | line = input(f"{role} [ctrl-d/z on empty line to end]: ") 73 | while True: 74 | prompt_data.append(line.strip()) 75 | try: 76 | line = input() 77 | except EOFError as e: 78 | break 79 | return "\n".join(prompt_data) 80 | 81 | def prompt_for_input(self, role): 82 | if self.rank == 0: 83 | inputs = self._prompt_for_input(role) 84 | else: 85 | inputs = "" 86 | inputs = self.comm.bcast(inputs, root=0) 87 | return inputs 88 | 89 | def prompt_for_output(self, role: str): 90 | self._print(f"{role}: ", end="", flush=True) 91 | 92 | def _print(self, *args, **kwargs): 93 | if self.rank == 0: 94 | print(*args, **kwargs) 95 | 96 | def stream_output(self, output_stream): 97 | pre = 0 98 | for outputs in output_stream: 99 | output_text = outputs["text"] 100 | output_text = output_text.strip().split(" ") 101 | now = len(output_text) - 1 102 | if now > pre: 103 | self._print(" ".join(output_text[pre:now]), end=" ", flush=True) 104 | pre = now 105 | self._print(" ".join(output_text[pre:]), flush=True) 106 | outputs = " ".join(output_text) 107 | return outputs 108 | 109 | def output(self, message): 110 | output = message['text'] 111 | self._print(output, flush=True) 112 | return output 113 | 114 | 115 | -------------------------------------------------------------------------------- /qllm/plugin/chatcli/conversation.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import List 3 | 4 | @dataclasses.dataclass 5 | class Conversation: 6 | """A class that manages prompt templates and keeps all conversation history.""" 7 | 8 | # The name of this template 9 | name: str 10 | # The system prompt 11 | system: str 12 | # Two roles 13 | roles: List[str] 14 | # All messages. Each item is (role, message). 15 | messages: List[List[str]] 16 | # The number of few shot examples 17 | offset: int 18 | # Separators 19 | sep: str 20 | sep2: str = None 21 | # Stop criteria (the default one is EOS token) 22 | stop_str: str = None 23 | # Stops generation if meeting any token in this list 24 | stop_token_ids: List[int] = None 25 | 26 | def get_prompt(self) -> str: 27 | """Get the prompt for generation.""" 28 | seps = [self.sep, self.sep2] 29 | ret = "" 30 | for i, (role, message) in enumerate(self.messages): 31 | if message: 32 | if i == 0: 33 | ret += self.system + message 34 | else: 35 | ret += role + " " + message + seps[i % 2] 36 | else: 37 | ret += role 38 | return ret 39 | 40 | def append_message(self, role: str, message: str): 41 | """Append a new message.""" 42 | self.messages.append([role, message]) 43 | 44 | def update_last_message(self, message: str): 45 | """Update the last output. 46 | 47 | The last message is typically set to be None when constructing the prompt, 48 | so we need to update it in-place after getting the response from a model. 49 | """ 50 | self.messages[-1][1] = message 51 | 52 | def copy(self): 53 | return Conversation( 54 | name=self.name, 55 | system=self.system, 56 | roles=self.roles, 57 | messages=[[x, y] for x, y in self.messages], 58 | offset=self.offset, 59 | sep=self.sep, 60 | sep2=self.sep2, 61 | stop_str=self.stop_str, 62 | stop_token_ids=self.stop_token_ids, 63 | ) 64 | 65 | CONV_MAP = { 66 | 'llama2': Conversation( 67 | name="llama-2", 68 | system="", 69 | roles=("[INST]", "[/INST]"), 70 | messages=(), 71 | offset=0, 72 | sep=" ", 73 | sep2=" ", 74 | stop_token_ids=[2], 75 | ) 76 | } 77 | 78 | def get_conv(name: str): 79 | assert name in CONV_MAP, f'not find name in conversations.' 80 | return CONV_MAP[name].copy() 81 | -------------------------------------------------------------------------------- /qllm/plugin/chatcli/generation.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import transformers 4 | 5 | def generate_stream(model, tokenizer, prompt: str, device, max_new_tokens: int, context_len: int, echo: bool = False, stream_interval=2): 6 | stop_token_ids = [model.config.eos_token_id] 7 | device = model.device 8 | 9 | inputs = tokenizer(prompt) 10 | 11 | lhs_tokens = torch.tensor(inputs.input_ids, dtype=torch.int64, device=device).unsqueeze(0) 12 | 13 | past_kvs = transformers.DynamicCache() 14 | output_ids = list(inputs.input_ids) 15 | input_echo_len = len(output_ids) 16 | 17 | # check max_new_tokens 18 | remain_tokens = context_len - input_echo_len 19 | max_new_tokens = min(remain_tokens, max_new_tokens) 20 | 21 | for i in range(max_new_tokens): 22 | with torch.no_grad(): 23 | lhs_results = model(lhs_tokens, past_key_values=past_kvs, use_cache=True) 24 | 25 | logits = lhs_results.logits 26 | past_kvs = lhs_results.past_key_values 27 | 28 | # greedy search 29 | lhs_tokens = torch.argmax( 30 | lhs_results.logits[:, -1, :], dim=1, keepdim=True) 31 | 32 | token = lhs_tokens[0].item() 33 | output_ids.append(token) 34 | 35 | if token in stop_token_ids: 36 | stoped = True 37 | else: 38 | stoped = False 39 | 40 | if i % stream_interval == 0 or i == max_new_tokens - 1 or stoped: 41 | if echo: 42 | tmp_output_ids = output_ids 43 | else: 44 | tmp_output_ids = output_ids[input_echo_len:] 45 | 46 | output = tokenizer.decode( 47 | tmp_output_ids, 48 | skip_special_tokens=True, 49 | spaces_between_special_tokens=True, 50 | clean_up_tokenization_spaces=True 51 | ) 52 | 53 | yield { 54 | 'text': output, 55 | } 56 | 57 | if stoped: 58 | break 59 | 60 | yield { 61 | 'text': output 62 | } 63 | 64 | 65 | def generate(model, tokenizer, prompt: str, max_new_tokens:int, context_len: int, echo: bool=False): 66 | stop_token_ids = [model.config.eos_token_id] 67 | device = model.device 68 | 69 | inputs = tokenizer(prompt) 70 | 71 | lhs_tokens = torch.tensor(inputs.input_ids, dtype=torch.int64, device=device).unsqueeze(0) 72 | 73 | past_kvs = transformers.DynamicCache() 74 | output_ids = list(inputs.input_ids) 75 | input_echo_len = len(output_ids) 76 | 77 | # check max_new_tokens 78 | remain_tokens = context_len - input_echo_len 79 | max_new_tokens = min(remain_tokens, max_new_tokens) 80 | 81 | for i in range(max_new_tokens): 82 | with torch.no_grad(): 83 | lhs_results = model(lhs_tokens, past_key_values=past_kvs, use_cache=True) 84 | 85 | logits = lhs_results.logits 86 | past_kvs = lhs_results.past_key_values 87 | 88 | # greedy search 89 | lhs_tokens = torch.argmax( 90 | lhs_results.logits[:, -1, :], dim=1, keepdim=True) 91 | 92 | token = lhs_tokens[0].item() 93 | output_ids.append(token) 94 | 95 | if token in stop_token_ids: 96 | stoped = True 97 | else: 98 | stoped = False 99 | 100 | if stoped: 101 | break 102 | 103 | if echo: 104 | tmp_output_ids = output_ids 105 | else: 106 | tmp_output_ids = output_ids[input_echo_len:] 107 | 108 | output = tokenizer.decode( 109 | tmp_output_ids, 110 | skip_special_tokens=True, 111 | spaces_between_special_tokens=True, 112 | clean_up_tokenization_spaces=True 113 | ) 114 | 115 | return {'text': output} 116 | -------------------------------------------------------------------------------- /qllm/plugin/conversation.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .chatcli import chat_loop, generate 3 | 4 | 5 | def loop_in_chat_completion(tokenizer, llm:nn.Module): 6 | llm = llm.cuda() 7 | 8 | chat_loop( 9 | llm, 10 | tokenizer, 11 | generate_func=generate, 12 | max_new_tokens=512, 13 | ) -------------------------------------------------------------------------------- /qllm/plugin/perplexity_utils.py: -------------------------------------------------------------------------------- 1 | #from https://github.com/AutoGPTQ/AutoGPTQ/blob/main/auto_gptq/utils/perplexity_utils.py 2 | import sys 3 | 4 | import numpy as np 5 | import torch 6 | from datasets import load_dataset 7 | from tqdm import tqdm 8 | 9 | 10 | class Perplexity: 11 | """ 12 | A class for calculating the perplexity of a language model. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | model, 18 | tokenizer, 19 | dataset_path="wikitext", 20 | dataset_name=None, 21 | split="test", 22 | text_column="text", 23 | ): 24 | """ 25 | Calculate perplexity using the same method as seen in llama.cpp. 26 | 27 | Parameters 28 | ---------- 29 | model : AutoModelForCausalLM 30 | The language model for which the perplexity is calculated. 31 | tokenizer : AutoTokenizer 32 | The tokenizer corresponding to the model. 33 | device : str, optional 34 | The device to run the calculations on. If auto, the device that your model uses 35 | will be the device used for these calculations. Default is 'auto'. 36 | dataset_path : str, optional 37 | The path to the dataset on the Hugging Face dataset hub. Default is 'wikitext'. 38 | dataset_name : str, optional 39 | The name of the dataset. Default is None. 40 | split : str, optional 41 | The split of the dataset to use. Default is 'test'. 42 | text_column : str, optional 43 | The name of the column in the dataset that contains the text data. Default is 'text'. 44 | """ 45 | self._model = model 46 | self._tokenizer = tokenizer 47 | self._dataset_path = dataset_path 48 | self._dataset_name = dataset_name 49 | self._split = split 50 | self._text_column = text_column 51 | self._text = self._prepare_data() 52 | 53 | def _get_device(self): 54 | if torch.backends.mps.is_available(): 55 | return "mps" 56 | elif torch.cuda.is_available(): 57 | return "cuda:0" 58 | else: 59 | return "cpu" 60 | 61 | def _prepare_data(self): 62 | """ 63 | Prepares the dataset by loading and formatting. 64 | 65 | Returns 66 | ------- 67 | str 68 | The formatted dataset as a single string. 69 | """ 70 | if self._dataset_path == "wikitext": 71 | self._dataset_name = "wikitext-2-raw-v1" 72 | 73 | # Load the dataset 74 | data = load_dataset(self._dataset_path, self._dataset_name, split=self._split) 75 | # Format the text column of the dataset 76 | text_list = [" \n" if s == "" else s for s in data[self._text_column]] 77 | return "".join(text_list) 78 | 79 | @staticmethod 80 | def softmax(logits): 81 | """ 82 | Static method for applying the softmax function. 83 | 84 | Parameters 85 | ---------- 86 | logits : np.ndarray 87 | The input to the softmax function. 88 | 89 | Returns 90 | ------- 91 | np.ndarray 92 | The output of the softmax function. 93 | """ 94 | e_x = np.exp(logits - np.max(logits)) 95 | return e_x / e_x.sum(axis=0) 96 | 97 | def calculate_perplexity(self, n_ctx=512, n_batch=512): 98 | """ 99 | Calculates the perplexity of the language model. 100 | 101 | Parameters 102 | ---------- 103 | n_ctx : int 104 | The context size. 105 | n_batch : int 106 | The batch size. 107 | 108 | Returns 109 | ------- 110 | list 111 | The list of perplexity scores calculated. 112 | """ 113 | # Tokenize the text 114 | self._tokenizer.model_max_length = sys.maxsize 115 | tokens = self._tokenizer(self._text, truncation=False, return_tensors="pt").input_ids.to(self._model.device) 116 | 117 | nll = 0.0 # Negative log likelihood 118 | count = 0 # Counter for processed tokens 119 | curr_ppl = 0 120 | all_perplexity = [] 121 | 122 | with tqdm(range(len(tokens[0]) // n_ctx), desc="Perplexity: - ") as progress: 123 | for i in progress: 124 | # Process each batch of tokens 125 | nll, count = self._process_batch(i, n_ctx, n_batch, tokens, nll, count) 126 | 127 | # Calculate and display the current perplexity 128 | curr_ppl = np.exp(nll / count) 129 | all_perplexity.append(curr_ppl) 130 | progress.set_description(f"Perplexity: {curr_ppl:.4f}") 131 | 132 | return all_perplexity 133 | 134 | def _process_batch(self, i, n_ctx, n_batch, tokens, nll, count): 135 | """ 136 | Processes each batch of tokens. 137 | 138 | Parameters 139 | ---------- 140 | i : int 141 | The batch index. 142 | n_ctx : int 143 | The context size. 144 | n_batch : int 145 | The batch size. 146 | tokens : torch.Tensor 147 | The tokenized text. 148 | nll : float 149 | The current negative log likelihood. 150 | count : int 151 | The current count of processed tokens. 152 | 153 | Returns 154 | ------- 155 | float 156 | The updated negative log likelihood. 157 | int 158 | The updated count of processed tokens. 159 | """ 160 | start = i * n_ctx 161 | end = start + n_ctx 162 | 163 | num_batches = (n_ctx + n_batch - 1) // n_batch 164 | 165 | logits = [] 166 | 167 | for j in range(num_batches): 168 | batch_start = start + j * n_batch 169 | batch_size = min(end - batch_start, n_batch) 170 | 171 | token_org = tokens[0][batch_start].item() 172 | 173 | if j == 0: 174 | # Replace the first token with the BOS token 175 | tokens[0][batch_start] = self._tokenizer.bos_token_id 176 | 177 | # Compute the logits for the current batch of tokens 178 | batch_logits = self._compute_batch_logits(tokens, batch_start, batch_size) 179 | 180 | tokens[0][batch_start] = token_org 181 | 182 | logits.append(batch_logits) 183 | 184 | # We rely on the fact that attention in the forward pass only looks at previous 185 | # tokens here, so the logits returned for each token are an accurate representation 186 | # of what the model would have predicted at that point. 187 | # 188 | # Example, we have a context window of 512, we will compute perplexity for each of the 189 | # last 256 tokens. Then, we split the input up into context window size chunks to 190 | # process the entire prompt. 191 | 192 | for j in range(min(512, n_ctx // 2), n_ctx - 1): 193 | tok_logits = logits[0][0][j].cpu().numpy() 194 | # Compute the probability of the next token 195 | prob = self.softmax(tok_logits)[tokens[0][start + j + 1]] 196 | 197 | # Update the negative log likelihood and the count of processed tokens 198 | nll += -np.log(prob, where=prob > 0) 199 | count += 1 200 | 201 | return nll, count 202 | 203 | def _compute_batch_logits(self, tokens, batch_start, batch_size): 204 | """ 205 | Computes the logits for a batch of tokens. 206 | 207 | Parameters 208 | ---------- 209 | tokens : torch.Tensor 210 | The tokenized text. 211 | batch_start : int 212 | The start index of the batch. 213 | batch_size : int 214 | The size of the batch. 215 | 216 | Returns 217 | ------- 218 | torch.Tensor 219 | The logits for the batch of tokens. 220 | """ 221 | # Compute the logits without keeping track of gradients 222 | with torch.no_grad(): 223 | outputs = self._model(tokens[:, batch_start : batch_start + batch_size]) 224 | return outputs.logits.detach() 225 | -------------------------------------------------------------------------------- /qllm/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | from .awq.quant_awq import AWQQuant 2 | from .gptq.quant_gptq import GPTQQuant 3 | from .hqq.quant_hqq import HQQQuant 4 | from .vptq.quant_vptq import VPTQQuant 5 | from .config_builder import build_config 6 | 7 | def get_quantizer(config): 8 | if config.quant_method == "gptq": 9 | return GPTQQuant(config) 10 | elif config.quant_method == "awq": 11 | return AWQQuant(config) 12 | elif config.quant_method == "hqq": 13 | return HQQQuant(config) 14 | elif config.quant_method == "vptq": 15 | return VPTQQuant(config) 16 | else: 17 | raise NotImplementedError 18 | -------------------------------------------------------------------------------- /qllm/quantization/awq/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wejoncy/QLLM/df20c15920bfabfd0581d7fcccbe87e5c96cd5c7/qllm/quantization/awq/__init__.py -------------------------------------------------------------------------------- /qllm/quantization/awq/quant_awq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | import functools 4 | from collections import defaultdict 5 | 6 | from ...utils.comm_utils import clear_memory 7 | from ...utils.modelutils import set_op_by_name 8 | 9 | from ..quant_frame_base import QuantFrameBase 10 | from ...utils import find_layers 11 | from ._awq_quantizer import (InternalAWQuantizer, pseudo_quantize_tensor, 12 | USE_ACCUMULATE_BATCH, ScaledActivation) 13 | 14 | 15 | def scale_activations(module): 16 | param = next(module.parameters()) 17 | dtype = param.dtype 18 | device = param.device 19 | if 'mptblock' in str(module.__class__.__name__).lower(): 20 | if isinstance(module.ffn.act, ScaledActivation): 21 | return 22 | c = module.ffn.up_proj.out_features 23 | act = ScaledActivation( 24 | module.ffn.act, 25 | torch.ones(c, dtype=dtype, device=device) 26 | ) 27 | set_op_by_name(module, "ffn.act", act) 28 | elif 'falcon' in str(module.__class__).lower(): 29 | if isinstance(module.mlp.act, ScaledActivation): 30 | return 31 | c = module.mlp.dense_h_to_4h.out_features 32 | act = ScaledActivation( 33 | module.mlp.act, 34 | torch.ones(c, dtype=dtype, device=device) 35 | ) 36 | set_op_by_name(module, "mlp.act", act) 37 | 38 | 39 | class AWQQuant(QuantFrameBase): 40 | def __init__(self, config) -> None: 41 | super().__init__() 42 | self.auto_scale = True 43 | self.auto_clip = True 44 | self.quant_config = config 45 | 46 | def hijack_internal_block(self, named_linears, layer_block, inps, layer_kwargs): 47 | dev = next(layer_block.parameters()).device 48 | # firstly, get input features of all linear layers 49 | if "mixtral" in (layer_block).__class__.__name__.lower(): 50 | named_linears.pop("block_sparse_moe.gate", None) 51 | named_linears = {**named_linears, "block_sparse_moe": layer_block.block_sparse_moe} 52 | def cache_input_hook(m, x, y, name, feat_dict): 53 | x = x[0] 54 | x = x.detach().cpu() 55 | feat_dict[name].append(x) 56 | 57 | input_feat = defaultdict(list) 58 | handles = [] 59 | for name in named_linears: 60 | handles.append(named_linears[name].register_forward_hook( 61 | functools.partial(cache_input_hook, name=name, feat_dict=input_feat))) 62 | # in case multi-gpu 63 | # get output as next layer's input 64 | if USE_ACCUMULATE_BATCH == -1: 65 | inps = inps.to(dev) 66 | outs = layer_block(inps, **layer_kwargs)[0] 67 | else: 68 | outs = [] 69 | for start in range(0, len(inps), USE_ACCUMULATE_BATCH): 70 | end = min(start + USE_ACCUMULATE_BATCH, len(inps)) 71 | single_x = inps[start:end].to(dev) 72 | outs.append(layer_block(single_x, **layer_kwargs)[0]) 73 | for key in input_feat: 74 | input_feat[key] = [torch.cat(input_feat[key], dim=0)] 75 | outs = torch.concat(outs, dim=0) 76 | for h in handles: 77 | h.remove() 78 | # now solve for scaling and clipping 79 | input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()} 80 | 81 | # Clear GPU memory 82 | clear_memory() 83 | return outs, input_feat 84 | 85 | def _apply_quant(self, model, named_linears, quantizers, state_dict_prefix, version="GEMM"): 86 | bits = self.quant_config.to_meta.bits 87 | groupsize = self.quant_config.to_meta.group_size 88 | for name, linear_layer in named_linears.items(): 89 | # NOTE: small regression in perplexity if linear layer uses .cpu().float() 90 | # NOTE: use the original dtype of the model 91 | linear_layer = linear_layer.cuda() 92 | 93 | linear_layer.weight.data, scales, zeros = pseudo_quantize_tensor( 94 | linear_layer.weight.data, 95 | n_bit=bits, 96 | q_config=self.quant_config, 97 | get_scale_zp=True, 98 | ) 99 | # get_op_name(model, linear_layer) 100 | layer_key = f"{state_dict_prefix}.{name}" 101 | if zeros is not None: 102 | zeros = zeros.cpu() 103 | quantizers[layer_key] = (None, scales.cpu(), zeros, None, bits, groupsize) 104 | linear_layer.cpu() 105 | clear_memory(scales, zeros) 106 | 107 | @torch.no_grad() 108 | def do_quantize(self, model, dataloader, model_prefix, dev): 109 | inps, attention_layers, layer_kwargs = self.hijack_block_inputs(model, dataloader, model_prefix, dev) 110 | run_batch = len(dataloader) if USE_ACCUMULATE_BATCH == -1 else USE_ACCUMULATE_BATCH 111 | if layer_kwargs.get('attention_mask', None) is not None: 112 | layer_kwargs['attention_mask'] = layer_kwargs['attention_mask'].expand(run_batch, -1, -1, -1) 113 | print('Ready.') 114 | 115 | quantizers = {} 116 | # solve layer by layer 117 | for i in tqdm.tqdm(range(len(attention_layers)), desc="Running AWQ..."): 118 | layer = attention_layers[i] 119 | layer = layer.cuda() 120 | named_linears = find_layers(layer, self.quant_layers) 121 | inps, input_feat = self.hijack_internal_block(named_linears, layer, inps, layer_kwargs) 122 | 123 | in_quantizer = InternalAWQuantizer() 124 | in_quantizer.configure(self.quant_config.to_meta.bits, self.quant_config, self.auto_scale, self.auto_clip) 125 | 126 | in_quantizer.fast_quant_layer(layer_kwargs, input_feat, layer, attention_layers, i, model.__class__.__name__) 127 | self._apply_quant(model, named_linears, quantizers, f"{model_prefix}.{i}") 128 | 129 | layer = layer.cpu() 130 | # Haotian: check activation replacement 131 | clear_memory(input_feat) 132 | # real_quantize_model_weight(attention_layers, self.quant_config.to_meta.bits, self.quant_config) 133 | 134 | return quantizers 135 | -------------------------------------------------------------------------------- /qllm/quantization/config_builder.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from dataclasses import dataclass, asdict 3 | import json 4 | import typing 5 | 6 | @dataclass 7 | class MetaConfig: 8 | bits: int 9 | group_size: int 10 | quant_method: str 11 | 12 | 13 | @dataclass 14 | class MetaInterface: 15 | to_dict = asdict 16 | dict = asdict 17 | @property 18 | def to_meta(self): 19 | bits = -1 20 | if hasattr(self, "bits"): 21 | bits = self.bits 22 | elif hasattr(self, "w_bit"): 23 | bits = self.w_bit 24 | if hasattr(self, "group_size"): 25 | group_size = self.group_size 26 | elif hasattr(self, "q_group_size"): 27 | group_size = self.q_group_size 28 | return MetaConfig(bits, group_size, self.quant_method) 29 | 30 | 31 | @dataclass 32 | class GPTQConfig(MetaInterface): 33 | damp_percent: float 34 | group_size: int 35 | desc_act: str 36 | bits: int 37 | sym: bool 38 | allow_mix_bits: bool 39 | true_sequential: bool 40 | static_groups: bool = False 41 | version: str = "" 42 | quant_method: str = "gptq" 43 | 44 | 45 | @dataclass 46 | class AWQConfig(MetaInterface): 47 | q_group_size: int 48 | w_bit: int 49 | zero_point: bool 50 | version: str = "" 51 | quant_method: str = "awq" 52 | 53 | 54 | @dataclass 55 | class HQQConfig(MetaInterface): 56 | group_size: int 57 | bits: int 58 | version: str = "" 59 | quant_method: str = "hqq" 60 | 61 | 62 | def dataclass_from_dict(klass, d): 63 | try: 64 | fieldtypes = {f.name:f.type for f in dataclasses.fields(klass)} 65 | return klass(**{f:dataclass_from_dict(fieldtypes[f], d[f]) for f in d}) 66 | except: 67 | return d # Not a dataclass field 68 | 69 | @dataclass 70 | class HessianConfig(MetaInterface): 71 | batch_size : int = 2 72 | devset_size : int = 32 # 3072 73 | iter_size : int = 16 74 | ctx_size : int = 8192 75 | chunk_size : int = 256 76 | base_model : str = None 77 | act_save_rate : int = 50 78 | sample_proc : int = 4 79 | scratch_path: str = None 80 | save_activations: bool = False 81 | save_path: str = None 82 | 83 | @dataclass 84 | class VPTQLayerConfig(MetaInterface): 85 | bias: bool = dataclasses.field(default=False) 86 | enable_norm: bool = dataclasses.field(default=True) 87 | enable_perm: bool = dataclasses.field(default=True) 88 | group_num: int = dataclasses.field(default=1) 89 | outlier_size: int = dataclasses.field(default=0) 90 | group_size: int = dataclasses.field(default=-1) 91 | vector_lens: tuple = (-1, 8) 92 | num_centroids: tuple = (-1, 65536) 93 | num_res_centroids: tuple = (-1, 256) 94 | 95 | @dataclass 96 | class VPTQConfig(MetaInterface): 97 | model_name: str = dataclasses.field(default="meta-llama/Meta-Llama-3.1-8B-Instruct") 98 | seq_len: int = dataclasses.field(default=8192) 99 | quant_step: int = dataclasses.field(default=1) 100 | percdamp: float = dataclasses.field(default=0.01) 101 | blocksize: int = dataclasses.field(default=128) 102 | output_dir: str = dataclasses.field(default="outputs") 103 | seed: int = dataclasses.field(default=0) 104 | save_model: bool = dataclasses.field(default=False) 105 | # disable_actorder: bool = dataclasses.field(default=False) 106 | hessian_path: typing.Optional[str] = dataclasses.field(default=None) 107 | inv_hessian_path: typing.Optional[str] = dataclasses.field(default=None) 108 | num_gpus: int = dataclasses.field(default=1) 109 | # eval_nsamples: int = dataclasses.field(default=128) 110 | save_qlinear: bool = dataclasses.field(default=False) 111 | absorb_perm: bool = dataclasses.field(default=True) 112 | 113 | npercent : int = 0 114 | kmeans_mode : str= "hessian" 115 | norm_dim : int = 1 116 | ktol : float = 1e-5 117 | kiter :int = 100 118 | 119 | hessian_config: HessianConfig = dataclasses.field(default_factory=HessianConfig) 120 | layer_config: VPTQLayerConfig = dataclasses.field(default_factory=VPTQLayerConfig) 121 | version: str = "" 122 | quant_method: str = "vptq" 123 | 124 | @classmethod 125 | def from_dict(cls, config: dict): 126 | return dataclass_from_dict(cls, config) 127 | 128 | @dataclass 129 | class VPTQInferConfig(MetaInterface): 130 | group_size: int = 8 131 | bits: int = 2 132 | version: str = "" 133 | quant_method: str = "vptq" 134 | config_for_layers: typing.Dict[str, dict] = dataclasses.field(default_factory=dict) 135 | 136 | 137 | def build_config(args): 138 | if args.quant_method == 'gptq': 139 | config = GPTQConfig( 140 | damp_percent=args.percdamp, 141 | group_size=args.groupsize, 142 | desc_act=args.act_order, 143 | bits=args.wbits, 144 | sym=args.sym, 145 | allow_mix_bits=args.allow_mix_bits, 146 | true_sequential=args.true_sequential, 147 | static_groups=args.static_groups, 148 | ) 149 | elif args.quant_method == 'awq': 150 | config = AWQConfig(args.groupsize, args.wbits, not args.sym) 151 | elif args.quant_method == "hqq": 152 | config = HQQConfig(args.groupsize, args.wbits) 153 | elif args.quant_method == "vptq": 154 | with open(args.quant_config, 'r') as fp: 155 | dict_config = json.load(fp) 156 | config = VPTQConfig.from_dict(dict_config) 157 | config.model_name = args.load + args.model # one of them is empty 158 | 159 | return config 160 | -------------------------------------------------------------------------------- /qllm/quantization/gptq/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wejoncy/QLLM/df20c15920bfabfd0581d7fcccbe87e5c96cd5c7/qllm/quantization/gptq/__init__.py -------------------------------------------------------------------------------- /qllm/quantization/gptq/_gptq_quantizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class InternalGPTQQuantizer(nn.Module): 6 | 7 | def __init__(self, shape=1): 8 | super(InternalGPTQQuantizer, self).__init__() 9 | self.register_buffer('maxq', torch.tensor(0)) 10 | self.register_buffer('scale', torch.zeros(shape)) 11 | self.register_buffer('zero', torch.zeros(shape)) 12 | 13 | def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False): 14 | 15 | self.maxq = torch.tensor(2**bits - 1) 16 | self.perchannel = perchannel 17 | self.sym = sym 18 | self.mse = mse 19 | self.norm = norm 20 | self.grid = grid 21 | self.maxshrink = maxshrink 22 | if trits: 23 | self.maxq = torch.tensor(-1) 24 | self.scale = torch.zeros_like(self.scale) 25 | 26 | def _quantize(self, x, scale, zero, maxq): 27 | if maxq < 0: 28 | return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero 29 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) 30 | return scale * (q - zero) 31 | 32 | def find_params(self, x, weight=False): 33 | dev = x.device 34 | self.maxq = self.maxq.to(dev) 35 | 36 | shape = x.shape 37 | if self.perchannel: 38 | if weight: 39 | x = x.flatten(1) 40 | else: 41 | if len(shape) == 4: 42 | x = x.permute([1, 0, 2, 3]) 43 | x = x.flatten(1) 44 | if len(shape) == 3: 45 | x = x.reshape((-1, shape[-1])).t() 46 | if len(shape) == 2: 47 | x = x.t() 48 | else: 49 | x = x.flatten().unsqueeze(0) 50 | 51 | tmp = torch.zeros(x.shape[0], device=dev) 52 | xmin = torch.minimum(x.min(1)[0], tmp) 53 | xmax = torch.maximum(x.max(1)[0], tmp) 54 | 55 | if self.sym: 56 | xmax = torch.maximum(torch.abs(xmin), xmax) 57 | tmp = xmin < 0 58 | if torch.any(tmp): 59 | xmin[tmp] = -xmax[tmp] 60 | tmp = (xmin == 0) & (xmax == 0) 61 | xmin[tmp] = -1 62 | xmax[tmp] = +1 63 | 64 | if self.maxq < 0: 65 | self.scale = xmax 66 | self.zero = xmin 67 | else: 68 | self.scale = (xmax - xmin) / self.maxq 69 | if self.sym: 70 | self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) 71 | else: 72 | self.zero = torch.round(-xmin / self.scale) 73 | 74 | if self.mse: 75 | best = torch.full([x.shape[0]], float('inf'), device=dev) 76 | for i in range(int(self.maxshrink * self.grid)): 77 | p = 1 - i / self.grid 78 | xmin1 = p * xmin 79 | xmax1 = p * xmax 80 | scale1 = (xmax1 - xmin1) / self.maxq 81 | zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero 82 | q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) 83 | q -= x 84 | q.abs_() 85 | q.pow_(self.norm) 86 | err = torch.sum(q, 1) 87 | tmp = err < best 88 | if torch.any(tmp): 89 | best[tmp] = err[tmp] 90 | self.scale[tmp] = scale1[tmp] 91 | self.zero[tmp] = zero1[tmp] 92 | if not self.perchannel: 93 | if weight: # noqa:SIM108 94 | tmp = shape[0] 95 | else: 96 | tmp = shape[1] if len(shape) != 3 else shape[2] 97 | self.scale = self.scale.repeat(tmp) 98 | self.zero = self.zero.repeat(tmp) 99 | 100 | if weight: 101 | shape = [-1] + [1] * (len(shape) - 1) 102 | self.scale = self.scale.reshape(shape) 103 | self.zero = self.zero.reshape(shape) 104 | return 105 | if len(shape) == 4: 106 | self.scale = self.scale.reshape((1, -1, 1, 1)) 107 | self.zero = self.zero.reshape((1, -1, 1, 1)) 108 | if len(shape) == 3: 109 | self.scale = self.scale.reshape((1, 1, -1)) 110 | self.zero = self.zero.reshape((1, 1, -1)) 111 | if len(shape) == 2: 112 | self.scale = self.scale.unsqueeze(0) 113 | self.zero = self.zero.unsqueeze(0) 114 | 115 | def quantize(self, x): 116 | if self.ready(): 117 | return self._quantize(x, self.scale, self.zero, self.maxq) 118 | 119 | return x 120 | 121 | def enabled(self): 122 | return self.maxq > 0 123 | 124 | def ready(self): 125 | return torch.all(self.scale != 0) 126 | -------------------------------------------------------------------------------- /qllm/quantization/gptq/quant_gptq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | from texttable import Texttable 4 | 5 | from ..quant_frame_base import QuantFrameBase 6 | from .gptq import GPTQ, Observer 7 | from ...utils import find_layers, gen_conditions 8 | from ...utils.logger import get_logger 9 | from . import sequential_layes_gptq_config 10 | logger = get_logger('qllm') 11 | 12 | class ObserverHelper: 13 | def __init__(self, config) -> None: 14 | self.observer = Observer() 15 | self.quant_config = config 16 | 17 | def submit(self, name, layerid, gptq, error): 18 | if self.quant_config.allow_mix_bits: 19 | self.observer.submit(name=name, layerid=layerid, gptq=gptq, error=error) 20 | return True 21 | return False 22 | 23 | def post_quant(self, quantizers, state_dict_prefix): 24 | if not self.quant_config.allow_mix_bits: 25 | return 26 | logger.debug(self.observer.print()) 27 | conditions = gen_conditions(self.quant_config.bits, self.quant_config.group_size) 28 | for item in tqdm.tqdm(self.observer.items(), desc="Optimizing with mix bits/groupsize"): 29 | name = item[0] 30 | layerid = item[1] 31 | gptq = item[2]['gptq'] 32 | error = item[2]['error'] 33 | target = error / 2 34 | 35 | table = Texttable() 36 | table.header(['wbits', 'groupsize', 'error']) 37 | table.set_cols_dtype(['i', 'i', 'f']) 38 | table.add_row([self.quant_config.bits, self.quant_config.group_size, error]) 39 | 40 | logger.debug('Optimizing {} {} ..'.format(name, layerid)) 41 | for wbits, groupsize in conditions: 42 | 43 | if error < target: 44 | # if error dropped 50%, skip 45 | break 46 | 47 | gptq.quantizer.configure(wbits, perchannel=True, sym=self.quant_config.sym, mse=False) 48 | 49 | scale, zero, g_idx, error = gptq.fasterquant( 50 | percdamp=self.quant_config.damp_percent, 51 | groupsize=groupsize, 52 | actorder=self.quant_config.desc_act, 53 | static_groups=self.quant_config.static_groups, 54 | name=name, 55 | ) 56 | 57 | table.add_row([wbits, groupsize, error]) 58 | quantizers[f'{state_dict_prefix}.{layerid}.{name}'] = ( 59 | gptq.quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), wbits, groupsize) 60 | 61 | logger.debug(table.draw()) 62 | logger.debug('\n') 63 | gptq.layer.to('cpu') 64 | gptq.free() 65 | 66 | 67 | class GPTQQuant(QuantFrameBase): 68 | 69 | def __init__(self, config) -> None: 70 | super().__init__() 71 | self.quant_config = config 72 | 73 | def hijack_internal_block(self, gptq, subset, layer_block, inps, layer_kwargs): 74 | dev = next(layer_block.parameters()).device 75 | 76 | def add_batch(name): 77 | def tmp(_, inp, out): 78 | gptq[name].add_batch(inp[0].data, out.data) 79 | return tmp 80 | 81 | handles = [] 82 | for name in subset: 83 | handles.append(subset[name].register_forward_hook(add_batch(name))) 84 | for j in range(len(inps)): 85 | _ = layer_block(inps[j].unsqueeze(0).to(dev), **layer_kwargs) 86 | for h in handles: 87 | h.remove() 88 | 89 | @torch.inference_mode() 90 | def do_quantize(self, model, dataloader, model_prefix, dev): 91 | inps, attention_layers, layer_input_args = self.hijack_block_inputs(model, dataloader, model_prefix, dev) 92 | outs = torch.zeros_like(inps) 93 | print('Ready.') 94 | 95 | quantizers = {} 96 | observer_helper = ObserverHelper(self.quant_config) 97 | for i in tqdm.tqdm(range(len(attention_layers)), desc="running GPTQ"): 98 | self.hook_before_qlayer(i, self.quant_config) 99 | 100 | block_layer = attention_layers[i].to(dev) 101 | named_linear_layers = find_layers(block_layer, self.quant_layers) 102 | 103 | sequential = [list(named_linear_layers.keys())] 104 | # filter out the layers that shouldnt be quantized 105 | true_sequential = sequential_layes_gptq_config.auto_detect_sequential_layers( 106 | sequential, model.__class__.__name__) 107 | if not self.quant_config.true_sequential: 108 | sequential_tmp = [sum(true_sequential, [])] 109 | if len(sequential_tmp[0]) != len(sequential[0]): 110 | # if true_sequential is not the same as sequential, we need to re-order the layers 111 | sequential[0] = sorted(sequential_tmp[0], key=lambda x: sequential[0].index(x)) 112 | for names in sequential: 113 | subset = {n: named_linear_layers[n] for n in names} 114 | gptq = {} 115 | for name in subset: 116 | gptq[name] = GPTQ(subset[name], allow_mix_bits=self.quant_config.allow_mix_bits) 117 | gptq[name].quantizer.configure( 118 | self.quant_config.bits, perchannel=True, sym=self.quant_config.sym, mse=False 119 | ) 120 | 121 | self.hijack_internal_block(gptq, subset, block_layer, inps, layer_input_args) 122 | 123 | for name in subset: 124 | scale, zero, g_idx, error = gptq[name].fasterquant( 125 | percdamp=self.quant_config.damp_percent, 126 | groupsize=self.quant_config.group_size, 127 | actorder=self.quant_config.desc_act, 128 | static_groups=self.quant_config.static_groups, 129 | name=name, 130 | ) 131 | quantizers[f"{model_prefix}.{i}.{name}"] = ( 132 | gptq[name].quantizer.cpu(), 133 | scale.cpu(), 134 | zero.cpu(), 135 | g_idx.cpu(), 136 | self.quant_config.bits, 137 | self.quant_config.group_size, 138 | ) 139 | 140 | if not observer_helper.submit(name=name, layerid=i, gptq=gptq[name], error=error): 141 | gptq[name].free() 142 | 143 | # [ TODO ] 144 | # I am supposing layer's weight should be quantized and modified, we are statisting the error 145 | # accumulated from the previous layers and compensate next layer 146 | for j in range(len(dataloader)): 147 | outs[j] = block_layer(inps[j].unsqueeze(0).to(dev), **layer_input_args)[0].cpu() 148 | 149 | attention_layers[i] = block_layer.cpu() 150 | del block_layer 151 | del gptq 152 | torch.cuda.empty_cache() 153 | 154 | inps, outs = outs, inps 155 | 156 | observer_helper.post_quant(quantizers, model_prefix) 157 | return quantizers 158 | -------------------------------------------------------------------------------- /qllm/quantization/gptq/sequential_layes_gptq_config.py: -------------------------------------------------------------------------------- 1 | # Description: This file contains the sequential layers for each model. 2 | 3 | true_sequential_layers_for_model = dict( 4 | BaiChuanForCausalLM=[ 5 | ["self_attn.W_pack"], 6 | ["self_attn.o_proj"], 7 | ["mlp.up_proj", "mlp.gate_proj"], 8 | ["mlp.down_proj"] 9 | ], 10 | BloomForCausalLM=[ 11 | ["self_attention.query_key_value"], 12 | ["self_attention.dense"], 13 | ["mlp.dense_h_to_4h"], 14 | ["mlp.dense_4h_to_h"] 15 | ], 16 | CodeGenForCausalLM=[ 17 | ["attn.qkv_proj"], 18 | ["attn.out_proj"], 19 | ["mlp.fc_in"], 20 | ["mlp.fc_out"] 21 | ], 22 | DeciLMForCausalLM=[ 23 | ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], 24 | ["self_attn.o_proj"], 25 | ["mlp.up_proj", "mlp.gate_proj"], 26 | ["mlp.down_proj"] 27 | ], 28 | GPT2ForCausalLM=[ 29 | ["attn.qkv_proj"], 30 | ["attn.out_proj"], 31 | ["mlp.fc_in"], 32 | ["mlp.fc_out"] 33 | ], 34 | GPTBigCodeForCausalLM=[ 35 | ["attn.c_attn"], 36 | ["attn.c_proj"], 37 | ["mlp.c_fc"], 38 | ["mlp.c_proj"] 39 | ], 40 | GPTNeoXForCausalLM=[ 41 | ["attention.query_key_value"], 42 | ["attention.dense"], 43 | ["mlp.dense_h_to_4h"], 44 | ["mlp.dense_4h_to_h"] 45 | ], 46 | GPTJForCausalLM=[ 47 | ["attn.k_proj", "attn.v_proj", "attn.q_proj"], 48 | ["attn.out_proj"], 49 | ["mlp.fc_in"], 50 | ["mlp.fc_out"] 51 | ], 52 | InternLMForCausalLM=[ 53 | ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], 54 | ["self_attn.o_proj"], 55 | ["mlp.up_proj", "mlp.gate_proj"], 56 | ["mlp.down_proj"], 57 | ], 58 | LlamaForCausalLM=[ 59 | ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], 60 | ["self_attn.o_proj"], 61 | ["mlp.up_proj", "mlp.gate_proj"], 62 | ["mlp.down_proj"] 63 | ], 64 | MistralForCausalLM=[ 65 | ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], 66 | ["self_attn.o_proj"], 67 | ["mlp.up_proj", "mlp.gate_proj"], 68 | ["mlp.down_proj"], 69 | ], 70 | MixtralForCausalLM = [ 71 | ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], 72 | ["self_attn.o_proj"], 73 | [ 74 | "block_sparse_moe.experts.0.w1", 75 | "block_sparse_moe.experts.1.w1", 76 | "block_sparse_moe.experts.2.w1", 77 | "block_sparse_moe.experts.3.w1", 78 | "block_sparse_moe.experts.4.w1", 79 | "block_sparse_moe.experts.5.w1", 80 | "block_sparse_moe.experts.6.w1", 81 | "block_sparse_moe.experts.7.w1", 82 | "block_sparse_moe.experts.0.w3", 83 | "block_sparse_moe.experts.1.w3", 84 | "block_sparse_moe.experts.2.w3", 85 | "block_sparse_moe.experts.3.w3", 86 | "block_sparse_moe.experts.4.w3", 87 | "block_sparse_moe.experts.5.w3", 88 | "block_sparse_moe.experts.6.w3", 89 | "block_sparse_moe.experts.7.w3", 90 | ], 91 | [ 92 | "block_sparse_moe.experts.0.w2", 93 | "block_sparse_moe.experts.1.w2", 94 | "block_sparse_moe.experts.2.w2", 95 | "block_sparse_moe.experts.3.w2", 96 | "block_sparse_moe.experts.4.w2", 97 | "block_sparse_moe.experts.5.w2", 98 | "block_sparse_moe.experts.6.w2", 99 | "block_sparse_moe.experts.7.w2", 100 | ] 101 | ], 102 | MOSSForCausalLM=[ 103 | ["attn.qkv_proj"], 104 | ["attn.out_proj"], 105 | ["mlp.fc_in"], 106 | ["mlp.fc_out"] 107 | ], 108 | OPTForCausalLM=[ 109 | ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], 110 | ["self_attn.out_proj"], 111 | ["fc1"], 112 | ["fc2"] 113 | ], 114 | QwenForCausalLM=[ 115 | ["attn.c_attn"], 116 | ["attn.c_proj"], 117 | ["mlp.w1", "mlp.w2"], 118 | ["mlp.c_proj"] 119 | ], 120 | RWForCausalLM=[ 121 | ["self_attention.query_key_value"], 122 | ["self_attention.dense"], 123 | ["mlp.dense_h_to_4h"], 124 | ["mlp.dense_4h_to_h"] 125 | ], 126 | StableLMEpochForCausalLM=[ 127 | ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], 128 | ["self_attn.o_proj"], 129 | ["mlp.up_proj", "mlp.gate_proj"], 130 | ["mlp.down_proj"] 131 | ], 132 | XverseForCausalLM=[ 133 | ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], 134 | ["self_attn.o_proj"], 135 | ["mlp.up_proj", "mlp.gate_proj"], 136 | ["mlp.down_proj"] 137 | ], 138 | YiForCausalLM=[ 139 | ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], 140 | ["self_attn.o_proj"], 141 | ["mlp.up_proj", "mlp.gate_proj"], 142 | ["mlp.down_proj"] 143 | ], 144 | Qwen2ForCausalLM=[ 145 | ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], 146 | ["self_attn.o_proj"], 147 | ["mlp.up_proj", "mlp.gate_proj"], 148 | ["mlp.down_proj"], 149 | ], 150 | ) 151 | 152 | 153 | def auto_detect_sequential_layers(flatten_layers, model_type): 154 | if model_type in true_sequential_layers_for_model: 155 | return true_sequential_layers_for_model[model_type] 156 | top_layers = [] 157 | layers = [flatten_layers[0][0]] 158 | for i in range(1, len(flatten_layers[0])): 159 | if flatten_layers[0][i].split('.')[0] == layers[-1].split('.')[0]: 160 | layers.append(flatten_layers[0][i]) 161 | else: 162 | top_layers.append(layers) 163 | layers = [flatten_layers[0][i]] 164 | top_layers.append(layers) 165 | 166 | # filter out o_projection 167 | top_layers.insert(1, []) 168 | top_layers[1].append(top_layers[0][-1]) 169 | top_layers[0].pop(-1) 170 | top_layers.append([]) 171 | top_layers[-1].append(top_layers[-2][-1]) 172 | top_layers[-2].pop(-1) 173 | return top_layers -------------------------------------------------------------------------------- /qllm/quantization/hqq/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wejoncy/QLLM/df20c15920bfabfd0581d7fcccbe87e5c96cd5c7/qllm/quantization/hqq/__init__.py -------------------------------------------------------------------------------- /qllm/quantization/hqq/_hqq_quantizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class InternalHQQQuantizer(nn.Module): 7 | 8 | def __init__(self, layer: nn.Module): 9 | super(InternalHQQQuantizer, self).__init__() 10 | self.scale_quant_params = None 11 | self.zero_quant_params = None 12 | self.layer = layer 13 | 14 | 15 | def configure(self, bits: int, channel_wise: bool = True, group_size: int = 64, 16 | optimize: bool = False, round_zero: bool = False, axis: int = 0): 17 | self.bits = bits 18 | assert bits in [2, 3, 4, 8], "bits=" + str(bits) + " not supported." 19 | 20 | self.channel_wise = channel_wise 21 | self.group_size = group_size 22 | self.optimize = optimize 23 | self.round_zero = round_zero 24 | self.axis = axis 25 | 26 | # Proximal solver || W - dequantize(quantize(W))||_p^p 27 | 28 | @torch.inference_mode() 29 | def optimize_weights_proximal(self, tensor, scale, zero, min_max, axis=0, device='cuda', 30 | opt_params={'lp_norm': 0.7, 'beta': 1e1, 'kappa': 1.01, 'iters': 20}, verbose=False): # noqa:B006 31 | lp_norm, beta, kappa, iters = opt_params['lp_norm'], opt_params['beta'], opt_params['kappa'], opt_params['iters'] 32 | 33 | dtype = torch.float16 if (device == 'cuda') else torch.float32 34 | W_f = tensor.to(dtype).to(device) 35 | scale = scale.to(dtype).to(device) 36 | zero = zero.to(dtype).to(device) 37 | 38 | if (lp_norm == 1): 39 | shrink_op = lambda x, beta: torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1. / beta) 40 | else: 41 | shrink_op = lambda x, beta, p=lp_norm: torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - (1. / beta) * torch.pow(torch.abs(x), p - 1)) 42 | 43 | best_error = 1e4 44 | for i in range(iters): 45 | W_q = torch.round(W_f * scale + zero).clamp(min_max[0], min_max[1]) 46 | W_r = (W_q - zero) / scale 47 | W_e = shrink_op(W_f - W_r, beta) 48 | zero = torch.mean(W_q - (W_f - W_e) * scale, axis=axis, keepdim=True) 49 | beta *= kappa 50 | 51 | current_error = float(torch.abs(W_f - W_r).mean()) 52 | if (verbose): 53 | print(i, np.round(current_error, 6)) 54 | if (current_error < best_error): 55 | best_error = current_error 56 | else: 57 | break 58 | 59 | scale = scale.to(tensor.device) 60 | zero = zero.to(tensor.device) 61 | del W_f, W_q, W_r, W_e 62 | torch.cuda.empty_cache() 63 | 64 | return scale, zero 65 | 66 | def quantize(self): 67 | tensor = self.layer.weight 68 | nbits = self.bits 69 | channel_wise = self.channel_wise 70 | group_size = self.group_size 71 | optimize = self.optimize 72 | round_zero = self.round_zero 73 | axis = self.axis 74 | 75 | assert axis in [0, 1], "axis should be either 0 or 1" 76 | if (group_size is not None): 77 | assert tensor.shape[axis] % group_size == 0, "group_size should be divisble by the total tensor dimensions." 78 | 79 | W = tensor.float() 80 | shape = W.shape 81 | 82 | # Reshape for grouping 83 | if ((group_size is not None) and channel_wise): 84 | W = W.reshape([-1, group_size]) if (axis == 1) else W.reshape([group_size, -1]) 85 | 86 | # Get min/max values 87 | if (channel_wise is False): 88 | _min, _max = W.min(), W.max() 89 | optimize = False 90 | else: 91 | _min = W.min(axis=axis, keepdim=True)[0] 92 | _max = W.max(axis=axis, keepdim=True)[0] 93 | 94 | max_v = 2**nbits - 1 95 | min_v = 0 96 | min_max = [min_v, max_v] 97 | 98 | # Note: here we work with the inverse of the scale to avoid division and quantize instead via W*scale + zero, the scale is inverted later on. 99 | scale = (max_v / (_max - _min)).clamp(max=2e4) # clamp to avoid half-precision problems 100 | zero = -_min * scale 101 | 102 | if (round_zero): zero = torch.round(zero) 103 | 104 | # Fine-tune weights 105 | if (optimize): 106 | scale, zero = self.optimize_weights_proximal(tensor=W, scale=scale, zero=zero, min_max=min_max, axis=axis) 107 | #Quantize 108 | W_q = torch.round(W*scale + zero).clamp(min_max[0], min_max[1]) 109 | self.layer.weight.data = ((W_q- zero)/scale).reshape(shape).type(tensor.dtype) 110 | scale = 1.0/scale 111 | # cleanup 112 | del W, _min, _max 113 | torch.cuda.empty_cache() 114 | 115 | if axis == 1: 116 | scale = scale.reshape(shape[0], -1) 117 | zero = zero.reshape(shape[0], -1) 118 | else: 119 | scale = scale.reshape(-1, shape[-1]) 120 | zero = zero.reshape(-1, shape[-1]) 121 | return scale.cpu(), zero.cpu() 122 | 123 | def free(self): 124 | del self.layer 125 | torch.cuda.empty_cache() 126 | -------------------------------------------------------------------------------- /qllm/quantization/hqq/quant_hqq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | 4 | from ..quant_frame_base import QuantFrameBase 5 | from ._hqq_quantizer import InternalHQQQuantizer 6 | from ...utils import find_layers 7 | from ...utils.logger import get_logger 8 | logger = get_logger('qllm') 9 | 10 | class HQQQuant(QuantFrameBase): 11 | def __init__(self, config) -> None: 12 | super().__init__() 13 | self.quant_config = config 14 | 15 | 16 | @torch.inference_mode() 17 | def do_quantize(self, model, dataloader, model_prefix, dev): 18 | dataloader = [] 19 | _, attention_layers, layer_input_args = self.hijack_block_inputs(model, dataloader, model_prefix, dev) 20 | print('Ready.') 21 | bits, groupsize = self.quant_config.to_meta.bits, self.quant_config.to_meta.group_size 22 | quantizers = {} 23 | for i in tqdm.tqdm(range(len(attention_layers)), desc="running HQQ"): 24 | block_layer = attention_layers[i].to(dev) 25 | named_linear_layers = find_layers(block_layer, self.quant_layers) 26 | 27 | # [ TODO ] how to filter out the layers, which won't be quantized or harness the quality 28 | sequential = [list(named_linear_layers.keys())] 29 | for names in sequential: 30 | subset = {n: named_linear_layers[n] for n in names} 31 | gptq = {} 32 | for name in subset: 33 | gptq[name] = InternalHQQQuantizer(subset[name]) 34 | gptq[name].configure(bits, channel_wise=True, group_size=groupsize, 35 | optimize=True, round_zero=True, axis=1) 36 | scale, zero = gptq[name].quantize() 37 | quantizers[f'{model_prefix}.{i}.{name}'] = ( 38 | gptq[name], scale.cpu(), zero.cpu(), None, bits, groupsize) 39 | 40 | gptq[name].free() 41 | 42 | 43 | attention_layers[i] = block_layer.cpu() 44 | del block_layer 45 | del gptq 46 | torch.cuda.empty_cache() 47 | 48 | return quantizers 49 | 50 | -------------------------------------------------------------------------------- /qllm/quantization/quant_frame_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from ..utils import comm_utils, find_layers 4 | from ..utils.logger import get_logger 5 | from ..utils.modelutils import get_op_by_name 6 | 7 | logger = get_logger() 8 | 9 | class QuantFrameBase: 10 | def __init__(self) -> None: 11 | self.rec_use_cache = False 12 | self.quant_layers = [torch.nn.Linear] 13 | self.swap_device = torch.device('cpu') 14 | 15 | def set_tokenizer(self, tokenizer): 16 | pass 17 | 18 | @torch.no_grad() 19 | def prepare(self, model): 20 | print('Starting ...') 21 | self.rec_use_cache = getattr(model.config, 'use_cache', False) 22 | 23 | if hasattr(model, 'hf_device_map') and len(model.hf_device_map) > 1: 24 | pass 25 | else: 26 | model = model.cpu() 27 | model.config.use_cache = False 28 | return model 29 | 30 | @torch.no_grad() 31 | def extract_prefix(self, model): 32 | ''' 33 | heristicly extract the prefix of the state_dict 34 | support encoder-decoder model and decoder-only model 35 | a model usually has a state_dict like this: 36 | x.embed 37 | x.model x.encoder x.decoder 38 | x.lm_head 39 | x.dropout 40 | x.loss 41 | ''' 42 | 43 | state_dict_prefix = None 44 | prefix_list = [] #encoder/decoder 45 | for name in model.state_dict(): 46 | if '.0.' not in name: 47 | continue 48 | state_dict_prefix = name.split('.0.')[0] 49 | prefix_list.append(state_dict_prefix) 50 | 51 | prefix_list = list(set(prefix_list)) 52 | min_len = min([len(i) for i in prefix_list]) 53 | prefix_list = [i for i in prefix_list if len(i) == min_len] 54 | if len(prefix_list) > 1: 55 | raise ValueError(f"Multiple prefix found: {prefix_list}, encoder-decoder model is not supported") 56 | assert prefix_list, "state_dict_prefix not found" 57 | return prefix_list 58 | 59 | @torch.no_grad() 60 | def extract_layers(self, model, model_prefix): 61 | attention_layers = None 62 | pre_layers_of_attention = [] # enmbedding layer, norm layer 63 | # find the attention layers, and the pre layers of attention layers 64 | transformer_model = get_op_by_name(model, '.'.join(model_prefix.split('.')[:-1])) 65 | for _, layer in transformer_model.named_children(): 66 | if type(layer) in [torch.nn.ModuleList]: 67 | attention_layers = layer 68 | continue 69 | else: 70 | pre_layers_of_attention.append(layer) 71 | assert attention_layers is not None, "attention_layers not found" 72 | return attention_layers, pre_layers_of_attention 73 | 74 | def hijack_block_inputs(self, model, dataloader, model_prefix, dev): 75 | inps = [] 76 | layer_input_args = {} 77 | swap_device = self.swap_device 78 | 79 | class Catcher(nn.Module): 80 | def __init__(self, module): 81 | super().__init__() 82 | self.module = module 83 | 84 | def forward(self, inp, **kwargs): 85 | inps.append(inp.to(swap_device)) 86 | layer_input_args.update(kwargs) 87 | raise ValueError 88 | 89 | attention_layers, pre_layers_of_attention = self.extract_layers(model, model_prefix) 90 | for layer in pre_layers_of_attention: 91 | if isinstance(layer, torch.nn.Embedding) and hasattr(layer, '_old_forward'): 92 | layer._hf_hook.execution_device = dev 93 | layer = layer.to(dev) 94 | attention_layers[0] = attention_layers[0].to(dev) 95 | attention_layers[0] = Catcher(attention_layers[0]) 96 | for batch in dataloader: 97 | try: # noqa:SIM105 98 | model(batch[0].to(dev)) 99 | except ValueError: 100 | pass 101 | attention_layers[0] = attention_layers[0].module 102 | attention_layers[0] = attention_layers[0].cpu() 103 | for layer in pre_layers_of_attention: 104 | layer = layer.cpu() 105 | comm_utils.clear_memory() 106 | 107 | # allow dataloader is None 108 | inps = torch.cat(inps, dim=0) if inps else torch.tensor([]) 109 | return inps, attention_layers, layer_input_args 110 | 111 | def hook_before_qlayer(self, layer_id, config): 112 | mix_qlayer_conf = {} 113 | if str(layer_id + 1) in mix_qlayer_conf: 114 | layer_key = str(layer_id + 1) 115 | self.quant_config.wbits = mix_qlayer_conf[layer_key].get("wbits", self.quant_config.wbits) 116 | self.quant_config.groupsize = mix_qlayer_conf[layer_key].get("groupsize", self.quant_config.groupsize) 117 | 118 | def do_quantize(self, model, dataloader, model_prefix, dev): 119 | raise NotImplementedError 120 | 121 | def quantize(self, model, dataloader, dev): 122 | model = self.prepare(model) 123 | quantizers = {} 124 | state_dict_prefix:list = self.extract_prefix(model) 125 | for prefix in state_dict_prefix: 126 | quantizers.update(self.do_quantize(model, dataloader, prefix, dev)) 127 | 128 | model.config.use_cache = self.rec_use_cache 129 | model.quant_config = self.quant_config 130 | return quantizers 131 | -------------------------------------------------------------------------------- /qllm/quantization/vptq/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wejoncy/QLLM/df20c15920bfabfd0581d7fcccbe87e5c96cd5c7/qllm/quantization/vptq/__init__.py -------------------------------------------------------------------------------- /qllm/quantization/vptq/_vptq_quantizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import logging 4 | from ...utils.logger import get_logger 5 | 6 | logger = get_logger("qllm") 7 | 8 | class InternalVPTQQuantizer(nn.Module): 9 | 10 | def __init__(self): 11 | super(InternalVPTQQuantizer, self).__init__() 12 | 13 | 14 | def quantize_layer(self, tasks, args, quant_args, name2hessian=None, dev=None): 15 | """ 16 | Quantize the given layers in tasks. 17 | Args: 18 | task_id: Task ID 19 | tasks: List of layers to quantize 20 | args: Command line arguments 21 | quant_args: Quantization arguments 22 | input_queues: Input queue 23 | output_queues: Output queue 24 | name2hessian: Dictionary mapping layer names to Hessians 25 | """ 26 | if dev is not None: 27 | torch.cuda.set_device(dev) 28 | try: 29 | from vptq.layer_quantizer import layer_quantizer 30 | from vptq.quantize_executer import setup_logging 31 | except ImportError: 32 | logger.warning("Please install vptq by 'pip install -U vptq'") 33 | raise 34 | 35 | layer, layer_idx = tasks 36 | vptq_logger = setup_logging(f"{args.output_dir}/logs/", str(dev).replace(':', '_'), debug=False) 37 | 38 | dtype = next(iter(layer.parameters())).dtype 39 | layer, qlinear_args = layer_quantizer( 40 | args, quant_args, layer, layer_idx, vptq_logger, dev, dtype, name2hessian=name2hessian 41 | ) 42 | layer = layer.to(dtype).cpu() 43 | return layer 44 | -------------------------------------------------------------------------------- /qllm/quantization/vptq/inv_hessian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from argparse import ArgumentParser 3 | import os 4 | import tqdm 5 | from ...utils.logger import get_logger 6 | 7 | # load Hessian from files 8 | def load_hessian(hessian_path, pbar=None, logger=None): 9 | if logger is None and pbar is None: 10 | print(f'load Hessian from {hessian_path}') 11 | elif pbar is not None: 12 | pbar.set_postfix_str(f'load Hessian ...{hessian_path[-10:]}') 13 | else: 14 | logger.info(f'load Hessian from {hessian_path}') 15 | H_data = torch.load(f'{hessian_path}', weights_only=True, map_location='cpu') 16 | 17 | # convert H to sym matrix 18 | def flat_to_sym(V, N): 19 | A = torch.zeros(N, N, dtype=V.dtype, device=V.device) 20 | idxs = torch.tril_indices(N, N, device=V.device) 21 | A[idxs.unbind()] = V 22 | A[idxs[1, :], idxs[0, :]] = V 23 | return A 24 | 25 | def regularize_H(H, n, sigma_reg): 26 | H.div_(torch.diag(H).mean()) 27 | idx = torch.arange(n) 28 | H[idx, idx] += sigma_reg 29 | return H 30 | 31 | def basic_preprocess(H, mu, n): 32 | H.add_(mu[None, :] * mu[:, None]) 33 | H = regularize_H(H, n, 1e-2) 34 | return H, mu 35 | 36 | H = flat_to_sym(H_data['flatH'], H_data['n']) 37 | mu = H_data['mu'] 38 | n = H_data['n'] 39 | H, mu = basic_preprocess(H, mu, n) 40 | 41 | return H, mu 42 | 43 | def main(args): 44 | logger = get_logger("qllm") 45 | # create folder 46 | os.makedirs(args.store_inv_hessian_dir, exist_ok=True) 47 | 48 | percdamp = 0.01 49 | hessian_files = [f for f in os.listdir( 50 | args.load_hessian_dir) if f.endswith('.pt')] 51 | 52 | for hessian_file in (pbar := tqdm.tqdm(hessian_files, desc="Inverting Hessian")): 53 | hessian_path = os.path.join(args.load_hessian_dir, hessian_file) 54 | hessian, mu = load_hessian(hessian_path, pbar=pbar, logger=logger) 55 | dev = 'cuda' 56 | hessian = hessian.to(dev) 57 | 58 | zero_idx = torch.diag(hessian) == 0 59 | hessian[zero_idx, zero_idx] = 1 60 | 61 | # get permutation 62 | perm = torch.argsort(torch.diag(hessian), descending=True).to(dev) 63 | if args.enable_perm: 64 | hessian = hessian[perm][:, perm] 65 | 66 | # add damping 67 | damp = percdamp * torch.mean(torch.diag(hessian)) 68 | diag = torch.arange(hessian.shape[0], device=dev) 69 | hessian[diag, diag] += damp 70 | 71 | # inverse Hessian 72 | hessian = torch.linalg.cholesky(hessian) 73 | hessian = torch.cholesky_inverse(hessian) 74 | hessian = torch.linalg.cholesky(hessian, upper=True) 75 | inv_hessian = hessian 76 | 77 | # Saving the inverted Hessian to the specified directory with the same file name 78 | save_path = os.path.join(args.store_inv_hessian_dir, hessian_file) 79 | if args.enable_perm is False: 80 | perm = torch.arange(inv_hessian.shape[0]) 81 | 82 | torch.save({'invH': inv_hessian.to('cpu'), 83 | 'perm': perm.to('cpu'), 84 | 'zero_idx': zero_idx.to('cpu')}, save_path) 85 | 86 | pbar.set_postfix_str(f'Saved inverted Hessian to {save_path}') 87 | 88 | if __name__ == "__main__": 89 | parser = ArgumentParser() 90 | parser.add_argument('--load_hessian_dir', type=str, default=None, 91 | help='Directory containing Hessian .pt files') 92 | parser.add_argument('--store_inv_hessian_dir', type=str, default=None, 93 | help='Directory to save inverted Hessian .pt files') 94 | parser.add_argument('--enable_perm', action='store_true', 95 | help='Enable permutation of Hessian') 96 | args = parser.parse_args() 97 | 98 | -------------------------------------------------------------------------------- /qllm/quantization/vptq/merge_hessian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import tqdm 4 | import argparse 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser(description='Merge Hessian components across multiple groups') 8 | parser.add_argument('--base-dir', type=str, required=True, 9 | help='Base directory path (e.g., "./Hessians-Qwen2-57B-A14B-Instruct-6144-8k-seed-")') 10 | parser.add_argument('--save-dir', type=str, required=True, 11 | help='Directory to save merged results') 12 | parser.add_argument('--groups', type=int, nargs='+', required=True, 13 | help='Group numbers to merge (e.g., 4 5)') 14 | return parser.parse_args() 15 | 16 | def merge_and_save_hessian(base_dir, groups, save_dir, entry): 17 | """ 18 | Merges Hessian components across multiple groups and saves the merged result. 19 | Args: 20 | base_dir: Base directory path 21 | groups: List of group numbers 22 | save_dir: Directory to save results 23 | entry: File name to process 24 | """ 25 | if not os.path.exists(save_dir): 26 | os.makedirs(save_dir) 27 | 28 | total_flatH = None 29 | total_mu = None 30 | total_ct = 0 31 | 32 | for group in groups: 33 | full_path = os.path.join(f'{base_dir}{group}', entry) 34 | if full_path.endswith('.txt'):continue 35 | data = torch.load(full_path, weights_only=False) 36 | 37 | if total_flatH is None: 38 | total_flatH = torch.zeros_like(data['flatH']) 39 | total_mu = torch.zeros_like(data['mu']) 40 | 41 | total_flatH += data['flatH'] 42 | total_mu += data['mu'] * data['ct'] 43 | total_ct += data['ct'] 44 | 45 | average_mu = total_mu / total_ct if total_ct > 0 else total_mu 46 | 47 | merged_data = { 48 | 'flatH': total_flatH / len(groups), 49 | 'mu': average_mu, 50 | 'n': data['n'], 51 | 'ct': total_ct 52 | } 53 | 54 | save_path = os.path.join(save_dir, entry) 55 | torch.save(merged_data, save_path) 56 | # print(f"Merged data saved to {save_path}") 57 | 58 | def main(args): 59 | # Use the first group to get the list of files to process 60 | first_group_dir = f'{args.base_dir}{args.groups[0]}' 61 | for entry in (pbar := tqdm.tqdm(os.listdir(first_group_dir), desc="Merging Hessian")): 62 | if not entry.endswith('.pt'): continue 63 | pbar.set_postfix_str(f'Processing {entry}') 64 | merge_and_save_hessian( 65 | args.base_dir, 66 | args.groups, 67 | args.save_dir, 68 | entry 69 | ) 70 | # print('----') 71 | 72 | if __name__ == "__main__": 73 | args = parse_args() 74 | main(args) -------------------------------------------------------------------------------- /qllm/run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | from .auto_model_quantization import AutoModelQuantization 5 | from .args_config import FakeArgs 6 | 7 | def define_basic_args(): 8 | parser = argparse.ArgumentParser(description=""" 9 | A general tool to quantize LLMs with the GPTQ/AWQ/HQQ/VPTQ quant_method. 10 | you can easily quantize your model and save to checkpoint, which is compatiable with \ 11 | [vLLM](https://github.com/vllm-project/vllm). 12 | You can also test the quantized model with a conversation plugin. 13 | 14 | A typical usage is: 15 | python -m qllm --model meta-llama/Llama-2-7b-chat-hf --quant_method=awq \ 16 | --dataset=pileval --nsamples=16 --use_plugin --save ./Llama-2-7b-chat-hf_awq_q4/ \ 17 | --export_onnx ./onnx_models/ 18 | 19 | quant_method can be ['awq' or 'gptq', 'hqq', "vptq"] """, 20 | formatter_class=argparse.RawTextHelpFormatter) 21 | default_args = FakeArgs() 22 | parser.add_argument('--quant_method', type=str, default=default_args.quant_method, 23 | choices=["gptq", "awq", "hqq", "vptq"], help='the quantization quant_method') 24 | parser.add_argument('--model', type=str, default="", 25 | help='float/float16/bfloat16 model to load, such as [mosaicml/mpt-7b]') 26 | parser.add_argument('--tokenizer', type=str, default="", help='default same as [model]') 27 | parser.add_argument('--dataset', type=str, default=default_args.dataset, 28 | choices=['wikitext2', 'ptb', 'c4', 'ptb-new', 'c4-new', "pileval"], help='Where to extract calibration data from.') 29 | parser.add_argument('--seed', type=int, default=default_args.seed, help='Seed for sampling the calibration data.') 30 | parser.add_argument('--nsamples', type=int, default=default_args.nsamples, help='Number of calibration data samples.') 31 | parser.add_argument('--percdamp', type=float, default=default_args.percdamp, 32 | help='Percent of the average Hessian diagonal to use for dampening.') 33 | parser.add_argument('--quant_config', type=Path, default=None, 34 | help='a json file to config quantization like vptq. pass "--quant_config help" to get a example config') 35 | parser.add_argument( 36 | '--static-groups', action='store_true', 37 | help='(gptq only.) Whether to use static groups; recommended when using `--actorder` for more efficient inference.' 38 | ) 39 | parser.add_argument('--wbits', type=int, default=default_args.wbits, 40 | choices=[2, 3, 4, 5, 6, 7, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.') 41 | parser.add_argument('--mix_qlayer_conf', type=str, default=None, 42 | help='Mix quantization layer configuration.(groupsize,wbits)') 43 | parser.add_argument('--groupsize', type=int, default=default_args.groupsize, 44 | help='Groupsize to use for quantization; -1 uses full row.') 45 | parser.add_argument('--eval', action='store_true', help='evaluate quantized model.') 46 | parser.add_argument('--save', type=str, default='', help='Save quantized checkpoint under this name.') 47 | parser.add_argument('--save_safetensors', type=str, default='', 48 | help='Save quantized `.safetensors` checkpoint under this name.') 49 | parser.add_argument('--load', type=str, default='', help='Load quantized model.') 50 | parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.') 51 | parser.add_argument('--act-order', action='store_true', 52 | help='Whether to apply the activation order GPTQ heuristic') 53 | parser.add_argument('--true-sequential', action='store_true', help='Whether to run in true sequential model.') 54 | parser.add_argument('--allow_mix_bits',action='store_true', 55 | help='Auto upgrade layer precision to higher precision, for example int2 to int4, groupsize 128 to 64.') 56 | parser.add_argument('--export_onnx', type=str, default=None, help='where does the onnx model save to.') 57 | parser.add_argument('--use_plugin', action='store_true', help='test with plugin, such as fastchat conversation') 58 | parser.add_argument( 59 | "--pack_mode", 60 | type=str, 61 | default=default_args.pack_mode, 62 | choices=["AUTO", "GEMM", "GPTQ", "ORT", "HQQ", "MARLIN"], 63 | help="""the quantization pack mode, 64 | `GEMM` represents to use AWQ GEMM engine, same as what AutoAWQ used, 65 | `AUTO` represents that it is selected by the GPU arch, as awq GEMM needs SM75+ 66 | `GPTQ` represent using old GPTQ style DQ+GEMM, it is not fused but more general than AWQ-GEMM, 67 | `MARLIN` is fastest but require sm80+. 68 | `ORT` represents using onnxruntime packing stype, 69 | """, 70 | ) 71 | 72 | return parser 73 | 74 | 75 | def main(): 76 | parser = define_basic_args() 77 | args = parser.parse_args() 78 | print(args) 79 | 80 | model_quanter = AutoModelQuantization() 81 | model_quanter.run(args) 82 | 83 | 84 | if __name__ == '__main__': 85 | main() 86 | -------------------------------------------------------------------------------- /qllm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: F401 2 | from .modelutils import find_layers, gen_conditions, torch_snr_error 3 | from .datautils import get_loaders 4 | from . import comm_utils 5 | from . import logger 6 | -------------------------------------------------------------------------------- /qllm/utils/comm_utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import torch 3 | 4 | 5 | def set_seed(seed): 6 | import random 7 | import numpy as np 8 | random.seed(seed) 9 | np.random.seed(seed) 10 | torch.manual_seed(seed) 11 | 12 | def clear_memory(*args): 13 | for weight in args: 14 | del weight 15 | gc.collect() 16 | torch.cuda.empty_cache() 17 | 18 | 19 | def get_Model_Size(model): 20 | param_size = 0 21 | param_sum = 0 22 | for param in model.parameters(): 23 | param_size += param.nelement() * param.element_size() 24 | param_sum += param.nelement() 25 | buffer_size = 0 26 | buffer_sum = 0 27 | for buffer in model.buffers(): 28 | buffer_size += buffer.nelement() * buffer.element_size() 29 | buffer_sum += buffer.nelement() 30 | all_size = (param_size + buffer_size) / 1024 / 1024 31 | return all_size 32 | 33 | 34 | def retrieve_onnx_inputs(model, sample_inputs): 35 | user_inputs = [] 36 | 37 | def hook_for_inputs(mod, inputs, kwargs): 38 | user_inputs.append((inputs, kwargs)) 39 | return user_inputs[0] 40 | hook_handle = model.register_forward_pre_hook(hook_for_inputs, with_kwargs=True) 41 | import inspect 42 | forward_params = inspect.signature(model.forward).parameters 43 | input_keys = list(forward_params.keys()) 44 | default_values = [forward_params.get(key).default for key in input_keys] 45 | model(sample_inputs[0], attention_mask=sample_inputs[1]) 46 | hook_handle.remove() 47 | user_inputs = user_inputs[0] 48 | onnx_inputs = default_values 49 | for idx, _ in enumerate(user_inputs[0]): 50 | onnx_inputs[idx] = user_inputs[0][idx] 51 | for key, value in user_inputs[1].items(): 52 | idx = input_keys.index(key) 53 | onnx_inputs[idx] = value 54 | for value in onnx_inputs: 55 | if type(value) is torch.Tensor: 56 | value.to(model.device) 57 | return input_keys, tuple(onnx_inputs) 58 | 59 | 60 | def disable_huggingface_init(): 61 | # do not init model twice as it slow initialization 62 | import torch 63 | import torch.nn.init 64 | torch.nn.init.kaiming_uniform_ = lambda x, *args, **kwargs: x 65 | torch.nn.init.uniform_ = lambda x, *args, **kwargs: x 66 | torch.nn.init.normal_ = lambda x, *args, **kwargs: x 67 | torch.nn.init.constant_ = lambda x, *args, **kwargs: x 68 | torch.nn.init.xavier_uniform_ = lambda x, *args, **kwargs: x 69 | torch.nn.init.xavier_normal_ = lambda x, *args, **kwargs: x 70 | torch.nn.init.kaiming_normal_ = lambda x, *args, **kwargs: x 71 | torch.nn.init.orthogonal_ = lambda x, *args, **kwargs: x 72 | -------------------------------------------------------------------------------- /qllm/utils/datautils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import logging 4 | from typing import List, Union 5 | from datasets import load_dataset 6 | import random 7 | 8 | 9 | def get_wikitext2(nsamples, seed, seqlen, tokenizer): 10 | random.seed(seed) 11 | 12 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 13 | #testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 14 | 15 | traindata = traindata.select(indices=random.sample(range(len(traindata)), 16 | min(len(traindata), nsamples*4)), keep_in_memory=True,) 17 | traindata_text = [text for text in traindata['text'] if text != ''] 18 | 19 | trainenc = tokenizer("\n\n".join(traindata_text), return_tensors='pt') 20 | #testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 21 | testenc = None 22 | 23 | trainloader = [] 24 | for _ in range(nsamples): 25 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 26 | j = i + seqlen 27 | inp = trainenc.input_ids[:, i:j] 28 | tar = inp.clone() 29 | tar[:, :-1] = -100 30 | trainloader.append((inp, tar)) 31 | return trainloader, testenc 32 | 33 | 34 | def get_ptb(nsamples, seed, seqlen, tokenizer): 35 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') 36 | valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation') 37 | 38 | trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt') 39 | testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt') 40 | 41 | import random 42 | random.seed(seed) 43 | trainloader = [] 44 | for _ in range(nsamples): 45 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 46 | j = i + seqlen 47 | inp = trainenc.input_ids[:, i:j] 48 | tar = inp.clone() 49 | tar[:, :-1] = -100 50 | trainloader.append((inp, tar)) 51 | return trainloader, testenc 52 | 53 | 54 | def get_c4(nsamples, seed, seqlen, tokenizer): 55 | traindata = load_dataset('allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train') 56 | #valdata = load_dataset('allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') 57 | 58 | random.seed(seed) 59 | trainloader = [] 60 | for _ in range(nsamples): 61 | while True: 62 | i = random.randint(0, len(traindata) - 1) 63 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 64 | if trainenc.input_ids.shape[1] > seqlen: 65 | break 66 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 67 | j = i + seqlen 68 | inp = trainenc.input_ids[:, i:j] 69 | tar = inp.clone() 70 | tar[:, :-1] = -100 71 | trainloader.append((inp, tar)) 72 | 73 | #import random 74 | #random.seed(0) 75 | #valenc = [] 76 | #for _ in range(256): 77 | # while True: 78 | # i = random.randint(0, len(valdata) - 1) 79 | # tmp = tokenizer(valdata[i]['text'], return_tensors='pt') 80 | # if tmp.input_ids.shape[1] >= seqlen: 81 | # break 82 | # i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) 83 | # j = i + seqlen 84 | # valenc.append(tmp.input_ids[:, i:j]) 85 | #valenc = torch.hstack(valenc) 86 | 87 | #class TokenizerWrapper: 88 | 89 | # def __init__(self, input_ids): 90 | # self.input_ids = input_ids 91 | 92 | #valenc = TokenizerWrapper(valenc) 93 | valenc = None 94 | 95 | return trainloader, valenc 96 | 97 | 98 | def get_ptb_new(nsamples, seed, seqlen, tokenizer): 99 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') 100 | testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') 101 | 102 | trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') 103 | testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') 104 | 105 | import random 106 | random.seed(seed) 107 | trainloader = [] 108 | for _ in range(nsamples): 109 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 110 | j = i + seqlen 111 | inp = trainenc.input_ids[:, i:j] 112 | tar = inp.clone() 113 | tar[:, :-1] = -100 114 | trainloader.append((inp, tar)) 115 | return trainloader, testenc 116 | 117 | 118 | def get_c4_new(nsamples, seed, seqlen, tokenizer): 119 | traindata = load_dataset( 120 | 'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train') 121 | valdata = load_dataset('allenai/c4', data_files={ 122 | 'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') 123 | 124 | import random 125 | random.seed(seed) 126 | trainloader = [] 127 | for _ in range(nsamples): 128 | while True: 129 | i = random.randint(0, len(traindata) - 1) 130 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 131 | if trainenc.input_ids.shape[1] >= seqlen: 132 | break 133 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 134 | j = i + seqlen 135 | inp = trainenc.input_ids[:, i:j] 136 | tar = inp.clone() 137 | tar[:, :-1] = -100 138 | trainloader.append((inp, tar)) 139 | 140 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') 141 | valenc = valenc.input_ids[:, :(256 * seqlen)] 142 | 143 | class TokenizerWrapper: 144 | 145 | def __init__(self, input_ids): 146 | self.input_ids = input_ids 147 | 148 | valenc = TokenizerWrapper(valenc) 149 | 150 | return trainloader, valenc 151 | 152 | 153 | def get_loaders(name, nsamples=128, seed=0, seqlen=2048, tokenizer=''): 154 | from transformers import AutoTokenizer 155 | if isinstance(tokenizer, str): 156 | try: 157 | tokenizer = AutoTokenizer.from_pretrained(tokenizer, fast=False) 158 | except: # noqa: E722 159 | tokenizer = AutoTokenizer.from_pretrained(tokenizer, use_fast=True) 160 | 161 | if 'wikitext2' in name: 162 | return get_wikitext2(nsamples, seed, seqlen, tokenizer) 163 | if 'ptb' in name: 164 | if 'new' in name: 165 | return get_ptb_new(nsamples, seed, seqlen, tokenizer) 166 | return get_ptb(nsamples, seed, seqlen, tokenizer) 167 | if 'c4' in name: 168 | if 'new' in name: 169 | return get_c4_new(nsamples, seed, seqlen, tokenizer) 170 | return get_c4(nsamples, seed, seqlen, tokenizer) 171 | 172 | return get_calib_dataset(data="pileval", n_samples=nsamples, block_size=seqlen, tokenizer=tokenizer) 173 | 174 | 175 | def get_calib_dataset(data: Union[str, List[str]] = "pileval", 176 | tokenizer=None, n_samples=512, block_size=512, 177 | split="train", text_column="text"): 178 | if isinstance(data, str): 179 | if data == "pileval": 180 | dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") 181 | else: 182 | dataset = load_dataset(data, split=split) 183 | 184 | dataset = dataset.shuffle(seed=42) 185 | 186 | elif isinstance(data, list): 187 | dataset = [{text_column: text} for text in data] 188 | else: 189 | raise NotImplementedError( 190 | "Either pass a string to a huggingface dataset or a list" 191 | "that is preprocessed with one sample of text per element.") 192 | 193 | samples = [] 194 | n_run = 0 195 | for data in dataset: 196 | line = data[text_column] 197 | line = line.strip() 198 | line_encoded = tokenizer.encode(line) 199 | if len(line_encoded) > 512: 200 | continue 201 | sample = torch.tensor([line_encoded]) 202 | if sample.numel() == 0: 203 | continue 204 | samples.append(sample) 205 | n_run += 1 206 | if n_run == n_samples: 207 | break 208 | # now concatenate all samples and split according to block size 209 | cat_samples = torch.cat(samples, dim=1) 210 | n_split = cat_samples.shape[1] // block_size 211 | logging.debug(f" * Split into {n_split} blocks") 212 | return [(cat_samples[:, i*block_size:(i+1)*block_size], None) for i in range(n_split)], None 213 | -------------------------------------------------------------------------------- /qllm/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import functools 3 | 4 | def run_once(func): 5 | result = [] 6 | @functools.wraps(func) 7 | def wrapper(*args, **kwargs): 8 | if not result: 9 | result.append(func(*args, **kwargs)) 10 | return result[0] 11 | return wrapper 12 | 13 | @run_once 14 | def get_logger(): 15 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 16 | logger = logging.getLogger('qllm') 17 | console = logging.StreamHandler() 18 | console.setFormatter(formatter) 19 | logger.addHandler(console) 20 | logger.setLevel(logging.INFO) 21 | return logger 22 | -------------------------------------------------------------------------------- /qllm/utils/onnx/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wejoncy/QLLM/df20c15920bfabfd0581d7fcccbe87e5c96cd5c7/qllm/utils/onnx/__init__.py -------------------------------------------------------------------------------- /qllm/utils/onnx/exporter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from onnxruntime.transformers import large_model_exporter 3 | from pathlib import Path 4 | 5 | from ..logger import get_logger 6 | from .merge_encoder_decoder import merge_decoders 7 | 8 | logger = get_logger() 9 | 10 | 11 | def export_onnx( 12 | model: torch.nn.Module, onnx_path_str: str, sample_inputs: tuple, with_past: bool = False, opset=16 13 | ) -> Path: 14 | # since onnxruntime 1.7 15 | logger.info("Exporting onnx model ...") 16 | sample_inputs_tp = list(sample_inputs) 17 | if sample_inputs_tp[1] is None: 18 | sample_inputs_tp[1] = torch.ones_like(sample_inputs_tp[0]) 19 | # FIXME: this is a workaround for the bug in onnxruntime 1.7 20 | move_to_device = ( 21 | large_model_exporter.move_to_appropriate_device 22 | if hasattr(large_model_exporter, "move_to_appropriate_device") 23 | else large_model_exporter.move_to_approprate_device 24 | ) 25 | model = move_to_device(model, sample_inputs_tp) 26 | 27 | sample_inputs = large_model_exporter.adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device) 28 | 29 | # input_keys would be usesful if the model has some special inputs 30 | input_keys, onnx_inputs, past_key_value = large_model_exporter.retrieve_onnx_inputs(model, sample_inputs, with_past) 31 | if "position_ids" in input_keys: 32 | onnx_inputs[input_keys.index("position_ids")] = torch.arange( 33 | 0, onnx_inputs[0].shape[1], dtype=torch.int64, device=onnx_inputs[0].device 34 | ).unsqueeze(0) 35 | onnx_io_tuple = large_model_exporter.fetch_onnx_inputs_outputs_name( 36 | model, onnx_inputs, input_keys, past_key_value, with_past, False 37 | ) 38 | 39 | onnx_model_name = "decoder.onnx" 40 | onnx_path: Path = Path(onnx_path_str).absolute() 41 | onnx_path_enc = onnx_path / onnx_model_name if onnx_path.suffix != ".onnx" else onnx_path 42 | onnx_path_enc.parent.mkdir(parents=True, exist_ok=True) 43 | 44 | large_model_exporter.do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path_enc, opset) 45 | if not with_past: 46 | return onnx_path_enc 47 | 48 | onnx_io_tuple = large_model_exporter.fetch_onnx_inputs_outputs_name( 49 | model, onnx_inputs, input_keys, past_key_value, with_past, True 50 | ) 51 | # workaround for attention_mask 52 | onnx_inputs[1] = onnx_inputs[1].long() 53 | 54 | onnx_model_name = "decoder_with_past.onnx" 55 | onnx_path_dec = onnx_path_enc.parent / onnx_model_name 56 | if "position_ids" in input_keys: 57 | onnx_inputs[input_keys.index("position_ids")] = torch.tensor( 58 | [onnx_inputs[0].shape[1]], dtype=torch.int64, device=onnx_inputs[0].device 59 | ).unsqueeze(0) 60 | large_model_exporter.do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path_dec, opset) 61 | 62 | onnx_path_one_for_all = onnx_path_enc.parent / "decoder_merged.onnx" 63 | merge_decoders(onnx_path_enc, onnx_path_dec, save_path=onnx_path_one_for_all) 64 | logger.info(f"model is exported to onnx and saved to {onnx_path_one_for_all}") 65 | return onnx_path_one_for_all 66 | 67 | 68 | def verify_correcness( 69 | model: torch.nn.Module, 70 | sample_inputs: tuple, 71 | onnx_model_path: str, 72 | with_past: bool, 73 | ): 74 | import onnxruntime 75 | import numpy as np 76 | 77 | ref = model(sample_inputs[0].cuda(), torch.ones(sample_inputs[0].shape, dtype=torch.int64).cuda()) 78 | 79 | mask = np.ones(sample_inputs[0].shape, dtype=np.int64) 80 | num_layers = model.config.num_hidden_layers 81 | session_options = onnxruntime.SessionOptions() 82 | # session_options.register_custom_ops_library(onnx_ops.__file__) 83 | # onnx_path_str = Path(onnx_model_path).parent.absolute() 84 | # session = onnxruntime.InferenceSession(f'{onnx_path_str}/model.onnx', providers=['CUDAExecutionProvider'], sess_options=session_options) 85 | session = onnxruntime.InferenceSession( 86 | onnx_model_path, providers=["CUDAExecutionProvider"], sess_options=session_options 87 | ) 88 | inputs = {"input_ids": sample_inputs[0].cpu().numpy(), "attention_mask": mask} 89 | if "position_ids" in [i.name for i in session.get_inputs()]: 90 | inputs["position_ids"] = np.arange(0, inputs["input_ids"].shape[1], dtype=np.int64).reshape(1, -1) 91 | if with_past: 92 | if "merge" in str(onnx_model_path): 93 | inputs["use_cache_branch"] = np.array([0], dtype=np.bool_) 94 | kv_cache_shape = list(ref.past_key_values[0][0].shape) 95 | kv_cache_shape[-2] = 0 96 | kv_cache_shape = tuple(kv_cache_shape) 97 | for i in range(num_layers): 98 | inputs[f"past_key_values.{i}.key"] = np.zeros(kv_cache_shape, dtype=np.float16) 99 | inputs[f"past_key_values.{i}.value"] = np.zeros(kv_cache_shape, dtype=np.float16) 100 | outputs = session.run(None, inputs) 101 | err_prefill = ref.logits.cpu().numpy() - outputs[0] 102 | err_decode = np.zeros_like(err_prefill) 103 | if with_past: 104 | # session = onnxruntime.InferenceSession(f'{onnx_path_str}/model_with_past.onnx', providers=['CUDAExecutionProvider'], sess_options=session_options) 105 | mask = np.concatenate([mask, np.array([[1]])], axis=1) 106 | inputs = {"input_ids": np.array([[3]]), "attention_mask": mask} 107 | if "position_ids" in [i.name for i in session.get_inputs()]: 108 | inputs["position_ids"] = np.array([sample_inputs[0].shape[1]], dtype=np.int64).reshape(1, -1) 109 | if "merge" in str(onnx_model_path): 110 | inputs["use_cache_branch"] = np.array([1], dtype=np.bool_) 111 | for i in range(num_layers): 112 | inputs[f"past_key_values.{i}.key"] = ref.past_key_values[i][0].cpu().numpy() 113 | inputs[f"past_key_values.{i}.value"] = ref.past_key_values[i][1].cpu().numpy() 114 | outputs = session.run(None, inputs) 115 | 116 | ref = model( 117 | torch.tensor([[3]], device="cuda"), torch.from_numpy(mask).cuda(), past_key_values=ref.past_key_values 118 | ) 119 | err_decode = ref.logits.cpu().numpy() - outputs[0] 120 | print( 121 | "max abs err_prefill:", 122 | np.abs(err_prefill).max(), 123 | "max abs err_decode:", 124 | np.abs(err_decode).max(), 125 | "correctness check is ", 126 | "" if np.abs(err_decode).max() < 1e-2 else "not", 127 | " passed", 128 | ) 129 | -------------------------------------------------------------------------------- /qllm_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Install qllm" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%pip install qllm\n", 17 | "%pip install fschat accelerate\n" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "# quantize qwen with gptq/hqq\n", 25 | "NOTE: awq will consume more memory" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "!python -m qllm --model Qwen/Qwen2.5-3B-Instruct --quant_method gptq --eval --save ./qwen2.5-3b-instruct_4bit\n", 35 | "# or python -m qllm --model Qwen/Qwen2.5-3B-Instruct --quant_method hqq --eval --save ./qwen2.5-3b-instruct_4bit" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "# API\n", 45 | "from qllm.plugin.conversation import loop_in_chat_completion\n", 46 | "from qllm.auto_model_quantization import AutoModelQuantization\n", 47 | "import transformers\n", 48 | "m='Qwen/Qwen2.5-3B-Instruct'\n", 49 | "quantizer = AutoModelQuantization()\n", 50 | "tokenizer = transformers.AutoTokenizer.from_pretrained(m, use_fast=True, trust_remote_code=True)\n", 51 | "\n", 52 | "qm = quantizer.api_quantize(m, quant_method='hqq')\n", 53 | "loop_in_chat_completion(tokenizer, qm)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "# Run Tinyllama-2 with qllm cli" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "!python -m qllm --load TheBloke/Tinyllama-2-1b-miniguanaco-GPTQ --use_plugin" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "# Run Llama-3-8B-Instruct with qllm API" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "from qllm.plugin.conversation import loop_in_chat_completion\n", 86 | "from qllm.auto_model_quantization import AutoModelQuantization\n", 87 | "import transformers\n", 88 | "\n", 89 | "quantizer = AutoModelQuantization()\n", 90 | "model=quantizer.from_pretrained(\"astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit\")\n", 91 | "tokenizer = transformers.AutoTokenizer.from_pretrained(\"astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit\", use_fast=True, trust_remote_code=True)\n", 92 | "\n", 93 | "loop_in_chat_completion(tokenizer, model)" 94 | ] 95 | } 96 | ], 97 | "metadata": { 98 | "accelerator": "GPU", 99 | "colab": { 100 | "gpuType": "T4", 101 | "include_colab_link": true, 102 | "provenance": [] 103 | }, 104 | "kernelspec": { 105 | "display_name": "base", 106 | "language": "python", 107 | "name": "python3" 108 | }, 109 | "language_info": { 110 | "codemirror_mode": { 111 | "name": "ipython", 112 | "version": 3 113 | }, 114 | "file_extension": ".py", 115 | "mimetype": "text/x-python", 116 | "name": "python", 117 | "nbconvert_exporter": "python", 118 | "pygments_lexer": "ipython3", 119 | "version": "3.11.5" 120 | } 121 | }, 122 | "nbformat": 4, 123 | "nbformat_minor": 2 124 | } 125 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | safetensors 2 | torch>=2.0.0 3 | datasets==2.18.0 4 | zstandard 5 | #sentencepiece 6 | transformers 7 | accelerate 8 | #triton==2.0.0 9 | texttable 10 | tqdm 11 | numpy 12 | onnx 13 | pyarrow 14 | pandas 15 | #protobuf==3.20.2 16 | --------------------------------------------------------------------------------