├── .github └── workflows │ ├── build.yaml │ ├── docs.yaml │ └── scripts │ └── github_create_release.js ├── .gitignore ├── LICENSE ├── README.md ├── awq ├── __init__.py ├── evaluation │ ├── __init__.py │ ├── eval_utils.py │ ├── humaneval_utils.py │ └── kl_divergence.py ├── models │ ├── __init__.py │ ├── _config.py │ ├── aquila.py │ ├── auto.py │ ├── baichuan.py │ ├── base.py │ ├── bloom.py │ ├── cohere.py │ ├── deepseek_v2.py │ ├── deepseek_v3.py │ ├── exaone.py │ ├── falcon.py │ ├── gemma.py │ ├── gemma2.py │ ├── gpt_bigcode.py │ ├── gpt_neox.py │ ├── gptj.py │ ├── internlm2.py │ ├── llama.py │ ├── llava.py │ ├── llava_next.py │ ├── minicpm.py │ ├── minicpm3.py │ ├── mistral.py │ ├── mixtral.py │ ├── mpt.py │ ├── opt.py │ ├── phi3.py │ ├── phi3_v.py │ ├── qwen.py │ ├── qwen2.py │ ├── qwen2_5_omni.py │ ├── qwen2_5_vl.py │ ├── qwen2vl.py │ ├── qwen3.py │ ├── qwen3_moe.py │ ├── stablelm.py │ ├── starcoder2.py │ └── yi.py ├── modules │ ├── __init__.py │ ├── act.py │ ├── fused │ │ ├── __init__.py │ │ ├── attn.py │ │ ├── block.py │ │ ├── cache.py │ │ ├── mlp.py │ │ ├── model.py │ │ ├── moe.py │ │ └── norm.py │ ├── linear │ │ ├── __init__.py │ │ ├── exllama.py │ │ ├── exllamav2.py │ │ ├── gemm.py │ │ ├── gemm_ipex.py │ │ ├── gemv.py │ │ ├── gemv_fast.py │ │ └── marlin.py │ └── triton │ │ ├── __init__.py │ │ └── gemm.py ├── quantize │ ├── __init__.py │ ├── quantizer.py │ └── scale.py └── utils │ ├── __init__.py │ ├── calib_data.py │ ├── fused_utils.py │ ├── module.py │ ├── packing_utils.py │ ├── parallel.py │ ├── quant_utils.py │ ├── qwen_vl_utils.py │ └── utils.py ├── docs ├── examples.md ├── index.md └── reference │ └── index.md ├── examples ├── README.md ├── benchmark.py ├── cli.py ├── eval.py ├── generate.py ├── quantize.py └── train.py ├── mkdocs.yml ├── scripts ├── download_wheels.sh └── runpod_quantize.py ├── setup.py └── tests ├── test_dequantization.py ├── test_ipex_cpu.py └── test_quantization.py /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: Build AutoAWQ Wheels with CUDA 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v*" 7 | 8 | jobs: 9 | release: 10 | # Retrieve tag and create release 11 | name: Create Release 12 | runs-on: ubuntu-latest 13 | outputs: 14 | upload_url: ${{ steps.create_release.outputs.upload_url }} 15 | steps: 16 | - name: Checkout 17 | uses: actions/checkout@v3 18 | 19 | - name: Extract branch info 20 | shell: bash 21 | run: | 22 | echo "release_tag=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV 23 | 24 | - name: Create Release 25 | id: create_release 26 | uses: "actions/github-script@v6" 27 | env: 28 | RELEASE_TAG: ${{ env.release_tag }} 29 | with: 30 | github-token: "${{ secrets.GITHUB_TOKEN }}" 31 | script: | 32 | const script = require('.github/workflows/scripts/github_create_release.js') 33 | await script(github, context, core) 34 | 35 | build_wheels: 36 | name: Build AWQ 37 | runs-on: ${{ matrix.os }} 38 | needs: release 39 | 40 | strategy: 41 | matrix: 42 | os: [ubuntu-latest] 43 | pyver: ["3.9"] 44 | defaults: 45 | run: 46 | shell: pwsh 47 | env: 48 | MIN_TORCH_VER: "2.2.0" 49 | 50 | steps: 51 | - name: Free Disk Space 52 | uses: jlumbroso/free-disk-space@v1.3.0 53 | if: runner.os == 'Linux' 54 | with: 55 | tool-cache: false 56 | android: true 57 | dotnet: true 58 | haskell: true 59 | large-packages: false 60 | docker-images: true 61 | swap-storage: false 62 | 63 | - uses: actions/checkout@v3 64 | 65 | - uses: actions/setup-python@v3 66 | with: 67 | python-version: ${{ matrix.pyver }} 68 | 69 | - name: Setup Conda 70 | uses: conda-incubator/setup-miniconda@v3 71 | with: 72 | activate-environment: "build" 73 | python-version: ${{ matrix.pyver }} 74 | add-pip-as-python-dependency: true 75 | auto-activate-base: false 76 | 77 | - name: Install Dependencies 78 | run: | 79 | # Install torch 80 | python -m pip install --upgrade --no-cache-dir torch==$env:MIN_TORCH_VER 81 | python -m pip install build setuptools wheel 82 | 83 | # Print version information 84 | python --version 85 | python -c "import torch; print('PyTorch:', torch.__version__)" 86 | 87 | - name: Build Wheel 88 | run: | 89 | python setup.py sdist 90 | 91 | - name: Upload Assets 92 | uses: shogo82148/actions-upload-release-asset@v1 93 | with: 94 | upload_url: ${{ needs.release.outputs.upload_url }} 95 | asset_path: ./dist/*.whl 96 | -------------------------------------------------------------------------------- /.github/workflows/docs.yaml: -------------------------------------------------------------------------------- 1 | name: Documentation 2 | on: 3 | push: 4 | branches: 5 | - main 6 | permissions: 7 | contents: write 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - name: Git Credentials 14 | run: | 15 | git config user.name github-actions[bot] 16 | git config user.email 41898282+github-actions[bot]@users.noreply.github.com 17 | - uses: actions/setup-python@v4 18 | with: 19 | python-version: 3.11 20 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV 21 | - uses: actions/cache@v3 22 | with: 23 | key: mkdocs-material-${{ env.cache_id }} 24 | path: .cache 25 | restore-keys: | 26 | mkdocs-material-docs 27 | - run: pip install mkdocstrings-python mkdocs-material griffe-typingdoc 28 | - run: mkdocs gh-deploy --force 29 | -------------------------------------------------------------------------------- /.github/workflows/scripts/github_create_release.js: -------------------------------------------------------------------------------- 1 | module.exports = async (github, context, core) => { 2 | try { 3 | const response = await github.rest.repos.createRelease({ 4 | draft: false, 5 | generate_release_notes: true, 6 | name: process.env.RELEASE_TAG, 7 | owner: context.repo.owner, 8 | prerelease: false, 9 | repo: context.repo.repo, 10 | tag_name: process.env.RELEASE_TAG, 11 | }); 12 | 13 | core.setOutput('upload_url', response.data.upload_url); 14 | } catch (error) { 15 | core.setFailed(error.message); 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | data/ 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | *.pyc 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | 166 | *.pt 167 | **/*.pt 168 | **/*.pyc 169 | *.json 170 | __pycache__ 171 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 MIT HAN Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /awq/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.simplefilter("default", DeprecationWarning) 4 | 5 | _FINAL_DEV_MESSAGE = """ 6 | I have left this message as the final dev message to help you transition. 7 | 8 | Important Notice: 9 | - AutoAWQ is officially deprecated and will no longer be maintained. 10 | - The last tested configuration used Torch 2.6.0 and Transformers 4.51.3. 11 | - If future versions of Transformers break AutoAWQ compatibility, please report the issue to the Transformers project. 12 | 13 | Alternative: 14 | - AutoAWQ has been adopted by the vLLM Project: https://github.com/vllm-project/llm-compressor 15 | 16 | For further inquiries, feel free to reach out: 17 | - X: https://x.com/casper_hansen_ 18 | - LinkedIn: https://www.linkedin.com/in/casper-hansen-804005170/ 19 | """ 20 | 21 | warnings.warn(_FINAL_DEV_MESSAGE, category=DeprecationWarning, stacklevel=1) 22 | 23 | __version__ = "0.2.9" 24 | from awq.models.auto import AutoAWQForCausalLM -------------------------------------------------------------------------------- /awq/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from awq.evaluation.eval_utils import ( 2 | evaluate_perplexity, 3 | eval_librispeech, 4 | eval_mmlu, 5 | ) 6 | from awq.evaluation.humaneval_utils import eval_humaneval 7 | from awq.evaluation.kl_divergence import eval_kl_divergence 8 | -------------------------------------------------------------------------------- /awq/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mpt import MptAWQForCausalLM 2 | from .llama import LlamaAWQForCausalLM 3 | from .opt import OptAWQForCausalLM 4 | from .falcon import FalconAWQForCausalLM 5 | from .exaone import ExaoneAWQForCausalLM 6 | from .bloom import BloomAWQForCausalLM 7 | from .gptj import GPTJAWQForCausalLM 8 | from .gpt_bigcode import GptBigCodeAWQForCausalLM 9 | from .mistral import MistralAWQForCausalLM 10 | from .gpt_neox import GPTNeoXAWQForCausalLM 11 | from .aquila import AquilaAWQForCausalLM 12 | from .yi import YiAWQForCausalLM 13 | from .qwen import QwenAWQForCausalLM 14 | from .baichuan import BaichuanAWQForCausalLM 15 | from .llava import LlavaAWQForCausalLM 16 | from .mixtral import MixtralAWQForCausalLM 17 | from .qwen2 import Qwen2AWQForCausalLM 18 | from .qwen3 import Qwen3AWQForCausalLM 19 | from .qwen3_moe import Qwen3MoeAWQForCausalLM 20 | from .gemma import GemmaAWQForCausalLM 21 | from .gemma2 import Gemma2AWQForCausalLM 22 | from .stablelm import StableLmAWQForCausalLM 23 | from .starcoder2 import Starcoder2AWQForCausalLM 24 | from .llava_next import LlavaNextAWQForCausalLM 25 | from .phi3 import Phi3AWQForCausalLM 26 | from .phi3_v import Phi3VAWQForCausalLM 27 | from .cohere import CohereAWQForCausalLM 28 | from .deepseek_v2 import DeepseekV2AWQForCausalLM 29 | from .deepseek_v3 import DeepseekV3AWQForCausalLM 30 | from .minicpm import MiniCPMAWQForCausalLM 31 | from .internlm2 import InternLM2AWQForCausalLM 32 | from .minicpm3 import MiniCPM3AWQForCausalLM 33 | from .qwen2vl import Qwen2VLAWQForCausalLM 34 | from .qwen2_5_vl import Qwen2_5_VLAWQForCausalLM 35 | from .qwen2_5_omni import Qwen2_5_OmniAWQForConditionalGeneration -------------------------------------------------------------------------------- /awq/models/_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import Dict, Optional, List 4 | from dataclasses import dataclass, field 5 | from transformers.utils.hub import PushToHubMixin, cached_file 6 | 7 | 8 | @dataclass 9 | class AwqConfig(PushToHubMixin): 10 | quant_method: str = field(default="awq") 11 | zero_point: bool = field(default=True) 12 | q_group_size: int = field(default=128) 13 | w_bit: int = field(default=4) 14 | version: str = field(default="gemm") 15 | config_file_name = "config.json" 16 | modules_to_not_convert: Optional[List] = None 17 | 18 | @classmethod 19 | def from_dict(cls, quant_config: Dict = {}): 20 | if not quant_config: 21 | quant_config = cls() 22 | else: 23 | quant_config = cls(**quant_config) 24 | quant_config.version = quant_config.version.lower() 25 | 26 | return quant_config 27 | 28 | @classmethod 29 | def from_pretrained(cls, save_dir: str, **kwargs): 30 | cache_dir = kwargs.pop("cache_dir", None) 31 | force_download = kwargs.pop("force_download", False) 32 | resume_download = kwargs.pop("resume_download", False) 33 | proxies = kwargs.pop("proxies", None) 34 | local_files_only = kwargs.pop("local_files_only", False) 35 | use_auth_token = kwargs.pop("use_auth_token", None) 36 | revision = kwargs.pop("revision", None) 37 | subfolder = kwargs.pop("subfolder", None) 38 | commit_hash = kwargs.pop("_commit_hash", None) 39 | 40 | if os.path.isdir(save_dir): # Local 41 | resolved_config_file = os.path.join(save_dir, cls.config_file_name) 42 | else: # Remote 43 | resolved_config_file = cached_file( 44 | save_dir, 45 | cls.config_file_name, 46 | cache_dir=cache_dir, 47 | force_download=force_download, 48 | resume_download=resume_download, 49 | proxies=proxies, 50 | use_auth_token=use_auth_token, 51 | revision=revision, 52 | local_files_only=local_files_only, 53 | subfolder=subfolder, 54 | _raise_exceptions_for_missing_entries=False, 55 | _raise_exceptions_for_connection_errors=False, 56 | _commit_hash=commit_hash, 57 | ) 58 | 59 | quant_config = None 60 | if os.path.exists(resolved_config_file): 61 | with open(resolved_config_file, "r", encoding="utf-8") as file: 62 | loaded_config = json.loads(file.read()) 63 | 64 | quant_config = loaded_config.get("quantization_config") 65 | 66 | if quant_config is not None: 67 | awq_config = cls.from_transformers_dict(cls, quant_config) 68 | quant_config = cls(**awq_config) 69 | 70 | if quant_config is None: 71 | quant_config = cls() 72 | 73 | return quant_config 74 | 75 | def to_dict(self): 76 | return { 77 | "zero_point": self.zero_point, 78 | "q_group_size": self.q_group_size, 79 | "w_bit": self.w_bit, 80 | "version": self.version, 81 | "modules_to_not_convert": self.modules_to_not_convert, 82 | } 83 | 84 | def to_transformers_dict(self): 85 | return { 86 | "quant_method": self.quant_method, 87 | "zero_point": self.zero_point, 88 | "group_size": self.q_group_size, 89 | "bits": self.w_bit, 90 | "version": self.version.lower(), 91 | "modules_to_not_convert": self.modules_to_not_convert, 92 | } 93 | 94 | def from_transformers_dict(self, transformers_dict: Dict): 95 | return { 96 | "quant_method": transformers_dict.get("quant_method"), 97 | "zero_point": transformers_dict.get("zero_point"), 98 | "q_group_size": transformers_dict.get("group_size"), 99 | "w_bit": transformers_dict.get("bits"), 100 | "version": transformers_dict.get("version"), 101 | "modules_to_not_convert": transformers_dict.get("modules_to_not_convert"), 102 | } 103 | -------------------------------------------------------------------------------- /awq/models/aquila.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from typing import List, Tuple 3 | from .base import BaseAWQForCausalLM 4 | from awq.utils.fused_utils import fuse_qkv 5 | from awq.modules.fused.block import LlamaLikeBlock 6 | from awq.modules.fused.model import LlamaLikeModel 7 | from transformers.models.llama.modeling_llama import ( 8 | LlamaDecoderLayer as OldAquilaDecoderLayer, 9 | LlamaForCausalLM as OldAquilaForCausalLM, 10 | ) 11 | from awq.modules.fused.norm import FasterTransformerRMSNorm 12 | 13 | 14 | class AquilaAWQForCausalLM(BaseAWQForCausalLM): 15 | layer_type = "AquilaDecoderLayer" 16 | max_seq_len_key = "max_position_embeddings" 17 | 18 | @staticmethod 19 | def fuse_layers(model: OldAquilaForCausalLM): 20 | fuser = AquilaFuser(model) 21 | fuser.fuse_transformer() 22 | 23 | @staticmethod 24 | def get_model_layers(model: OldAquilaForCausalLM): 25 | return model.model.layers 26 | 27 | @staticmethod 28 | def get_act_for_scaling(module: OldAquilaDecoderLayer): 29 | return dict(is_scalable=False) 30 | 31 | @staticmethod 32 | def move_embed(model: OldAquilaForCausalLM, device: str): 33 | model.model.embed_tokens = model.model.embed_tokens.to(device) 34 | model.model.rotary_emb = model.model.rotary_emb.to(device) 35 | 36 | @staticmethod 37 | def get_layers_for_scaling( 38 | module: OldAquilaDecoderLayer, input_feat, module_kwargs 39 | ): 40 | layers = [] 41 | 42 | # attention input 43 | layers.append( 44 | dict( 45 | prev_op=module.input_layernorm, 46 | layers=[ 47 | module.self_attn.q_proj, 48 | module.self_attn.k_proj, 49 | module.self_attn.v_proj, 50 | ], 51 | inp=input_feat["self_attn.q_proj"], 52 | module2inspect=module.self_attn, 53 | kwargs=module_kwargs, 54 | ) 55 | ) 56 | 57 | # attention out 58 | # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 59 | if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: 60 | layers.append( 61 | dict( 62 | prev_op=module.self_attn.v_proj, 63 | layers=[module.self_attn.o_proj], 64 | inp=input_feat["self_attn.o_proj"], 65 | ) 66 | ) 67 | 68 | # linear 1 69 | layers.append( 70 | dict( 71 | prev_op=module.post_attention_layernorm, 72 | layers=[module.mlp.gate_proj, module.mlp.up_proj], 73 | inp=input_feat["mlp.gate_proj"], 74 | module2inspect=module.mlp, 75 | ) 76 | ) 77 | 78 | # linear 2 79 | layers.append( 80 | dict( 81 | prev_op=module.mlp.up_proj, 82 | layers=[module.mlp.down_proj], 83 | inp=input_feat["mlp.down_proj"], 84 | ) 85 | ) 86 | 87 | return layers 88 | 89 | 90 | class AquilaFuser: 91 | def __init__(self, model: OldAquilaForCausalLM): 92 | self.model = model 93 | 94 | self.aquila_blocks: List[Tuple[str, OldAquilaDecoderLayer]] = [ 95 | (name, module) 96 | for name, module in self.model.named_modules() 97 | if "AquilaDecoderLayer".lower() in module.__class__.__name__.lower() 98 | ] 99 | 100 | def fuse_transformer(self): 101 | blocks = [] 102 | 103 | module: OldAquilaDecoderLayer 104 | for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): 105 | device = next(iter(module.state_dict().values())).device 106 | qkv = fuse_qkv( 107 | module, 108 | module.self_attn.q_proj, 109 | module.self_attn.k_proj, 110 | module.self_attn.v_proj, 111 | ) 112 | norm_1 = FasterTransformerRMSNorm( 113 | module.input_layernorm.weight, module.input_layernorm.variance_epsilon 114 | ) 115 | norm_2 = FasterTransformerRMSNorm( 116 | module.post_attention_layernorm.weight, 117 | module.post_attention_layernorm.variance_epsilon, 118 | ) 119 | blocks.append( 120 | LlamaLikeBlock( 121 | hidden_size=self.model.config.hidden_size, 122 | n_heads=self.model.config.num_attention_heads, 123 | n_kv_heads=self.model.config.num_key_value_heads, 124 | qkv_layer=qkv, 125 | o_proj=module.self_attn.o_proj, 126 | mlp=module.mlp, 127 | norm_1=norm_1, 128 | norm_2=norm_2, 129 | dev=device, 130 | max_seq_len=self.model.config.max_seq_len, 131 | rope_theta=self.model.config.rope_theta, 132 | ) 133 | ) 134 | 135 | self.model.model = LlamaLikeModel( 136 | self.model.config.vocab_size, 137 | blocks, 138 | self.model.model.embed_tokens, 139 | self.model.model.norm, 140 | ) 141 | 142 | setattr(self.model.model, "blocks", self.model.model.blocks) 143 | -------------------------------------------------------------------------------- /awq/models/auto.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | from transformers import AutoConfig 5 | from awq.models import * 6 | from awq.models.base import BaseAWQForCausalLM 7 | 8 | 9 | AWQ_CAUSAL_LM_MODEL_MAP = { 10 | "mpt": MptAWQForCausalLM, 11 | "llama": LlamaAWQForCausalLM, 12 | "opt": OptAWQForCausalLM, 13 | "RefinedWeb": FalconAWQForCausalLM, 14 | "RefinedWebModel": FalconAWQForCausalLM, 15 | "exaone": ExaoneAWQForCausalLM, 16 | "falcon": FalconAWQForCausalLM, 17 | "bloom": BloomAWQForCausalLM, 18 | "gptj": GPTJAWQForCausalLM, 19 | "gpt_bigcode": GptBigCodeAWQForCausalLM, 20 | "mistral": MistralAWQForCausalLM, 21 | "mixtral": MixtralAWQForCausalLM, 22 | "gpt_neox": GPTNeoXAWQForCausalLM, 23 | "aquila": AquilaAWQForCausalLM, 24 | "Yi": YiAWQForCausalLM, 25 | "qwen": QwenAWQForCausalLM, 26 | "baichuan": BaichuanAWQForCausalLM, 27 | "llava": LlavaAWQForCausalLM, 28 | "qwen2": Qwen2AWQForCausalLM, 29 | "qwen3": Qwen3AWQForCausalLM, 30 | "qwen3_moe": Qwen3MoeAWQForCausalLM, 31 | "gemma": GemmaAWQForCausalLM, 32 | "gemma2": Gemma2AWQForCausalLM, 33 | "stablelm": StableLmAWQForCausalLM, 34 | "starcoder2": Starcoder2AWQForCausalLM, 35 | "llava_next": LlavaNextAWQForCausalLM, 36 | "phi3": Phi3AWQForCausalLM, 37 | "phi3_v": Phi3VAWQForCausalLM, 38 | "cohere": CohereAWQForCausalLM, 39 | "deepseek_v2": DeepseekV2AWQForCausalLM, 40 | "deepseek_v3": DeepseekV3AWQForCausalLM, 41 | "minicpm": MiniCPMAWQForCausalLM, 42 | "internlm2": InternLM2AWQForCausalLM, 43 | "minicpm3": MiniCPM3AWQForCausalLM, 44 | "qwen2_vl": Qwen2VLAWQForCausalLM, 45 | "qwen2_5_vl": Qwen2_5_VLAWQForCausalLM, 46 | "qwen2_5_omni": Qwen2_5_OmniAWQForConditionalGeneration 47 | } 48 | 49 | 50 | def check_and_get_model_type(model_dir, trust_remote_code=True, **model_init_kwargs): 51 | config = AutoConfig.from_pretrained( 52 | model_dir, trust_remote_code=trust_remote_code, **model_init_kwargs 53 | ) 54 | if config.model_type not in AWQ_CAUSAL_LM_MODEL_MAP.keys(): 55 | raise TypeError(f"{config.model_type} isn't supported yet.") 56 | model_type = config.model_type 57 | return model_type 58 | 59 | 60 | class AutoAWQForCausalLM: 61 | def __init__(self): 62 | raise EnvironmentError( 63 | "You must instantiate AutoAWQForCausalLM with\n" 64 | "AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained" 65 | ) 66 | 67 | @classmethod 68 | def from_pretrained( 69 | self, 70 | model_path, 71 | torch_dtype="auto", 72 | trust_remote_code=True, 73 | safetensors=True, 74 | device_map=None, 75 | download_kwargs=None, 76 | low_cpu_mem_usage=True, 77 | use_cache=False, 78 | **model_init_kwargs, 79 | ) -> BaseAWQForCausalLM: 80 | model_type = check_and_get_model_type( 81 | model_path, trust_remote_code, **model_init_kwargs 82 | ) 83 | return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained( 84 | model_path, 85 | model_type, 86 | torch_dtype=torch_dtype, 87 | trust_remote_code=trust_remote_code, 88 | safetensors=safetensors, 89 | device_map=device_map, 90 | download_kwargs=download_kwargs, 91 | low_cpu_mem_usage=low_cpu_mem_usage, 92 | use_cache=use_cache, 93 | **model_init_kwargs, 94 | ) 95 | 96 | @classmethod 97 | def from_quantized( 98 | self, 99 | quant_path, 100 | quant_filename="", 101 | max_seq_len=2048, 102 | trust_remote_code=True, 103 | fuse_layers=True, 104 | use_exllama=False, 105 | use_exllama_v2=False, 106 | use_ipex=False, 107 | batch_size=1, 108 | safetensors=True, 109 | device_map="balanced", 110 | max_memory=None, 111 | offload_folder=None, 112 | download_kwargs=None, 113 | **config_kwargs, 114 | ) -> BaseAWQForCausalLM: 115 | os.environ["AWQ_BATCH_SIZE"] = str(batch_size) 116 | model_type = check_and_get_model_type(quant_path, trust_remote_code) 117 | 118 | if config_kwargs.get("max_new_tokens") is not None: 119 | max_seq_len = config_kwargs["max_new_tokens"] 120 | logging.warning( 121 | "max_new_tokens argument is deprecated... gracefully " 122 | "setting max_seq_len=max_new_tokens." 123 | ) 124 | 125 | return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( 126 | quant_path, 127 | model_type, 128 | quant_filename, 129 | max_seq_len, 130 | trust_remote_code=trust_remote_code, 131 | fuse_layers=fuse_layers, 132 | use_exllama=use_exllama, 133 | use_exllama_v2=use_exllama_v2, 134 | use_ipex=use_ipex, 135 | safetensors=safetensors, 136 | device_map=device_map, 137 | max_memory=max_memory, 138 | offload_folder=offload_folder, 139 | download_kwargs=download_kwargs, 140 | **config_kwargs, 141 | ) 142 | -------------------------------------------------------------------------------- /awq/models/baichuan.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from typing import List, Tuple 3 | from .base import BaseAWQForCausalLM 4 | from awq.modules.fused.block import LlamaLikeBlock 5 | from awq.modules.fused.model import LlamaLikeModel 6 | from transformers.models.llama.modeling_llama import ( 7 | LlamaDecoderLayer as OldLlamaDecoderLayer, 8 | ) 9 | from awq.modules.fused.norm import FasterTransformerRMSNorm 10 | 11 | 12 | class BaichuanAWQForCausalLM(BaseAWQForCausalLM): 13 | layer_type = "BaichuanLayer" 14 | max_seq_len_key = "model_max_length" 15 | 16 | @staticmethod 17 | def fuse_layers(model): 18 | fuser = BaichuanFuser(model) 19 | fuser.fuse_transformer() 20 | 21 | @staticmethod 22 | def get_model_layers(model): 23 | return model.model.layers 24 | 25 | @staticmethod 26 | def get_act_for_scaling(module): 27 | return dict(is_scalable=False) 28 | 29 | @staticmethod 30 | def move_embed(model, device: str): 31 | model.model.embed_tokens = model.model.embed_tokens.to(device) 32 | model.model.rotary_emb = model.model.rotary_emb.to(device) 33 | 34 | @staticmethod 35 | # def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs): 36 | def get_layers_for_scaling(module, input_feat, module_kwargs): 37 | layers = [] 38 | 39 | # attention input 40 | layers.append( 41 | dict( 42 | prev_op=module.input_layernorm, 43 | layers=[module.self_attn.W_pack], 44 | inp=input_feat["self_attn.W_pack"], 45 | module2inspect=module.self_attn, 46 | kwargs=module_kwargs, 47 | ) 48 | ) 49 | 50 | # # attention out 51 | # # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 52 | # if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: 53 | # layers.append(dict( 54 | # prev_op=module.self_attn.v_proj, 55 | # layers=[module.self_attn.o_proj], 56 | # inp=input_feat['self_attn.o_proj'], 57 | # )) 58 | 59 | # attention out 60 | # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 61 | layers.append( 62 | dict( 63 | prev_op=module.self_attn.W_pack, 64 | layers=[module.self_attn.o_proj], 65 | inp=input_feat["self_attn.o_proj"], 66 | ) 67 | ) 68 | 69 | # linear 1 70 | layers.append( 71 | dict( 72 | prev_op=module.post_attention_layernorm, 73 | layers=[module.mlp.gate_proj, module.mlp.up_proj], 74 | inp=input_feat["mlp.gate_proj"], 75 | module2inspect=module.mlp, 76 | ) 77 | ) 78 | 79 | # linear 2 80 | layers.append( 81 | dict( 82 | prev_op=module.mlp.up_proj, 83 | layers=[module.mlp.down_proj], 84 | inp=input_feat["mlp.down_proj"], 85 | ) 86 | ) 87 | 88 | return layers 89 | 90 | 91 | class BaichuanFuser: 92 | def __init__(self, model): 93 | self.model = model 94 | 95 | self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [ 96 | (name, module) 97 | for name, module in self.model.named_modules() 98 | if "LlamaDecoderLayer".lower() in module.__class__.__name__.lower() 99 | ] 100 | 101 | def fuse_transformer(self): 102 | blocks = [] 103 | 104 | for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): 105 | device = next(iter(module.state_dict().values())).device 106 | # qkv = fuse_qkv( 107 | # module, 108 | # module.self_attn.q_proj, 109 | # module.self_attn.k_proj, 110 | # module.self_attn.v_proj 111 | # ) 112 | qkv = module.self_attn.W_pack 113 | norm_1 = FasterTransformerRMSNorm( 114 | module.input_layernorm.weight, module.input_layernorm.epsilon 115 | ) 116 | norm_2 = FasterTransformerRMSNorm( 117 | module.post_attention_layernorm.weight, 118 | module.post_attention_layernorm.epsilon, 119 | ) 120 | blocks.append( 121 | LlamaLikeBlock( 122 | hidden_size=self.model.config.hidden_size, 123 | n_heads=self.model.config.num_attention_heads, 124 | n_kv_heads=self.model.config.num_attention_heads, 125 | qkv_layer=qkv, 126 | o_proj=module.self_attn.o_proj, 127 | mlp=module.mlp, 128 | norm_1=norm_1, 129 | norm_2=norm_2, 130 | dev=device, 131 | max_seq_len=self.model.config.max_seq_len, 132 | use_alibi=True, 133 | ) 134 | ) 135 | 136 | self.model.model = LlamaLikeModel( 137 | self.model.config.vocab_size, 138 | blocks, 139 | self.model.model.embed_tokens, 140 | self.model.model.norm, 141 | ) 142 | 143 | setattr(self.model.model, "blocks", self.model.model.blocks) 144 | -------------------------------------------------------------------------------- /awq/models/bloom.py: -------------------------------------------------------------------------------- 1 | from .base import BaseAWQForCausalLM 2 | from transformers.models.bloom.modeling_bloom import BloomForCausalLM, BloomBlock 3 | 4 | 5 | class BloomAWQForCausalLM(BaseAWQForCausalLM): 6 | layer_type = "BloomBlock" 7 | 8 | @staticmethod 9 | def get_model_layers(model: BloomForCausalLM): 10 | return model.transformer.h 11 | 12 | @staticmethod 13 | def get_act_for_scaling(module: BloomBlock): 14 | return dict( 15 | is_scalable=True, 16 | scale_name="mlp.gelu_impl", 17 | scale_layer=module.mlp.gelu_impl, 18 | scale_shape=module.mlp.dense_h_to_4h.out_features, 19 | ) 20 | 21 | @staticmethod 22 | def move_embed(model: BloomForCausalLM, device: str): 23 | model.transformer.word_embeddings = model.transformer.word_embeddings.to(device) 24 | model.transformer.word_embeddings_layernorm = ( 25 | model.transformer.word_embeddings_layernorm.to(device) 26 | ) 27 | 28 | @staticmethod 29 | def get_layers_for_scaling(module: BloomBlock, input_feat, module_kwargs): 30 | layers = [] 31 | 32 | # attention input 33 | layers.append( 34 | dict( 35 | prev_op=module.input_layernorm, 36 | layers=[module.self_attention.query_key_value], 37 | inp=input_feat["self_attention.query_key_value"], 38 | module2inspect=module, 39 | kwargs=module_kwargs, 40 | ) 41 | ) 42 | # attention out 43 | # Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469 44 | """ 45 | scales_list.append(_auto_get_scale( 46 | prev_op=module.self_attention.query_key_value, 47 | layers=[module.self_attention.dense], 48 | inp=input_feat['self_attention.dense'], 49 | )) 50 | """ 51 | # linear 1 52 | layers.append( 53 | dict( 54 | prev_op=module.post_attention_layernorm, 55 | layers=[module.mlp.dense_h_to_4h], 56 | inp=input_feat["mlp.dense_h_to_4h"], 57 | module2inspect=module, 58 | kwargs=module_kwargs, 59 | ) 60 | ) 61 | # linear 2 62 | layers.append( 63 | dict( 64 | prev_op=module.mlp.gelu_impl, 65 | layers=[module.mlp.dense_4h_to_h], 66 | inp=input_feat["mlp.dense_4h_to_h"], 67 | ) 68 | ) 69 | 70 | return layers 71 | -------------------------------------------------------------------------------- /awq/models/cohere.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from typing import List, Tuple 3 | from .base import BaseAWQForCausalLM 4 | from awq.utils.fused_utils import fuse_qkv 5 | from awq.modules.fused.block import CohereBlock 6 | from awq.modules.fused.model import CohereModel 7 | from transformers.models.cohere.modeling_cohere import ( 8 | CohereDecoderLayer as OldCohereDecoderLayer, 9 | CohereForCausalLM as OldCohereForCausalLM, 10 | ) 11 | from awq.modules.fused.norm import FasterTransformerRMSNorm 12 | 13 | class CohereAWQForCausalLM(BaseAWQForCausalLM): 14 | layer_type = "CohereDecoderLayer" 15 | max_seq_len_key = "max_position_embeddings" 16 | 17 | @staticmethod 18 | def fuse_layers(model: OldCohereForCausalLM): 19 | fuser = CohereFuser(model) 20 | fuser.fuse_transformer() 21 | 22 | @staticmethod 23 | def get_model_layers(model: OldCohereForCausalLM): 24 | return model.model.layers 25 | 26 | @staticmethod 27 | def get_act_for_scaling(module: OldCohereDecoderLayer): 28 | return dict(is_scalable=False) 29 | 30 | @staticmethod 31 | def move_embed(model: OldCohereForCausalLM, device: str): 32 | model.model.embed_tokens = model.model.embed_tokens.to(device) 33 | model.model.rotary_emb = model.model.rotary_emb.to(device) 34 | 35 | @staticmethod 36 | def get_layers_for_scaling( 37 | module: OldCohereDecoderLayer, input_feat, module_kwargs 38 | ): 39 | layers = [] 40 | 41 | # input 42 | layers.append( 43 | dict( 44 | prev_op=module.input_layernorm, 45 | layers=[ 46 | module.self_attn.q_proj, 47 | module.self_attn.k_proj, 48 | module.self_attn.v_proj, 49 | module.mlp.gate_proj, 50 | module.mlp.up_proj, 51 | ], 52 | inp=input_feat["self_attn.q_proj"], 53 | module2inspect=module, 54 | kwargs=module_kwargs, 55 | ) 56 | ) 57 | 58 | # attention out 59 | # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 60 | if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: 61 | layers.append( 62 | dict( 63 | prev_op=module.self_attn.v_proj, 64 | layers=[module.self_attn.o_proj], 65 | inp=input_feat["self_attn.o_proj"], 66 | ) 67 | ) 68 | 69 | # linear out 70 | layers.append( 71 | dict( 72 | prev_op=module.mlp.up_proj, 73 | layers=[module.mlp.down_proj], 74 | inp=input_feat["mlp.down_proj"], 75 | ) 76 | ) 77 | 78 | return layers 79 | 80 | class CohereFuser: 81 | def __init__(self, model: OldCohereForCausalLM): 82 | self.model = model 83 | 84 | self.cohere_blocks: List[Tuple[str, OldCohereDecoderLayer]] = [ 85 | (name, module) 86 | for name, module in self.model.named_modules() 87 | if "CohereDecoderLayer".lower() in module.__class__.__name__.lower() 88 | ] 89 | 90 | def fuse_transformer(self): 91 | blocks = [] 92 | 93 | module: OldCohereDecoderLayer 94 | for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): 95 | device = next(iter(module.state_dict().values())).device 96 | qkv = fuse_qkv( 97 | module, 98 | module.self_attn.q_proj, 99 | module.self_attn.k_proj, 100 | module.self_attn.v_proj, 101 | ) 102 | norm_1 = module.input_layernorm 103 | # norm_2 = FasterTransformerRMSNorm( 104 | # module.post_attention_layernorm.weight, 105 | # module.post_attention_layernorm.variance_epsilon, 106 | # ) 107 | blocks.append( 108 | CohereBlock( 109 | hidden_size=self.model.config.hidden_size, 110 | n_heads=self.model.config.num_attention_heads, 111 | n_kv_heads=self.model.config.num_key_value_heads, 112 | qkv_layer=qkv, 113 | o_proj=module.self_attn.o_proj, 114 | mlp=module.mlp, 115 | norm_1=norm_1, 116 | # norm_2=norm_2, 117 | dev=device, 118 | max_seq_len=self.model.config.max_seq_len, 119 | rope_theta=self.model.config.rope_theta, 120 | ) 121 | ) 122 | 123 | self.model.model = CohereModel( 124 | self.model.config.vocab_size, 125 | blocks, 126 | self.model.model.embed_tokens, 127 | self.model.model.norm, 128 | ) 129 | setattr(self.model.model, "blocks", self.model.model.blocks) 130 | -------------------------------------------------------------------------------- /awq/models/deepseek_v2.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from typing import List, Tuple 3 | from .base import BaseAWQForCausalLM 4 | 5 | 6 | class DeepseekV2AWQForCausalLM(BaseAWQForCausalLM): 7 | layer_type = "DeepseekV2DecoderLayer" 8 | max_seq_len_key = "max_position_embeddings" 9 | 10 | @staticmethod 11 | def get_model_layers(model): 12 | return model.model.layers 13 | 14 | @staticmethod 15 | def get_act_for_scaling(module): 16 | return dict(is_scalable=False) 17 | 18 | @staticmethod 19 | def move_embed(model, device: str): 20 | model.model.embed_tokens = model.model.embed_tokens.to(device) 21 | 22 | @staticmethod 23 | def get_layers_for_scaling( 24 | module, input_feat, module_kwargs 25 | ): 26 | layers = [] 27 | 28 | if hasattr(module.self_attn, "q_proj"): 29 | # attention input 30 | layers.append( 31 | dict( 32 | prev_op=module.input_layernorm, 33 | layers=[ 34 | module.self_attn.q_proj, 35 | module.self_attn.kv_a_proj_with_mqa, 36 | ], 37 | inp=input_feat["self_attn.q_proj"], 38 | module2inspect=module.self_attn, 39 | kwargs=module_kwargs, 40 | ) 41 | ) 42 | else: 43 | # attention input 44 | layers.append( 45 | dict( 46 | prev_op=module.input_layernorm, 47 | layers=[ 48 | module.self_attn.q_a_proj, 49 | module.self_attn.kv_a_proj_with_mqa, 50 | ], 51 | inp=input_feat["self_attn.q_a_proj"], 52 | module2inspect=module.self_attn, 53 | kwargs=module_kwargs, 54 | ) 55 | ) 56 | layers.append( 57 | dict( 58 | prev_op=module.self_attn.q_a_layernorm, 59 | layers=[ 60 | module.self_attn.q_b_proj, 61 | ], 62 | inp=input_feat["self_attn.q_b_proj"], 63 | ) 64 | ) 65 | 66 | # kv layernorm 67 | layers.append( 68 | dict( 69 | prev_op=module.self_attn.kv_a_layernorm, 70 | layers=[ 71 | module.self_attn.kv_b_proj, 72 | ], 73 | inp=input_feat["self_attn.kv_b_proj"], 74 | ) 75 | ) 76 | 77 | if hasattr(module.mlp, "gate"): 78 | # linear in 79 | layers.append( 80 | dict( 81 | prev_op=module.post_attention_layernorm, 82 | layers=[ 83 | w 84 | for expert in module.mlp.experts 85 | for w in [expert.gate_proj, expert.up_proj] 86 | ] + [module.mlp.shared_experts.gate_proj, module.mlp.shared_experts.up_proj], 87 | inp=input_feat["mlp"], 88 | module2inspect=module.mlp, 89 | ) 90 | ) 91 | 92 | # linear out 93 | for i, expert in enumerate(module.mlp.experts): 94 | layers.append( 95 | dict( 96 | prev_op=expert.up_proj, 97 | layers=[expert.down_proj], 98 | inp=input_feat[f"mlp.experts.{i}.down_proj"], 99 | ) 100 | ) 101 | layers.append( 102 | dict( 103 | prev_op=module.mlp.shared_experts.up_proj, 104 | layers=[module.mlp.shared_experts.down_proj], 105 | inp=input_feat[f"mlp.shared_experts.down_proj"], 106 | ) 107 | ) 108 | else: 109 | # linear 1 110 | layers.append( 111 | dict( 112 | prev_op=module.post_attention_layernorm, 113 | layers=[module.mlp.gate_proj, module.mlp.up_proj], 114 | inp=input_feat["mlp.gate_proj"], 115 | module2inspect=module.mlp, 116 | ) 117 | ) 118 | 119 | # linear 2 120 | layers.append( 121 | dict( 122 | prev_op=module.mlp.up_proj, 123 | layers=[module.mlp.down_proj], 124 | inp=input_feat["mlp.down_proj"], 125 | ) 126 | ) 127 | 128 | return layers 129 | -------------------------------------------------------------------------------- /awq/models/deepseek_v3.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from typing import List, Tuple 3 | from .base import BaseAWQForCausalLM 4 | 5 | 6 | class DeepseekV3AWQForCausalLM(BaseAWQForCausalLM): 7 | layer_type = "DeepseekV3DecoderLayer" 8 | max_seq_len_key = "max_position_embeddings" 9 | 10 | @staticmethod 11 | def get_model_layers(model): 12 | return model.model.layers 13 | 14 | @staticmethod 15 | def get_act_for_scaling(module): 16 | return dict(is_scalable=False) 17 | 18 | @staticmethod 19 | def move_embed(model, device: str): 20 | model.model.embed_tokens = model.model.embed_tokens.to(device) 21 | 22 | @staticmethod 23 | def get_layers_for_scaling( 24 | module, input_feat, module_kwargs 25 | ): 26 | layers = [] 27 | 28 | if hasattr(module.self_attn, "q_proj"): 29 | # attention input 30 | layers.append( 31 | dict( 32 | prev_op=module.input_layernorm, 33 | layers=[ 34 | module.self_attn.q_proj, 35 | module.self_attn.kv_a_proj_with_mqa, 36 | ], 37 | inp=input_feat["self_attn.q_proj"], 38 | module2inspect=module.self_attn, 39 | kwargs=module_kwargs, 40 | ) 41 | ) 42 | else: 43 | # attention input 44 | layers.append( 45 | dict( 46 | prev_op=module.input_layernorm, 47 | layers=[ 48 | module.self_attn.q_a_proj, 49 | module.self_attn.kv_a_proj_with_mqa, 50 | ], 51 | inp=input_feat["self_attn.q_a_proj"], 52 | module2inspect=module.self_attn, 53 | kwargs=module_kwargs, 54 | ) 55 | ) 56 | layers.append( 57 | dict( 58 | prev_op=module.self_attn.q_a_layernorm, 59 | layers=[ 60 | module.self_attn.q_b_proj, 61 | ], 62 | inp=input_feat["self_attn.q_b_proj"], 63 | ) 64 | ) 65 | 66 | # kv layernorm 67 | layers.append( 68 | dict( 69 | prev_op=module.self_attn.kv_a_layernorm, 70 | layers=[ 71 | module.self_attn.kv_b_proj, 72 | ], 73 | inp=input_feat["self_attn.kv_b_proj"], 74 | ) 75 | ) 76 | 77 | if hasattr(module.mlp, "gate"): 78 | # linear in 79 | layers.append( 80 | dict( 81 | prev_op=module.post_attention_layernorm, 82 | layers=[ 83 | w 84 | for expert in module.mlp.experts 85 | for w in [expert.gate_proj, expert.up_proj] 86 | ] + [module.mlp.shared_experts.gate_proj, module.mlp.shared_experts.up_proj], 87 | inp=input_feat["mlp"], 88 | module2inspect=module.mlp, 89 | ) 90 | ) 91 | 92 | # linear out 93 | for i, expert in enumerate(module.mlp.experts): 94 | layers.append( 95 | dict( 96 | prev_op=expert.up_proj, 97 | layers=[expert.down_proj], 98 | inp=input_feat[f"mlp.experts.{i}.down_proj"], 99 | ) 100 | ) 101 | layers.append( 102 | dict( 103 | prev_op=module.mlp.shared_experts.up_proj, 104 | layers=[module.mlp.shared_experts.down_proj], 105 | inp=input_feat[f"mlp.shared_experts.down_proj"], 106 | ) 107 | ) 108 | else: 109 | # linear 1 110 | layers.append( 111 | dict( 112 | prev_op=module.post_attention_layernorm, 113 | layers=[module.mlp.gate_proj, module.mlp.up_proj], 114 | inp=input_feat["mlp.gate_proj"], 115 | module2inspect=module.mlp, 116 | ) 117 | ) 118 | 119 | # linear 2 120 | layers.append( 121 | dict( 122 | prev_op=module.mlp.up_proj, 123 | layers=[module.mlp.down_proj], 124 | inp=input_feat["mlp.down_proj"], 125 | ) 126 | ) 127 | 128 | return layers 129 | -------------------------------------------------------------------------------- /awq/models/exaone.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from typing import List, Tuple 3 | from .base import BaseAWQForCausalLM 4 | from awq.utils.fused_utils import fuse_qkv 5 | from awq.modules.fused.block import LlamaLikeBlock 6 | from awq.modules.fused.model import LlamaLikeModel 7 | try: 8 | from transformers.models.exaone.modeling_exaone import ( 9 | ExaoneBlock as OldExaoneBlock, 10 | ExaoneForCausalLM as OldExaoneForCausalLM, 11 | ) 12 | except: 13 | OldExaoneBlock = None 14 | OldExaoneForCausalLM = None 15 | from awq.modules.fused.norm import FasterTransformerRMSNorm 16 | 17 | 18 | class ExaoneAWQForCausalLM(BaseAWQForCausalLM): 19 | layer_type = "ExaoneBlock" 20 | max_seq_len_key = "max_position_embeddings" 21 | 22 | @staticmethod 23 | def fuse_layers(model: OldExaoneForCausalLM): 24 | fuser = LlamaFuser(model) 25 | fuser.fuse_transformer() 26 | 27 | @staticmethod 28 | def get_model_layers(model: OldExaoneForCausalLM): 29 | return model.transformer.h 30 | 31 | @staticmethod 32 | def get_act_for_scaling(module: OldExaoneBlock): 33 | return dict(is_scalable=False) 34 | 35 | @staticmethod 36 | def move_embed(model: OldExaoneForCausalLM, device: str): 37 | model.transformer.wte = model.transformer.wte.to(device) 38 | model.transformer.rotary = model.transformer.rotary.to(device) 39 | 40 | @staticmethod 41 | def get_layers_for_scaling(module: OldExaoneBlock, input_feat, module_kwargs): 42 | layers = [] 43 | 44 | # attention input 45 | layers.append( 46 | dict( 47 | prev_op=module.ln_1, 48 | layers=[ 49 | module.attn.attention.q_proj, 50 | module.attn.attention.k_proj, 51 | module.attn.attention.v_proj, 52 | ], 53 | inp=input_feat["attn.attention.q_proj"], 54 | module2inspect=module.attn.attention, 55 | kwargs=module_kwargs, 56 | ) 57 | ) 58 | 59 | # attention out 60 | # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 61 | if module.attn.attention.v_proj.weight.shape == module.attn.attention.out_proj.weight.shape: 62 | layers.append( 63 | dict( 64 | prev_op=module.attn.attention.v_proj, 65 | layers=[module.attn.attention.out_proj], 66 | inp=input_feat["attn.attention.out_proj"], 67 | ) 68 | ) 69 | 70 | # linear 1 71 | layers.append( 72 | dict( 73 | prev_op=module.ln_2, 74 | layers=[module.mlp.c_fc_0, module.mlp.c_fc_1], 75 | inp=input_feat["mlp.c_fc_0"], 76 | module2inspect=module.mlp, 77 | ) 78 | ) 79 | 80 | # linear 2 81 | layers.append( 82 | dict( 83 | prev_op=module.mlp.c_fc_1, 84 | layers=[module.mlp.c_proj], 85 | inp=input_feat["mlp.c_proj"], 86 | ) 87 | ) 88 | 89 | return layers 90 | 91 | 92 | class LlamaFuser: 93 | def __init__(self, model: OldExaoneForCausalLM): 94 | self.model = model 95 | 96 | self.llama_blocks: List[Tuple[str, OldExaoneBlock]] = [ 97 | (name, module) 98 | for name, module in self.model.named_modules() 99 | if "LlamaDecoderLayer".lower() in module.__class__.__name__.lower() 100 | ] 101 | 102 | def fuse_transformer(self): 103 | blocks = [] 104 | 105 | module: OldExaoneBlock 106 | for module in tqdm.tqdm(self.model.transformer.h, desc="Fusing layers..."): 107 | device = next(iter(module.state_dict().values())).device 108 | qkv = fuse_qkv( 109 | module, 110 | module.attn.attention.q_proj, 111 | module.attn.attention.k_proj, 112 | module.attn.attention.v_proj, 113 | ) 114 | norm_1 = FasterTransformerRMSNorm( 115 | module.ln_1.weight, module.ln_1.eps 116 | ) 117 | norm_2 = FasterTransformerRMSNorm( 118 | module.ln_2.weight, 119 | module.ln_2.eps, 120 | ) 121 | blocks.append( 122 | LlamaLikeBlock( 123 | hidden_size=self.model.config.hidden_size, 124 | n_heads=self.model.config.num_attention_heads, 125 | n_kv_heads=self.model.config.num_key_value_heads, 126 | qkv_layer=qkv, 127 | o_proj=module.attn.attention.out_proj, 128 | mlp=module.mlp, 129 | norm_1=norm_1, 130 | norm_2=norm_2, 131 | dev=device, 132 | max_seq_len=self.model.config.max_seq_len, 133 | rope_theta=self.model.config.rope_theta, 134 | ) 135 | ) 136 | 137 | self.model.transformer = LlamaLikeModel( 138 | self.model.config.vocab_size, 139 | blocks, 140 | self.model.transformer.wte, 141 | self.model.transformer.ln_f, 142 | ) 143 | setattr(self.model.transformer, "blocks", self.model.transformer.blocks) 144 | -------------------------------------------------------------------------------- /awq/models/falcon.py: -------------------------------------------------------------------------------- 1 | from .base import BaseAWQForCausalLM 2 | from transformers.models.falcon.modeling_falcon import ( 3 | FalconDecoderLayer as OldFalconDecoderLayer, 4 | FalconForCausalLM, 5 | FalconAttention, 6 | ) 7 | 8 | 9 | class FalconAWQForCausalLM(BaseAWQForCausalLM): 10 | layer_type = "FalconDecoderLayer" 11 | 12 | @staticmethod 13 | def fuse_layers(model: FalconForCausalLM): 14 | fuser = FalconFuser(model) 15 | 16 | # TODO: Implement correctly fused modules for Falcon 40B and Falcon 180B 17 | if model.config.num_attention_heads == 71: 18 | fuser.fuse_transformer() 19 | 20 | @staticmethod 21 | def get_model_layers(model: FalconForCausalLM): 22 | return model.transformer.h 23 | 24 | @staticmethod 25 | def get_act_for_scaling(module: OldFalconDecoderLayer): 26 | return dict( 27 | is_scalable=True, 28 | scale_name="mlp.act", 29 | scale_layer=module.mlp.act, 30 | scale_shape=module.mlp.dense_h_to_4h.out_features, 31 | ) 32 | 33 | @staticmethod 34 | def move_embed(model: FalconForCausalLM, device): 35 | model.transformer.word_embeddings = model.transformer.word_embeddings.to(device) 36 | model.transformer.rotary_emb = model.transformer.rotary_emb.to(device) 37 | 38 | @staticmethod 39 | def get_layers_for_scaling( 40 | module: OldFalconDecoderLayer, input_feat, module_kwargs 41 | ): 42 | layers = [] 43 | 44 | # Falcon 7B (older architecture) 45 | if module.config.num_attention_heads == 71: 46 | # linear 1 + attention 47 | layers.append( 48 | dict( 49 | prev_op=module.input_layernorm, 50 | layers=[ 51 | module.mlp.dense_h_to_4h, 52 | module.self_attention.query_key_value, 53 | ], 54 | inp=input_feat["self_attention.query_key_value"], 55 | module2inspect=module, 56 | kwargs=module_kwargs, 57 | ) 58 | ) 59 | 60 | # Falcon 40B (newer architecture) 61 | else: 62 | # linear 1 + attention 63 | layers.append( 64 | dict( 65 | prev_op=module.ln_attn, 66 | layers=[module.self_attention.query_key_value], 67 | inp=input_feat["self_attention.query_key_value"], 68 | module2inspect=module, 69 | kwargs=module_kwargs, 70 | ) 71 | ) 72 | 73 | # linear 2 74 | layers.append( 75 | dict( 76 | prev_op=module.ln_mlp, 77 | layers=[module.mlp.dense_h_to_4h], 78 | inp=input_feat["mlp.dense_h_to_4h"], 79 | module2inspect=module, 80 | kwargs=module_kwargs, 81 | ) 82 | ) 83 | 84 | return layers 85 | 86 | 87 | from awq.modules.fused.model import FalconModel 88 | from awq.modules.fused.block import FalconDecoderLayer 89 | 90 | 91 | class FalconFuser: 92 | def __init__(self, model: FalconForCausalLM): 93 | self.model = model 94 | 95 | def fuse_transformer(self): 96 | blocks = [] 97 | 98 | module: OldFalconDecoderLayer 99 | for module in self.model.transformer.h: 100 | if module.config.num_attention_heads == 71: 101 | input_layernorm = module.input_layernorm 102 | ln_attn = None 103 | ln_mlp = None 104 | new_decoder_arch = False 105 | else: 106 | input_layernorm = None 107 | ln_attn = module.ln_attn 108 | ln_mlp = module.ln_mlp 109 | new_decoder_arch = True 110 | 111 | blocks.append( 112 | FalconDecoderLayer( 113 | hidden_size=module.config.hidden_size, 114 | n_heads=module.config.num_attention_heads, 115 | qkv_layer=module.self_attention.query_key_value, 116 | o_proj=module.self_attention.dense, 117 | mlp=module.mlp, 118 | dev=next(iter(module.state_dict().values())).device, 119 | max_seq_len=self.model.config.max_seq_len, 120 | input_layernorm=input_layernorm, 121 | ln_attn=ln_attn, 122 | ln_mlp=ln_mlp, 123 | new_decoder_arch=new_decoder_arch, 124 | ) 125 | ) 126 | 127 | self.model.transformer = FalconModel( 128 | self.model.config.vocab_size, 129 | blocks, 130 | self.model.transformer.word_embeddings, 131 | self.model.transformer.ln_f, 132 | ) 133 | 134 | setattr(self.model.transformer, "blocks", self.model.transformer.blocks) 135 | -------------------------------------------------------------------------------- /awq/models/gemma.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import torch 3 | from typing import List, Tuple 4 | from .base import BaseAWQForCausalLM 5 | from awq.utils.fused_utils import fuse_qkv 6 | from awq.modules.fused.block import LlamaLikeBlock 7 | from awq.modules.fused.model import LlamaLikeModel 8 | from transformers.models.gemma.modeling_gemma import ( 9 | GemmaDecoderLayer as OldGemmaDecoderLayer, 10 | GemmaForCausalLM as OldGemmaForCausalLM, 11 | ) 12 | from awq.modules.fused.norm import FasterTransformerRMSNorm 13 | 14 | 15 | class GemmaAWQForCausalLM(BaseAWQForCausalLM): 16 | layer_type = "GemmaDecoderLayer" 17 | max_new_tokens_key = "max_position_embeddings" 18 | 19 | @staticmethod 20 | def fuse_layers(model: OldGemmaDecoderLayer): 21 | fuser = GemmaFuser(model) 22 | fuser.fuse_transformer() 23 | 24 | @staticmethod 25 | def get_model_layers(model: OldGemmaForCausalLM): 26 | return model.model.layers 27 | 28 | @staticmethod 29 | def get_act_for_scaling(module: OldGemmaDecoderLayer): 30 | return dict(is_scalable=False) 31 | 32 | @staticmethod 33 | def move_embed(model: OldGemmaForCausalLM, device: str): 34 | model.model.embed_tokens = model.model.embed_tokens.to(device) 35 | 36 | @staticmethod 37 | def get_layers_for_scaling(module: OldGemmaDecoderLayer, input_feat, module_kwargs): 38 | layers = [] 39 | 40 | # attention input 41 | layers.append( 42 | dict( 43 | prev_op=module.input_layernorm, 44 | layers=[ 45 | module.self_attn.q_proj, 46 | module.self_attn.k_proj, 47 | module.self_attn.v_proj, 48 | ], 49 | inp=input_feat["self_attn.q_proj"], 50 | module2inspect=module.self_attn, 51 | kwargs=module_kwargs, 52 | ) 53 | ) 54 | 55 | # attention out 56 | # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 57 | if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: 58 | layers.append( 59 | dict( 60 | prev_op=module.self_attn.v_proj, 61 | layers=[module.self_attn.o_proj], 62 | inp=input_feat["self_attn.o_proj"], 63 | ) 64 | ) 65 | 66 | # linear 1 67 | layers.append( 68 | dict( 69 | prev_op=module.post_attention_layernorm, 70 | layers=[module.mlp.gate_proj, module.mlp.up_proj], 71 | inp=input_feat["mlp.gate_proj"], 72 | module2inspect=module.mlp, 73 | ) 74 | ) 75 | 76 | # linear 2 77 | layers.append( 78 | dict( 79 | prev_op=module.mlp.up_proj, 80 | layers=[module.mlp.down_proj], 81 | inp=input_feat["mlp.down_proj"], 82 | ) 83 | ) 84 | 85 | return layers 86 | 87 | 88 | class GemmaFuser: 89 | def __init__(self, model: OldGemmaForCausalLM): 90 | self.model = model 91 | 92 | self.Gemma_blocks: List[Tuple[str, OldGemmaDecoderLayer]] = [ 93 | (name, module) 94 | for name, module in self.model.named_modules() 95 | if "GemmaDecoderLayer".lower() in module.__class__.__name__.lower() 96 | ] 97 | 98 | def fuse_transformer(self): 99 | blocks = [] 100 | 101 | module: OldGemmaDecoderLayer 102 | for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): 103 | device = next(iter(module.state_dict().values())).device 104 | qkv = fuse_qkv( 105 | module, 106 | module.self_attn.q_proj, 107 | module.self_attn.k_proj, 108 | module.self_attn.v_proj, 109 | ) 110 | with torch.no_grad(): 111 | # GemmaRMSNorm is different from Llama's in that it multiplies 112 | # (1 + weight) to the output, instead of just weight. 113 | module.input_layernorm.weight += 1 114 | module.post_attention_layernorm.weight += 1 115 | norm_1 = FasterTransformerRMSNorm( 116 | module.input_layernorm.weight, module.input_layernorm.eps 117 | ) 118 | norm_2 = FasterTransformerRMSNorm( 119 | module.post_attention_layernorm.weight, 120 | module.post_attention_layernorm.eps, 121 | ) 122 | blocks.append( 123 | LlamaLikeBlock( 124 | hidden_size=self.model.config.hidden_size, 125 | n_heads=self.model.config.num_attention_heads, 126 | n_kv_heads=self.model.config.num_key_value_heads, 127 | qkv_layer=qkv, 128 | o_proj=module.self_attn.o_proj, 129 | mlp=module.mlp, 130 | norm_1=norm_1, 131 | norm_2=norm_2, 132 | dev=device, 133 | max_seq_len=self.model.config.max_seq_len, 134 | rope_theta=self.model.config.rope_theta, 135 | head_dim=self.model.config.head_dim, 136 | ) 137 | ) 138 | 139 | with torch.no_grad(): 140 | # Normalize Gemma's embedding layer 141 | self.model.model.embed_tokens.weight *= self.model.config.hidden_size**0.5 142 | 143 | self.model.model = LlamaLikeModel( 144 | self.model.config.vocab_size, 145 | blocks, 146 | self.model.model.embed_tokens, 147 | self.model.model.norm, 148 | ) 149 | setattr(self.model.model, "blocks", self.model.model.blocks) 150 | -------------------------------------------------------------------------------- /awq/models/gpt_bigcode.py: -------------------------------------------------------------------------------- 1 | from .base import BaseAWQForCausalLM 2 | from transformers.models.gpt_bigcode.modeling_gpt_bigcode import ( 3 | GPTBigCodeForCausalLM, 4 | GPTBigCodeBlock as OldGptBigCodeBlock, 5 | ) 6 | 7 | 8 | class GptBigCodeAWQForCausalLM(BaseAWQForCausalLM): 9 | layer_type = "GPTBigCodeBlock" 10 | max_seq_len_key = "n_positions" 11 | 12 | @staticmethod 13 | def get_model_layers(model: GPTBigCodeForCausalLM): 14 | return model.transformer.h 15 | 16 | @staticmethod 17 | def get_act_for_scaling(module: OldGptBigCodeBlock): 18 | return dict( 19 | is_scalable=True, 20 | scale_name="mlp.act", 21 | scale_layer=module.mlp.act, 22 | scale_shape=module.mlp.c_fc.out_features, 23 | ) 24 | 25 | @staticmethod 26 | def move_embed(model: GPTBigCodeForCausalLM, device): 27 | model.transformer.wte = model.transformer.wte.to(device) 28 | model.transformer.wpe = model.transformer.wpe.to(device) 29 | model.transformer.drop = model.transformer.drop.to(device) 30 | 31 | @staticmethod 32 | def get_layers_for_scaling(module: OldGptBigCodeBlock, input_feat, module_kwargs): 33 | layers = [] 34 | 35 | # attention input 36 | layers.append( 37 | dict( 38 | prev_op=module.ln_1, 39 | layers=[module.attn.c_attn], 40 | inp=input_feat["attn.c_attn"], 41 | module2inspect=module.attn, 42 | kwargs=module_kwargs, 43 | ) 44 | ) 45 | 46 | # linear 1 47 | layers.append( 48 | dict( 49 | prev_op=module.ln_2, 50 | layers=[module.mlp.c_fc], 51 | inp=input_feat["mlp.c_fc"], 52 | module2inspect=module.mlp, 53 | ) 54 | ) 55 | 56 | # linear 2 57 | layers.append( 58 | dict( 59 | prev_op=module.mlp.act, 60 | layers=[module.mlp.c_proj], 61 | inp=input_feat["mlp.c_proj"], 62 | ) 63 | ) 64 | 65 | return layers 66 | -------------------------------------------------------------------------------- /awq/models/gpt_neox.py: -------------------------------------------------------------------------------- 1 | from .base import BaseAWQForCausalLM 2 | from transformers.models.gpt_neox.modeling_gpt_neox import ( 3 | GPTNeoXLayer, 4 | GPTNeoXForCausalLM, 5 | ) 6 | 7 | 8 | class GPTNeoXAWQForCausalLM(BaseAWQForCausalLM): 9 | layer_type = "GPTNeoXDecoderLayer" 10 | max_seq_len_key = "max_position_embeddings" 11 | 12 | @staticmethod 13 | def get_model_layers(model: GPTNeoXForCausalLM): 14 | return model.gpt_neox.layers 15 | 16 | @staticmethod 17 | def get_act_for_scaling(module: GPTNeoXLayer): 18 | return dict( 19 | is_scalable=True, 20 | scale_name="mlp.act", 21 | scale_layer=module.mlp.act, 22 | scale_shape=module.mlp.dense_h_to_4h.out_features, 23 | ) 24 | 25 | @staticmethod 26 | def move_embed(model: GPTNeoXForCausalLM, device: str): 27 | model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(device) 28 | model.gpt_neox.rotary_emb = model.gpt_neox.rotary_emb.to(device) 29 | 30 | @staticmethod 31 | def get_layers_for_scaling(module: GPTNeoXLayer, input_feat, module_kwargs): 32 | layers = [] 33 | 34 | # attention input 35 | layers.append( 36 | dict( 37 | prev_op=module.input_layernorm, 38 | layers=[module.attention.query_key_value], 39 | inp=input_feat["attention.query_key_value"], 40 | ) 41 | ) 42 | 43 | # attention out 44 | # Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469 45 | """ 46 | layers.append(dict( 47 | prev_op=module.attention.query_key_value, 48 | layers=[module.attention.dense], 49 | inp=input_feat['attention.dense'], 50 | )) 51 | """ 52 | 53 | # linear 1 54 | layers.append( 55 | dict( 56 | prev_op=module.post_attention_layernorm, 57 | layers=[module.mlp.dense_h_to_4h], 58 | inp=input_feat["mlp.dense_h_to_4h"], 59 | ) 60 | ) 61 | 62 | # linear 2 63 | layers.append( 64 | dict( 65 | prev_op=module.mlp.act, 66 | layers=[module.mlp.dense_4h_to_h], 67 | inp=input_feat["mlp.dense_4h_to_h"], 68 | ) 69 | ) 70 | 71 | return layers 72 | -------------------------------------------------------------------------------- /awq/models/gptj.py: -------------------------------------------------------------------------------- 1 | from .base import BaseAWQForCausalLM 2 | from transformers.models.gptj.modeling_gptj import GPTJForCausalLM, GPTJBlock 3 | 4 | 5 | class GPTJAWQForCausalLM(BaseAWQForCausalLM): 6 | layer_type = "GPTJBlock" 7 | max_seq_len_key = "n_positions" 8 | 9 | @staticmethod 10 | def get_model_layers(model: GPTJForCausalLM): 11 | return model.transformer.h 12 | 13 | @staticmethod 14 | def get_act_for_scaling(module: GPTJBlock): 15 | return dict( 16 | is_scalable=True, 17 | scale_name="mlp.act", 18 | scale_layer=module.mlp.act, 19 | scale_shape=module.mlp.fc_in.out_features, 20 | ) 21 | 22 | @staticmethod 23 | def move_embed(model: GPTJForCausalLM, device: str): 24 | model.transformer.wte = model.transformer.wte.to(device) 25 | 26 | @staticmethod 27 | def get_layers_for_scaling(module: GPTJBlock, input_feat, module_kwargs): 28 | layers = [] 29 | 30 | # attention input + linear 1 31 | layers.append( 32 | dict( 33 | prev_op=module.ln_1, 34 | layers=[ 35 | module.attn.q_proj, 36 | module.attn.k_proj, 37 | module.attn.v_proj, 38 | module.mlp.fc_in, 39 | ], 40 | inp=input_feat["attn.q_proj"], 41 | module2inspect=module, 42 | kwargs=module_kwargs, 43 | ) 44 | ) 45 | 46 | # attention out 47 | layers.append( 48 | dict( 49 | prev_op=module.attn.v_proj, 50 | layers=[module.attn.out_proj], 51 | inp=input_feat["attn.out_proj"], 52 | ) 53 | ) 54 | 55 | # linear 2 56 | layers.append( 57 | dict( 58 | prev_op=module.mlp.act, 59 | layers=[module.mlp.fc_out], 60 | inp=input_feat["mlp.fc_out"], 61 | ) 62 | ) 63 | 64 | return layers 65 | -------------------------------------------------------------------------------- /awq/models/internlm2.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from typing import List, Tuple 3 | from .base import BaseAWQForCausalLM 4 | 5 | 6 | class InternLM2AWQForCausalLM(BaseAWQForCausalLM): 7 | layer_type = "InternLM2DecoderLayer" 8 | max_seq_len_key = "max_position_embeddings" 9 | 10 | @staticmethod 11 | def get_model_layers(model): 12 | return model.model.layers 13 | 14 | @staticmethod 15 | def get_act_for_scaling(module): 16 | return dict( 17 | is_scalable=True, 18 | scale_name="feed_forward.w2", 19 | scale_layer=module.feed_forward.w2, 20 | scale_shape=module.feed_forward.w2.out_features, 21 | ) 22 | 23 | @staticmethod 24 | def move_embed(model, device: str): 25 | model.model.tok_embeddings = model.model.tok_embeddings.to(device) 26 | 27 | @staticmethod 28 | def get_layers_for_scaling(module, input_feat, module_kwargs): 29 | layers = [] 30 | 31 | # attention input 32 | layers.append( 33 | dict( 34 | prev_op=module.attention_norm, 35 | layers=[ 36 | module.attention.wqkv, 37 | ], 38 | inp=input_feat["attention.wqkv"], 39 | module2inspect=module.attention, 40 | kwargs=module_kwargs, 41 | ) 42 | ) 43 | 44 | # attention out 45 | layers.append( 46 | dict( 47 | prev_op=module.attention.wqkv, 48 | layers=[module.attention.wo], 49 | inp=input_feat["attention.wo"], 50 | ) 51 | ) 52 | 53 | # feed forward input 54 | layers.append( 55 | dict( 56 | prev_op=module.ffn_norm, 57 | layers=[ 58 | module.feed_forward.w1, 59 | module.feed_forward.w3, 60 | ], 61 | inp=input_feat["feed_forward.w1"], 62 | module2inspect=module.feed_forward, 63 | kwargs=module_kwargs, 64 | ) 65 | ) 66 | 67 | # feed forward output 68 | layers.append( 69 | dict( 70 | prev_op=module.feed_forward.w1, 71 | layers=[module.feed_forward.w2], 72 | inp=input_feat["feed_forward.w2"], 73 | ) 74 | ) 75 | 76 | return layers 77 | -------------------------------------------------------------------------------- /awq/models/llama.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from typing import List, Tuple 3 | from .base import BaseAWQForCausalLM 4 | from awq.utils.fused_utils import fuse_qkv 5 | from awq.modules.fused.block import LlamaLikeBlock 6 | from awq.modules.fused.model import LlamaLikeModel 7 | from transformers.models.llama.modeling_llama import ( 8 | LlamaDecoderLayer as OldLlamaDecoderLayer, 9 | LlamaForCausalLM as OldLlamaForCausalLM, 10 | ) 11 | from awq.modules.fused.norm import FasterTransformerRMSNorm 12 | 13 | 14 | class LlamaAWQForCausalLM(BaseAWQForCausalLM): 15 | layer_type = "LlamaDecoderLayer" 16 | max_seq_len_key = "max_position_embeddings" 17 | 18 | @staticmethod 19 | def fuse_layers(model: OldLlamaForCausalLM): 20 | fuser = LlamaFuser(model) 21 | fuser.fuse_transformer() 22 | 23 | @staticmethod 24 | def get_model_layers(model: OldLlamaForCausalLM): 25 | return model.model.layers 26 | 27 | @staticmethod 28 | def get_act_for_scaling(module: OldLlamaDecoderLayer): 29 | return dict(is_scalable=False) 30 | 31 | @staticmethod 32 | def move_embed(model: OldLlamaForCausalLM, device: str): 33 | model.model.embed_tokens = model.model.embed_tokens.to(device) 34 | model.model.rotary_emb = model.model.rotary_emb.to(device) 35 | 36 | @staticmethod 37 | def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs): 38 | layers = [] 39 | 40 | # attention input 41 | layers.append( 42 | dict( 43 | prev_op=module.input_layernorm, 44 | layers=[ 45 | module.self_attn.q_proj, 46 | module.self_attn.k_proj, 47 | module.self_attn.v_proj, 48 | ], 49 | inp=input_feat["self_attn.q_proj"], 50 | module2inspect=module.self_attn, 51 | kwargs=module_kwargs, 52 | ) 53 | ) 54 | 55 | # attention out 56 | # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 57 | if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: 58 | layers.append( 59 | dict( 60 | prev_op=module.self_attn.v_proj, 61 | layers=[module.self_attn.o_proj], 62 | inp=input_feat["self_attn.o_proj"], 63 | ) 64 | ) 65 | 66 | # linear 1 67 | layers.append( 68 | dict( 69 | prev_op=module.post_attention_layernorm, 70 | layers=[module.mlp.gate_proj, module.mlp.up_proj], 71 | inp=input_feat["mlp.gate_proj"], 72 | module2inspect=module.mlp, 73 | ) 74 | ) 75 | 76 | # linear 2 77 | layers.append( 78 | dict( 79 | prev_op=module.mlp.up_proj, 80 | layers=[module.mlp.down_proj], 81 | inp=input_feat["mlp.down_proj"], 82 | ) 83 | ) 84 | 85 | return layers 86 | 87 | 88 | class LlamaFuser: 89 | def __init__(self, model: OldLlamaForCausalLM): 90 | self.model = model 91 | 92 | self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [ 93 | (name, module) 94 | for name, module in self.model.named_modules() 95 | if "LlamaDecoderLayer".lower() in module.__class__.__name__.lower() 96 | ] 97 | 98 | def fuse_transformer(self): 99 | blocks = [] 100 | 101 | module: OldLlamaDecoderLayer 102 | for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): 103 | device = next(iter(module.state_dict().values())).device 104 | qkv = fuse_qkv( 105 | module, 106 | module.self_attn.q_proj, 107 | module.self_attn.k_proj, 108 | module.self_attn.v_proj, 109 | ) 110 | norm_1 = FasterTransformerRMSNorm( 111 | module.input_layernorm.weight, module.input_layernorm.variance_epsilon 112 | ) 113 | norm_2 = FasterTransformerRMSNorm( 114 | module.post_attention_layernorm.weight, 115 | module.post_attention_layernorm.variance_epsilon, 116 | ) 117 | blocks.append( 118 | LlamaLikeBlock( 119 | hidden_size=self.model.config.hidden_size, 120 | n_heads=self.model.config.num_attention_heads, 121 | n_kv_heads=self.model.config.num_key_value_heads, 122 | qkv_layer=qkv, 123 | o_proj=module.self_attn.o_proj, 124 | mlp=module.mlp, 125 | norm_1=norm_1, 126 | norm_2=norm_2, 127 | dev=device, 128 | max_seq_len=self.model.config.max_seq_len, 129 | rope_theta=self.model.config.rope_theta, 130 | ) 131 | ) 132 | 133 | self.model.model = LlamaLikeModel( 134 | self.model.config.vocab_size, 135 | blocks, 136 | self.model.model.embed_tokens, 137 | self.model.model.norm, 138 | ) 139 | setattr(self.model.model, "blocks", self.model.model.blocks) 140 | -------------------------------------------------------------------------------- /awq/models/llava.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from typing import List, Tuple 3 | from .base import BaseAWQForCausalLM 4 | from awq.utils.fused_utils import fuse_qkv 5 | from awq.modules.fused.block import LlamaLikeBlock 6 | from awq.modules.fused.model import LlamaLikeModel 7 | from transformers.models.llama.modeling_llama import ( 8 | LlamaDecoderLayer as OldLlamaDecoderLayer, 9 | ) 10 | from transformers.models.llava.modeling_llava import ( 11 | LlavaForConditionalGeneration as OldLlavaForConditionalGeneration, 12 | ) 13 | from awq.modules.fused.norm import FasterTransformerRMSNorm 14 | 15 | 16 | class LlavaAWQForCausalLM(BaseAWQForCausalLM): 17 | layer_type = "LlamaDecoderLayer" 18 | max_seq_len_key = "max_position_embeddings" 19 | 20 | @staticmethod 21 | def fuse_layers(model: OldLlavaForConditionalGeneration): 22 | fuser = LlavaFuser(model) 23 | fuser.fuse_transformer() 24 | 25 | @staticmethod 26 | def get_model_layers(model: OldLlavaForConditionalGeneration): 27 | return model.language_model.model.layers 28 | 29 | @staticmethod 30 | def get_act_for_scaling(module: OldLlamaDecoderLayer): 31 | return dict(is_scalable=False) 32 | 33 | @staticmethod 34 | def move_embed(model: OldLlavaForConditionalGeneration, device: str): 35 | model.language_model.model.embed_tokens = model.get_input_embeddings().to( 36 | device 37 | ) 38 | if hasattr(model.language_model.model, "rotary_emb"): 39 | model.language_model.model.rotary_emb = model.language_model.model.rotary_emb.to(device) 40 | 41 | @staticmethod 42 | def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs): 43 | layers = [] 44 | 45 | # attention input 46 | layers.append( 47 | dict( 48 | prev_op=module.input_layernorm, 49 | layers=[ 50 | module.self_attn.q_proj, 51 | module.self_attn.k_proj, 52 | module.self_attn.v_proj, 53 | ], 54 | inp=input_feat["self_attn.q_proj"], 55 | module2inspect=module.self_attn, 56 | kwargs=module_kwargs, 57 | ) 58 | ) 59 | 60 | # attention out 61 | # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 62 | if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: 63 | layers.append( 64 | dict( 65 | prev_op=module.self_attn.v_proj, 66 | layers=[module.self_attn.o_proj], 67 | inp=input_feat["self_attn.o_proj"], 68 | ) 69 | ) 70 | 71 | # linear 1 72 | layers.append( 73 | dict( 74 | prev_op=module.post_attention_layernorm, 75 | layers=[module.mlp.gate_proj, module.mlp.up_proj], 76 | inp=input_feat["mlp.gate_proj"], 77 | module2inspect=module.mlp, 78 | ) 79 | ) 80 | 81 | # linear 2 82 | layers.append( 83 | dict( 84 | prev_op=module.mlp.up_proj, 85 | layers=[module.mlp.down_proj], 86 | inp=input_feat["mlp.down_proj"], 87 | ) 88 | ) 89 | 90 | return layers 91 | 92 | 93 | class LlavaFuser: 94 | def __init__(self, model: OldLlavaForConditionalGeneration): 95 | self.model = model.language_model 96 | 97 | self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [ 98 | (name, module) 99 | for name, module in self.model.named_modules() 100 | if "LlamaDecoderLayer".lower() in module.__class__.__name__.lower() 101 | ] 102 | 103 | def fuse_transformer(self): 104 | blocks = [] 105 | 106 | module: OldLlamaDecoderLayer 107 | for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): 108 | device = next(iter(module.state_dict().values())).device 109 | qkv = fuse_qkv( 110 | module, 111 | module.self_attn.q_proj, 112 | module.self_attn.k_proj, 113 | module.self_attn.v_proj, 114 | ) 115 | norm_1 = FasterTransformerRMSNorm( 116 | module.input_layernorm.weight, module.input_layernorm.variance_epsilon 117 | ) 118 | norm_2 = FasterTransformerRMSNorm( 119 | module.post_attention_layernorm.weight, 120 | module.post_attention_layernorm.variance_epsilon, 121 | ) 122 | if hasattr(self.model.config, "max_seq_len"): 123 | max_seq_len = self.model.config.max_seq_len 124 | else: 125 | max_seq_len = self.model.config.max_position_embeddings 126 | blocks.append( 127 | LlamaLikeBlock( 128 | hidden_size=self.model.config.hidden_size, 129 | n_heads=self.model.config.num_attention_heads, 130 | n_kv_heads=self.model.config.num_key_value_heads, 131 | qkv_layer=qkv, 132 | o_proj=module.self_attn.o_proj, 133 | mlp=module.mlp, 134 | norm_1=norm_1, 135 | norm_2=norm_2, 136 | dev=device, 137 | max_seq_len=max_seq_len, 138 | rope_theta=self.model.config.rope_theta, 139 | ) 140 | ) 141 | 142 | self.model.model = LlamaLikeModel( 143 | self.model.config.vocab_size, 144 | blocks, 145 | self.model.model.embed_tokens, 146 | self.model.model.norm, 147 | ) 148 | setattr(self.model.model, "blocks", self.model.model.blocks) 149 | -------------------------------------------------------------------------------- /awq/models/llava_next.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from typing import List, Tuple 3 | from .base import BaseAWQForCausalLM 4 | from awq.utils.fused_utils import fuse_qkv 5 | from awq.modules.fused.block import LlamaLikeBlock 6 | from awq.modules.fused.model import LlamaLikeModel 7 | from transformers.models.llama.modeling_llama import ( 8 | LlamaDecoderLayer as OldLlamaDecoderLayer, 9 | ) 10 | from transformers.models.llava_next.modeling_llava_next import LlavaNextForConditionalGeneration 11 | from awq.modules.fused.norm import FasterTransformerRMSNorm 12 | 13 | 14 | class LlavaNextAWQForCausalLM(BaseAWQForCausalLM): 15 | layer_type = "LlamaDecoderLayer" 16 | max_seq_len_key = "max_position_embeddings" 17 | 18 | @staticmethod 19 | def fuse_layers(model: LlavaNextForConditionalGeneration): 20 | pass 21 | 22 | @staticmethod 23 | def get_model_layers(model: LlavaNextForConditionalGeneration): 24 | return model.language_model.model.layers 25 | 26 | @staticmethod 27 | def get_act_for_scaling(module: OldLlamaDecoderLayer): 28 | return dict(is_scalable=False) 29 | 30 | @staticmethod 31 | def move_embed(model: LlavaNextForConditionalGeneration, device: str): 32 | model.language_model.model.embed_tokens = model.get_input_embeddings().to( 33 | device 34 | ) 35 | model.language_model.model.rotary_emb = model.language_model.model.rotary_emb.to(device) 36 | 37 | @staticmethod 38 | def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs): 39 | layers = [] 40 | 41 | # attention input 42 | layers.append( 43 | dict( 44 | prev_op=module.input_layernorm, 45 | layers=[ 46 | module.self_attn.q_proj, 47 | module.self_attn.k_proj, 48 | module.self_attn.v_proj, 49 | ], 50 | inp=input_feat["self_attn.q_proj"], 51 | module2inspect=module.self_attn, 52 | kwargs=module_kwargs, 53 | ) 54 | ) 55 | 56 | # attention out 57 | # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 58 | if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: 59 | layers.append( 60 | dict( 61 | prev_op=module.self_attn.v_proj, 62 | layers=[module.self_attn.o_proj], 63 | inp=input_feat["self_attn.o_proj"], 64 | ) 65 | ) 66 | 67 | # linear 1 68 | layers.append( 69 | dict( 70 | prev_op=module.post_attention_layernorm, 71 | layers=[module.mlp.gate_proj, module.mlp.up_proj], 72 | inp=input_feat["mlp.gate_proj"], 73 | module2inspect=module.mlp, 74 | ) 75 | ) 76 | 77 | # linear 2 78 | layers.append( 79 | dict( 80 | prev_op=module.mlp.up_proj, 81 | layers=[module.mlp.down_proj], 82 | inp=input_feat["mlp.down_proj"], 83 | ) 84 | ) 85 | 86 | return layers 87 | 88 | 89 | class LlavaNextFuser: 90 | def __init__(self, model: LlavaNextForConditionalGeneration): 91 | self.model = model.language_model 92 | 93 | self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [ 94 | (name, module) 95 | for name, module in self.model.named_modules() 96 | if "LlamaDecoderLayer".lower() in module.__class__.__name__.lower() 97 | ] 98 | 99 | def fuse_transformer(self): 100 | blocks = [] 101 | 102 | module: OldLlamaDecoderLayer 103 | for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): 104 | device = next(iter(module.state_dict().values())).device 105 | qkv = fuse_qkv( 106 | module, 107 | module.self_attn.q_proj, 108 | module.self_attn.k_proj, 109 | module.self_attn.v_proj, 110 | ) 111 | norm_1 = FasterTransformerRMSNorm( 112 | module.input_layernorm.weight, module.input_layernorm.variance_epsilon 113 | ) 114 | norm_2 = FasterTransformerRMSNorm( 115 | module.post_attention_layernorm.weight, 116 | module.post_attention_layernorm.variance_epsilon, 117 | ) 118 | if hasattr(self.model.config, "max_seq_len"): 119 | max_seq_len = self.model.config.max_seq_len 120 | else: 121 | max_seq_len = self.model.config.max_position_embeddings 122 | blocks.append( 123 | LlamaLikeBlock( 124 | hidden_size=self.model.config.hidden_size, 125 | n_heads=self.model.config.num_attention_heads, 126 | n_kv_heads=self.model.config.num_key_value_heads, 127 | qkv_layer=qkv, 128 | o_proj=module.self_attn.o_proj, 129 | mlp=module.mlp, 130 | norm_1=norm_1, 131 | norm_2=norm_2, 132 | dev=device, 133 | max_seq_len=max_seq_len, 134 | rope_theta=self.model.config.rope_theta, 135 | ) 136 | ) 137 | 138 | self.model.model = LlamaLikeModel( 139 | self.model.config.vocab_size, 140 | blocks, 141 | self.model.model.embed_tokens, 142 | self.model.model.norm, 143 | ) 144 | setattr(self.model.model, "blocks", self.model.model.blocks) 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /awq/models/minicpm.py: -------------------------------------------------------------------------------- 1 | 2 | from .base import BaseAWQForCausalLM 3 | 4 | 5 | class MiniCPMAWQForCausalLM(BaseAWQForCausalLM): 6 | layer_type = "MiniCPMDecoderLayer" 7 | max_seq_len_key = "seq_length" 8 | 9 | @staticmethod 10 | def get_model_layers(model): 11 | return model.model.layers 12 | 13 | @staticmethod 14 | def get_act_for_scaling(module): 15 | return dict(is_scalable=False) 16 | 17 | @staticmethod 18 | def move_embed(model, device: str): 19 | model.model.embed_tokens = model.model.embed_tokens.to(device) 20 | 21 | @staticmethod 22 | def get_layers_for_scaling(module, input_feat, module_kwargs): 23 | layers = [] 24 | 25 | 26 | 27 | # # mlp 28 | layers.append( 29 | dict( 30 | prev_op=module.input_layernorm, 31 | layers=[ 32 | module.self_attn.q_proj, 33 | module.self_attn.k_proj, 34 | module.self_attn.v_proj, 35 | ], 36 | inp=input_feat["self_attn.q_proj"], 37 | module2inspect=module.self_attn, 38 | kwargs=module_kwargs, 39 | ) 40 | ) 41 | 42 | # linear 2 43 | layers.append( 44 | dict( 45 | prev_op=module.mlp.up_proj, 46 | layers=[module.mlp.down_proj], 47 | inp=input_feat["mlp.down_proj"], 48 | ) 49 | ) 50 | 51 | layers.append( 52 | dict( 53 | prev_op=module.post_attention_layernorm, 54 | layers=[module.mlp.gate_proj,module.mlp.up_proj], 55 | inp=input_feat["mlp.gate_proj"], 56 | module2inspect=module.mlp 57 | ) 58 | ) 59 | 60 | return layers 61 | 62 | 63 | -------------------------------------------------------------------------------- /awq/models/minicpm3.py: -------------------------------------------------------------------------------- 1 | from .base import BaseAWQForCausalLM 2 | 3 | class MiniCPM3AWQForCausalLM(BaseAWQForCausalLM): 4 | layer_type = "MiniCPMDecoderLayer" 5 | max_seq_len_key = "max_position_embeddings" 6 | 7 | @staticmethod 8 | def get_model_layers(model): 9 | print(model.model.layers) 10 | return model.model.layers 11 | 12 | @staticmethod 13 | def get_act_for_scaling(module): 14 | return dict(is_scalable=False) 15 | 16 | @staticmethod 17 | def move_embed(model, device: str): 18 | model.model.embed_tokens = model.model.embed_tokens.to(device) 19 | 20 | @staticmethod 21 | def get_layers_for_scaling(module, input_feat, module_kwargs): 22 | layers = [] 23 | 24 | # mlp 25 | layers.append( 26 | dict( 27 | prev_op=module.self_attn.q_a_layernorm, 28 | layers=[ 29 | module.self_attn.q_b_proj, 30 | 31 | ], 32 | inp=input_feat["self_attn.q_b_proj"], 33 | module2inspect=module.self_attn.q_b_proj, 34 | kwargs=module_kwargs, 35 | ) 36 | ) 37 | 38 | layers.append( 39 | dict( 40 | prev_op=module.self_attn.kv_a_layernorm, 41 | layers=[ 42 | module.self_attn.kv_b_proj, 43 | ], 44 | inp=input_feat["self_attn.kv_b_proj"], 45 | module2inspect=module.self_attn.kv_b_proj, 46 | kwargs=module_kwargs, 47 | ) 48 | ) 49 | 50 | 51 | # linear 2 52 | layers.append( 53 | dict( 54 | prev_op=module.mlp.up_proj, 55 | layers=[module.mlp.down_proj], 56 | inp=input_feat["mlp.down_proj"], 57 | ) 58 | ) 59 | 60 | layers.append( 61 | dict( 62 | prev_op=module.post_attention_layernorm, 63 | layers=[module.mlp.gate_proj,module.mlp.up_proj], 64 | inp=input_feat["mlp.gate_proj"], 65 | module2inspect=module.mlp 66 | ) 67 | ) 68 | 69 | return layers -------------------------------------------------------------------------------- /awq/models/mistral.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from typing import List, Tuple 3 | from .base import BaseAWQForCausalLM 4 | from awq.utils.fused_utils import fuse_qkv 5 | from awq.modules.fused.block import LlamaLikeBlock 6 | from awq.modules.fused.model import LlamaLikeModel 7 | from transformers.models.mistral.modeling_mistral import ( 8 | MistralDecoderLayer as OldMistralDecoderLayer, 9 | MistralForCausalLM as OldMistralForCausalLM, 10 | ) 11 | from awq.modules.fused.norm import FasterTransformerRMSNorm 12 | 13 | 14 | class MistralAWQForCausalLM(BaseAWQForCausalLM): 15 | layer_type = "MistralDecoderLayer" 16 | max_seq_len_key = "max_position_embeddings" 17 | 18 | @staticmethod 19 | def fuse_layers(model: OldMistralForCausalLM): 20 | fuser = MistralFuser(model) 21 | fuser.fuse_transformer() 22 | 23 | @staticmethod 24 | def get_model_layers(model: OldMistralForCausalLM): 25 | return model.model.layers 26 | 27 | @staticmethod 28 | def get_act_for_scaling(module: OldMistralDecoderLayer): 29 | return dict(is_scalable=False) 30 | 31 | @staticmethod 32 | def move_embed(model: OldMistralForCausalLM, device: str): 33 | model.model.embed_tokens = model.model.embed_tokens.to(device) 34 | 35 | @staticmethod 36 | def get_layers_for_scaling( 37 | module: OldMistralDecoderLayer, input_feat, module_kwargs 38 | ): 39 | layers = [] 40 | 41 | # attention input 42 | layers.append( 43 | dict( 44 | prev_op=module.input_layernorm, 45 | layers=[ 46 | module.self_attn.q_proj, 47 | module.self_attn.k_proj, 48 | module.self_attn.v_proj, 49 | ], 50 | inp=input_feat["self_attn.q_proj"], 51 | module2inspect=module.self_attn, 52 | kwargs=module_kwargs, 53 | ) 54 | ) 55 | 56 | # attention out 57 | # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 58 | if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: 59 | layers.append( 60 | dict( 61 | prev_op=module.self_attn.v_proj, 62 | layers=[module.self_attn.o_proj], 63 | inp=input_feat["self_attn.o_proj"], 64 | ) 65 | ) 66 | 67 | # linear 1 68 | layers.append( 69 | dict( 70 | prev_op=module.post_attention_layernorm, 71 | layers=[module.mlp.gate_proj, module.mlp.up_proj], 72 | inp=input_feat["mlp.gate_proj"], 73 | module2inspect=module.mlp, 74 | ) 75 | ) 76 | 77 | # linear 2 78 | layers.append( 79 | dict( 80 | prev_op=module.mlp.up_proj, 81 | layers=[module.mlp.down_proj], 82 | inp=input_feat["mlp.down_proj"], 83 | ) 84 | ) 85 | 86 | return layers 87 | 88 | 89 | class MistralFuser: 90 | def __init__(self, model: OldMistralForCausalLM): 91 | self.model = model 92 | 93 | self.mistral_blocks: List[Tuple[str, OldMistralDecoderLayer]] = [ 94 | (name, module) 95 | for name, module in self.model.named_modules() 96 | if "MistralDecoderLayer".lower() in module.__class__.__name__.lower() 97 | ] 98 | 99 | def fuse_transformer(self): 100 | blocks = [] 101 | 102 | module: OldMistralDecoderLayer 103 | for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): 104 | device = next(iter(module.state_dict().values())).device 105 | qkv = fuse_qkv( 106 | module, 107 | module.self_attn.q_proj, 108 | module.self_attn.k_proj, 109 | module.self_attn.v_proj, 110 | ) 111 | norm_1 = FasterTransformerRMSNorm( 112 | module.input_layernorm.weight, module.input_layernorm.variance_epsilon 113 | ) 114 | norm_2 = FasterTransformerRMSNorm( 115 | module.post_attention_layernorm.weight, 116 | module.post_attention_layernorm.variance_epsilon, 117 | ) 118 | blocks.append( 119 | LlamaLikeBlock( 120 | hidden_size=self.model.config.hidden_size, 121 | n_heads=self.model.config.num_attention_heads, 122 | n_kv_heads=self.model.config.num_key_value_heads, 123 | qkv_layer=qkv, 124 | o_proj=module.self_attn.o_proj, 125 | mlp=module.mlp, 126 | norm_1=norm_1, 127 | norm_2=norm_2, 128 | dev=device, 129 | max_seq_len=self.model.config.max_seq_len, 130 | rope_theta=self.model.config.rope_theta, 131 | ) 132 | ) 133 | 134 | self.model.model = LlamaLikeModel( 135 | self.model.config.vocab_size, 136 | blocks, 137 | self.model.model.embed_tokens, 138 | self.model.model.norm, 139 | ) 140 | setattr(self.model.model, "blocks", self.model.model.blocks) 141 | -------------------------------------------------------------------------------- /awq/models/mpt.py: -------------------------------------------------------------------------------- 1 | from .base import BaseAWQForCausalLM 2 | from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM 3 | 4 | 5 | class MptAWQForCausalLM(BaseAWQForCausalLM): 6 | layer_type = "MPTBlock" 7 | max_seq_len_key = "max_seq_len" 8 | 9 | @staticmethod 10 | def fuse_layers(model: MptForCausalLM): 11 | fuser = MptFuser(model) 12 | fuser.fuse_transformer() 13 | 14 | @staticmethod 15 | def get_model_layers(model: MptForCausalLM): 16 | return model.transformer.blocks 17 | 18 | @staticmethod 19 | def get_act_for_scaling(module: OldMptBlock): 20 | return dict( 21 | is_scalable=True, 22 | scale_name="ffn.act", 23 | scale_layer=module.ffn.act, 24 | scale_shape=module.ffn.up_proj.out_features, 25 | ) 26 | 27 | @staticmethod 28 | def move_embed(model: MptForCausalLM, device: str): 29 | model.transformer.wte = model.transformer.wte.to(device) 30 | model.transformer.emb_drop = model.transformer.emb_drop.to(device) 31 | 32 | @staticmethod 33 | def get_layers_for_scaling(module: OldMptBlock, input_feat, module_kwargs): 34 | layers = [] 35 | 36 | if module_kwargs.get("output_attentions") is not None: 37 | module_kwargs.pop("output_attentions") 38 | 39 | # attention input 40 | layers.append( 41 | dict( 42 | prev_op=module.norm_1, 43 | layers=[module.attn.Wqkv], 44 | inp=input_feat["attn.Wqkv"], 45 | module2inspect=module.attn, 46 | kwargs=module_kwargs, 47 | ) 48 | ) 49 | 50 | # attention output 51 | layers.append( 52 | dict( 53 | prev_op=module.attn.Wqkv, 54 | layers=[module.attn.out_proj], 55 | inp=input_feat["attn.out_proj"], 56 | ) 57 | ) 58 | 59 | # linear 1 60 | layers.append( 61 | dict( 62 | prev_op=module.norm_2, 63 | layers=[module.ffn.up_proj], 64 | inp=input_feat["ffn.up_proj"], 65 | module2inspect=module.ffn, 66 | ) 67 | ) 68 | 69 | # linear 2 70 | layers.append( 71 | dict( 72 | prev_op=module.ffn.act, 73 | layers=[module.ffn.down_proj], 74 | inp=input_feat["ffn.down_proj"], 75 | ) 76 | ) 77 | 78 | return layers 79 | 80 | 81 | from typing import List, Tuple 82 | from awq.utils.utils import set_module_name 83 | from awq.modules.fused.block import MPTBlock 84 | from awq.modules.fused.model import MPTModel 85 | 86 | 87 | class MptFuser: 88 | def __init__(self, model: MptForCausalLM): 89 | self.model = model 90 | 91 | self.mpt_blocks: List[Tuple[str, OldMptBlock]] = [ 92 | (name, module) 93 | for name, module in self.model.named_modules() 94 | if "mptblock" in module.__class__.__name__.lower() 95 | ] 96 | 97 | def fuse_transformer(self): 98 | blocks = [] 99 | 100 | module: OldMptBlock 101 | for module in self.model.transformer.blocks: 102 | blocks.append( 103 | MPTBlock( 104 | self.model.config.d_model, 105 | self.model.config.n_heads, 106 | module.attn.Wqkv, 107 | module.attn.out_proj, 108 | module.ffn, 109 | module.norm_1, 110 | module.norm_2, 111 | next(iter(module.state_dict().values())).device, 112 | self.model.config.max_seq_len, 113 | ) 114 | ) 115 | 116 | self.model.transformer = MPTModel( 117 | self.model.config.vocab_size, 118 | blocks, 119 | self.model.transformer.wte, 120 | self.model.transformer.norm_f, 121 | ) 122 | 123 | setattr(self.model.transformer, "blocks", self.model.transformer.blocks) 124 | -------------------------------------------------------------------------------- /awq/models/opt.py: -------------------------------------------------------------------------------- 1 | from .base import BaseAWQForCausalLM 2 | from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTDecoderLayer 3 | 4 | 5 | class OptAWQForCausalLM(BaseAWQForCausalLM): 6 | layer_type = "OPTDecoderLayer" 7 | max_seq_len_key = "max_position_embeddings" 8 | 9 | @staticmethod 10 | def get_model_layers(model: OPTForCausalLM): 11 | return model.model.decoder.layers 12 | 13 | @staticmethod 14 | def get_act_for_scaling(module: OPTDecoderLayer): 15 | return dict(is_scalable=False) 16 | 17 | @staticmethod 18 | def move_embed(model: OPTForCausalLM, device: str): 19 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device) 20 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to( 21 | device 22 | ) 23 | 24 | @staticmethod 25 | def get_layers_for_scaling(module: OPTDecoderLayer, input_feat, module_kwargs): 26 | layers = [] 27 | 28 | # attention input 29 | layers.append( 30 | dict( 31 | prev_op=module.self_attn_layer_norm, 32 | layers=[ 33 | module.self_attn.q_proj, 34 | module.self_attn.k_proj, 35 | module.self_attn.v_proj, 36 | ], 37 | inp=input_feat["self_attn.q_proj"], 38 | module2inspect=module.self_attn, 39 | kwargs=module_kwargs, 40 | ) 41 | ) 42 | 43 | # attention out 44 | layers.append( 45 | dict( 46 | prev_op=module.self_attn.v_proj, 47 | layers=[module.self_attn.out_proj], 48 | inp=input_feat["self_attn.out_proj"], 49 | ) 50 | ) 51 | 52 | # linear 1 53 | layers.append( 54 | dict( 55 | prev_op=module.final_layer_norm, 56 | layers=[module.fc1], 57 | inp=input_feat["fc1"], 58 | ) 59 | ) 60 | 61 | # linear 2 62 | layers.append( 63 | dict( 64 | prev_op=module.fc1, 65 | layers=[module.fc2], 66 | inp=input_feat["fc2"], 67 | ) 68 | ) 69 | 70 | return layers 71 | -------------------------------------------------------------------------------- /awq/models/phi3.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from typing import List, Tuple 3 | from .base import BaseAWQForCausalLM 4 | from awq.utils.fused_utils import fuse_qkv 5 | from awq.modules.fused.block import Phi3Block 6 | from awq.modules.fused.model import Phi3Model as AWQPhi3Model 7 | from transformers.models.phi3.modeling_phi3 import ( 8 | Phi3DecoderLayer as OldPhi3DecoderLayer, 9 | Phi3ForCausalLM as OldPhi3ForCausalLM, 10 | ) 11 | from awq.modules.fused.norm import FasterTransformerRMSNorm 12 | 13 | 14 | class Phi3AWQForCausalLM(BaseAWQForCausalLM): 15 | layer_type = "Phi3DecoderLayer" 16 | max_seq_len_key = "max_position_embeddings" 17 | 18 | @staticmethod 19 | def fuse_layers(model: OldPhi3ForCausalLM): 20 | fuser = Phi3Fuser(model) 21 | fuser.fuse_transformer() 22 | 23 | @staticmethod 24 | def get_model_layers(model: OldPhi3ForCausalLM): 25 | return model.model.layers 26 | 27 | @staticmethod 28 | def get_act_for_scaling(module: OldPhi3DecoderLayer): 29 | return dict(is_scalable=False) 30 | 31 | @staticmethod 32 | def move_embed(model: OldPhi3ForCausalLM, device: str): 33 | model.model.embed_tokens = model.model.embed_tokens.to(device) 34 | 35 | @staticmethod 36 | def get_layers_for_scaling(module: OldPhi3DecoderLayer, input_feat, module_kwargs): 37 | layers = [] 38 | 39 | # attention input 40 | layers.append( 41 | dict( 42 | prev_op=module.input_layernorm, 43 | layers=[module.self_attn.qkv_proj], 44 | inp=input_feat["self_attn.qkv_proj"], 45 | module2inspect=module.self_attn, 46 | kwargs=module_kwargs, 47 | ) 48 | ) 49 | 50 | # attention out 51 | layers.append( 52 | dict( 53 | prev_op=module.self_attn.qkv_proj, 54 | layers=[module.self_attn.o_proj], 55 | inp=input_feat["self_attn.o_proj"], 56 | ) 57 | ) 58 | 59 | # linear 1 60 | layers.append( 61 | dict( 62 | prev_op=module.post_attention_layernorm, 63 | layers=[module.mlp.gate_up_proj], 64 | inp=input_feat["mlp.gate_up_proj"], 65 | module2inspect=module.mlp, 66 | ) 67 | ) 68 | 69 | # linear 2 70 | layers.append( 71 | dict( 72 | prev_op=module.mlp.gate_up_proj, 73 | layers=[module.mlp.down_proj], 74 | inp=input_feat["mlp.down_proj"], 75 | ) 76 | ) 77 | 78 | return layers 79 | 80 | 81 | class Phi3Fuser: 82 | def __init__(self, model: OldPhi3ForCausalLM): 83 | self.model = model 84 | 85 | self.phi3_blocks: List[Tuple[str, OldPhi3DecoderLayer]] = [ 86 | (name, module) 87 | for name, module in self.model.named_modules() 88 | if "Phi3DecoderLayer".lower() in module.__class__.__name__.lower() 89 | ] 90 | 91 | def fuse_transformer(self): 92 | blocks = [] 93 | 94 | module: OldPhi3DecoderLayer 95 | for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): 96 | device = next(iter(module.state_dict().values())).device 97 | qkv = module.self_attn.qkv_proj 98 | norm_1 = FasterTransformerRMSNorm( 99 | module.input_layernorm.weight, module.input_layernorm.variance_epsilon 100 | ) 101 | norm_2 = FasterTransformerRMSNorm( 102 | module.post_attention_layernorm.weight, 103 | module.post_attention_layernorm.variance_epsilon, 104 | ) 105 | blocks.append( 106 | Phi3Block( 107 | hidden_size=self.model.config.hidden_size, 108 | n_heads=self.model.config.num_attention_heads, 109 | n_kv_heads=self.model.config.num_key_value_heads, 110 | qkv_layer=qkv, 111 | o_proj=module.self_attn.o_proj, 112 | mlp=module.mlp, 113 | norm_1=norm_1, 114 | norm_2=norm_2, 115 | dev=device, 116 | max_seq_len=self.model.config.max_position_embeddings, 117 | rope_theta=self.model.config.rope_theta, 118 | rope_scaling=self.model.config.rope_scaling, 119 | ) 120 | ) 121 | 122 | self.model.model = AWQPhi3Model( 123 | self.model.config.vocab_size, 124 | blocks, 125 | self.model.model.embed_tokens, 126 | self.model.model.norm, 127 | ) 128 | setattr(self.model.model, "blocks", self.model.model.blocks) -------------------------------------------------------------------------------- /awq/models/phi3_v.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from typing import List, Tuple 3 | from .base import BaseAWQForCausalLM 4 | from awq.modules.fused.block import Phi3Block 5 | from awq.modules.fused.model import Phi3Model as AWQPhi3Model 6 | from transformers.models.phi3.modeling_phi3 import ( 7 | Phi3DecoderLayer as OldPhi3DecoderLayer 8 | ) 9 | from awq.modules.fused.norm import FasterTransformerRMSNorm 10 | 11 | 12 | class Phi3VAWQForCausalLM(BaseAWQForCausalLM): 13 | layer_type = "Phi3DecoderLayer" 14 | max_seq_len_key = "max_position_embeddings" 15 | modules_to_not_convert = ["vision_embed_tokens"] 16 | 17 | @staticmethod 18 | def get_model_layers(model): 19 | return model.model.layers 20 | 21 | @staticmethod 22 | def get_act_for_scaling(module: OldPhi3DecoderLayer): 23 | return dict(is_scalable=False) 24 | 25 | @staticmethod 26 | def move_embed(model, device: str): 27 | model.model.embed_tokens = model.model.embed_tokens.to(device) 28 | 29 | @staticmethod 30 | def get_layers_for_scaling(module: OldPhi3DecoderLayer, input_feat, module_kwargs): 31 | layers = [] 32 | 33 | # attention input 34 | layers.append( 35 | dict( 36 | prev_op=module.input_layernorm, 37 | layers=[module.self_attn.qkv_proj], 38 | inp=input_feat["self_attn.qkv_proj"], 39 | module2inspect=module.self_attn, 40 | kwargs=module_kwargs, 41 | ) 42 | ) 43 | 44 | # attention out 45 | layers.append( 46 | dict( 47 | prev_op=module.self_attn.qkv_proj, 48 | layers=[module.self_attn.o_proj], 49 | inp=input_feat["self_attn.o_proj"], 50 | ) 51 | ) 52 | 53 | # linear 1 54 | layers.append( 55 | dict( 56 | prev_op=module.post_attention_layernorm, 57 | layers=[module.mlp.gate_up_proj], 58 | inp=input_feat["mlp.gate_up_proj"], 59 | module2inspect=module.mlp, 60 | ) 61 | ) 62 | 63 | # linear 2 64 | layers.append( 65 | dict( 66 | prev_op=module.mlp.gate_up_proj, 67 | layers=[module.mlp.down_proj], 68 | inp=input_feat["mlp.down_proj"], 69 | ) 70 | ) 71 | 72 | return layers 73 | -------------------------------------------------------------------------------- /awq/models/qwen.py: -------------------------------------------------------------------------------- 1 | from .base import BaseAWQForCausalLM 2 | 3 | 4 | class QwenAWQForCausalLM(BaseAWQForCausalLM): 5 | layer_type = "QWenBlock" 6 | max_seq_len_key = "seq_length" 7 | 8 | @staticmethod 9 | def get_model_layers(model): 10 | return model.transformer.h 11 | 12 | @staticmethod 13 | def get_act_for_scaling(module): 14 | return dict(is_scalable=False) 15 | 16 | @staticmethod 17 | def move_embed(model, device: str): 18 | model.transformer.wte = model.transformer.wte.to(device) 19 | model.transformer.rotary_emb = model.transformer.rotary_emb.to(device) 20 | 21 | @staticmethod 22 | def get_layers_for_scaling(module, input_feat, module_kwargs): 23 | layers = [] 24 | 25 | # attention 26 | layers.append( 27 | dict( 28 | prev_op=module.ln_1, 29 | layers=[module.attn.c_attn], 30 | inp=input_feat["attn.c_attn"], 31 | module2inspect=module.attn, 32 | kwargs=module_kwargs, 33 | ) 34 | ) 35 | 36 | # mlp 37 | layers.append( 38 | dict( 39 | prev_op=module.ln_2, 40 | layers=[module.mlp.w2, module.mlp.w1], 41 | inp=input_feat["mlp.w2"], 42 | module2inspect=module.mlp, 43 | ) 44 | ) 45 | 46 | # linear 2 47 | layers.append( 48 | dict( 49 | prev_op=module.mlp.w1, 50 | layers=[module.mlp.c_proj], 51 | inp=input_feat["mlp.c_proj"], 52 | ) 53 | ) 54 | 55 | return layers 56 | -------------------------------------------------------------------------------- /awq/models/qwen2.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from typing import List, Tuple 3 | from .base import BaseAWQForCausalLM 4 | from awq.utils.fused_utils import fuse_qkv 5 | from awq.modules.fused.block import LlamaLikeBlock 6 | from awq.modules.fused.model import LlamaLikeModel 7 | from transformers.models.qwen2.modeling_qwen2 import ( 8 | Qwen2DecoderLayer as OldQwen2DecoderLayer, 9 | Qwen2ForCausalLM as OldQwen2ForCausalLM, 10 | ) 11 | from awq.modules.fused.norm import FasterTransformerRMSNorm 12 | 13 | 14 | class Qwen2AWQForCausalLM(BaseAWQForCausalLM): 15 | layer_type = "Qwen2DecoderLayer" 16 | max_seq_len_key = "max_position_embeddings" 17 | 18 | @staticmethod 19 | def fuse_layers(model: OldQwen2ForCausalLM): 20 | fuser = Qwen2Fuser(model) 21 | fuser.fuse_transformer() 22 | 23 | @staticmethod 24 | def get_model_layers(model: OldQwen2ForCausalLM): 25 | return model.model.layers 26 | 27 | @staticmethod 28 | def get_act_for_scaling(module: OldQwen2DecoderLayer): 29 | return dict(is_scalable=False) 30 | 31 | @staticmethod 32 | def move_embed(model: OldQwen2ForCausalLM, device: str): 33 | model.model.embed_tokens = model.model.embed_tokens.to(device) 34 | model.model.rotary_emb = model.model.rotary_emb.to(device) 35 | 36 | @staticmethod 37 | def get_layers_for_scaling(module: OldQwen2DecoderLayer, input_feat, module_kwargs): 38 | layers = [] 39 | 40 | # attention input 41 | layers.append( 42 | dict( 43 | prev_op=module.input_layernorm, 44 | layers=[ 45 | module.self_attn.q_proj, 46 | module.self_attn.k_proj, 47 | module.self_attn.v_proj, 48 | ], 49 | inp=input_feat["self_attn.q_proj"], 50 | module2inspect=module.self_attn, 51 | kwargs=module_kwargs, 52 | ) 53 | ) 54 | 55 | # attention out 56 | # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 57 | if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: 58 | layers.append( 59 | dict( 60 | prev_op=module.self_attn.v_proj, 61 | layers=[module.self_attn.o_proj], 62 | inp=input_feat["self_attn.o_proj"], 63 | ) 64 | ) 65 | 66 | # linear 1 67 | layers.append( 68 | dict( 69 | prev_op=module.post_attention_layernorm, 70 | layers=[module.mlp.gate_proj, module.mlp.up_proj], 71 | inp=input_feat["mlp.gate_proj"], 72 | module2inspect=module.mlp, 73 | ) 74 | ) 75 | 76 | # linear 2 77 | layers.append( 78 | dict( 79 | prev_op=module.mlp.up_proj, 80 | layers=[module.mlp.down_proj], 81 | inp=input_feat["mlp.down_proj"], 82 | ) 83 | ) 84 | 85 | return layers 86 | 87 | 88 | class Qwen2Fuser: 89 | def __init__(self, model: OldQwen2ForCausalLM): 90 | self.model = model 91 | 92 | self.qwen2_blocks: List[Tuple[str, OldQwen2DecoderLayer]] = [ 93 | (name, module) 94 | for name, module in self.model.named_modules() 95 | if "Qwen2DecoderLayer".lower() in module.__class__.__name__.lower() 96 | ] 97 | 98 | def fuse_transformer(self): 99 | blocks = [] 100 | 101 | module: OldQwen2DecoderLayer 102 | for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): 103 | device = next(iter(module.state_dict().values())).device 104 | qkv = fuse_qkv( 105 | module, 106 | module.self_attn.q_proj, 107 | module.self_attn.k_proj, 108 | module.self_attn.v_proj, 109 | ) 110 | norm_1 = FasterTransformerRMSNorm( 111 | module.input_layernorm.weight, module.input_layernorm.variance_epsilon 112 | ) 113 | norm_2 = FasterTransformerRMSNorm( 114 | module.post_attention_layernorm.weight, 115 | module.post_attention_layernorm.variance_epsilon, 116 | ) 117 | blocks.append( 118 | LlamaLikeBlock( 119 | hidden_size=self.model.config.hidden_size, 120 | n_heads=self.model.config.num_attention_heads, 121 | n_kv_heads=self.model.config.num_key_value_heads, 122 | qkv_layer=qkv, 123 | o_proj=module.self_attn.o_proj, 124 | mlp=module.mlp, 125 | norm_1=norm_1, 126 | norm_2=norm_2, 127 | dev=device, 128 | max_seq_len=self.model.config.max_seq_len, 129 | rope_theta=self.model.config.rope_theta, 130 | ) 131 | ) 132 | 133 | self.model.model = LlamaLikeModel( 134 | self.model.config.vocab_size, 135 | blocks, 136 | self.model.model.embed_tokens, 137 | self.model.model.norm, 138 | ) 139 | setattr(self.model.model, "blocks", self.model.model.blocks) 140 | -------------------------------------------------------------------------------- /awq/models/qwen2_5_omni.py: -------------------------------------------------------------------------------- 1 | from .base import BaseAWQForCausalLM 2 | from typing_extensions import TYPE_CHECKING 3 | 4 | if TYPE_CHECKING: 5 | from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( 6 | Qwen2_5OmniDecoderLayer 7 | ) 8 | from transformers import Qwen2_5OmniForConditionalGeneration 9 | 10 | 11 | class Qwen2_5_OmniAWQForConditionalGeneration(BaseAWQForCausalLM): 12 | layer_type = "Qwen2_5OmniDecoderLayer" 13 | max_seq_len_key = "max_position_embeddings" 14 | modules_to_not_convert = ["visual"] 15 | @staticmethod 16 | def get_model_layers(model: "Qwen2_5OmniForConditionalGeneration"): 17 | return model.thinker.model.layers 18 | 19 | @staticmethod 20 | def get_act_for_scaling(module: "Qwen2_5OmniForConditionalGeneration"): 21 | return dict(is_scalable=False) 22 | 23 | @staticmethod 24 | def move_embed(model: "Qwen2_5OmniForConditionalGeneration", device: str): 25 | model.thinker.model.embed_tokens = model.thinker.model.embed_tokens.to(device) 26 | model.thinker.visual = model.thinker.visual.to(device) 27 | model.thinker.audio_tower = model.thinker.audio_tower.to(device) 28 | 29 | model.thinker.visual.rotary_pos_emb = model.thinker.visual.rotary_pos_emb.to(device) 30 | model.thinker.model.rotary_emb = model.thinker.model.rotary_emb.to(device) 31 | 32 | for layer in model.thinker.model.layers: 33 | layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(device) 34 | 35 | @staticmethod 36 | def get_layers_for_scaling( 37 | module: "Qwen2_5OmniDecoderLayer", input_feat, module_kwargs 38 | ): 39 | layers = [] 40 | 41 | # attention input 42 | layers.append( 43 | dict( 44 | prev_op=module.input_layernorm, 45 | layers=[ 46 | module.self_attn.q_proj, 47 | module.self_attn.k_proj, 48 | module.self_attn.v_proj, 49 | ], 50 | inp=input_feat["self_attn.q_proj"], 51 | module2inspect=module.self_attn, 52 | kwargs=module_kwargs, 53 | ) 54 | ) 55 | 56 | # attention out 57 | # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 58 | if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: 59 | layers.append( 60 | dict( 61 | prev_op=module.self_attn.v_proj, 62 | layers=[module.self_attn.o_proj], 63 | inp=input_feat["self_attn.o_proj"], 64 | ) 65 | ) 66 | 67 | # linear 1 68 | layers.append( 69 | dict( 70 | prev_op=module.post_attention_layernorm, 71 | layers=[module.mlp.gate_proj, module.mlp.up_proj], 72 | inp=input_feat["mlp.gate_proj"], 73 | module2inspect=module.mlp, 74 | ) 75 | ) 76 | 77 | # linear 2 78 | layers.append( 79 | dict( 80 | prev_op=module.mlp.up_proj, 81 | layers=[module.mlp.down_proj], 82 | inp=input_feat["mlp.down_proj"], 83 | ) 84 | ) 85 | 86 | return layers 87 | -------------------------------------------------------------------------------- /awq/models/qwen2_5_vl.py: -------------------------------------------------------------------------------- 1 | from .base import BaseAWQForCausalLM 2 | from typing_extensions import TYPE_CHECKING 3 | 4 | if TYPE_CHECKING: 5 | from transformers import Qwen2_5_VLForConditionalGeneration 6 | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( 7 | Qwen2_5_VLDecoderLayer, 8 | ) 9 | 10 | 11 | class Qwen2_5_VLAWQForCausalLM(BaseAWQForCausalLM): 12 | layer_type = "Qwen2_5_VLDecoderLayer" 13 | max_seq_len_key = "max_position_embeddings" 14 | modules_to_not_convert = ["visual"] 15 | 16 | @staticmethod 17 | def get_model_layers(model: "Qwen2_5_VLForConditionalGeneration"): 18 | return model.model.layers 19 | 20 | @staticmethod 21 | def get_act_for_scaling(module: "Qwen2_5_VLForConditionalGeneration"): 22 | return dict(is_scalable=False) 23 | 24 | @staticmethod 25 | def move_embed(model: "Qwen2_5_VLForConditionalGeneration", device: str): 26 | model.model.embed_tokens = model.model.embed_tokens.to(device) 27 | model.visual = model.visual.to(device) 28 | model.model.rotary_emb = model.model.rotary_emb.to(device) 29 | 30 | @staticmethod 31 | def get_layers_for_scaling( 32 | module: "Qwen2_5_VLDecoderLayer", input_feat, module_kwargs 33 | ): 34 | layers = [] 35 | 36 | # attention input 37 | layers.append( 38 | dict( 39 | prev_op=module.input_layernorm, 40 | layers=[ 41 | module.self_attn.q_proj, 42 | module.self_attn.k_proj, 43 | module.self_attn.v_proj, 44 | ], 45 | inp=input_feat["self_attn.q_proj"], 46 | module2inspect=module.self_attn, 47 | kwargs=module_kwargs, 48 | ) 49 | ) 50 | 51 | # attention out 52 | # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 53 | if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: 54 | layers.append( 55 | dict( 56 | prev_op=module.self_attn.v_proj, 57 | layers=[module.self_attn.o_proj], 58 | inp=input_feat["self_attn.o_proj"], 59 | ) 60 | ) 61 | 62 | # linear 1 63 | layers.append( 64 | dict( 65 | prev_op=module.post_attention_layernorm, 66 | layers=[module.mlp.gate_proj, module.mlp.up_proj], 67 | inp=input_feat["mlp.gate_proj"], 68 | module2inspect=module.mlp, 69 | ) 70 | ) 71 | 72 | # linear 2 73 | layers.append( 74 | dict( 75 | prev_op=module.mlp.up_proj, 76 | layers=[module.mlp.down_proj], 77 | inp=input_feat["mlp.down_proj"], 78 | ) 79 | ) 80 | 81 | return layers 82 | -------------------------------------------------------------------------------- /awq/models/qwen2vl.py: -------------------------------------------------------------------------------- 1 | from .base import BaseAWQForCausalLM 2 | from typing_extensions import TYPE_CHECKING 3 | 4 | if TYPE_CHECKING: 5 | from transformers import Qwen2VLForConditionalGeneration 6 | from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLDecoderLayer 7 | 8 | class Qwen2VLAWQForCausalLM(BaseAWQForCausalLM): 9 | layer_type = "Qwen2VLDecoderLayer" 10 | max_seq_len_key = "max_position_embeddings" 11 | modules_to_not_convert = ["visual"] 12 | 13 | @staticmethod 14 | def get_model_layers(model: "Qwen2VLForConditionalGeneration"): 15 | return model.model.layers 16 | 17 | @staticmethod 18 | def get_act_for_scaling(module: "Qwen2VLForConditionalGeneration"): 19 | return dict(is_scalable=False) 20 | 21 | @staticmethod 22 | def move_embed(model: "Qwen2VLForConditionalGeneration", device: str): 23 | model.model.embed_tokens = model.model.embed_tokens.to(device) 24 | model.visual = model.visual.to(device) 25 | model.model.rotary_emb = model.model.rotary_emb.to(device) 26 | 27 | @staticmethod 28 | def get_layers_for_scaling(module: "Qwen2VLDecoderLayer", input_feat, module_kwargs): 29 | layers = [] 30 | 31 | # attention input 32 | layers.append( 33 | dict( 34 | prev_op=module.input_layernorm, 35 | layers=[ 36 | module.self_attn.q_proj, 37 | module.self_attn.k_proj, 38 | module.self_attn.v_proj, 39 | ], 40 | inp=input_feat["self_attn.q_proj"], 41 | module2inspect=module.self_attn, 42 | kwargs=module_kwargs, 43 | ) 44 | ) 45 | 46 | # attention out 47 | # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 48 | if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: 49 | layers.append( 50 | dict( 51 | prev_op=module.self_attn.v_proj, 52 | layers=[module.self_attn.o_proj], 53 | inp=input_feat["self_attn.o_proj"], 54 | ) 55 | ) 56 | 57 | # linear 1 58 | layers.append( 59 | dict( 60 | prev_op=module.post_attention_layernorm, 61 | layers=[module.mlp.gate_proj, module.mlp.up_proj], 62 | inp=input_feat["mlp.gate_proj"], 63 | module2inspect=module.mlp, 64 | ) 65 | ) 66 | 67 | # linear 2 68 | layers.append( 69 | dict( 70 | prev_op=module.mlp.up_proj, 71 | layers=[module.mlp.down_proj], 72 | inp=input_feat["mlp.down_proj"], 73 | ) 74 | ) 75 | 76 | return layers -------------------------------------------------------------------------------- /awq/models/qwen3.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from typing import List, Tuple 3 | from .base import BaseAWQForCausalLM 4 | from transformers.models.qwen3.modeling_qwen3 import ( 5 | Qwen3DecoderLayer as OldQwen3DecoderLayer, 6 | Qwen3ForCausalLM as OldQwen3ForCausalLM, 7 | ) 8 | from awq.utils.fused_utils import fuse_qkv 9 | from awq.modules.fused.block import QwenBlock 10 | from awq.modules.fused.model import LlamaLikeModel 11 | from awq.modules.fused.norm import FasterTransformerRMSNorm 12 | 13 | 14 | class Qwen3AWQForCausalLM(BaseAWQForCausalLM): 15 | layer_type = "Qwen3DecoderLayer" 16 | max_seq_len_key = "max_position_embeddings" 17 | 18 | @staticmethod 19 | def fuse_layers(model: OldQwen3ForCausalLM): 20 | fuser = Qwen3Fuser(model) 21 | fuser.fuse_transformer() 22 | 23 | @staticmethod 24 | def get_model_layers(model: OldQwen3ForCausalLM): 25 | return model.model.layers 26 | 27 | @staticmethod 28 | def get_act_for_scaling(module: OldQwen3DecoderLayer): 29 | return dict(is_scalable=False) 30 | 31 | @staticmethod 32 | def move_embed(model: OldQwen3ForCausalLM, device: str): 33 | model.model.embed_tokens = model.model.embed_tokens.to(device) 34 | model.model.rotary_emb = model.model.rotary_emb.to(device) 35 | 36 | @staticmethod 37 | def get_layers_for_scaling(module: OldQwen3DecoderLayer, input_feat, module_kwargs): 38 | layers = [] 39 | 40 | # attention input 41 | layers.append( 42 | dict( 43 | prev_op=module.input_layernorm, 44 | layers=[ 45 | module.self_attn.q_proj, 46 | module.self_attn.k_proj, 47 | module.self_attn.v_proj, 48 | ], 49 | inp=input_feat["self_attn.q_proj"], 50 | module2inspect=module.self_attn, 51 | kwargs=module_kwargs, 52 | ) 53 | ) 54 | 55 | # attention out 56 | # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 57 | if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: 58 | layers.append( 59 | dict( 60 | prev_op=module.self_attn.v_proj, 61 | layers=[module.self_attn.o_proj], 62 | inp=input_feat["self_attn.o_proj"], 63 | ) 64 | ) 65 | 66 | # linear 1 67 | layers.append( 68 | dict( 69 | prev_op=module.post_attention_layernorm, 70 | layers=[module.mlp.gate_proj, module.mlp.up_proj], 71 | inp=input_feat["mlp.gate_proj"], 72 | module2inspect=module.mlp, 73 | ) 74 | ) 75 | 76 | # linear 2 77 | layers.append( 78 | dict( 79 | prev_op=module.mlp.up_proj, 80 | layers=[module.mlp.down_proj], 81 | inp=input_feat["mlp.down_proj"], 82 | ) 83 | ) 84 | 85 | return layers 86 | 87 | class Qwen3Fuser: 88 | def __init__(self, model: OldQwen3ForCausalLM): 89 | self.model = model 90 | 91 | self.qwen3_blocks: List[Tuple[str, OldQwen3DecoderLayer]] = [ 92 | (name, module) 93 | for name, module in self.model.named_modules() 94 | if "Qwen3DecoderLayer".lower() in module.__class__.__name__.lower() 95 | ] 96 | 97 | def fuse_transformer(self): 98 | blocks = [] 99 | 100 | module: OldQwen3DecoderLayer 101 | for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): 102 | device = next(iter(module.state_dict().values())).device 103 | qkv = fuse_qkv( 104 | module, 105 | module.self_attn.q_proj, 106 | module.self_attn.k_proj, 107 | module.self_attn.v_proj, 108 | ) 109 | norm_1 = FasterTransformerRMSNorm( 110 | module.input_layernorm.weight, module.input_layernorm.variance_epsilon 111 | ) 112 | norm_2 = FasterTransformerRMSNorm( 113 | module.post_attention_layernorm.weight, 114 | module.post_attention_layernorm.variance_epsilon, 115 | ) 116 | blocks.append( 117 | QwenBlock( 118 | hidden_size=self.model.config.hidden_size, 119 | n_heads=self.model.config.num_attention_heads, 120 | n_kv_heads=self.model.config.num_key_value_heads, 121 | qkv_layer=qkv, 122 | o_proj=module.self_attn.o_proj, 123 | mlp=module.mlp, 124 | norm_1=norm_1, 125 | norm_2=norm_2, 126 | dev=device, 127 | max_seq_len=self.model.config.max_seq_len, 128 | rope_theta=self.model.config.rope_theta, 129 | q_norm=module.self_attn.q_norm, 130 | k_norm=module.self_attn.k_norm, 131 | head_dim=self.model.config.head_dim, 132 | ) 133 | ) 134 | 135 | self.model.model = LlamaLikeModel( 136 | self.model.config.vocab_size, 137 | blocks, 138 | self.model.model.embed_tokens, 139 | self.model.model.norm, 140 | ) 141 | setattr(self.model.model, "blocks", self.model.model.blocks) 142 | 143 | -------------------------------------------------------------------------------- /awq/models/qwen3_moe.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from typing import List, Tuple 3 | from .base import BaseAWQForCausalLM 4 | 5 | 6 | class Qwen3MoeAWQForCausalLM(BaseAWQForCausalLM): 7 | layer_type = "Qwen3MoeDecoderLayer" 8 | max_seq_len_key = "max_position_embeddings" 9 | 10 | @staticmethod 11 | def get_model_layers(model): 12 | return model.model.layers 13 | 14 | @staticmethod 15 | def get_act_for_scaling(module): 16 | return dict(is_scalable=False) 17 | 18 | @staticmethod 19 | def move_embed(model, device: str): 20 | model.model.embed_tokens = model.model.embed_tokens.to(device) 21 | model.model.rotary_emb = model.model.rotary_emb.to(device) 22 | 23 | @staticmethod 24 | def get_layers_for_scaling(module, input_feat, module_kwargs): 25 | layers = [] 26 | 27 | # attention input 28 | layers.append( 29 | dict( 30 | prev_op=module.input_layernorm, 31 | layers=[ 32 | module.self_attn.q_proj, 33 | module.self_attn.k_proj, 34 | module.self_attn.v_proj, 35 | ], 36 | inp=input_feat["self_attn.q_proj"], 37 | module2inspect=module.self_attn, 38 | kwargs=module_kwargs, 39 | ) 40 | ) 41 | 42 | # attention out 43 | # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 44 | if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: 45 | layers.append( 46 | dict( 47 | prev_op=module.self_attn.v_proj, 48 | layers=[module.self_attn.o_proj], 49 | inp=input_feat["self_attn.o_proj"], 50 | ) 51 | ) 52 | 53 | if hasattr(module.mlp, "gate"): 54 | # linear in 55 | layers.append( 56 | dict( 57 | prev_op=module.post_attention_layernorm, 58 | layers=[ 59 | w 60 | for expert in module.mlp.experts 61 | for w in [expert.gate_proj, expert.up_proj] 62 | ], 63 | inp=input_feat["mlp"], 64 | module2inspect=module.mlp, 65 | ) 66 | ) 67 | 68 | # linear out 69 | for i, expert in enumerate(module.mlp.experts): 70 | layers.append( 71 | dict( 72 | prev_op=expert.up_proj, 73 | layers=[expert.down_proj], 74 | inp=input_feat[f"mlp.experts.{i}.down_proj"], 75 | ) 76 | ) 77 | 78 | else: 79 | # linear 1 80 | layers.append( 81 | dict( 82 | prev_op=module.post_attention_layernorm, 83 | layers=[module.mlp.gate_proj, module.mlp.up_proj], 84 | inp=input_feat["mlp.gate_proj"], 85 | module2inspect=module.mlp, 86 | ) 87 | ) 88 | 89 | # linear 2 90 | layers.append( 91 | dict( 92 | prev_op=module.mlp.up_proj, 93 | layers=[module.mlp.down_proj], 94 | inp=input_feat["mlp.down_proj"], 95 | ) 96 | ) 97 | 98 | return layers 99 | 100 | 101 | -------------------------------------------------------------------------------- /awq/models/stablelm.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from typing import List, Tuple 3 | from .base import BaseAWQForCausalLM 4 | from awq.utils.fused_utils import fuse_qkv 5 | from awq.modules.fused.block import LlamaLikeBlock 6 | from awq.modules.fused.model import LlamaLikeModel 7 | from transformers.models.stablelm import StableLmForCausalLM as OldStableLmForCausalLM 8 | from transformers.models.stablelm.modeling_stablelm import ( 9 | StableLmDecoderLayer as OldStableLmDecoderLayer, 10 | ) 11 | from awq.modules.fused.norm import FasterTransformerRMSNorm 12 | 13 | 14 | class StableLmAWQForCausalLM(BaseAWQForCausalLM): 15 | layer_type = "StableLmDecoderLayer" 16 | max_seq_len_key = "max_position_embeddings" 17 | 18 | @staticmethod 19 | def fuse_layers(model: OldStableLmForCausalLM): 20 | fuser = StableLmFuser(model) 21 | fuser.fuse_transformer() 22 | 23 | @staticmethod 24 | def get_model_layers(model: OldStableLmForCausalLM): 25 | return model.model.layers 26 | 27 | @staticmethod 28 | def get_act_for_scaling(module: OldStableLmForCausalLM): 29 | return dict(is_scalable=False) 30 | 31 | @staticmethod 32 | def move_embed(model: OldStableLmForCausalLM, device: str): 33 | model.model.embed_tokens = model.model.embed_tokens.to(device) 34 | model.model.rotary_emb = model.model.rotary_emb.to(device) 35 | 36 | @staticmethod 37 | def get_layers_for_scaling( 38 | module: OldStableLmDecoderLayer, input_feat, module_kwargs 39 | ): 40 | layers = [] 41 | 42 | # attention input 43 | layers.append( 44 | dict( 45 | prev_op=module.input_layernorm, 46 | layers=[ 47 | module.self_attn.q_proj, 48 | module.self_attn.k_proj, 49 | module.self_attn.v_proj, 50 | ], 51 | inp=input_feat["self_attn.q_proj"], 52 | module2inspect=module.self_attn, 53 | kwargs=module_kwargs, 54 | ) 55 | ) 56 | 57 | # attention out 58 | # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 59 | if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: 60 | layers.append( 61 | dict( 62 | prev_op=module.self_attn.v_proj, 63 | layers=[module.self_attn.o_proj], 64 | inp=input_feat["self_attn.o_proj"], 65 | ) 66 | ) 67 | 68 | # linear 1 69 | layers.append( 70 | dict( 71 | prev_op=module.post_attention_layernorm, 72 | layers=[module.mlp.gate_proj, module.mlp.up_proj], 73 | inp=input_feat["mlp.gate_proj"], 74 | module2inspect=module.mlp, 75 | ) 76 | ) 77 | 78 | # linear 2 79 | layers.append( 80 | dict( 81 | prev_op=module.mlp.up_proj, 82 | layers=[module.mlp.down_proj], 83 | inp=input_feat["mlp.down_proj"], 84 | ) 85 | ) 86 | 87 | return layers 88 | 89 | 90 | class StableLmFuser: 91 | def __init__(self, model: OldStableLmForCausalLM): 92 | self.model = model 93 | 94 | self.stablelm_blocks: List[Tuple[str, OldStableLmDecoderLayer]] = [ 95 | (name, module) 96 | for name, module in self.model.named_modules() 97 | if "StableLmDecoderLayer".lower() in module.__class__.__name__.lower() 98 | ] 99 | 100 | def fuse_transformer(self): 101 | blocks = [] 102 | 103 | module: OldStableLmDecoderLayer 104 | for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): 105 | device = next(iter(module.state_dict().values())).device 106 | qkv = fuse_qkv( 107 | module, 108 | module.self_attn.q_proj, 109 | module.self_attn.k_proj, 110 | module.self_attn.v_proj, 111 | ) 112 | norm_1 = module.input_layernorm 113 | norm_2 = module.post_attention_layernorm 114 | blocks.append( 115 | LlamaLikeBlock( 116 | hidden_size=self.model.config.hidden_size, 117 | n_heads=self.model.config.num_attention_heads, 118 | n_kv_heads=self.model.config.num_key_value_heads, 119 | qkv_layer=qkv, 120 | o_proj=module.self_attn.o_proj, 121 | mlp=module.mlp, 122 | norm_1=norm_1, 123 | norm_2=norm_2, 124 | dev=device, 125 | max_seq_len=self.model.config.max_seq_len, 126 | rope_theta=self.model.config.rope_theta, 127 | partial_rotary_factor=self.model.config.partial_rotary_factor, 128 | ) 129 | ) 130 | 131 | self.model.model = LlamaLikeModel( 132 | self.model.config.vocab_size, 133 | blocks, 134 | self.model.model.embed_tokens, 135 | self.model.model.norm, 136 | ) 137 | setattr(self.model.model, "blocks", self.model.model.blocks) 138 | -------------------------------------------------------------------------------- /awq/models/starcoder2.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from typing import List, Tuple 3 | from .base import BaseAWQForCausalLM 4 | from awq.utils.fused_utils import fuse_qkv 5 | from awq.modules.fused.block import LlamaLikeBlock 6 | from awq.modules.fused.model import LlamaLikeModel 7 | from transformers.models.starcoder2.modeling_starcoder2 import ( 8 | Starcoder2ForCausalLM as OldStarcoder2ForCausalLM, 9 | Starcoder2DecoderLayer as OldStarcoder2DecoderLayer, 10 | ) 11 | from awq.modules.fused.norm import FasterTransformerRMSNorm 12 | 13 | 14 | class Starcoder2AWQForCausalLM(BaseAWQForCausalLM): 15 | layer_type = "Starcoder2DecoderLayer" 16 | max_seq_len_key = "max_position_embeddings" 17 | 18 | @staticmethod 19 | def fuse_layers(model: OldStarcoder2ForCausalLM): 20 | fuser = Starcoder2Fuser(model) 21 | fuser.fuse_transformer() 22 | 23 | @staticmethod 24 | def get_model_layers(model: OldStarcoder2ForCausalLM): 25 | return model.model.layers 26 | 27 | @staticmethod 28 | def get_act_for_scaling(module: OldStarcoder2DecoderLayer): 29 | return dict( 30 | is_scalable=True, 31 | scale_name="mlp.act", 32 | scale_layer=module.mlp.act, 33 | scale_shape=module.mlp.c_fc.out_features, 34 | ) 35 | # return dict(is_scalable=False) 36 | 37 | @staticmethod 38 | def move_embed(model: OldStarcoder2ForCausalLM, device): 39 | model.model.embed_tokens = model.model.embed_tokens.to(device) 40 | model.model.rotary_emb = model.model.rotary_emb.to(device) 41 | 42 | @staticmethod 43 | def get_layers_for_scaling(module: OldStarcoder2DecoderLayer, input_feat, module_kwargs): 44 | layers = [] 45 | 46 | # attention input 47 | layers.append( 48 | dict( 49 | prev_op=module.input_layernorm, 50 | layers=[ 51 | module.self_attn.q_proj, 52 | module.self_attn.k_proj, 53 | module.self_attn.v_proj, 54 | ], 55 | inp=input_feat["self_attn.q_proj"], 56 | module2inspect=module.self_attn, 57 | kwargs=module_kwargs, 58 | ) 59 | ) 60 | 61 | # attention out 62 | if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: 63 | layers.append( 64 | dict( 65 | prev_op=module.self_attn.v_proj, 66 | layers=[module.self_attn.o_proj], 67 | inp=input_feat["self_attn.o_proj"], 68 | ) 69 | ) 70 | 71 | # linear 1 72 | layers.append( 73 | dict( 74 | prev_op=module.post_attention_layernorm, 75 | layers=[module.mlp.c_fc], 76 | inp=input_feat["mlp.c_fc"], 77 | module2inspect=module.mlp, 78 | ) 79 | ) 80 | 81 | # linear 2 82 | layers.append( 83 | dict( 84 | prev_op=module.mlp.act, 85 | layers=[module.mlp.c_proj], 86 | inp=input_feat["mlp.c_proj"], 87 | ) 88 | ) 89 | 90 | return layers 91 | 92 | class Starcoder2Fuser: 93 | def __init__(self, model: OldStarcoder2ForCausalLM): 94 | self.model = model 95 | 96 | self.starcoder2_blocks: List[Tuple[str, OldStarcoder2DecoderLayer]] = [ 97 | (name, module) 98 | for name, module in self.model.named_modules() 99 | if "Starcoder2DecoderLayer".lower() in module.__class__.__name__.lower() 100 | ] 101 | 102 | def fuse_transformer(self): 103 | blocks = [] 104 | 105 | module: OldStarcoder2DecoderLayer 106 | for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): 107 | device = next(iter(module.state_dict().values())).device 108 | qkv = fuse_qkv( 109 | module, 110 | module.self_attn.q_proj, 111 | module.self_attn.k_proj, 112 | module.self_attn.v_proj, 113 | ) 114 | # SC2 use normal LayerNorm 115 | norm_1 = module.input_layernorm 116 | norm_2 = module.post_attention_layernorm 117 | blocks.append( 118 | LlamaLikeBlock( 119 | hidden_size=self.model.config.hidden_size, 120 | n_heads=self.model.config.num_attention_heads, 121 | n_kv_heads=self.model.config.num_key_value_heads, 122 | qkv_layer=qkv, 123 | o_proj=module.self_attn.o_proj, 124 | mlp=module.mlp, 125 | norm_1=norm_1, 126 | norm_2=norm_2, 127 | dev=device, 128 | max_seq_len=self.model.config.max_seq_len, 129 | rope_theta=self.model.config.rope_theta, 130 | ) 131 | ) 132 | 133 | self.model.model = LlamaLikeModel( 134 | self.model.config.vocab_size, 135 | blocks, 136 | self.model.model.embed_tokens, 137 | self.model.model.norm, 138 | ) 139 | setattr(self.model.model, "blocks", self.model.model.blocks) -------------------------------------------------------------------------------- /awq/models/yi.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from typing import List, Tuple 3 | from .base import BaseAWQForCausalLM 4 | from awq.utils.fused_utils import fuse_qkv 5 | from awq.modules.fused.block import LlamaLikeBlock 6 | from awq.modules.fused.model import LlamaLikeModel 7 | from awq.modules.fused.norm import FasterTransformerRMSNorm 8 | 9 | 10 | class YiAWQForCausalLM(BaseAWQForCausalLM): 11 | layer_type = "YiDecoderLayer" 12 | max_seq_len_key = "max_position_embeddings" 13 | 14 | @staticmethod 15 | def fuse_layers(model): 16 | fuser = YiFuser(model) 17 | fuser.fuse_transformer() 18 | 19 | @staticmethod 20 | def get_model_layers(model): 21 | return model.model.layers 22 | 23 | @staticmethod 24 | def get_act_for_scaling(module): 25 | return dict(is_scalable=False) 26 | 27 | @staticmethod 28 | def move_embed(model, device: str): 29 | model.model.embed_tokens = model.model.embed_tokens.to(device) 30 | model.model.rotary_emb = model.model.rotary_emb.to(device) 31 | 32 | @staticmethod 33 | def get_layers_for_scaling(module, input_feat, module_kwargs): 34 | layers = [] 35 | 36 | # attention input 37 | layers.append( 38 | dict( 39 | prev_op=module.ln1, 40 | layers=[ 41 | module.self_attn.q_proj, 42 | module.self_attn.k_proj, 43 | module.self_attn.v_proj, 44 | ], 45 | inp=input_feat["self_attn.q_proj"], 46 | module2inspect=module.self_attn, 47 | kwargs=module_kwargs, 48 | ) 49 | ) 50 | 51 | # attention out 52 | # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 53 | if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: 54 | layers.append( 55 | dict( 56 | prev_op=module.self_attn.v_proj, 57 | layers=[module.self_attn.o_proj], 58 | inp=input_feat["self_attn.o_proj"], 59 | ) 60 | ) 61 | 62 | # linear 1 63 | layers.append( 64 | dict( 65 | prev_op=module.ln2, 66 | layers=[module.mlp.gate_proj, module.mlp.up_proj], 67 | inp=input_feat["mlp.gate_proj"], 68 | module2inspect=module.mlp, 69 | ) 70 | ) 71 | 72 | # linear 2 73 | layers.append( 74 | dict( 75 | prev_op=module.mlp.up_proj, 76 | layers=[module.mlp.down_proj], 77 | inp=input_feat["mlp.down_proj"], 78 | ) 79 | ) 80 | 81 | return layers 82 | 83 | 84 | class YiFuser: 85 | def __init__(self, model): 86 | self.model = model 87 | 88 | self.yi_blocks: List[Tuple[str, object]] = [ 89 | (name, module) 90 | for name, module in self.model.named_modules() 91 | if "YiDecoderLayer".lower() in module.__class__.__name__.lower() 92 | ] 93 | 94 | def fuse_transformer(self): 95 | blocks = [] 96 | 97 | for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): 98 | device = next(iter(module.state_dict().values())).device 99 | qkv = fuse_qkv( 100 | module, 101 | module.self_attn.q_proj, 102 | module.self_attn.k_proj, 103 | module.self_attn.v_proj, 104 | ) 105 | norm_1 = FasterTransformerRMSNorm( 106 | module.ln1.weight, module.ln1.variance_epsilon 107 | ) 108 | norm_2 = FasterTransformerRMSNorm( 109 | module.ln2.weight, module.ln2.variance_epsilon 110 | ) 111 | blocks.append( 112 | LlamaLikeBlock( 113 | hidden_size=self.model.config.hidden_size, 114 | n_heads=self.model.config.num_attention_heads, 115 | n_kv_heads=self.model.config.num_key_value_heads, 116 | qkv_layer=qkv, 117 | o_proj=module.self_attn.o_proj, 118 | mlp=module.mlp, 119 | norm_1=norm_1, 120 | norm_2=norm_2, 121 | dev=device, 122 | max_seq_len=self.model.config.max_seq_len, 123 | rope_theta=self.model.config.rope_theta, 124 | ) 125 | ) 126 | 127 | self.model.model = LlamaLikeModel( 128 | self.model.config.vocab_size, 129 | blocks, 130 | self.model.model.embed_tokens, 131 | self.model.model.norm, 132 | ) 133 | setattr(self.model.model, "blocks", self.model.model.blocks) 134 | -------------------------------------------------------------------------------- /awq/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/casper-hansen/AutoAWQ/88e4c76b20755db275574e6a03c83c84ba3bece5/awq/modules/__init__.py -------------------------------------------------------------------------------- /awq/modules/act.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class ScaledActivation(nn.Module): 5 | def __init__(self, module, scales): 6 | super().__init__() 7 | self.act = module 8 | self.scales = nn.Parameter(scales.data) 9 | 10 | def forward(self, x): 11 | return self.act(x) / self.scales.view(1, 1, -1).to(x.device) 12 | -------------------------------------------------------------------------------- /awq/modules/fused/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/casper-hansen/AutoAWQ/88e4c76b20755db275574e6a03c83c84ba3bece5/awq/modules/fused/__init__.py -------------------------------------------------------------------------------- /awq/modules/fused/cache.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class WindowedCache: 5 | def __init__( 6 | self, cache_batch_size, n_heads, n_kv_heads, head_dim, max_seq_len, device 7 | ): 8 | """ 9 | The window size is the same as the max_seq_len. The window will 10 | automatically roll once max_seq_len is exceeded. 11 | """ 12 | size = ( 13 | cache_batch_size, 14 | max_seq_len, 15 | n_kv_heads if n_kv_heads != 0 else n_heads, 16 | head_dim, 17 | ) 18 | self.v = torch.zeros( 19 | size, 20 | device=device, 21 | dtype=torch.float16, 22 | ) 23 | self.k = torch.zeros( 24 | size, 25 | device=device, 26 | dtype=torch.float16, 27 | ) 28 | self.max_seq_len = max_seq_len 29 | 30 | def get_kv(self, batch_size, start_pos, seqlen): 31 | """ 32 | Gets the key-value store in correct shapes. 33 | NOTE: This function is a legacy function. It is only available to showcase 34 | how to accurately retrieve the KV-cache but is not currently used. 35 | """ 36 | xv = self.v[:batch_size, : start_pos + seqlen] 37 | xk = self.k[:batch_size, : start_pos + seqlen] 38 | 39 | return xv, xk 40 | 41 | def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen): 42 | """ 43 | Updates the values in the key-value store. 44 | """ 45 | self.v[:batch_size, start_pos : start_pos + seqlen, :, :] = values_store 46 | self.k[:batch_size, start_pos : start_pos + seqlen, :, :] = keys_store 47 | 48 | def roll_kv_n_steps(self, start_pos, n=100): 49 | """ 50 | Roll cache n to the left. 51 | """ 52 | n = min(n, self.max_seq_len) 53 | # Roll cache to the left 54 | self.v = torch.roll(self.v, shifts=-n, dims=2) 55 | self.k = torch.roll(self.k, shifts=-n, dims=2) 56 | 57 | # Zero out the new part 58 | self.v[:, :, -n:, :] = 0 59 | self.k[:, :, -n:, :] = 0 60 | 61 | return start_pos - n 62 | 63 | def to(self, device): 64 | self.k = self.k.to(device) 65 | self.v = self.v.to(device) 66 | 67 | def increase_batch_size(self, to_bsz): 68 | """Dynamically allocate new kv when batch size changes.""" 69 | self.v = torch.zeros( 70 | to_bsz, *self.v.shape[1:], dtype=self.v.dtype, device=self.v.device 71 | ) 72 | self.k = torch.zeros( 73 | to_bsz, *self.k.shape[1:], dtype=self.k.dtype, device=self.k.device 74 | ) 75 | 76 | def decrease_batch_size(self, to_bsz): 77 | """Dynamically remove part of cache if batch size changes.""" 78 | self.v = self.v[:to_bsz, :, :, :] 79 | self.k = self.k[:to_bsz, :, :, :] 80 | -------------------------------------------------------------------------------- /awq/modules/fused/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from awq.modules.linear.gemm import WQLinear_GEMM 4 | from awq.modules.linear.gemv import WQLinear_GEMV 5 | 6 | try: 7 | import awq_ext # with CUDA kernels 8 | 9 | AWQ_INSTALLED = True 10 | except: 11 | AWQ_INSTALLED = False 12 | 13 | 14 | class QuantFusedMLP(nn.Module): 15 | def __init__( 16 | self, 17 | gate_proj, 18 | down_proj, 19 | up_proj, 20 | activation=F.silu, 21 | ): 22 | super().__init__() 23 | 24 | self.register_buffer("gate_proj_qweight", gate_proj.qweight) 25 | self.register_buffer("gate_proj_scales", gate_proj.scales) 26 | self.register_buffer("gate_proj_qzeros", gate_proj.qzeros) 27 | self.register_buffer("up_proj_qweight", up_proj.qweight) 28 | self.register_buffer("up_proj_scales", up_proj.scales) 29 | self.register_buffer("up_proj_qzeros", up_proj.qzeros) 30 | 31 | self.in_features = gate_proj.in_features 32 | self.intermediate_size = gate_proj.out_features 33 | self.out_features = down_proj.out_features 34 | self.w_bit = gate_proj.w_bit 35 | self.down_proj = down_proj 36 | 37 | if isinstance(down_proj, WQLinear_GEMV): 38 | self.linear = awq_ext.gemv_forward_cuda 39 | self.group_size = down_proj.group_size 40 | else: 41 | self.linear = awq_ext.gemm_forward_cuda 42 | self.group_size = 8 43 | 44 | self.activation = activation 45 | 46 | def forward(self, x, routing_weights=None): 47 | out_shape = x.shape[:-1] + (self.intermediate_size,) 48 | x = x.reshape(-1, x.shape[-1]) 49 | gate_output = self.linear( 50 | x, 51 | self.gate_proj_qweight, 52 | self.gate_proj_scales, 53 | self.gate_proj_qzeros, 54 | self.group_size, 55 | ) 56 | up_output = self.linear( 57 | x, 58 | self.up_proj_qweight, 59 | self.up_proj_scales, 60 | self.up_proj_qzeros, 61 | self.group_size, 62 | ) 63 | x = self.activation(gate_output) * up_output 64 | x = x.reshape(out_shape) 65 | x = self.down_proj(x) 66 | 67 | if routing_weights is not None: 68 | x = routing_weights * x 69 | 70 | return x 71 | 72 | 73 | class QuantLlamaMLP(QuantFusedMLP): 74 | r""" 75 | QuantLlamaMLP class kept for backward compatibilty, in the future, users 76 | should always use `QuantFusedMLP` class instead. 77 | """ 78 | 79 | def __init__(self, gate_proj, down_proj, up_proj): 80 | super().__init__(gate_proj, down_proj, up_proj) 81 | -------------------------------------------------------------------------------- /awq/modules/fused/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | try: 5 | import awq_ext # with CUDA kernels 6 | 7 | AWQ_INSTALLED = True 8 | except: 9 | AWQ_INSTALLED = False 10 | 11 | try: 12 | import intel_extension_for_pytorch as ipex # with IPEX kernels 13 | 14 | IPEX_INSTALLED = True 15 | except: 16 | IPEX_INSTALLED = False 17 | 18 | 19 | class FasterTransformerRMSNorm(nn.Module): 20 | def __init__(self, weight, eps=1e-6): 21 | super().__init__() 22 | self.weight = weight 23 | self.variance_epsilon = eps 24 | 25 | def forward(self, x): 26 | if IPEX_INSTALLED: 27 | output = ipex.llm.functional.rms_norm(x, self.weight, self.variance_epsilon) 28 | else: 29 | assert AWQ_INSTALLED, ( 30 | "AWQ kernels could not be loaded. " 31 | "Please install them from https://github.com/casper-hansen/AutoAWQ_kernels" 32 | ) 33 | output = torch.empty_like(x) 34 | awq_ext.layernorm_forward_cuda( 35 | x, self.weight, output, self.variance_epsilon 36 | ) 37 | 38 | return output 39 | -------------------------------------------------------------------------------- /awq/modules/linear/__init__.py: -------------------------------------------------------------------------------- 1 | from .exllama import WQLinear_Exllama, exllama_post_init 2 | from .exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init 3 | from .gemm import WQLinear_GEMM 4 | from .gemm_ipex import WQLinear_IPEX, ipex_post_init 5 | from .gemv import WQLinear_GEMV 6 | from .marlin import WQLinear_Marlin, marlin_post_init 7 | from .gemv_fast import WQLinear_GEMVFast 8 | -------------------------------------------------------------------------------- /awq/modules/linear/exllama.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from awq.utils.module import try_import 4 | from awq.utils.packing_utils import unpack_reorder_pack 5 | 6 | exl_ext, msg = try_import("exl_ext") 7 | 8 | # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension 9 | none_tensor = torch.empty((1, 1), device="meta") 10 | 11 | 12 | class WQLinear_Exllama(nn.Module): 13 | def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): 14 | super().__init__() 15 | 16 | if w_bit not in [4]: 17 | raise NotImplementedError("Only 4-bit are supported for Exllama kernels") 18 | 19 | self.q4 = None 20 | 21 | self.w_bit = w_bit 22 | self.in_features = in_features 23 | self.out_features = out_features 24 | self.group_size = group_size if group_size != -1 else in_features 25 | 26 | ################################################################################## 27 | ## These shapes are only for compatibility with the state_dict of WQLinear_GEMM ## 28 | self.register_buffer( 29 | "qweight", 30 | torch.zeros( 31 | (in_features, out_features // (32 // self.w_bit)), 32 | dtype=torch.int32, 33 | device=dev, 34 | ), 35 | ) 36 | self.register_buffer( 37 | "qzeros", 38 | torch.zeros( 39 | (in_features // self.group_size, out_features // (32 // self.w_bit)), 40 | dtype=torch.int32, 41 | device=dev, 42 | ), 43 | ) 44 | ################################################################################## 45 | 46 | self.register_buffer( 47 | "scales", 48 | torch.zeros( 49 | (in_features // self.group_size, out_features), 50 | dtype=torch.float16, 51 | device=dev, 52 | ), 53 | ) 54 | if bias: 55 | self.register_buffer( 56 | "bias", 57 | torch.zeros( 58 | (out_features), 59 | dtype=torch.float16, 60 | device=dev, 61 | ), 62 | ) 63 | else: 64 | self.bias = None 65 | 66 | def post_init(self): 67 | assert self.qweight.device.type == "cuda" 68 | assert self.qweight.device.index is not None 69 | 70 | self.qweight, self.qzeros = unpack_reorder_pack( 71 | self.qweight, self.qzeros, self.w_bit 72 | ) 73 | self.q4 = exl_ext.make_q4( 74 | self.qweight, 75 | self.qzeros, 76 | self.scales, 77 | none_tensor, # g_idx 78 | self.qweight.device.index, # device index 79 | ) 80 | 81 | @classmethod 82 | def from_linear( 83 | cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None 84 | ): 85 | awq_linear = cls( 86 | w_bit, 87 | group_size, 88 | linear.in_features, 89 | linear.out_features, 90 | linear.bias is not None, 91 | linear.weight.device, 92 | ) 93 | if init_only: # just prepare for loading sd 94 | return awq_linear 95 | 96 | raise NotImplementedError("Only inference is supported for Exllama kernels") 97 | 98 | def forward(self, x): 99 | assert self.q4 is not None, ( 100 | "module.post_init() must be called before module.forward(). " 101 | "Use exllama_post_init() on the whole model." 102 | ) 103 | if exl_ext is None: 104 | raise ModuleNotFoundError("External ExLlama kernels are not properly installed." + msg) 105 | 106 | input_dtype = x.dtype 107 | out_shape = x.shape[:-1] + (self.out_features,) 108 | 109 | if input_dtype != torch.float16: 110 | x = x.to(dtype=torch.float16) 111 | 112 | x = x.view(-1, x.shape[-1]) 113 | 114 | out = torch.empty( 115 | (x.shape[0], self.out_features), 116 | dtype=torch.float16, 117 | device=x.device, 118 | ) 119 | exl_ext.q4_matmul(x, self.q4, out) 120 | 121 | if input_dtype != torch.float16: 122 | out = out.to(dtype=input_dtype) 123 | 124 | if self.bias is not None: 125 | out.add_(self.bias) 126 | 127 | return out.view(out_shape) 128 | 129 | 130 | def exllama_post_init(model): 131 | for _, submodule in model.named_modules(): 132 | if isinstance(submodule, WQLinear_Exllama): 133 | submodule.post_init() 134 | 135 | return model 136 | -------------------------------------------------------------------------------- /awq/modules/linear/gemm_ipex.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .gemm import WQLinear_GEMM 4 | from awq.utils.packing_utils import dequantize_gemm 5 | 6 | try: 7 | from intel_extension_for_pytorch.llm.quantization import IPEXWeightOnlyQuantizedLinear 8 | assert hasattr(IPEXWeightOnlyQuantizedLinear, "from_weight"), "The minimum version for ipex is at least 2.4" 9 | IPEX_INSTALLED = True 10 | except: 11 | IPEX_INSTALLED = False 12 | 13 | 14 | class WQLinear_IPEX(WQLinear_GEMM): 15 | 16 | def __init__(self, w_bit, group_size, in_features, out_features, bias, dev, training=False): 17 | nn.Module.__init__(self) 18 | assert IPEX_INSTALLED, \ 19 | "Please install IPEX package with `pip install intel_extension_for_pytorch`." 20 | assert w_bit == 4, "Only 4 bit are supported for now." 21 | 22 | self.use_bf16 = True # Intel platform support bf16 even without amx. 23 | 24 | self.in_features = in_features 25 | self.out_features = out_features 26 | self.w_bit = w_bit 27 | self.group_size = group_size if group_size != -1 else in_features 28 | self.scale_dtype = torch.float32 29 | self.training = training 30 | 31 | # quick sanity check (make sure aligment) 32 | assert self.in_features % self.group_size == 0 33 | assert out_features % (32 // self.w_bit) == 0 34 | self.pack_num = 32 // self.w_bit 35 | 36 | self.init_ipex = False 37 | 38 | self.register_buffer( 39 | "qzeros", 40 | torch.zeros( 41 | (in_features // self.group_size, out_features // self.pack_num), 42 | dtype=torch.int32, 43 | device=dev, 44 | ), 45 | ) 46 | self.register_buffer( 47 | "scales", 48 | torch.zeros( 49 | (in_features // self.group_size, out_features), 50 | dtype=torch.bfloat16 if self.use_bf16 else torch.float32, 51 | device=dev, 52 | )) 53 | if bias: 54 | self.register_buffer( 55 | "bias", 56 | torch.zeros((out_features), dtype=torch.bfloat16 if self.use_bf16 else torch.float32, device=dev), 57 | ) 58 | else: 59 | self.register_buffer( 60 | "bias", 61 | None, 62 | ) 63 | qweight = torch.zeros((in_features, out_features // self.pack_num), dtype=torch.int32, device=dev) 64 | self.register_buffer("qweight", qweight) 65 | 66 | def post_init(self): 67 | device_type = self.qweight.device.type 68 | if device_type != "meta": 69 | assert device_type in ("cpu", "xpu") 70 | 71 | def init_ipex_linear(self): 72 | if not self.training: 73 | self.ipex_linear = IPEXWeightOnlyQuantizedLinear.from_weight(self.qweight, self.scales, self.qzeros, \ 74 | self.in_features, self.out_features, None, self.bias, \ 75 | self.group_size, None, quant_method=1, dtype=0) 76 | 77 | @classmethod 78 | def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None): 79 | awq_linear = cls( 80 | w_bit, 81 | group_size, 82 | linear.in_features, 83 | linear.out_features, 84 | linear.bias is not None, 85 | linear.weight.device, 86 | ) 87 | if init_only: # just prepare for loading sd 88 | return awq_linear 89 | 90 | raise NotImplementedError("Only inference is supported for IPEX kernels") 91 | 92 | def forward(self, x): 93 | assert IPEX_INSTALLED, ( 94 | "IPEX kernels could not be loaded. " 95 | "Please install with `pip install intel_extension_for_pytorch` and " 96 | "refer to the detial https://github.com/intel/intel-extension-for-pytorch/tree/main") 97 | 98 | if not self.init_ipex: 99 | self.init_ipex_linear() 100 | self.init_ipex = True 101 | 102 | if hasattr(self, "ipex_linear"): 103 | with torch.no_grad(): 104 | outputs = self.ipex_linear(x) 105 | else: 106 | outputs = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.w_bit, self.group_size).to(x.dtype) 107 | outputs = torch.matmul(x, outputs) 108 | 109 | return outputs 110 | 111 | def backward(self, grad_output): 112 | weights = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.w_bit, self.group_size).to(grad_output.dtype) 113 | batch_size = grad_output.shape[0] 114 | grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1)) 115 | 116 | return grad_input, None, None, None, None, None, None, None 117 | 118 | def extra_repr(self) -> str: 119 | return ("in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format( 120 | self.in_features, 121 | self.out_features, 122 | self.bias is not None, 123 | self.w_bit, 124 | self.group_size, 125 | )) 126 | 127 | 128 | def ipex_post_init(model): 129 | for _, submodule in model.named_modules(): 130 | if isinstance(submodule, WQLinear_IPEX): 131 | submodule.post_init() 132 | 133 | return model 134 | -------------------------------------------------------------------------------- /awq/modules/triton/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/casper-hansen/AutoAWQ/88e4c76b20755db275574e6a03c83c84ba3bece5/awq/modules/triton/__init__.py -------------------------------------------------------------------------------- /awq/quantize/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/casper-hansen/AutoAWQ/88e4c76b20755db275574e6a03c83c84ba3bece5/awq/quantize/__init__.py -------------------------------------------------------------------------------- /awq/quantize/scale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Tuple, List 4 | from awq.utils.utils import get_best_device 5 | from awq.modules.act import ScaledActivation 6 | from awq.utils.module import get_op_by_name, set_op_by_name 7 | from transformers.models.bloom.modeling_bloom import BloomGelu 8 | from transformers.models.llama.modeling_llama import LlamaRMSNorm 9 | from transformers.models.gemma.modeling_gemma import GemmaRMSNorm 10 | from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm 11 | from transformers.models.cohere.modeling_cohere import CohereLayerNorm 12 | from transformers.activations import NewGELUActivation, PytorchGELUTanh, GELUActivation 13 | 14 | allowed_norms = [nn.LayerNorm, LlamaRMSNorm, GemmaRMSNorm, Gemma2RMSNorm, CohereLayerNorm] 15 | allowed_act_fns = [ 16 | nn.GELU, 17 | BloomGelu, 18 | NewGELUActivation, 19 | PytorchGELUTanh, 20 | GELUActivation, 21 | ] 22 | 23 | 24 | @torch.no_grad() 25 | def apply_clip(module, clip_list: Tuple[str, torch.Tensor]): 26 | for name, max_val in clip_list: 27 | layer: nn.Linear = get_op_by_name(module, name) 28 | layer.to(get_best_device()) 29 | max_val = max_val.to(layer.weight.device) 30 | org_shape = layer.weight.shape 31 | layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1) 32 | layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val) 33 | layer.weight.data = layer.weight.data.reshape(org_shape) 34 | layer.cpu() 35 | 36 | 37 | def apply_scale(module, scales_list, input_feat_dict=None): 38 | for prev_op_name, layer_names, scales in scales_list: 39 | prev_op = get_op_by_name(module, prev_op_name) 40 | layers = [get_op_by_name(module, name) for name in layer_names] 41 | 42 | best_device = get_best_device() 43 | prev_op.to(best_device) 44 | for layer in layers: 45 | layer.to(best_device) 46 | scales.to(best_device) 47 | 48 | if ( 49 | isinstance(prev_op, nn.Linear) 50 | and type(layers) == list 51 | and isinstance(layers[0], nn.Linear) 52 | ): 53 | scale_fc_fcs(prev_op, layers, scales) 54 | 55 | elif isinstance(prev_op, nn.Linear): 56 | assert len(layers) == 1 57 | scale_fc_fc(prev_op, layers[0], scales) 58 | 59 | elif ( 60 | any(isinstance(prev_op, t) for t in allowed_norms) 61 | or "rmsnorm" in str(prev_op.__class__).lower() 62 | ): 63 | scale_ln_fcs(prev_op, layers, scales) 64 | 65 | elif any(isinstance(prev_op, t) for t in allowed_act_fns): 66 | new_module = ScaledActivation(prev_op, scales) 67 | set_op_by_name(module, prev_op_name, new_module) 68 | scale_gelu_fc(prev_op, layers[0], scales) 69 | 70 | else: 71 | raise NotImplementedError(f"prev_op {type(prev_op)} not supported yet!") 72 | 73 | # apply the scaling to input feat if given; prepare it for clipping 74 | if input_feat_dict is not None: 75 | for layer_name in layer_names: 76 | # Skip the modules that are not quantized 77 | if layer_name in input_feat_dict: 78 | inp = input_feat_dict[layer_name] 79 | inp.div_(scales.view(1, -1).to(inp.device)) 80 | 81 | prev_op.cpu() 82 | for layer in layers: 83 | layer.cpu() 84 | scales.cpu() 85 | 86 | 87 | @torch.no_grad() 88 | def scale_ln_fcs(ln: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor): 89 | if not isinstance(fcs, list): 90 | fcs = [fcs] 91 | 92 | scales = scales.to(ln.weight.device) 93 | 94 | # GemmaRMSNorm is different from Llama's in that it multiplies 95 | # (1 + weight) to the output, instead of just weight. 96 | if isinstance(ln, GemmaRMSNorm) or isinstance(ln, Gemma2RMSNorm): 97 | ln.weight += 1 98 | ln.weight.div_(scales) 99 | ln.weight -= 1 100 | else: 101 | ln.weight.div_(scales) 102 | 103 | if hasattr(ln, "bias") and ln.bias is not None: 104 | ln.bias.div_(scales) 105 | 106 | for fc in fcs: 107 | fc.weight.mul_(scales.view(1, -1)) 108 | 109 | for p in ln.parameters(): 110 | assert torch.isnan(p).sum() == 0 111 | for fc in fcs: 112 | for p in fc.parameters(): 113 | assert torch.isnan(p).sum() == 0 114 | 115 | 116 | @torch.no_grad() 117 | def scale_fc_fc(fc1: nn.Linear, fc2: nn.Linear, scales: torch.Tensor): 118 | assert isinstance(fc1, nn.Linear) 119 | assert isinstance(fc2, nn.Linear) 120 | 121 | scales = scales.to(fc1.weight.device) 122 | 123 | fc1.weight[-scales.size(0) :].div_(scales.view(-1, 1)) 124 | if fc1.bias is not None: 125 | fc1.bias.div_(scales.view(-1)) 126 | 127 | fc2.weight.mul_(scales.view(1, -1)) 128 | 129 | for p in fc1.parameters(): 130 | assert torch.isnan(p).sum() == 0 131 | for p in fc2.parameters(): 132 | assert torch.isnan(p).sum() == 0 133 | 134 | 135 | @torch.no_grad() 136 | def scale_fc_fcs(fc1: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor): 137 | if not isinstance(fcs, list): 138 | fcs = [fcs] 139 | 140 | scales = scales.to(fc1.weight.device) 141 | 142 | fc1.weight[-scales.size(0) :].div_(scales.view(-1, 1)) 143 | if fc1.bias is not None: 144 | fc1.bias.div_(scales.view(-1)) 145 | 146 | for fc in fcs: 147 | fc.weight.mul_(scales.view(1, -1)) 148 | 149 | for p in fc1.parameters(): 150 | assert torch.isnan(p).sum() == 0 151 | for fc in fcs: 152 | for p in fc.parameters(): 153 | assert torch.isnan(p).sum() == 0 154 | 155 | 156 | @torch.no_grad() 157 | def scale_gelu_fc(gelu: allowed_act_fns, fc: nn.Linear, scales: torch.Tensor): 158 | assert any(isinstance(gelu, t) for t in allowed_act_fns) 159 | assert isinstance(fc, nn.Linear) 160 | 161 | fc.weight.mul_(scales.view(1, -1).to(fc.weight.device)) 162 | 163 | for p in fc.parameters(): 164 | assert torch.isnan(p).sum() == 0 165 | -------------------------------------------------------------------------------- /awq/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/casper-hansen/AutoAWQ/88e4c76b20755db275574e6a03c83c84ba3bece5/awq/utils/__init__.py -------------------------------------------------------------------------------- /awq/utils/calib_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | from typing import List, Union 4 | from datasets import load_dataset 5 | 6 | 7 | def get_calib_dataset( 8 | data: Union[str, List[str], List[List[int]]] = "pileval", 9 | tokenizer=None, 10 | n_samples=128, 11 | max_seq_len=512, 12 | split="train", 13 | text_column="text", 14 | ): 15 | if isinstance(data, str): 16 | if data == "pileval": 17 | dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") 18 | else: 19 | dataset = load_dataset(data, split=split) 20 | 21 | dataset = dataset.shuffle(seed=42) 22 | 23 | elif isinstance(data, list): 24 | if isinstance(data[0], str): 25 | dataset = [{text_column: text} for text in data] 26 | elif isinstance(data[0][0], int): 27 | dataset = data 28 | else: 29 | raise NotImplementedError( 30 | "Either pass a string to a huggingface dataset or a list" 31 | "that is preprocessed with one sample of text per element" 32 | " or a list of list of int for tokenized words." 33 | ) 34 | else: 35 | raise NotImplementedError( 36 | "Either pass a string to a huggingface dataset or a list" 37 | "that is preprocessed with one sample of text per element" 38 | " or a list of list of int for tokenized words." 39 | ) 40 | 41 | samples = [] 42 | n_run = 0 43 | for data in dataset: 44 | if isinstance(data, list): 45 | line_encoded = data 46 | else: 47 | line = data[text_column] 48 | line = line.strip() 49 | line_encoded = tokenizer.encode(line) 50 | if len(line_encoded) > max_seq_len: 51 | continue 52 | sample = torch.tensor([line_encoded]) 53 | if sample.numel() == 0: 54 | continue 55 | samples.append(sample) 56 | n_run += 1 57 | if n_run == n_samples: 58 | break 59 | # now concatenate all samples and split according to max sequence length 60 | cat_samples = torch.cat(samples, dim=1) 61 | n_split = cat_samples.shape[1] // max_seq_len 62 | logging.debug(f" * Split into {n_split} blocks") 63 | return [ 64 | cat_samples[:, i * max_seq_len : (i + 1) * max_seq_len] for i in range(n_split) 65 | ] 66 | -------------------------------------------------------------------------------- /awq/utils/module.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import importlib 3 | 4 | def try_import(module_name): 5 | try: 6 | module = importlib.import_module(module_name) 7 | return module, "" 8 | except Exception as ex: 9 | return None, str(ex) 10 | 11 | def get_named_linears(module): 12 | return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)} 13 | 14 | 15 | def get_op_by_name(module, op_name): 16 | # get the op by its name relative to the module 17 | for name, m in module.named_modules(): 18 | if name == op_name: 19 | return m 20 | raise ValueError(f"Cannot find op {op_name} in module {module}") 21 | 22 | 23 | def set_op_by_name(layer, name, new_module): 24 | levels = name.split(".") 25 | if len(levels) > 1: 26 | mod_ = layer 27 | for l_idx in range(len(levels) - 1): 28 | if levels[l_idx].isdigit(): 29 | mod_ = mod_[int(levels[l_idx])] 30 | else: 31 | mod_ = getattr(mod_, levels[l_idx]) 32 | setattr(mod_, levels[-1], new_module) 33 | else: 34 | setattr(layer, name, new_module) 35 | 36 | 37 | def get_op_name(module, op): 38 | # get the name of the op relative to the module 39 | for name, m in module.named_modules(): 40 | if m is op: 41 | return name 42 | raise ValueError(f"Cannot find op {op} in module {module}") 43 | 44 | 45 | def append_str_prefix(x, prefix): 46 | if isinstance(x, str): 47 | return prefix + x 48 | elif isinstance(x, tuple): 49 | return tuple([append_str_prefix(y, prefix) for y in x]) 50 | elif isinstance(x, list): 51 | return [append_str_prefix(y, prefix) for y in x] 52 | else: 53 | return x 54 | 55 | 56 | def exclude_layers_to_not_quantize(linear_layers, modules_to_not_convert): 57 | if modules_to_not_convert is None: 58 | return linear_layers 59 | 60 | filtered_layers = {} 61 | for name, linear_layer in linear_layers.items(): 62 | if not any(key in name for key in modules_to_not_convert): 63 | filtered_layers[name] = linear_layer 64 | return filtered_layers 65 | -------------------------------------------------------------------------------- /awq/utils/packing_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | AWQ_ORDER = [0, 2, 4, 6, 1, 3, 5, 7] 5 | AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] 6 | 7 | 8 | def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int): 9 | shifts = torch.arange(0, 32, bits, device=qzeros.device) 10 | 11 | # unpacking columnwise 12 | iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to( 13 | torch.int8 # smallest dtype available 14 | ) 15 | iweights = iweights.view(iweights.shape[0], -1) 16 | 17 | # unpacking columnwise 18 | if qzeros is not None: 19 | izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to( 20 | torch.int8 # smallest dtype available 21 | ) 22 | izeros = izeros.view(izeros.shape[0], -1) 23 | else: 24 | izeros = qzeros 25 | 26 | return iweights, izeros 27 | 28 | 29 | def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int): 30 | reverse_order_tensor = torch.arange( 31 | iweights.shape[-1], 32 | dtype=torch.int32, 33 | device=izeros.device, 34 | ) 35 | reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) 36 | reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER] 37 | reverse_order_tensor = reverse_order_tensor.view(-1) 38 | 39 | if izeros is not None: 40 | izeros = izeros[:, reverse_order_tensor] 41 | iweights = iweights[:, reverse_order_tensor] 42 | 43 | return iweights, izeros 44 | 45 | 46 | def pack_exllama(iweights: torch.Tensor, izeros: torch.Tensor, bits: int): 47 | shifts = torch.arange(0, 32, bits, device=iweights.device) 48 | 49 | # packing rowwise 50 | iweights = iweights.view(iweights.shape[0] // (32 // bits), 32 // bits, -1) 51 | qweight = ( 52 | torch.bitwise_left_shift(iweights, shifts[None, :, None]) 53 | .sum(dim=1) 54 | .to(torch.int32) 55 | ) 56 | 57 | # packing columnwise 58 | izeros = izeros.view(-1, izeros.shape[1] // (32 // bits), 32 // bits) 59 | qzeros = ( 60 | torch.bitwise_left_shift(izeros, shifts[None, None, :]) 61 | .sum(dim=-1) 62 | .to(torch.int32) 63 | ) 64 | 65 | return qweight, qzeros 66 | 67 | 68 | def unpack_reorder_pack(qweight, qzeros, bits): 69 | # Unpack the qweight and qzeros tensors 70 | iweight, izeros = unpack_awq(qweight, qzeros, bits) 71 | # Reverse the order of the iweight and izeros tensors 72 | iweight, izeros = reverse_awq_order(iweight, izeros, bits) 73 | 74 | # overflow checks 75 | iweight = torch.bitwise_and(iweight, (2**bits) - 1) 76 | izeros = torch.bitwise_and(izeros, (2**bits) - 1) 77 | 78 | # Subtract 1 from the izeros tensor (exllama adds 1 during inference) 79 | # We can remove it if we remove the +1 in the exllama code 80 | izeros = izeros - 1 81 | # Pack the qweight and qzeros tensors 82 | qweight, qzeros = pack_exllama(iweight, izeros, bits) 83 | 84 | return qweight, qzeros 85 | 86 | 87 | def dequantize_gemm(qweight, qzeros, scales, bits, group_size): 88 | # Unpack the qweight and qzeros tensors 89 | iweight, izeros = unpack_awq(qweight, qzeros, bits) 90 | # Reverse the order of the iweight and izeros tensors 91 | iweight, izeros = reverse_awq_order(iweight, izeros, bits) 92 | 93 | # overflow checks 94 | iweight = torch.bitwise_and(iweight, (2**bits) - 1) 95 | izeros = torch.bitwise_and(izeros, (2**bits) - 1) 96 | 97 | # fp16 weights 98 | scales = scales.repeat_interleave(group_size, dim=0) 99 | izeros = izeros.repeat_interleave(group_size, dim=0) 100 | iweight = (iweight - izeros) * scales 101 | 102 | return iweight 103 | -------------------------------------------------------------------------------- /awq/utils/parallel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import gc 4 | import logging 5 | 6 | 7 | def auto_parallel(args): 8 | model_size = args.model_path.split("-")[-1] 9 | if model_size.endswith("m"): 10 | model_gb = 1 11 | else: 12 | model_gb = float(model_size[:-1]) 13 | if model_gb < 20: 14 | n_gpu = 1 15 | elif model_gb < 50: 16 | n_gpu = 4 17 | else: 18 | n_gpu = 8 19 | args.parallel = n_gpu > 1 20 | cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) 21 | if isinstance(cuda_visible_devices, str): 22 | cuda_visible_devices = cuda_visible_devices.split(",") 23 | else: 24 | cuda_visible_devices = list(range(8)) 25 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( 26 | [str(dev) for dev in cuda_visible_devices[:n_gpu]] 27 | ) 28 | logging.debug("CUDA_VISIBLE_DEVICES: ", os.environ["CUDA_VISIBLE_DEVICES"]) 29 | return cuda_visible_devices 30 | -------------------------------------------------------------------------------- /awq/utils/quant_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List 3 | 4 | 5 | Q_BITS = 4 6 | STORAGE_BITS = 32 7 | PACK_NUM = STORAGE_BITS // Q_BITS 8 | 9 | ORDINAL_PACK_ORDER = [0, 1, 2, 3, 4, 5, 6, 7] 10 | AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7] 11 | REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] 12 | 13 | 14 | def pack(imatrix: torch.Tensor, direction: str = "column"): 15 | """ 16 | Packs a 4-bit integer matrix into a packed 32-bit integer matrix. 17 | Args: 18 | imatrix (torch.Tensor): matrix of integers 19 | direction (str): direction of packing, either "column" or "row" 20 | 21 | Returns: 22 | qmatrix (torch.Tensor): packed matrix of integers 23 | """ 24 | shifts = torch.arange(0, STORAGE_BITS, Q_BITS, device=imatrix.device) 25 | 26 | imatrix = imatrix.to(torch.int8) 27 | imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow 28 | 29 | if direction == "column": 30 | imatrix = imatrix.view(-1, imatrix.shape[1] // PACK_NUM, PACK_NUM) 31 | qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1) 32 | 33 | elif direction == "row": 34 | imatrix = imatrix.view(imatrix.shape[0] // PACK_NUM, PACK_NUM, -1) 35 | qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1) 36 | 37 | qmatrix = qmatrix.to(torch.int32) 38 | 39 | return qmatrix 40 | 41 | 42 | def unpack(qmatrix: torch.Tensor, direction: str = "column"): 43 | """ 44 | Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix. 45 | 46 | Args: 47 | qmatrix (torch.Tensor): matrix of packed integers 48 | direction (str): direction of unpacking, either "column" or "row" 49 | 50 | Returns: 51 | imatrix (torch.Tensor): matrix of integers 52 | """ 53 | shifts = torch.arange(0, STORAGE_BITS, Q_BITS, device=qmatrix.device) 54 | 55 | if direction == "column": 56 | imatrix = torch.bitwise_right_shift( 57 | qmatrix[:, :, None], shifts[None, None, :] 58 | ).view(qmatrix.shape[0], -1) 59 | 60 | elif direction == "row": 61 | imatrix = torch.bitwise_right_shift( 62 | qmatrix[:, None, :], shifts[None, :, None] 63 | ).view(-1, qmatrix.shape[-1]) 64 | 65 | imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow 66 | 67 | return imatrix 68 | 69 | 70 | def quantize(fmatrix, scales, zeros, group_size): 71 | """ 72 | Quantizes a matrix of 16-bit floats into a matrix of 4-bit integers. 73 | 74 | Args: 75 | fmatrix (torch.Tensor): matrix of 16-bit floats 76 | scales (torch.Tensor): matrix of 16-bit floats 77 | zeros (torch.Tensor): matrix of 4-bit integers 78 | group_size (int): group size 79 | 80 | Returns: 81 | imatrix (torch.Tensor): matrix of 4-bit integers 82 | """ 83 | zeros = zeros.to(torch.int8) & 0x0F 84 | 85 | imatrix = torch.round( 86 | ( 87 | fmatrix / scales.repeat_interleave(group_size, dim=0) 88 | + zeros.repeat_interleave(group_size, dim=0) 89 | ) 90 | ) 91 | 92 | imatrix = imatrix.to(torch.int8) & 0x0F 93 | 94 | return imatrix 95 | 96 | 97 | def dequantize(imatrix, scales, zeros, group_size): 98 | """ 99 | Dequantizes a 4-bit integer matrix into a float matrix. 100 | 101 | Args: 102 | imatrix (torch.Tensor): matrix of 4-bit integers 103 | scales (torch.Tensor): matrix of 16-bit floats 104 | zeros (torch.Tensor): matrix of 4-bit integers 105 | group_size (int): group size 106 | 107 | Returns: 108 | fmatrix (torch.Tensor): matrix of 16-bit floats 109 | """ 110 | zeros = zeros.to(torch.int8) & 0x0F 111 | imatrix = imatrix.to(torch.int8) & 0x0F 112 | 113 | fmatrix = ( 114 | imatrix - zeros.repeat_interleave(group_size, dim=0) 115 | ) * scales.repeat_interleave(group_size, dim=0) 116 | 117 | fmatrix = fmatrix.to(torch.float16) 118 | 119 | return fmatrix 120 | 121 | 122 | def apply_order( 123 | imatrix: torch.Tensor, 124 | direction: str = "column", 125 | order: List[int] = ORDINAL_PACK_ORDER, 126 | ): 127 | """ 128 | Applies the order to a 4-bit integer matrix. 129 | 130 | Args: 131 | imatrix (torch.Tensor): matrix of integers 132 | direction (str): direction of applying order, either "column" or "row" 133 | order (List[int]): order to apply, default is ordinal packing order 134 | 135 | Returns: 136 | imatrix (torch.Tensor): matrix of integers 137 | """ 138 | if direction == "column": 139 | imatrix = imatrix.view(-1, PACK_NUM)[:, order].view(imatrix.shape) 140 | elif direction == "row": 141 | imatrix = imatrix.view(PACK_NUM, -1)[order, :].view(imatrix.shape) 142 | 143 | return imatrix 144 | 145 | 146 | def awq_to_exllama(qweight, qzeros): 147 | # awq uses column packing for both weights and zeros 148 | izeros = unpack(qzeros, direction="column") 149 | iweights = unpack(qweight, direction="column") 150 | 151 | # Reverse the order of the iweight and izeros tensors 152 | izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER) 153 | iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER) 154 | # Subtract 1 from the izeros tensor (exllama adds 1 during inference) 155 | izeros = izeros - 1 156 | # exllama uses row packing for weights and column packing for zeros 157 | qzeros = pack(izeros, direction="column") 158 | qweight = pack(iweights, direction="row") 159 | 160 | return qweight, qzeros 161 | -------------------------------------------------------------------------------- /awq/utils/utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import importlib 3 | import torch 4 | import accelerate 5 | 6 | 7 | ipex_available = importlib.util.find_spec("intel_extension_for_pytorch") is not None 8 | try: 9 | import triton as tl 10 | triton_available = True 11 | except ImportError: 12 | triton_available = False 13 | 14 | 15 | 16 | def get_module_by_name_suffix(model, module_name: str): 17 | for name, module in model.named_modules(): 18 | if name.endswith(module_name): 19 | return module 20 | 21 | 22 | def simple_dispatch_model(model, device_map): 23 | from accelerate.hooks import add_hook_to_module, AlignDevicesHook 24 | 25 | if "" in device_map: 26 | d = device_map[""] 27 | model = model.to(torch.device(d)) 28 | model.hf_device_map = device_map 29 | return model 30 | 31 | tied_params = accelerate.utils.modeling.find_tied_parameters(model) 32 | if set(device_map.values()) == {"cpu"} or set(device_map.values()) == { 33 | "cpu", 34 | "disk", 35 | }: 36 | main_device = "cpu" 37 | else: 38 | main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0] 39 | 40 | cpu_offload_group = [(n, d) for n, d in device_map.items() if d == "cpu"] 41 | prev_hook = None 42 | for idx, (n, d) in enumerate(cpu_offload_group): 43 | m = get_module_by_name_suffix(model, n) 44 | _, prev_hook = accelerate.cpu_offload_with_hook( 45 | m, execution_device=main_device, prev_module_hook=prev_hook 46 | ) 47 | # set first cpu offload module's prev_module_hook to the last cpu offload module's hook 48 | if len(cpu_offload_group) > 1: 49 | get_module_by_name_suffix( 50 | model, cpu_offload_group[0][0] 51 | )._hf_hook.prev_module_hook = prev_hook 52 | 53 | for n, d in device_map.items(): 54 | m = get_module_by_name_suffix(model, n) 55 | if d != "cpu": 56 | d = torch.device(d) 57 | hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True) 58 | add_hook_to_module(m, hook) 59 | accelerate.utils.modeling.retie_parameters(model, tied_params) 60 | model.hf_device_map = device_map 61 | 62 | return model 63 | 64 | 65 | def set_module_name(model, name, value): 66 | if "." in name: 67 | parent_name = name.rsplit(".", 1)[0] 68 | child_name = name[len(parent_name) + 1 :] 69 | parent = model.get_submodule(parent_name) 70 | else: 71 | parent_name = "" 72 | parent = model 73 | child_name = name 74 | 75 | setattr(parent, child_name, value) 76 | 77 | 78 | def clear_memory(weight=None): 79 | if weight is not None: 80 | del weight 81 | # gc.collect() 82 | # torch.cuda.empty_cache() 83 | 84 | 85 | def compute_memory_used_pct(device): 86 | memory_used = torch.cuda.max_memory_allocated(device) / (1024**3) 87 | memory_pct = ( 88 | memory_used 89 | / (torch.cuda.get_device_properties(device).total_memory / (1024**3)) 90 | * 100 91 | ) 92 | return memory_pct 93 | 94 | 95 | def get_best_device(): 96 | if torch.backends.mps.is_available(): 97 | return "mps" 98 | elif torch.cuda.is_available(): 99 | return "cuda:0" 100 | elif torch.xpu.is_available(): 101 | return "xpu:0" 102 | else: 103 | return "cpu" 104 | 105 | 106 | def get_lowest_memory_device_index(): 107 | device = None 108 | curr_device_memory_pct = 0 109 | for device_index in range(torch.cuda.device_count()): 110 | device_memory_pct = compute_memory_used_pct(device_index) 111 | if device is None or device_memory_pct < curr_device_memory_pct: 112 | device = device_index 113 | curr_device_memory_pct = device_memory_pct 114 | 115 | return device 116 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # AutoAWQ 2 | 3 | AutoAWQ pushes ease of use and fast inference speed into one package. In the following documentation, 4 | you will learn how to quantize and run inference. 5 | 6 | Example inference speed (RTX 4090, Ryzen 9 7950X, 64 tokens): 7 | 8 | - Vicuna 7B (GEMV kernel): 198.848 tokens/s 9 | - Mistral 7B (GEMM kernel): 156.317 tokens/s 10 | - Mistral 7B (ExLlamaV2 kernel): 188.865 tokens/s 11 | - Mixtral 46.7B (GEMM kernel): 93 tokens/s (2x 4090) 12 | 13 | ## Installation notes 14 | 15 | - Install: `pip install autoawq`. 16 | - Your torch version must match the build version, i.e. you cannot use torch 2.0.1 with a wheel that was built with 2.2.0. 17 | - For AMD GPUs, inference will run through ExLlamaV2 kernels without fused layers. You need to pass the following arguments to run with AMD GPUs: 18 | ```python 19 | model = AutoAWQForCausalLM.from_quantized( 20 | ..., 21 | fuse_layers=False, 22 | use_exllama_v2=True 23 | ) 24 | ``` 25 | - For CPU device, you should install intel_extension_for_pytorch with `pip install intel_extension_for_pytorch`. And the latest version of torch is required since "intel_extension_for_pytorch(IPEX)" was built with the latest version of torch(now IPEX 2.4 was build with torch 2.4). If you build IPEX from source code, then you need to ensure the consistency of the torch version. And you should use "use_ipex=True" for CPU device. 26 | ```python 27 | model = AutoAWQForCausalLM.from_quantized( 28 | ..., 29 | use_ipex=True 30 | ) 31 | ``` 32 | 33 | ## Supported models 34 | 35 | We support modern LLMs. You can find a list of supported Huggingface `model_types` in `awq/models`. -------------------------------------------------------------------------------- /docs/reference/index.md: -------------------------------------------------------------------------------- 1 | # Auto and Base model classes in AutoAWQ 2 | 3 | View the documentation of the main classes of AutoAWQ models below. 4 | 5 | ::: awq.models.auto.AutoAWQForCausalLM 6 | ::: awq.models.base.BaseAWQForCausalLM 7 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # AutoAWQ examples 2 | 3 | Please see the docs for more thorough examples. In this folder, you will only find the 4 | very basic examples of quantization, inference, and training. -------------------------------------------------------------------------------- /examples/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from awq import AutoAWQForCausalLM 3 | from transformers import AutoTokenizer 4 | 5 | def main(): 6 | parser = argparse.ArgumentParser(description="CLI for model quantization and saving") 7 | parser.add_argument("--hf_model_path", type=str, required=True, help="Path to the Hugging Face model") 8 | parser.add_argument("--quant_name", type=str, required=True, help="Name of the quantized model") 9 | parser.add_argument("--local_save_path", type=str, required=True, help="Path to save the quantized model") 10 | 11 | # Quantization config arguments 12 | parser.add_argument("--zero_point", action="store_true", help="Enable zero point for quantization") 13 | parser.add_argument("--no-zero_point", action="store_false", dest="zero_point", help="Disable zero point for quantization") 14 | parser.add_argument("--q_group_size", type=int, default=128, help="Quantization group size") 15 | parser.add_argument("--w_bit", type=int, default=4, help="Weight bit width") 16 | parser.add_argument("--version", type=str, default="GEMM", help="Quantization version") 17 | 18 | # Model config arguments 19 | parser.add_argument("--device_map", type=str, default=None, help="Device map for loading the pretrained model") 20 | 21 | # Quantize parameters 22 | parser.add_argument("--max_calib_samples", type=int, default=128, help="Number of calibration samples.") 23 | parser.add_argument("--max_calib_seq_len", type=int, default=512, help="Calibration sample sequence length.") 24 | 25 | args = parser.parse_args() 26 | 27 | quant_config = { 28 | "zero_point": args.zero_point, 29 | "q_group_size": args.q_group_size, 30 | "w_bit": args.w_bit, 31 | "version": args.version 32 | } 33 | 34 | print(f"Loading model from: {args.hf_model_path}") 35 | model = AutoAWQForCausalLM.from_pretrained( 36 | args.hf_model_path, 37 | device_map=args.device_map, 38 | ) 39 | tokenizer = AutoTokenizer.from_pretrained(args.hf_model_path, trust_remote_code=True) 40 | 41 | print(f"Quantizing model with config: {quant_config}") 42 | model.quantize( 43 | tokenizer, 44 | quant_config=quant_config, 45 | max_calib_samples=args.max_calib_samples, 46 | max_calib_seq_len=args.max_calib_seq_len, 47 | ) 48 | 49 | print(f"Saving quantized model to: {args.local_save_path}") 50 | model.save_quantized(args.local_save_path) 51 | tokenizer.save_pretrained(args.local_save_path) 52 | 53 | print(f"Quantized model '{args.quant_name}' saved successfully.") 54 | 55 | if __name__ == "__main__": 56 | main() -------------------------------------------------------------------------------- /examples/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from lm_eval import evaluator 3 | from awq import AutoAWQForCausalLM 4 | from transformers import AutoTokenizer 5 | from awq.evaluation import ( 6 | evaluate_perplexity, 7 | eval_librispeech, 8 | eval_mmlu, 9 | eval_humaneval, 10 | eval_kl_divergence, 11 | ) 12 | 13 | def run_eval( 14 | model_path, quant_file, device, tasks, task_batch_size, task_n_shot, 15 | task_use_pretrained, pretrained_safetensors 16 | ): 17 | """ 18 | Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness 19 | """ 20 | tasks = tasks.split(',') 21 | 22 | # Load model 23 | if len(tasks) == 1 and tasks[0] != "mmlu" and tasks[0] != "librispeech": 24 | if task_use_pretrained: 25 | model = AutoAWQForCausalLM.from_pretrained(model_path, safetensors=pretrained_safetensors) 26 | else: 27 | model = AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=False) 28 | 29 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 30 | 31 | # Load adapter 32 | if len(tasks) == 1 and tasks[0] == 'wikitext': 33 | evaluate_perplexity(model.model, tokenizer) 34 | 35 | elif len(tasks) == 1 and tasks[0] == 'librispeech': 36 | eval_librispeech(model_path) 37 | 38 | elif len(tasks) == 1 and tasks[0] == 'mmlu': 39 | eval_mmlu(model_path, task_n_shot, task_batch_size, device, task_use_pretrained) 40 | 41 | elif len(tasks) == 1 and tasks[0] == 'humaneval': 42 | eval_humaneval(model, tokenizer) 43 | 44 | elif len(tasks) == 1 and tasks[0] == 'kldiv': 45 | eval_kl_divergence(model.model, model.model, tokenizer, seqlen=1024) 46 | 47 | else: 48 | # Evaluate perplexity of quantized model 49 | results = evaluator.simple_evaluate( 50 | model=model, 51 | tasks=tasks, 52 | batch_size=task_batch_size, 53 | no_cache=True, 54 | num_fewshot=task_n_shot, 55 | ) 56 | 57 | print(evaluator.make_table(results)) 58 | 59 | if __name__ == '__main__': 60 | """ 61 | - Run perplexity of quantized model: 62 | python examples/eval.py --model_path casperhansen/mistral-7b-instruct-v0.1-awq 63 | 64 | - Run perplexity unquantized FP16 model: 65 | python examples/eval.py --use_pretrained --model_path lmsys/vicuna-7b-v1.5 66 | 67 | - Run MMLU of quantized model: 68 | python examples/eval.py --model_path TheBloke/zephyr-7B-beta-AWQ --tasks mmlu --n_shot 1 --batch_size 4 69 | """ 70 | 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument('--model_path', type=str, help='Path to hf model') 73 | parser.add_argument('--quant_file', default='', type=str, help='Path to quantized AWQ model file') 74 | parser.add_argument('--device', type=str, default='cuda:0', help='Device to load model to') 75 | parser.add_argument("--use_pretrained", default=False, action='store_true', 76 | help="Pass '--use_pretrained' to use a pretrained model running FP16") 77 | parser.add_argument("--pretrained_safetensors", default=False, action='store_true', 78 | help="Load safetensors for FP16 model") 79 | parser.add_argument('--tasks', type=str, default='wikitext', help='Tasks to evaluate. ' 80 | 'Separate tasks by comma for multiple tasks.' 81 | 'https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md') 82 | parser.add_argument('--batch_size', type=int, default=1) 83 | parser.add_argument('--n_shot', type=int, default=0) 84 | args = parser.parse_args() 85 | 86 | run_eval( 87 | args.model_path, args.quant_file, args.device, 88 | args.tasks, args.batch_size, args.n_shot, args.use_pretrained, 89 | args.pretrained_safetensors 90 | ) 91 | -------------------------------------------------------------------------------- /examples/generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from awq import AutoAWQForCausalLM 3 | from transformers import AutoTokenizer, TextStreamer 4 | from awq.utils.utils import get_best_device 5 | 6 | device = get_best_device() 7 | model_id = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4" 8 | tokenizer = AutoTokenizer.from_pretrained(model_id) 9 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 10 | 11 | model = AutoAWQForCausalLM.from_quantized( 12 | model_id, 13 | torch_dtype=torch.float16, 14 | low_cpu_mem_usage=True, 15 | device_map="auto", 16 | ) 17 | 18 | prompt = [ 19 | {"role": "system", "content": "You are a helpful assistant, that responds as a pirate."}, 20 | {"role": "user", "content": \ 21 | "You're standing on the surface of the Earth. "\ 22 | "You walk one mile south, one mile west and one mile north. "\ 23 | "You end up exactly where you started. Where are you?"}, 24 | ] 25 | inputs = tokenizer.apply_chat_template( 26 | prompt, 27 | tokenize=True, 28 | add_generation_prompt=True, 29 | return_tensors="pt", 30 | return_dict=True, 31 | ).to(device) 32 | 33 | outputs = model.generate( 34 | **inputs, 35 | do_sample=True, 36 | max_new_tokens=256, 37 | streamer=streamer, 38 | ) 39 | -------------------------------------------------------------------------------- /examples/quantize.py: -------------------------------------------------------------------------------- 1 | from awq import AutoAWQForCausalLM 2 | from transformers import AutoTokenizer 3 | 4 | model_path = 'Qwen/Qwen2.5-14B-Instruct' 5 | quant_path = 'Qwen2.5-14B-Instruct-awq' 6 | quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } 7 | 8 | # Load model 9 | model = AutoAWQForCausalLM.from_pretrained(model_path) 10 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 11 | 12 | # Quantize 13 | model.quantize(tokenizer, quant_config=quant_config) 14 | 15 | # Save quantized model 16 | model.save_quantized(quant_path) 17 | tokenizer.save_pretrained(quant_path) 18 | 19 | print(f'Model is quantized and saved at "{quant_path}"') -------------------------------------------------------------------------------- /examples/train.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | from awq import AutoAWQForCausalLM 3 | from transformers import ( 4 | AutoTokenizer, 5 | TrainingArguments, 6 | Trainer, 7 | DataCollatorForLanguageModeling 8 | ) 9 | from peft import get_peft_model, LoraConfig, TaskType 10 | 11 | def prepare_split(tokenizer): 12 | data = datasets.load_dataset("mhenrichsen/alpaca_2k_test", split="train") 13 | prompt_template = "[INST] {prompt} [/INST] {output}" 14 | 15 | def format_prompt(x): 16 | return prompt_template.format( 17 | prompt=x["instruction"], 18 | output=x["output"] 19 | ) 20 | 21 | data = data.map( 22 | lambda x: {"text": format_prompt(x)}, 23 | ).select_columns(["text"]) 24 | data = data.map(lambda x: tokenizer(x["text"]), batched=True) 25 | 26 | return data 27 | 28 | model_path = "TheBloke/Mistral-7B-v0.1-AWQ" 29 | 30 | # Load model 31 | model = AutoAWQForCausalLM.from_quantized(model_path, fuse_layers=False) 32 | tokenizer = AutoTokenizer.from_pretrained(model_path) 33 | tokenizer.pad_token = tokenizer.eos_token 34 | 35 | # Prepare data 36 | data_train = prepare_split(tokenizer) 37 | 38 | # Config Lora 39 | lora_config = LoraConfig( 40 | r=4, 41 | lora_alpha=8, 42 | lora_dropout=0.5, 43 | bias="none", 44 | task_type=TaskType.CAUSAL_LM, 45 | inference_mode=False 46 | ) 47 | 48 | model = get_peft_model(model.model, lora_config) 49 | 50 | model.print_trainable_parameters() 51 | 52 | training_arguments = TrainingArguments( 53 | output_dir="./output", 54 | per_device_train_batch_size=1, 55 | optim="adamw_torch", 56 | num_train_epochs=1, 57 | learning_rate=1e-4, 58 | evaluation_strategy="no", 59 | save_strategy="epoch", 60 | save_steps=100, 61 | logging_steps=50, 62 | eval_steps=None, 63 | load_best_model_at_end=False 64 | ) 65 | 66 | trainer = Trainer( 67 | model=model, 68 | train_dataset=data_train, 69 | args=training_arguments, 70 | data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), 71 | ) 72 | 73 | trainer.train() 74 | trainer.save_model("output") -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: AutoAWQ 2 | repo_name: casper-hansen/AutoAWQ 3 | repo_url: https://github.com/casper-hansen/AutoAWQ 4 | 5 | nav: 6 | - index.md 7 | - Examples: examples.md 8 | - Reference: 9 | - reference/index.md 10 | 11 | markdown_extensions: 12 | toc: 13 | permalink: true 14 | markdown.extensions.codehilite: 15 | guess_lang: false 16 | admonition: null 17 | codehilite: null 18 | extra: null 19 | pymdownx.superfences: 20 | custom_fences: 21 | - name: mermaid 22 | class: mermaid 23 | format: !!python/name:pymdownx.superfences.fence_code_format '' 24 | pymdownx.tabbed: 25 | alternate_style: true 26 | pymdownx.tilde: null 27 | attr_list: null 28 | md_in_html: null 29 | 30 | plugins: 31 | search: null 32 | mkdocstrings: 33 | handlers: 34 | python: 35 | paths: [.] 36 | options: 37 | extensions: 38 | - griffe_typingdoc 39 | show_root_heading: true 40 | show_if_no_docstring: true 41 | inherited_members: true 42 | members_order: source 43 | separate_signature: true 44 | unwrap_annotated: true 45 | filters: 46 | - '!^_' 47 | merge_init_into_class: true 48 | docstring_section_style: spacy 49 | signature_crossrefs: true 50 | show_symbol_type_heading: true 51 | show_symbol_type_toc: true 52 | 53 | theme: 54 | name: material 55 | palette: 56 | - media: '(prefers-color-scheme: light)' 57 | scheme: default 58 | primary: teal 59 | accent: amber 60 | toggle: 61 | icon: material/lightbulb 62 | name: Switch to dark mode 63 | - media: '(prefers-color-scheme: dark)' 64 | scheme: slate 65 | primary: teal 66 | accent: amber 67 | toggle: 68 | icon: material/lightbulb-outline 69 | name: Switch to light mode 70 | features: 71 | - search.suggest 72 | - search.highlight 73 | - content.tabs.link 74 | - navigation.indexes 75 | - content.tooltips 76 | - navigation.path 77 | - content.code.annotate 78 | - content.code.copy 79 | - content.code.select 80 | - navigation.tabs 81 | icon: 82 | repo: fontawesome/brands/github-alt -------------------------------------------------------------------------------- /scripts/download_wheels.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set variables 4 | AWQ_VERSION="0.2.9" 5 | RELEASE_URL="https://github.com/casper-hansen/AutoAWQ/archive/refs/tags/v${AWQ_VERSION}.tar.gz" 6 | 7 | # Create a directory to download the wheels 8 | mkdir -p dist 9 | 10 | # Download the tar.gz file to dist directory 11 | wget -O "dist/v${AWQ_VERSION}.tar.gz" $RELEASE_URL -------------------------------------------------------------------------------- /scripts/runpod_quantize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import runpod 4 | 5 | # Load environment variables 6 | HF_TOKEN = os.environ.get('HF_TOKEN') 7 | runpod.api_key = os.environ.get('RUNPOD_API_KEY') 8 | 9 | # RunPod Parameters 10 | # get more by running print(runpod.get_gpus()) 11 | template_name = f"AutoAWQ Pod {int(time.time())}" 12 | docker_image = "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04" 13 | gpu_ids = { 14 | "MI300X": "AMD Instinct MI300X OAM", 15 | "B200": "NVIDIA B200", 16 | "H200": "NVIDIA H200", 17 | "H100": "NVIDIA H100 80GB HBM3", 18 | "A100": "NVIDIA A100-SXM4-80GB", 19 | "A6000": "NVIDIA RTX A6000", 20 | "4090": "NVIDIA GeForce RTX 4090", 21 | } 22 | env_variables = { 23 | "HF_TOKEN": HF_TOKEN, 24 | } 25 | gpu_id = gpu_ids["H200"] 26 | num_gpus = 1 27 | system_memory_gb = 100 28 | system_storage_gb = 300 # fp16 model is downloaded here 29 | volume_storage_gb = 100 # quantized model is saved here 30 | 31 | # Quantization Parameters 32 | hf_model_path = "meta-llama/Llama-3.2-3B-Instruct" 33 | quant_name = "Llama-3.2-3B-Instruct-awq".lower() 34 | local_save_path = f"/workspace/{quant_name}" 35 | hf_upload_path = f"casperhansen/{quant_name}" 36 | INSTALL_TRANSFORMERS_MAIN = False 37 | USE_HF_TRANSFER = True 38 | 39 | if USE_HF_TRANSFER: 40 | env_variables["HF_HUB_ENABLE_HF_TRANSFER"] = "1" 41 | 42 | cli_args = dict( 43 | hf_model_path = hf_model_path, 44 | quant_name = quant_name, 45 | local_save_path = local_save_path, 46 | zero_point = True, 47 | q_group_size = 128, 48 | w_bit = 4, 49 | version = "GEMM", 50 | ) 51 | cli_args = " ".join([f"--{k}" if isinstance(v, bool) else f"--{k} {v}" for k,v in cli_args.items()]) 52 | 53 | commands = [ 54 | "cd /workspace", 55 | "pip install requests", 56 | "git clone https://github.com/casper-hansen/AutoAWQ.git", 57 | "cd AutoAWQ", 58 | "pip install -e .", 59 | "pip install -U git+https://github.com/huggingface/transformers.git" if INSTALL_TRANSFORMERS_MAIN else "", 60 | "pip install hf-transfer" if USE_HF_TRANSFER else "", 61 | "huggingface-cli login --token $HF_TOKEN", 62 | f"python examples/cli.py {cli_args}", 63 | f"huggingface-cli upload {hf_upload_path} {local_save_path} ./", 64 | "runpodctl stop pod $RUNPOD_POD_ID", 65 | ] 66 | commands = [cmd for cmd in commands if cmd] 67 | commands = " && ".join(commands) 68 | 69 | docker_command = "bash -c '" + commands + "'" 70 | 71 | template = runpod.create_template( 72 | name=template_name, 73 | image_name=docker_image, 74 | docker_start_cmd=docker_command, 75 | container_disk_in_gb=system_storage_gb, 76 | volume_in_gb=volume_storage_gb, 77 | volume_mount_path="/workspace", 78 | ports="8888/http,22/tcp", 79 | ) 80 | 81 | pod = runpod.create_pod( 82 | name=template_name, 83 | image_name=docker_image, 84 | template_id=template["id"], 85 | gpu_type_id=gpu_id, 86 | gpu_count=num_gpus, 87 | min_memory_in_gb=system_memory_gb, 88 | volume_in_gb=volume_storage_gb, 89 | container_disk_in_gb=system_storage_gb, 90 | env=env_variables, 91 | volume_mount_path="/workspace", 92 | cloud_type="SECURE", 93 | ) 94 | 95 | print(pod) 96 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from setuptools import setup, find_packages 4 | 5 | AUTOAWQ_VERSION = "0.2.9" 6 | 7 | common_setup_kwargs = { 8 | "version": AUTOAWQ_VERSION, 9 | "name": "autoawq", 10 | "author": "Casper Hansen", 11 | "license": "MIT", 12 | "python_requires": ">=3.8.0", 13 | "description": "AutoAWQ implements the AWQ algorithm for 4-bit quantization with a 2x speedup during inference.", 14 | "long_description": (Path(__file__).parent / "README.md").read_text( 15 | encoding="UTF-8" 16 | ), 17 | "long_description_content_type": "text/markdown", 18 | "url": "https://github.com/casper-hansen/AutoAWQ", 19 | "keywords": ["awq", "autoawq", "quantization", "transformers"], 20 | "platforms": ["linux", "windows"], 21 | "classifiers": [ 22 | "Environment :: GPU :: NVIDIA CUDA :: 11.8", 23 | "Environment :: GPU :: NVIDIA CUDA :: 12", 24 | "License :: OSI Approved :: MIT License", 25 | "Natural Language :: English", 26 | "Programming Language :: Python :: 3.9", 27 | "Programming Language :: Python :: 3.10", 28 | "Programming Language :: Python :: 3.11", 29 | "Programming Language :: Python :: 3.12", 30 | "Programming Language :: C++", 31 | ], 32 | } 33 | 34 | requirements = [ 35 | "torch", 36 | "triton", 37 | "transformers>=4.45.0", 38 | "tokenizers>=0.12.1", 39 | "typing_extensions>=4.8.0", 40 | "accelerate", 41 | "datasets>=2.20", 42 | "zstandard", 43 | "huggingface_hub>=0.26.5", 44 | ] 45 | 46 | setup( 47 | packages=find_packages(), 48 | install_requires=requirements, 49 | extras_require={ 50 | "eval": ["lm_eval==0.4.1", "tabulate", "protobuf", "evaluate", "scipy"], 51 | "dev": ["black", "mkdocstrings-python", "mkdocs-material", "griffe-typingdoc"], 52 | "cpu": ["intel-extension-for-pytorch>=2.4.0"], 53 | "kernels": ["autoawq-kernels", "flash-attn>=2.2.0"], 54 | }, 55 | **common_setup_kwargs, 56 | ) 57 | -------------------------------------------------------------------------------- /tests/test_dequantization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | torch.manual_seed(0) 4 | torch.cuda.manual_seed(0) 5 | torch.cuda.manual_seed_all(0) 6 | 7 | import awq_ext 8 | from awq.utils.packing_utils import dequantize_gemm 9 | 10 | in_features = 4096 11 | out_features = 1792 12 | w_bit = 4 13 | group_size = 128 14 | 15 | MAX_INT32 = 0x7fffffff 16 | MIN_INT32 = -MAX_INT32 - 1 17 | 18 | qweight = torch.randint( 19 | MIN_INT32, 20 | MAX_INT32, 21 | (in_features, out_features // (32 // w_bit)), 22 | dtype=torch.int32, 23 | device="cuda", 24 | ) 25 | 26 | qzeros = torch.randint( 27 | MIN_INT32, 28 | MAX_INT32, 29 | (in_features // group_size, out_features // (32 // w_bit)), 30 | dtype=torch.int32, 31 | device="cuda", 32 | ) 33 | 34 | scales = torch.randn( 35 | (in_features // group_size, out_features), 36 | dtype=torch.float16, 37 | device="cuda", 38 | ) 39 | 40 | with torch.no_grad(): 41 | cuda_out = awq_ext.dequantize_weights_cuda( 42 | qweight, 43 | scales, 44 | qzeros, 45 | 0, 46 | 0, 47 | 0, 48 | False 49 | ) 50 | torch_out = dequantize_gemm( 51 | qweight, 52 | qzeros, 53 | scales, 54 | w_bit, 55 | group_size 56 | ) 57 | 58 | assert(torch.allclose(cuda_out, torch_out, rtol=0.0001)) -------------------------------------------------------------------------------- /tests/test_ipex_cpu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from awq.utils.packing_utils import dequantize_gemm 3 | from intel_extension_for_pytorch.nn.modules.weight_only_quantization import WeightOnlyQuantizedLinear 4 | 5 | assert hasattr(WeightOnlyQuantizedLinear, "from_weight"), "The minimum version for ipex is at least 2.4" 6 | torch.manual_seed(0) 7 | 8 | in_features = 256 9 | out_features = 128 10 | w_bit = 4 11 | group_size = 32 12 | torch_dtype = torch.bfloat16 13 | 14 | MAX_INT32 = 0x7fffffff 15 | MIN_INT32 = -MAX_INT32 - 1 16 | 17 | qweight = torch.randint( 18 | MIN_INT32, 19 | MAX_INT32, 20 | (in_features, out_features // (32 // w_bit)), 21 | dtype=torch.int32, 22 | device="cpu", 23 | ) 24 | 25 | qzeros = torch.randint( 26 | MIN_INT32, 27 | MAX_INT32, 28 | (in_features // group_size, out_features // (32 // w_bit)), 29 | dtype=torch.int32, 30 | device="cpu", 31 | ) 32 | 33 | scales = torch.randn( 34 | (in_features // group_size, out_features), 35 | dtype=torch_dtype, 36 | device="cpu", 37 | ) 38 | 39 | with torch.no_grad(): 40 | fp_weight = dequantize_gemm( 41 | qweight, 42 | qzeros, 43 | scales, 44 | w_bit, 45 | group_size 46 | ) 47 | 48 | ipex_linear = WeightOnlyQuantizedLinear.from_weight(qweight, scales, qzeros, \ 49 | in_features, out_features, None, None, \ 50 | group_size, None, 0, 1) 51 | 52 | 53 | input = torch.rand(1, in_features, dtype=torch_dtype) 54 | torch_out = torch.matmul(input, fp_weight) 55 | 56 | ipex_dst = ipex_linear(input) 57 | results = torch.amax(ipex_dst - torch_out) 58 | 59 | assert(torch.allclose(ipex_dst, torch_out, rtol=0.06)) -------------------------------------------------------------------------------- /tests/test_quantization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def test_per_channel_mean(inp, max_chunk_memory=1024*1024*1024, atol=1e-5, rtol=1e-5): 5 | # Original method 6 | x_mean_original = inp.abs().view(-1, inp.shape[-1]).mean(0) 7 | 8 | # New method with chunking 9 | inp_flat = inp.cpu().abs().view(-1, inp.shape[-1]) 10 | num_elements = inp_flat.size(0) 11 | num_channels = inp_flat.size(1) 12 | element_size_bytes = inp_flat.element_size() * 2 13 | 14 | chunk_size = int(max_chunk_memory // (element_size_bytes * num_channels)) 15 | chunk_size = min(chunk_size, num_elements) 16 | 17 | x_sum = torch.zeros(num_channels, dtype=torch.float32, device=inp.device) 18 | 19 | for i in range(0, num_elements, chunk_size): 20 | end = min(i + chunk_size, num_elements) 21 | chunk_sum = inp_flat[i:end].to(torch.float32).sum(dim=0) 22 | x_sum += chunk_sum.to(inp.device) 23 | 24 | x_mean_new = (x_sum / num_elements).to(inp.dtype) 25 | 26 | # Compare results 27 | are_close = torch.allclose(x_mean_original, x_mean_new, atol=atol, rtol=rtol) 28 | max_diff = torch.max(torch.abs(x_mean_original - x_mean_new)).item() 29 | 30 | print(f"Results are close: {are_close}") 31 | print(f"Maximum difference: {max_diff}") 32 | 33 | return are_close 34 | 35 | 36 | def pseudo_quantize_tensor(w: torch.Tensor, group_size=128, w_bit=4): 37 | org_w_shape = w.shape 38 | if group_size > 0: 39 | assert org_w_shape[-1] % group_size == 0 40 | w = w.reshape(-1, group_size) 41 | assert w.dim() == 2 42 | assert torch.isnan(w).sum() == 0 43 | 44 | # zero point quantization 45 | max_val = w.amax(dim=1, keepdim=True) 46 | min_val = w.amin(dim=1, keepdim=True) 47 | max_int = 2**w_bit - 1 48 | min_int = 0 49 | scales = (max_val - min_val).clamp(min=1e-5) / max_int 50 | zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) 51 | w = ( 52 | torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros 53 | ) * scales 54 | zeros = zeros.view(org_w_shape[0], -1) 55 | 56 | assert torch.isnan(scales).sum() == 0 57 | assert torch.isnan(w).sum() == 0 58 | 59 | scales = scales.view(org_w_shape[0], -1) 60 | w = w.reshape(org_w_shape) 61 | 62 | return w, scales, zeros 63 | 64 | 65 | def test_loss_computation(fp16_output, int_w_output, max_chunk_memory=1024*1024*1024, atol=1e-5, rtol=1e-5): 66 | # Original method 67 | loss_original = (fp16_output - int_w_output).float().pow(2).mean().item() 68 | 69 | # New method with chunking 70 | @torch.no_grad() 71 | def _compute_loss(fp16_output, int_w_output, device, max_chunk_memory): 72 | loss = 0.0 73 | fp16_output_flat = fp16_output.view(-1) 74 | int_w_output_flat = int_w_output.view(-1) 75 | num_elements = fp16_output_flat.size(0) 76 | element_size_bytes = fp16_output.element_size() 77 | 78 | chunk_size = max_chunk_memory // (element_size_bytes * 2) 79 | chunk_size = min(chunk_size, num_elements) 80 | 81 | fp16_chunks = torch.split(fp16_output_flat, chunk_size) 82 | int_w_chunks = torch.split(int_w_output_flat, chunk_size) 83 | 84 | for fp16_chunk, int_w_chunk in zip(fp16_chunks, int_w_chunks): 85 | chunk_loss = (fp16_chunk.to(device) - int_w_chunk.to(device)).float().pow(2).sum().item() 86 | loss += chunk_loss 87 | 88 | loss /= num_elements 89 | return loss 90 | 91 | loss_new = _compute_loss(fp16_output, int_w_output, fp16_output.device, max_chunk_memory) 92 | 93 | # Compare results 94 | are_close = np.isclose(loss_original, loss_new, atol=atol, rtol=rtol) 95 | diff = abs(loss_original - loss_new) 96 | 97 | print(f"Results are close: {are_close}") 98 | print(f"Difference: {diff}") 99 | 100 | return are_close 101 | 102 | fp16_output = torch.randn(1000, 1000, 512) 103 | int_w_output = pseudo_quantize_tensor(fp16_output)[0] 104 | test_result = test_loss_computation(fp16_output, int_w_output) 105 | 106 | inp = torch.randn(1000, 1000, 512) 107 | test_result = test_per_channel_mean(inp) --------------------------------------------------------------------------------