├── .gitignore ├── README.md ├── pack_quantized_model.py ├── quant.py └── src ├── data_utils.py ├── dist_utils.py ├── gptq.py ├── gptq_loop.py ├── linalg_utils.py ├── loading_utils.py ├── model_utils.py └── quant_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## MoE-Quant 2 | --- 3 | 4 | This repository provides code for [GPTQ](https://arxiv.org/abs/2210.17323) quantization of [DeepSeekV3](https://huggingface.co/deepseek-ai/DeepSeek-V3)/[DeepSeekR1](https://huggingface.co/deepseek-ai/DeepSeek-R1) model family. 5 | 6 | ### News 🔥 7 | 8 | - [2025/06] Quantized DeepSeek-R1-0528 model is on 🤗 hub. 9 | 10 | ### Features 11 | 12 | In order to quantize large model (671B parameters) with the `GPTQ` algorithm in reasonable time we introduce several optimizations: 13 | 14 | 1) **Fast `triton` kernel for `GPTQ`**: 15 | Since one has to quantize a lot (really a lot - ~45k) of linear layers, a faster `GPTQ` procedure is critical optimization. The provided `triton` implementation allows one to achieve ~10x relative to default `torch` implementation. 16 | 2) **Expert parallelism**: We shard MLP experts across all devices to fit Hessians into VRAM, required for `GPTQ` calibration. Each process stores only a fraction of expert layers and corresponding Hessians. 17 | 3) **Data parallelism**: To accelerate forward propagation we split calibration data uniformly across processes. 18 | 19 | **The total runtime of the algorithm to quantize DeepSeek-V3/R1 is 2 hours on a server with `8xH100` (for 512 calibration sequences of length 4096).** 20 | 21 | Currently we support conversion of `GPTQ`-quantized model into the [compressed_tensors](https://github.com/neuralmagic/compressed-tensors) format supported in HuggingFace transformers and vLLM. 22 | 23 | At the moment only 4-bit symmetric quantization with different quantization group sizes is supported. 24 | We plan to implement other bit widths and quantization formats (`AWQ`, `AutoGPQ`) in the future. 25 | 26 | 27 | ### GPTQ-quantized models on 🤗 28 | 29 | --- 30 | #### DeepSeek-R1 31 | 32 | | Models | Experts Quantized | Attention blocks quantized | Size (Gb) | 33 | | ------ | --------- | --------- | --------- | 34 | | [ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g](https://huggingface.co/ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g) | ✅ | ✅ | 325 GB | 35 | | [ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g-experts](https://huggingface.co/ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g-experts)| ✅ | ❌ | 346 GB | 36 | 37 | These models easily fit onto single 8x `A100/H100` node with context long enough for most of the applications of interest, including reasoning chains. 38 | 39 | **Evaluation results on OpenLLM Leaderboard V1 tasks** 40 | 41 | | | Recovery (%) | Average Score | ARC-Challenge
acc_norm, 25-shot | GSM8k
exact_match, 5-shot | HellaSwag
acc_norm, 10-shot | MMLU
acc, 5-shot | TruthfulQA
mc2, 0-shot | WinoGrande
acc, 5-shot | 42 | | :------------------------------------------: | :----------: | :-----------: | :--------------------------------: | :--------------------------: | :----------------------------: | :-----------------: | :-----------------------: | :-----------------------: | 43 | | deepseek/DeepSeek-R1 | 100.00 | 81.04 | 72.53 | 95.91 | 89.30 | 87.22 | 59.28 | 82.00 | 44 | | cognitivecomputations/DeepSeek-R1-AWQ | 100.07 | 81.10 | 73.12 | 95.15 | 89.07 | 86.86 | 60.09 | 82.32 | 45 | | ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g | 99.86 | 80.93 | 72.70 | 95.68 | 89.25 | 86.83 | 58.77 | 82.32 | 46 | | ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g-experts | 100.30 | 81.28 | 72.53 | 95.68 | 89.36 | 86.99 | 59.77 | 83.35 | 47 | 48 | **Evaluation results on reasoning tasks (AIME-24, GPQA-Diamond, MATH-500)** 49 | 50 | | | Recovery (%) | Average Score | AIME 2024
pass@1 | MATH-500
pass@1 | GPQA Diamond
pass@1 | 51 | | -------------------------------------------- | :----------: | :-----------: | :-----------------: | :----------------: | :--------------------: | 52 | | deepseek/DeepSeek-R1 | 100.00 | 82.99 | 78.33 | 97.24 | 73.38 | 53 | | cognitivecomputations/DeepSeek-R1-AWQ | 94.29 | 78.25 | 70.67 | 93.64 | 70.46 | 54 | | ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g | 96.52 | 80.10 | 72.96 | 97.09 | 70.26 | 55 | | ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g-experts | **98.81** | 82.00 | 77.00 | 97.08 | 71.92 | 56 | 57 | --- 58 | #### DeepSeek-R1-0528 59 | 60 | | Models | Experts Quantized | Attention blocks quantized | Size (Gb) | 61 | | ------ | --------- | --------- | --------- | 62 | | [ISTA-DASLab/DeepSeek-R1-0528-GPTQ-4b-128g-experts](https://huggingface.co/ISTA-DASLab/DeepSeek-R1-0528-GPTQ-4b-128g-experts)| ✅ | ❌ | 346 GB | 63 | 64 | **Evaluation results on reasoning tasks (AIME-24, GPQA-Diamond, MATH-500)** 65 | 66 | | | Recovery (%) | Average Score | AIME 2024
pass@1 | MATH-500
pass@1 | GPQA Diamond
pass@1 | 67 | | ------------------------------------------- | :----------: | :-----------: | :-----------------: | :----------------: | :--------------------: | 68 | | deepseek/DeepSeek-R1-0528 | 100.00 | 88.61 | 88.66 | 97.52 | 79.65 | 69 | | ISTA-DASLab/DeepSeek-R1-0528-GPTQ-4b-128g-experts | 99.82 | 88.45 | 87.33 | 97.40 | 80.61 | 70 | 71 | ### Usage 72 | 73 | **Model quantization** 74 | 75 | ```shell 76 | torchrun --nnodes=1 --nproc-per-node=$NUM_GPUS --master_port 29501 quant.py \ 77 | --model_name_or_path $MODEL_PATH \ 78 | --dataset_name_or_path $DATASET \ 79 | --num_calibration_samples 512 \ 80 | --max_sequence_length 4096 \ 81 | --bits 4 \ 82 | --group_size 128 \ 83 | --rel_damp 0.1 \ 84 | --sym \ 85 | --offload_activations \ 86 | --quantization_order $QUANTIZATION_ORDER \ 87 | --quantization_scale $QUANTIZATION_SCALE \ 88 | --quantize_only_experts \ 89 | --tie_gptq_handles \ 90 | --dtype bfloat16 \ 91 | --save_dir 92 | ``` 93 | 94 | Above: 95 | * `--model_name_or_path` - **exact path** to model weights, say (`$HF_HOME/hub/models/models--deepseek-ai--DeepSeek-V3-0324/snapshots/commit_hash/`) 96 | * `--dataset_name_or_path` - dataset used for calibration. We provide 3 choices `open-thoughts`, `open-platypus`, `fineweb-edu` 97 | * `--num_calibration_samples` - number of calibration samples 98 | * `--max_sequence_length` - maximal length of calibration samples (samples longer are capped to this value) 99 | * `--quantization_order` - `default` or `activation`, we recommend using the latter for best results 100 | * `--quantization_scale` - `absmax` or `mse`, we recommend using the latter for best results 101 | * `--quantize_only_experts` - quantize only *non-shared* experts. Yields potentially better accuracy at the cost of slightly higher memory overhead. 102 | * `--tie_gptq_handles` - reuse the same Hessian for `up` and `gate` projections to reduce memory overhead on quantization 103 | * `--save_dir` - directory to save the model 104 | 105 | The scripts above produces a directory with quantization metadata for each quantized layer, i.e `quantized_weight`, `scale`, and `zero`. 106 | 107 | **Model packing** 108 | 109 | To convert the model into `compressed_tensors` format run `pack_quantized_model.py` script 110 | 111 | ```shell 112 | python pack_quantized_model.py \ 113 | --model_name_or_path $MODEL_PATH \ 114 | --quantized_model_path $QUANTIZED_MODEL_PATH \ 115 | --packed_model_path $QUANTIZED_MODEL_PATH-packed \ 116 | --dtype bfloat16 117 | ``` 118 | 119 | Above: 120 | * `--model_name_or_path` - **exact path** to model weights 121 | * `--quantized_model_path` - path to quantized weights (output of `quant.py`) 122 | * `--packed_model_path` - path to model in `compressed_tensors` format ready for inference in HF and vLLM. 123 | 124 | ### Environment 125 | 126 | This code was tested with the following versions of libraries: 127 | * `torch 2.5.1` 128 | * `transformers 4.50.0` 129 | * `vllm 0.8.2` 130 | 131 | ### Performance benchmarking 132 | We follow the standard vLLM performance benchmarking with ShareGPT dataset and observe the following metrics (lower is better): 133 | 134 | | | Time to First Token
Median TTFT (ms) ↓ | Time per Output Token
Median TPOT (ms) ↓ | Inter-token Latency
Median ITL (ms) ↓ | 135 | | -------------------------------------------- | :-------------------------------------: | :---------------------------------------: | :------------------------------------: | 136 | | cognitivecomputations/DeepSeek-R1-AWQ | 1585.45 | 55.41 | 43.06 | 137 | | ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g-experts | 1344.68 | 41.49 | 36.33 | 138 | | ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g | 815.19 | 44.65 | 37.88 | 139 | 140 | GPTQ models are faster across all metrics than AWQ models because GPTQ uses less bits-per-parameter than AWQ. More specifically, AWQ has to use smaller group-size of 64 (vs 128 in GPTQ) to preserve accuracy, and zero-points due to asymmetric quantization. 141 | 142 | ### Contributors 143 | 144 | Denis Kuznedelev (Yandex), Eldar Kurtić (Red Hat AI & ISTA), Jiale Chen (ISTA), Michael Goin (Red Hat AI), Elias Frantar (ISTA), Dan Alistarh (Red Hat AI & ISTA). 145 | 146 | ### Citation 147 | 148 | ``` 149 | @article{gptq, 150 | title={{GPTQ}: Accurate Post-training Compression for Generative Pretrained Transformers}, 151 | author={Elias Frantar and Saleh Ashkboos and Torsten Hoefler and Dan Alistarh}, 152 | year={2022}, 153 | journal={arXiv preprint arXiv:2210.17323} 154 | } 155 | ``` 156 | -------------------------------------------------------------------------------- /pack_quantized_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import json 4 | import shutil 5 | import argparse 6 | from collections import defaultdict 7 | from typing import Optional, Any 8 | 9 | from tqdm import tqdm 10 | import torch 11 | from safetensors.torch import save_file 12 | from accelerate import init_empty_weights 13 | from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM 14 | from compressed_tensors.compressors import pack_to_int32 15 | 16 | from src import quant_utils 17 | from src import loading_utils 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser() 22 | # Model params 23 | parser.add_argument( 24 | "--model_name_or_path", 25 | type=str, 26 | required=True, 27 | help="The name or path to the DeepSeek model", 28 | ) 29 | parser.add_argument( 30 | "--quantized_model_path", 31 | type=str, 32 | required=True, 33 | help="Path to quantized model." 34 | ) 35 | parser.add_argument( 36 | "--packed_model_path", 37 | type=str, 38 | required=True, 39 | help="Whether to save packed model." 40 | ) 41 | # Misc params 42 | parser.add_argument( 43 | "--dtype", 44 | default="float16", 45 | type=str, 46 | choices=["float16", "bfloat16"], 47 | help="Torch dtype used." 48 | ) 49 | args = parser.parse_args() 50 | return args 51 | 52 | 53 | def is_subset(set1: set, set2: set): 54 | return set1 <= set2 55 | 56 | 57 | def pack_weight( 58 | weight: dict[torch.Tensor], 59 | bits: int, 60 | sym: bool, 61 | group_size: Optional[int] = None, 62 | ) -> dict[torch.Tensor]: 63 | compressed_data = {} 64 | qweight, scale, zero = weight['qweight'], weight['scale'], weight['zero'] 65 | group_size = group_size or qweight.shape[-1] 66 | qweight_shifted = qweight.to(torch.int8) - zero.repeat_interleave(group_size, dim=-1).to(torch.int8) 67 | qweight_packed = pack_to_int32(qweight_shifted, bits) 68 | compressed_data = { 69 | "weight_packed": qweight_packed, 70 | "weight_shape": torch.tensor(qweight.shape), 71 | "weight_scale": scale 72 | } 73 | if not sym: 74 | compressed_data["weight_zero_point"] = weight['zero'] 75 | return compressed_data 76 | 77 | 78 | def prepare_quantization_config(args: argparse.Namespace) -> dict[str, Any]: 79 | ignored_modules = ["lm_head"] 80 | if args.quantize_only_experts: 81 | ignored_modules += ["re:.*self_attn.*", "re:.*shared_experts.*", "re:.*mlp\.(gate|up|gate_up|down)_proj.*"] 82 | return { 83 | "config_groups": { 84 | "group_0": { 85 | "input_activations": None, 86 | "output_activations": None, 87 | "targets": [ 88 | "Linear" 89 | ], 90 | "weights": { 91 | "actorder": None, 92 | "block_structure": None, 93 | "dynamic": False, 94 | "group_size": args.group_size, 95 | "num_bits": args.bits, 96 | "observer": "minmax", 97 | "observer_kwargs": {}, 98 | "strategy": "group", 99 | "symmetric": True, 100 | "type": "int" 101 | } 102 | } 103 | }, 104 | "format": "pack-quantized", 105 | "ignore": ignored_modules, 106 | "kv_cache_scheme": None, 107 | "quant_method": "compressed-tensors", 108 | "quantization_status": "compressed" 109 | } 110 | 111 | 112 | def main(): 113 | args = parse_args() 114 | 115 | dtype = getattr(torch, args.dtype) 116 | 117 | # Load DeepSeek model 118 | config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True) 119 | if hasattr(config, "quantization_config"): 120 | delattr(config, "quantization_config") 121 | 122 | with init_empty_weights(): 123 | model = AutoModelForCausalLM.from_config( 124 | config=config, 125 | trust_remote_code=True, 126 | torch_dtype=torch.bfloat16 127 | ).eval() 128 | model.config.use_cache = False 129 | 130 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True) 131 | 132 | # Load quantization metadata 133 | metadata = torch.load(os.path.join(args.quantized_model_path, "metadata.pt")) 134 | args.bits = metadata["bits"] 135 | args.group_size = metadata["group_size"] 136 | args.quantize_only_experts = metadata["quantize_only_experts"] 137 | # Currently we do not support asymmetric quantization 138 | args.sym = True 139 | 140 | num_output_shards = len(model.model.layers) + 2 141 | current_output_shard_id = 1 142 | quantized_layer_names = defaultdict(list) 143 | for layer_name in sorted(os.listdir(args.quantized_model_path)): 144 | if os.path.isdir(os.path.join(args.quantized_model_path, layer_name)): 145 | block_idx = int(layer_name.split(".")[2]) 146 | quantized_layer_names[block_idx].append(layer_name) 147 | safetensors_index = {} 148 | # Prepare directory to save packed weights 149 | os.makedirs(args.packed_model_path, exist_ok=True) 150 | 151 | # Load initial weight shard 152 | weight_dir = args.model_name_or_path 153 | current_input_shard_id = 1 154 | weight_path = f"model-{current_input_shard_id:05}-of-000163.safetensors" 155 | 156 | param_buffer = loading_utils.load_param_shard(weight_dir, weight_path) 157 | 158 | # Save embeddings 159 | current_output_shard_path = f"model-{current_output_shard_id:05}-of-{num_output_shards:05}.safetensors" 160 | save_file( 161 | {"model.embed_tokens.weight": param_buffer["model.embed_tokens.weight"]}, 162 | os.path.join(args.packed_model_path, current_output_shard_path) 163 | ) 164 | safetensors_index["model.embed_tokens.weight"] = current_output_shard_path 165 | del param_buffer["model.embed_tokens.weight"] 166 | 167 | # Process blocks 168 | for block_idx, block in tqdm( 169 | enumerate(model.model.layers), 170 | desc="Processing transformer blocks", 171 | total=len(model.model.layers) 172 | ): 173 | current_output_shard_id += 1 174 | prefix = f"model.layers.{block_idx}." 175 | block_keys_with_prefix = set(f"{prefix}{k}" for k in block.state_dict()) 176 | 177 | while not is_subset(block_keys_with_prefix, set(param_buffer.keys())): 178 | current_input_shard_id += 1 179 | weight_path = f"model-{current_input_shard_id:05}-of-000163.safetensors" 180 | param_buffer.update(loading_utils.load_param_shard(weight_dir, weight_path)) 181 | 182 | block_state_dict = {k: param_buffer[k] for k in param_buffer if k.startswith(prefix)} 183 | quant_utils.dequantize_state_dict(block_state_dict, dtype) 184 | 185 | for layer_name in quantized_layer_names[block_idx]: 186 | weight_state_dict = torch.load( 187 | os.path.join(args.quantized_model_path, layer_name, "quantized_weight.pt"), 188 | weights_only=True, 189 | map_location="cpu" 190 | ) 191 | packed_weight_state_dict = pack_weight(weight_state_dict, args.bits, args.sym, args.group_size) 192 | block_state_dict.pop(f"{layer_name}.weight") 193 | block_state_dict.pop(f"{layer_name}.weight_scale_inv", None) 194 | block_state_dict.update({f"{layer_name}.{k}": v for k, v in packed_weight_state_dict.items()}) 195 | 196 | # Save block 197 | current_output_shard_path = f"model-{current_output_shard_id:05}-of-{num_output_shards:05}.safetensors" 198 | save_file( 199 | block_state_dict, 200 | os.path.join(args.packed_model_path, current_output_shard_path) 201 | ) 202 | for k in block_state_dict: 203 | safetensors_index[k] = current_output_shard_path 204 | 205 | for k in block_keys_with_prefix: 206 | param_buffer.pop(k, None) 207 | 208 | del block_state_dict 209 | gc.collect() 210 | 211 | # Load final shard 212 | if current_input_shard_id < 163: 213 | current_input_shard_id = 163 214 | weight_path = f"model-{current_input_shard_id:05}-of-000163.safetensors" 215 | param_buffer.update(loading_utils.load_param_shard(weight_dir, weight_path)) 216 | 217 | # Save lm head 218 | current_output_shard_id += 1 219 | current_output_shard_path = f"model-{current_output_shard_id:05}-of-{num_output_shards:05}.safetensors" 220 | save_file( 221 | { 222 | "lm_head.weight": param_buffer["lm_head.weight"], 223 | "model.norm.weight": param_buffer["model.norm.weight"] 224 | }, 225 | os.path.join(args.packed_model_path, current_output_shard_path) 226 | ) 227 | safetensors_index["lm_head.weight"] = current_output_shard_path 228 | safetensors_index["model.norm.weight"] = current_output_shard_path 229 | # Save safetensors index 230 | with open(os.path.join(args.packed_model_path, "model.safetensors.index.json"), "w") as f: 231 | json.dump({"metadata": {}, "weight_map": safetensors_index}, f) 232 | # Add quantization metadata 233 | config.quantization_config = prepare_quantization_config(args) 234 | # Save configs 235 | config.save_pretrained(args.packed_model_path) 236 | model.generation_config.save_pretrained(args.packed_model_path) 237 | # Save tokenizer 238 | tokenizer.save_pretrained(args.packed_model_path) 239 | # Copy modeling script 240 | shutil.copy( 241 | os.path.join(args.model_name_or_path, "modeling_deepseek.py"), 242 | args.packed_model_path 243 | ) 244 | 245 | 246 | if __name__ == "__main__": 247 | main() -------------------------------------------------------------------------------- /quant.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import re 4 | import argparse 5 | 6 | from tqdm import tqdm 7 | import torch 8 | import torch.distributed as dist 9 | from accelerate import init_empty_weights 10 | from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM 11 | 12 | try: 13 | import wandb 14 | wandb_enabled = True 15 | except: 16 | wandb_enabled = False 17 | 18 | 19 | from src import dist_utils, data_utils, model_utils, quant_utils, loading_utils, gptq 20 | 21 | 22 | ROUTED_EXPERTS_REGEX = ".*mlp.experts.\d+.(down|gate|up)_proj$" 23 | TIED_FFN_GROUPS = ("gate_proj", "up_proj") 24 | 25 | 26 | def parse_args(): 27 | parser = argparse.ArgumentParser() 28 | # Model params 29 | parser.add_argument( 30 | "--model_name_or_path", 31 | type=str, 32 | required=True, 33 | help="The name or path to the DeepSeek model", 34 | ) 35 | # Data params 36 | parser.add_argument( 37 | "--dataset_name_or_path", 38 | type=str, 39 | required=True, 40 | help="The name or path to calibration dataset", 41 | ) 42 | parser.add_argument("--num_calibration_samples", default=128, type=int, help="Number of samples for calibration.") 43 | parser.add_argument("--max_sequence_length", default=8192, type=int, help="Calibration sequence length.") 44 | # Quantization params 45 | parser.add_argument( 46 | "--bits", 47 | type=int, 48 | default=4, 49 | choices=[4], 50 | help="Quantization bitwidth.", 51 | ) 52 | parser.add_argument( 53 | "--group_size", 54 | type=int, 55 | default=None, 56 | help="How many weight columns (input features) are quantized with the same statistics, default = all of them", 57 | ) 58 | parser.add_argument("--sym", action="store_true", help="Whether to use symmetric quantization") 59 | parser.add_argument("--rel_damp", type=float, default=1e-2) 60 | parser.add_argument("--block_size", type=int, default=128) 61 | parser.add_argument("--quantization_scale", type=str, default="absmax", choices=["absmax", "mse"]) 62 | parser.add_argument("--quantization_order", type=str, default="default", choices=["default", "activation"]) 63 | parser.add_argument( 64 | "--quantize_only_experts", 65 | default=False, 66 | action="store_true", 67 | help="Whether to quantize only routed (non-shared) experts.", 68 | ) 69 | # Save params 70 | parser.add_argument("--save_dir", type=str, default=None, help="where to save quantized model.") 71 | # Logging params 72 | parser.add_argument("--log_wandb", default=False, action="store_true", help="Log to W&B") 73 | parser.add_argument("--log_error", default=False, action="store_true", help="Whether to log relative L2 error") 74 | # Misc params 75 | parser.add_argument("--offload_activations", action="store_true", help="whether to offload activations to CPU.") 76 | parser.add_argument("--tie_gptq_handles", action="store_true", help="whether to reuse hessian between gate and up projections.") 77 | parser.add_argument("--resume", action="store_true", help="whether to resume quantization from latest checkpoint.") 78 | parser.add_argument("--seed", default=0, type=int, help="Random seed.") 79 | parser.add_argument( 80 | "--dtype", default="float16", type=str, choices=["float16s", "bfloat16"], help="Torch dtype used." 81 | ) 82 | args = parser.parse_args() 83 | 84 | return args 85 | 86 | 87 | def is_subset(set1: set, set2: set): 88 | return set1 <= set2 89 | 90 | 91 | def get_resume_block_idx(save_dir: os.PathLike) -> int: 92 | resume_block_idx = 0 93 | if os.path.exists(save_dir): 94 | for layer_name in os.listdir(save_dir): 95 | block_idx = int(layer_name.split(".")[2]) 96 | resume_block_idx = max(resume_block_idx, block_idx) 97 | return resume_block_idx 98 | 99 | 100 | def main(): 101 | args = parse_args() 102 | # Distributed init 103 | if dist.is_available(): 104 | dist.init_process_group(backend="nccl", init_method="env://") 105 | world_size = dist_utils.get_world_size() 106 | rank = dist_utils.get_rank() 107 | # init device 108 | device = f"cuda:{rank}" 109 | torch.set_grad_enabled(False) 110 | torch.cuda.set_device(device) 111 | offload_device = "cpu" if args.offload_activations else None 112 | dtype = getattr(torch, args.dtype) 113 | # Init W&B logger 114 | if args.log_wandb and dist_utils.is_main(): 115 | assert wandb_enabled, "wandb not installed. try `pip install wandb`" 116 | wandb.init(config=args) 117 | 118 | # Load DeepSeek model 119 | config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True) 120 | # Sanity check 121 | assert config.architectures == ["DeepseekV3ForCausalLM"], "Only DeepseekV3 is supported!" 122 | if hasattr(config, "quantization_config"): 123 | delattr(config, "quantization_config") 124 | config.ep_size = world_size 125 | 126 | with init_empty_weights(): 127 | model = AutoModelForCausalLM.from_config( 128 | config=config, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=dtype 129 | ).eval() 130 | model.config.use_cache = False 131 | 132 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True) 133 | 134 | # Prepare calibration dataset 135 | calibration_dataset = data_utils.prepare_calibration_dataset( 136 | args.dataset_name_or_path, tokenizer, args.max_sequence_length, args.num_calibration_samples, args.seed 137 | ) 138 | 139 | # Take slices (if running on multiple workers) 140 | num_seq_per_rank = len(calibration_dataset) // world_size 141 | calibration_dataset = calibration_dataset[rank * num_seq_per_rank : (rank + 1) * num_seq_per_rank] 142 | dist_utils.barrier(device_ids=[rank]) 143 | 144 | # Load initial weight shard 145 | weight_dir = args.model_name_or_path 146 | current_shard_id = 1 147 | weight_path = f"model-{current_shard_id:05}-of-000163.safetensors" 148 | 149 | param_buffer = {} 150 | if dist_utils.is_main(): 151 | param_buffer = loading_utils.load_param_shard(weight_dir, weight_path) 152 | dist_utils.barrier(device_ids=[rank]) 153 | 154 | # Get resume block id 155 | resume_block_idx = 0 156 | if args.resume: 157 | resume_block_idx = get_resume_block_idx(args.save_dir) 158 | 159 | # Prepare input embeddings and position ids 160 | inputs = [] 161 | position_ids = [] 162 | model.model.embed_tokens.to_empty(device=device) 163 | if dist_utils.is_main(): 164 | model.model.embed_tokens.weight.data = param_buffer["model.embed_tokens.weight"].to(device=device, dtype=dtype) 165 | if dist_utils.is_dist_available_and_initialized(): 166 | dist_utils.broadcast_parameters(model.model.embed_tokens) 167 | for i in range(num_seq_per_rank): 168 | seq_length = calibration_dataset[i].shape[1] 169 | inputs.append(model.model.embed_tokens(calibration_dataset[i].to(device)).to(offload_device)) 170 | position_ids.append(torch.arange(0, seq_length, dtype=torch.long, device=device).unsqueeze(0)) 171 | # Offload embeddings back to meta 172 | model.model.embed_tokens.to(device="meta") 173 | param_buffer.pop("model.embed_tokens.weight", None) 174 | 175 | for block_idx, block in tqdm( 176 | enumerate(model.model.layers), desc="Processing transformer blocks", total=len(model.model.layers) 177 | ): 178 | prefix = f"model.layers.{block_idx}." 179 | 180 | # Collect state dict keys from all processes 181 | rank_block_keys = [k for k in block.state_dict()] 182 | if dist_utils.is_main(): 183 | block_keys_with_prefix = [f"{prefix}{k}" for k in rank_block_keys] 184 | other_ranks_keys = [] 185 | for i in range(1, world_size): 186 | other_rank_keys = [None for _ in rank_block_keys] 187 | dist.recv_object_list(other_rank_keys, src=i) 188 | block_keys_with_prefix.extend([f"{prefix}{k}" for k in other_rank_keys]) 189 | other_ranks_keys.append(other_rank_keys) 190 | # Make it a set 191 | block_keys_with_prefix = set(block_keys_with_prefix) 192 | else: 193 | block_keys_with_prefix = [] 194 | other_ranks_keys = [] 195 | dist.send_object_list(rank_block_keys, dst=0) 196 | 197 | if dist_utils.is_main(): 198 | can_dequantize = True 199 | # Select weights corresponding to current block 200 | block_state_dict = {k[len(prefix) :]: v for k, v in param_buffer.items() if k.startswith(prefix)} 201 | while not (is_subset(block_keys_with_prefix, set(param_buffer.keys())) and can_dequantize): 202 | current_shard_id += 1 203 | weight_path = f"model-{current_shard_id:05}-of-000163.safetensors" 204 | param_buffer.update(loading_utils.load_param_shard(weight_dir, weight_path)) 205 | # Update weights corresponding to current block 206 | block_state_dict = {k[len(prefix) :]: v for k, v in param_buffer.items() if k.startswith(prefix)} 207 | can_dequantize = quant_utils.can_dequantize_from_fp8(block_state_dict) 208 | # Dequantize weights corresponding to current block 209 | quant_utils.dequantize_state_dict(block_state_dict, dtype) 210 | 211 | # Put block onto GPU 212 | block.to_empty(device=device) 213 | 214 | # Simply load block state dict on master and broadcast 215 | if block_idx < model.config.first_k_dense_replace: 216 | if dist_utils.is_main(): 217 | block.load_state_dict(block_state_dict) 218 | if dist_utils.is_dist_available_and_initialized(): 219 | dist_utils.broadcast_parameters(block) 220 | # Send dict with part of expets to target device 221 | else: 222 | if dist_utils.is_main(): 223 | # Load state dict on master 224 | rank_state_dict = {k: block_state_dict[k] for k in rank_block_keys} 225 | block.load_state_dict(rank_state_dict) 226 | # Send to other processes 227 | for i in range(1, world_size): 228 | rank_state_dict = {k: block_state_dict[k] for k in other_ranks_keys[i - 1]} 229 | for k in rank_state_dict: 230 | dist.send(rank_state_dict[k].to(device), dst=i) 231 | else: 232 | rank_state_dict = block.state_dict() 233 | for k in rank_state_dict: 234 | dist.recv(rank_state_dict[k], src=0) 235 | block.load_state_dict(rank_state_dict) 236 | del rank_state_dict 237 | # Clear memory before calibration 238 | torch.cuda.empty_cache() 239 | gc.collect() 240 | 241 | if block_idx >= resume_block_idx: 242 | # Hessian estimate 243 | layers = model_utils.select_layers(model, prefix, ".*", model_utils.LINEAR_LAYERS) 244 | handles = {} 245 | hooks = {} 246 | 247 | for layer_name, layer in layers.items(): 248 | 249 | def update_handle_hook(name): 250 | def _hook(_, inp, out): 251 | handles[name].update(inp[0]) 252 | 253 | return _hook 254 | 255 | if args.quantize_only_experts and re.search(ROUTED_EXPERTS_REGEX, layer_name) is None: 256 | continue 257 | 258 | tied_gptq_handle = None 259 | if args.tie_gptq_handles and layer_name.endswith("up_proj"): 260 | parent_name, _ = layer_name.rsplit(".", 1) 261 | tied_layer_name = f"{parent_name}.gate_proj" 262 | tied_gptq_handle = handles[tied_layer_name] 263 | 264 | handles[layer_name] = gptq.GPTQ( 265 | layer, 266 | args.group_size, 267 | args.sym, 268 | args.rel_damp, 269 | args.block_size, 270 | args.quantization_order, 271 | args.quantization_scale, 272 | is_distributed=re.search(ROUTED_EXPERTS_REGEX, layer_name) is None, 273 | tied_gptq_handle=tied_gptq_handle 274 | ) 275 | 276 | if tied_gptq_handle is None: 277 | hooks[layer_name] = layer.register_forward_hook(update_handle_hook(layer_name)) 278 | 279 | # Collect Hessians 280 | for i in range(num_seq_per_rank): 281 | block(inputs[i].to(device), position_ids=position_ids[i]) 282 | 283 | for _, h in hooks.items(): 284 | h.remove() 285 | 286 | dist_utils.barrier(device_ids=[rank]) 287 | 288 | shared_handles = {k: v for k, v in handles.items() if re.search(ROUTED_EXPERTS_REGEX, k) is None} 289 | expert_handles = {k: v for k, v in handles.items() if k not in shared_handles} 290 | 291 | # Quantized shared handles first 292 | num_issue_zero_samples = 0 293 | num_issue_nan_hessian = 0 294 | num_issue_non_invertible = 0 295 | for handle_name, handle in shared_handles.items(): 296 | dist_utils.print_on_main(f"Quantizing layer {handle_name}") 297 | qweight, scale, zero = handle.quantize(args.bits) 298 | # Construct dequantized weight 299 | dequantized_weight = quant_utils.dequantize_linear_weight(qweight, scale, zero) 300 | assert ( 301 | torch.isfinite(dequantized_weight).all().item() 302 | ), f"[rank{rank}] {handle_name} weight is broken after quantization." 303 | # Update issue tracker 304 | num_issue_zero_samples += handle.issue_zero_samples 305 | num_issue_nan_hessian += handle.issue_nan_hessian 306 | num_issue_non_invertible += handle.issue_non_invertible 307 | 308 | if args.log_error: 309 | if handle.has_hessian_issues(): 310 | dist_utils.print_on_main( 311 | "An issue occured on Hessian computation. Output error cannot be estimated." 312 | ) 313 | else: 314 | relative_mse = quant_utils.get_relative_mse_error( 315 | dequantized_weight.float(), handle.layer.weight.float(), handle.H 316 | ) 317 | dist_utils.print_on_main(f"Relative error: {relative_mse.item():.2e}") 318 | if args.log_wandb and dist_utils.is_main(): 319 | wandb.log({f"relative_error/{handle_name}": relative_mse.item()}, step=0) 320 | 321 | if args.save_dir and dist_utils.is_main(): 322 | os.makedirs(os.path.join(args.save_dir, handle_name), exist_ok=True) 323 | torch.save( 324 | {"qweight": qweight, "scale": scale, "zero": zero}, 325 | os.path.join(args.save_dir, handle_name, f"quantized_weight.pt"), 326 | ) 327 | # Replace original weight by quantized one 328 | handle.layer.weight.data = dequantized_weight 329 | # Destroy handle 330 | handle.reset() 331 | 332 | dist_utils.print_on_main("-" * 10) 333 | dist_utils.print_on_main(f"GPTQ calibration issues for shared modules:") 334 | dist_utils.print_on_main(f"Zero Hessian: {num_issue_zero_samples}") 335 | dist_utils.print_on_main(f"Non-invertible: {num_issue_non_invertible}") 336 | dist_utils.print_on_main(f"NaN Hessian: {num_issue_nan_hessian}") 337 | dist_utils.print_on_main("-" * 10) 338 | 339 | # Quantize experts 340 | num_issue_zero_samples = 0 341 | num_issue_nan_hessian = 0 342 | num_issue_non_invertible = 0 343 | if len(expert_handles) > 0: 344 | dist_utils.print_on_main(f"Processing experts") 345 | 346 | expert_messages = None 347 | if dist_utils.is_main(): 348 | expert_messages = [None for _ in range(world_size)] 349 | rank_expert_message = "" 350 | 351 | for handle_name, handle in expert_handles.items(): 352 | rank_expert_message += f"Quantizing layer {handle_name}\n" 353 | qweight, scale, zero = handle.quantize(args.bits) 354 | # Construct dequantized weight 355 | dequantized_weight = quant_utils.dequantize_linear_weight(qweight, scale, zero) 356 | assert ( 357 | torch.isfinite(dequantized_weight).all().item() 358 | ), f"[rank{rank}] {handle_name} weight is broken after quantization." 359 | # Update issue tracker 360 | num_issue_zero_samples += handle.issue_zero_samples 361 | num_issue_nan_hessian += handle.issue_nan_hessian 362 | num_issue_non_invertible += handle.issue_non_invertible 363 | 364 | rank_expert_message += f"Tokens collected: {handle.tokens_collected}.\n" 365 | 366 | if args.log_error: 367 | if handle.has_hessian_issues(): 368 | rank_expert_message += "Hessian issue. Output error cannot be estimated.\n" 369 | else: 370 | relative_mse = quant_utils.get_relative_mse_error( 371 | dequantized_weight.float(), handle.layer.weight.float(), handle.H 372 | ) 373 | rank_expert_message += f"Relative error: {relative_mse.item():.2e}\n" 374 | # TODO send to main process 375 | if args.log_wandb and dist_utils.is_main(): 376 | wandb.log({f"relative_error/{handle_name}": relative_mse.item()}, step=0) 377 | 378 | if args.save_dir: 379 | os.makedirs(os.path.join(args.save_dir, handle_name), exist_ok=True) 380 | torch.save( 381 | {"qweight": qweight, "scale": scale, "zero": zero}, 382 | os.path.join(args.save_dir, handle_name, f"quantized_weight.pt"), 383 | ) 384 | # Replace original weight by quantized one 385 | handle.layer.weight.data = dequantized_weight 386 | # Destroy handle 387 | handle.reset() 388 | 389 | dist_utils.barrier(device_ids=[rank]) 390 | 391 | dist.gather_object(rank_expert_message, expert_messages) 392 | if dist_utils.is_main(): 393 | for expert_message in expert_messages: 394 | dist_utils.print_on_main(expert_message) 395 | 396 | # TODO sync data from other processes 397 | dist_utils.print_on_main("-" * 10) 398 | dist_utils.print_on_main(f"GPTQ calibration issues for expert modules:") 399 | dist_utils.print_on_main(f"Zero Hessian: {num_issue_zero_samples}") 400 | dist_utils.print_on_main(f"Non-invertible: {num_issue_non_invertible}") 401 | dist_utils.print_on_main(f"NaN Hessian: {num_issue_nan_hessian}") 402 | dist_utils.print_on_main("-" * 10) 403 | 404 | del handles 405 | del shared_handles 406 | del expert_handles 407 | del hooks 408 | torch.cuda.empty_cache() 409 | gc.collect() 410 | else: 411 | dist_utils.print_on_main(f"Block {block_idx} is already quantized. Skipping quantization.") 412 | 413 | # Update activations 414 | for i in range(num_seq_per_rank): 415 | inputs[i] = block(inputs[i].to(device), position_ids=position_ids[i])[0].to(offload_device) 416 | assert torch.isfinite(inputs[i]).all().item(), "NaN of inf encountered." 417 | 418 | # Offload block 419 | block.to(device="meta") 420 | for k in block_keys_with_prefix: 421 | param_buffer.pop(k, None) 422 | 423 | torch.cuda.empty_cache() 424 | gc.collect() 425 | 426 | # Save quantization metadata 427 | if args.save_dir: 428 | torch.save( 429 | { 430 | "bits": args.bits, 431 | "group_size": args.group_size, 432 | "quantize_only_experts": args.quantize_only_experts 433 | }, 434 | os.path.join(args.save_dir, "metadata.pt") 435 | ) 436 | 437 | dist.destroy_process_group() 438 | 439 | 440 | if __name__ == "__main__": 441 | main() 442 | -------------------------------------------------------------------------------- /src/data_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | import torch 4 | from datasets import load_dataset 5 | from transformers import AutoTokenizer 6 | 7 | 8 | def prepare_open_thoughts( 9 | tokenizer: AutoTokenizer, 10 | max_sequence_length: int, 11 | num_calibration_samples: Optional[int] = None, 12 | seed: int = 42 13 | ) -> List[torch.Tensor]: 14 | train_dataset_raw = load_dataset("open-thoughts/OpenThoughts-114k", split="train") 15 | if num_calibration_samples: 16 | train_dataset_raw = train_dataset_raw.shuffle(seed=seed).select(range(num_calibration_samples)) 17 | # Preprocess the data into the format the model is trained with. 18 | def preprocess(example): 19 | messages = [] 20 | # add system prompt 21 | messages.append({"role": "system", "content": example['system']}) 22 | # add dialogue 23 | for message in example['conversations']: 24 | messages.append({"role": message["from"], "content": message["value"]}) 25 | return {"text": tokenizer.apply_chat_template(messages, tokenize=False)} 26 | train_dataset_raw = train_dataset_raw.map(preprocess) 27 | # Tokenize the data 28 | def tokenize(sample): 29 | return tokenizer( 30 | sample["text"], 31 | padding=False, 32 | max_length=max_sequence_length, 33 | truncation=True, 34 | add_special_tokens=False, 35 | ) 36 | train_dataset = train_dataset_raw.map(tokenize, remove_columns=train_dataset_raw.column_names) 37 | train_dataset = [torch.tensor(sample['input_ids']).unsqueeze(0) for sample in train_dataset] 38 | return train_dataset 39 | 40 | 41 | def prepare_open_platypus( 42 | tokenizer: AutoTokenizer, 43 | max_sequence_length: int, 44 | num_calibration_samples: Optional[int] = None, 45 | seed: int = 42 46 | ) -> List[torch.Tensor]: 47 | train_dataset_raw = load_dataset("garage-bAInd/Open-Platypus", split="train") 48 | if num_calibration_samples: 49 | train_dataset_raw = train_dataset_raw.shuffle(seed=seed).select(range(num_calibration_samples)) 50 | # Preprocess the data into the format the model is trained with. 51 | def preprocess(example): 52 | messages = [ 53 | {"role": "user", "content": example["instruction"]}, 54 | {"role": "assistant", "content": example["output"]}, 55 | ] 56 | return {"text": tokenizer.apply_chat_template(messages, tokenize=False)} 57 | train_dataset_raw = train_dataset_raw.map(preprocess) 58 | # Tokenize the data 59 | def tokenize(sample): 60 | return tokenizer( 61 | sample["text"], 62 | padding=False, 63 | max_length=max_sequence_length, 64 | truncation=True, 65 | add_special_tokens=False, 66 | ) 67 | train_dataset = train_dataset_raw.map(tokenize, remove_columns=train_dataset_raw.column_names) 68 | train_dataset = [torch.tensor(sample['input_ids']).unsqueeze(0) for sample in train_dataset] 69 | return train_dataset 70 | 71 | 72 | def prepare_fineweb_edu( 73 | tokenizer: AutoTokenizer, 74 | max_sequence_length: int, 75 | num_calibration_samples: Optional[int] = None, 76 | seed: int = 42 77 | ) -> List[torch.Tensor]: 78 | train_dataset_raw = load_dataset("HuggingFaceFW/fineweb-edu", "sample-10BT", split="train", streaming=True) 79 | train_dataset_raw = train_dataset_raw.shuffle(seed=seed, buffer_size=1_000) 80 | train_dataset = [] 81 | for i, sample in enumerate(train_dataset_raw): 82 | if i == num_calibration_samples: 83 | break 84 | tokenized_sample = tokenizer( 85 | sample["text"], 86 | max_length=max_sequence_length, 87 | truncation=True, 88 | return_tensors="pt" 89 | ) 90 | train_dataset.append(tokenized_sample['input_ids']) 91 | return train_dataset 92 | 93 | 94 | def prepare_calibration_dataset( 95 | dataset_name: str, 96 | tokenizer: AutoTokenizer, 97 | max_sequence_length: int, 98 | num_calibration_samples: Optional[int] = None, 99 | seed: int = 42 100 | ) -> List[torch.Tensor]: 101 | if dataset_name == "open-thoughts": 102 | return prepare_open_thoughts(tokenizer, max_sequence_length, num_calibration_samples, seed) 103 | if dataset_name == "open-platypus": 104 | return prepare_open_platypus(tokenizer, max_sequence_length, num_calibration_samples, seed) 105 | if dataset_name == "fineweb-edu": 106 | return prepare_fineweb_edu(tokenizer, max_sequence_length, num_calibration_samples, seed) 107 | else: 108 | raise ValueError("Unknown dataset") 109 | -------------------------------------------------------------------------------- /src/dist_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Any 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.distributed as dist 6 | 7 | 8 | __all__ = [ 9 | "is_dist_available_and_initialized", 10 | "get_world_size", 11 | "get_rank", 12 | "is_main", 13 | "broadcast_parameters", 14 | "gather_into_tensor", 15 | "print_on_main", 16 | "barrier" 17 | ] 18 | 19 | 20 | def is_dist_available_and_initialized(): 21 | return dist.is_available() and dist.is_initialized() 22 | 23 | 24 | def get_world_size(): 25 | if is_dist_available_and_initialized(): 26 | return dist.get_world_size() 27 | return 1 28 | 29 | 30 | def get_rank(): 31 | if is_dist_available_and_initialized(): 32 | return dist.get_rank() 33 | return 0 34 | 35 | 36 | def is_main(): 37 | return get_rank() == 0 38 | 39 | 40 | def barrier(device_ids=None): 41 | if is_dist_available_and_initialized(): 42 | dist.barrier(device_ids=device_ids) 43 | 44 | 45 | def broadcast_parameters(module: nn.Module, src: Any = 0, group: Optional[Any] = None): 46 | for param in module.parameters(): 47 | dist.broadcast(param.data, src=src, group=group) 48 | 49 | 50 | def gather_into_tensor(tensor, dim: int = 0): 51 | world_size = get_world_size() 52 | if is_main(): 53 | gathered_shape = (*tensor.shape[:dim], world_size * tensor.shape[dim], *tensor.shape[dim + 1 :]) 54 | gathered_tensor = torch.empty(gathered_shape, device=tensor.device, dtype=tensor.dtype) 55 | gathered_tensor_chunks = list(gathered_tensor.chunk(world_size, dim=dim)) 56 | else: 57 | gathered_tensor = None 58 | gathered_tensor_chunks = None 59 | dist.gather(tensor, gathered_tensor_chunks) 60 | return gathered_tensor 61 | 62 | 63 | def print_on_main(*args, **kwargs): 64 | if is_main(): 65 | print(*args, **kwargs) 66 | -------------------------------------------------------------------------------- /src/gptq.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.distributed as dist 7 | from torch import Tensor 8 | from torch.nn.modules.conv import _ConvNd 9 | 10 | from src import dist_utils, model_utils, linalg_utils, quant_utils, gptq_loop 11 | 12 | 13 | torch.backends.cuda.matmul.allow_tf32 = False 14 | torch.backends.cudnn.allow_tf32 = False 15 | 16 | 17 | class QuantizationOrder(Enum): 18 | DEFAULT = "default" 19 | ACTIVATION = "activation" 20 | 21 | 22 | class GPTQ: 23 | 24 | def __init__( 25 | self, 26 | layer: nn.Module, 27 | group_size: Optional[int] = None, 28 | sym: bool = False, 29 | rel_damp: float = 1e-2, 30 | block_size: int = None, 31 | quantization_order: str = "default", 32 | quantization_scale: str = "absmax", 33 | is_distributed: bool = False, 34 | tied_gptq_handle: Optional["GPTQ"] = None 35 | ): 36 | self._validate_layer(layer) 37 | self.layer = layer 38 | self.W = self.layer.weight 39 | self.d_row, self.d_col = model_utils.get_number_of_rows_and_cols(layer) 40 | # Quantization hyperparameters 41 | self.sym = sym 42 | self.group_size = group_size 43 | # GPTQ hyperparameters 44 | self.rel_damp = rel_damp 45 | self.block_size = block_size or self.d_col 46 | self.quantization_order = QuantizationOrder(quantization_order) 47 | self.quantization_scale = quantization_scale 48 | # backup layer properties 49 | self.W_device = self.W.device 50 | self.W_dtype = self.W.dtype 51 | self.W_shape = self.W.shape 52 | # init hessian 53 | self.H = None 54 | self.num_samples = 0 55 | self.is_distributed = is_distributed 56 | self.tied_gptq_handle = tied_gptq_handle 57 | self.num_tied_handles = 0 58 | if tied_gptq_handle is not None: 59 | tied_gptq_handle.num_tied_handles += 1 60 | # Flags indicating issues 61 | self.issue_zero_samples = False 62 | self.issue_nan_hessian = False 63 | self.issue_non_invertible = False 64 | 65 | @staticmethod 66 | def _validate_layer(layer): 67 | assert isinstance(layer, (nn.Linear, _ConvNd)), "OBC supports only linear and convolutional layers." 68 | 69 | def has_hessian_issues(self) -> bool: 70 | return any([self.issue_zero_samples, self.issue_nan_hessian, self.issue_non_invertible]) 71 | 72 | # preparatory methods 73 | @torch.no_grad() 74 | def update(self, input: Tensor) -> None: 75 | """ 76 | Update the estimate of Hessian matrix from a batch of data. 77 | 78 | Args: 79 | input: batch of layer inputs 80 | """ 81 | # init hessian 82 | if self.H is None: 83 | self.H = torch.zeros((self.d_col, self.d_col), device=input.device, dtype=torch.float32) 84 | # input reshaping 85 | if isinstance(self.layer, nn.Linear): 86 | input = input.reshape(-1, input.shape[-1]) 87 | else: 88 | unfold = nn.Unfold( 89 | self.layer.kernel_size, 90 | dilation=self.layer.dilation, 91 | padding=self.layer.padding, 92 | stride=self.layer.stride, 93 | ) 94 | # output size (batch_size, channels * \prod kernel_size, num_patches) 95 | input = unfold(input) 96 | input = input.transpose(1, 2).flatten(0, 1) 97 | input = input.float() 98 | # get number of samples (tokens) in batch 99 | num_new_samples = input.shape[0] 100 | # hessian update 101 | beta = self.num_samples / (self.num_samples + num_new_samples) 102 | alpha = 2.0 / (self.num_samples + num_new_samples) 103 | self.H.addmm_(input.T, input, beta=beta, alpha=alpha) 104 | # update number of collected samples 105 | self.num_samples += num_new_samples 106 | 107 | @property 108 | def tokens_collected(self) -> int: 109 | return self.num_samples 110 | 111 | def reset(self) -> None: 112 | self.W = self.layer.weight 113 | if self.num_tied_handles == 0: 114 | self.H = None 115 | elif self.tied_gptq_handle: 116 | self.tied_gptq_handle.num_tied_handles -= 1 117 | if self.tied_gptq_handle.num_tied_handles == 0: 118 | self.tied_gptq_handle.H = None 119 | self.num_samples = 0 120 | torch.cuda.empty_cache() 121 | 122 | @torch.no_grad() 123 | def quantization_pre_step(self) -> None: 124 | """ 125 | Preparatory step with hessian regularization and weight reshaping. 126 | """ 127 | # 1) Hessian preparation 128 | reduce_if_needed = True 129 | if self.H is None: 130 | if self.tied_gptq_handle: 131 | self.H = self.tied_gptq_handle.H 132 | else: 133 | self.H = torch.eye(self.d_col, device=self.W_device, dtype=torch.float32) 134 | self.issue_zero_samples = True 135 | # No need to reduce 136 | reduce_if_needed = False 137 | # synchronize Hessians 138 | if self.is_distributed and reduce_if_needed and dist_utils.is_dist_available_and_initialized(): 139 | dist.all_reduce(self.H, op=dist.ReduceOp.AVG) 140 | # Replace matrix by identity in case of NaNs 141 | if torch.isnan(self.H).any().item(): 142 | self.H = torch.eye(self.d_col, device=self.W_device, dtype=torch.float32) 143 | self.issue_nan_hessian = True 144 | # get ids of pruned channels 145 | pruned_ids = torch.diag(self.H) == 0 146 | self.H[pruned_ids, pruned_ids] = 1 147 | # 2) Weight preparation 148 | # copy weight, flatten 149 | self.W = self.W.clone().float() 150 | if isinstance(self.layer, _ConvNd): 151 | self.W = self.W.flatten(1, -1) 152 | self.W[:, pruned_ids] = 0 153 | # flag pre step as completed 154 | self.pre_step_completed = True 155 | 156 | @torch.no_grad() 157 | def _quantize(self, bits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 158 | """ 159 | Quantize the weight matrix using GPTQ 160 | """ 161 | # 1) Define constants and chunk 162 | d_row, d_col, block_size, device, dtype = self.d_row, self.d_col, self.block_size, self.W_device, self.W_dtype 163 | # 2) Get quantization group size 164 | group_size = self.group_size or d_col 165 | num_groups = d_col // group_size 166 | 167 | is_main_gptq_process = dist_utils.is_main() or not self.is_distributed 168 | 169 | if is_main_gptq_process: 170 | # Get scale, qzero 171 | scale, zero, maxq = quant_utils.get_quantization_grid( 172 | weight=self.W, 173 | group_size=self.group_size, 174 | bits=bits, 175 | symmetric=self.sym, 176 | dtype=dtype, 177 | quantization_scale=self.quantization_scale, 178 | ) 179 | # Get permutation 180 | if self.quantization_order == QuantizationOrder.ACTIVATION: 181 | perm = torch.argsort(self.H.diag(), descending=True) 182 | else: 183 | perm = torch.arange(d_col, device=device) 184 | perm_inv = torch.argsort(perm) 185 | # Permute Hessian prior to inversion (if reusing hessian from other handle, Hessian is already permuted) 186 | if not self.tied_gptq_handle: 187 | self.H = self.H[perm][:, perm] 188 | # Get hessian inverse 189 | hessian_inv = self._get_hessian_inverse() 190 | # Quantize 191 | qweight = gptq_loop.gptq_loop( 192 | weight=self.W.transpose(-2, -1)[perm], 193 | hessian_inv=hessian_inv, 194 | scale=scale.transpose(-2, -1)[perm], 195 | qzero=zero.transpose(-2, -1)[perm], 196 | maxq=maxq, 197 | dtype=dtype, 198 | gptq_block_size=block_size, 199 | )[perm_inv].transpose(-2, -1).contiguous().to(torch.uint8) 200 | # Remove scale and zero replication 201 | scale = scale[:, ::group_size].to(dtype) 202 | zero = zero[:, ::group_size].to(dtype) 203 | else: 204 | qweight = torch.empty(d_row, d_col, device=device, dtype=torch.uint8) 205 | scale = torch.empty(d_row, num_groups, device=device, dtype=dtype) 206 | zero = torch.empty(d_row, num_groups, device=device, dtype=dtype) 207 | 208 | if self.is_distributed and dist_utils.is_dist_available_and_initialized(): 209 | dist.barrier() 210 | dist.broadcast(qweight, src=0) 211 | dist.broadcast(scale, src=0) 212 | dist.broadcast(zero, src=0) 213 | 214 | return qweight, scale, zero 215 | 216 | def quantize(self, bits: int | float) -> Tensor: 217 | self.quantization_pre_step() 218 | return self._quantize(bits) 219 | 220 | @torch.no_grad() 221 | def _get_hessian_inverse(self): 222 | w = self.W 223 | # Get columns with all zeros 224 | zero_cols = torch.nonzero(w.eq(0).all(dim=0)) 225 | H = self.H 226 | # Regularize Hessian before quantization 227 | if not self.tied_gptq_handle: 228 | # Mask rows with zero input channels 229 | H[zero_cols, :] = 0 230 | H[:, zero_cols] = 0 231 | H[zero_cols, zero_cols] = 1 232 | # Hessian regularization 233 | damp = self.rel_damp * torch.diag(self.H).mean() 234 | self.H[range(self.d_col), range(self.d_col)] += damp 235 | # Invert 236 | try: 237 | H = linalg_utils.inv_sym(H) 238 | H_inv_cho = torch.linalg.cholesky(H, upper=True) 239 | except: 240 | H_inv_cho = torch.eye(self.d_col, device=H.device, dtype=torch.float32) 241 | # Divide Hessian inverse by diagonal (in order to not divide on it later) 242 | H_inv_cho.div_(H_inv_cho.diag()[:, None]) 243 | return H_inv_cho 244 | -------------------------------------------------------------------------------- /src/gptq_loop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | from triton import language as tl 4 | 5 | from src.quant_utils import tl_quantize, tl_dequantize 6 | 7 | torch.backends.cuda.matmul.allow_tf32 = False 8 | torch.backends.cudnn.allow_tf32 = False 9 | torch.set_float32_matmul_precision("highest") 10 | 11 | 12 | @triton.jit 13 | def quantize_error_triton_kernel( 14 | x_ptr, 15 | qx_ptr, 16 | error_ptr, 17 | scale_ptr, 18 | qzero_ptr, 19 | maxq_ptr, 20 | dtype_ptr, 21 | n_elements: int, 22 | BLOCK_SIZE: tl.constexpr, 23 | ): 24 | pid = tl.program_id(axis=0) 25 | offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 26 | mask = offsets < n_elements 27 | 28 | x = tl.load(x_ptr + offsets, mask=mask) 29 | scale = tl.load(scale_ptr + offsets, mask=mask) 30 | qzero = tl.load(qzero_ptr + offsets, mask=mask) 31 | maxq = tl.load(maxq_ptr) 32 | dtype = None if dtype_ptr is None else tl.load(dtype_ptr).dtype 33 | 34 | qx = tl_quantize(x, scale, qzero, maxq) 35 | y = tl_dequantize(qx, scale, qzero, dtype) 36 | error = y - x 37 | 38 | tl.store(x_ptr + offsets, y, mask=mask) 39 | tl.store(qx_ptr + offsets, qx, mask=mask) 40 | tl.store(error_ptr + offsets, error, mask=mask) 41 | 42 | 43 | def quantize_error_triton( 44 | x: torch.Tensor, 45 | qx: torch.Tensor, 46 | error: torch.Tensor, 47 | scale: torch.Tensor, 48 | qzero: torch.Tensor, 49 | maxq: torch.Tensor, 50 | dtype: torch.dtype = None, 51 | ) -> None: 52 | 53 | n_elements: int = x.numel() 54 | grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 55 | quantize_error_triton_kernel[grid]( 56 | x, 57 | qx, 58 | error, 59 | scale, 60 | qzero, 61 | maxq, 62 | torch.empty(0, dtype=dtype) if dtype is not None else None, 63 | n_elements, 64 | BLOCK_SIZE=128, 65 | ) 66 | 67 | 68 | @triton.jit 69 | def addvv_triton_kernel( 70 | vec_a_ptr, 71 | vec_b_ptr, 72 | mat_c_ptr, 73 | size_a: int, 74 | size_b: int, 75 | BLOCK_SIZE_B: tl.constexpr, 76 | ): 77 | pid = tl.program_id(axis=0) 78 | offset_a = pid % size_a 79 | offsets_b = pid // size_a * BLOCK_SIZE_B + tl.arange(0, BLOCK_SIZE_B) 80 | mask = offsets_b < size_b 81 | c_ptrs = mat_c_ptr + offset_a * size_b + offsets_b 82 | 83 | a = tl.load(vec_a_ptr + offset_a) 84 | b = tl.load(vec_b_ptr + offsets_b, mask=mask) 85 | c = tl.load(c_ptrs, mask=mask) 86 | c = tl.fma(a, b, c) 87 | 88 | tl.store(c_ptrs, c, mask=mask) 89 | 90 | 91 | def addvv_triton( 92 | vec_a: torch.Tensor, 93 | vec_b: torch.Tensor, 94 | mat_c: torch.Tensor, 95 | ) -> None: 96 | size_a, size_b = mat_c.shape 97 | grid = lambda meta: (size_a * triton.cdiv(size_b, meta["BLOCK_SIZE_B"]),) 98 | addvv_triton_kernel[grid]( 99 | vec_a, 100 | vec_b, 101 | mat_c, 102 | size_a, 103 | size_b, 104 | BLOCK_SIZE_B=256, 105 | ) 106 | 107 | 108 | def gptq_loop_graph( 109 | weight: torch.Tensor, 110 | hessian_inv: torch.Tensor, 111 | scale: torch.Tensor, 112 | qzero: torch.Tensor, 113 | maxq: torch.Tensor, 114 | qweight: torch.Tensor = None, 115 | error_block: torch.Tensor = None, 116 | dtype: torch.dtype = None, 117 | gptq_block_size: int = 128, 118 | direct: bool = True, 119 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 120 | """ 121 | CUDA Graph wrapper for GPTQ loops 122 | """ 123 | n_columns, n_rows = weight.shape 124 | w_dtype: torch.dtype = weight.dtype 125 | device: torch.device = weight.device 126 | 127 | if direct: 128 | if qweight is None: 129 | qweight: torch.Tensor = torch.empty_like(weight) 130 | if error_block is None: 131 | error_block: torch.Tensor = torch.empty(gptq_block_size, n_rows, dtype=w_dtype, device=device) 132 | assert ( 133 | weight.is_contiguous() 134 | and hessian_inv.is_contiguous() 135 | and scale.is_contiguous() 136 | and qzero.is_contiguous() 137 | and maxq.is_contiguous() 138 | and qweight.is_contiguous() 139 | and error_block.is_contiguous() 140 | ) 141 | for i1 in range(0, n_columns, gptq_block_size): 142 | i2: int = min(i1 + gptq_block_size, n_columns) 143 | for j in range(i1, i2): 144 | quantize_error_triton( 145 | weight[j], qweight[j], error_block[j - i1], scale[j], qzero[j], maxq, dtype, 146 | ) 147 | addvv_triton(hessian_inv[j, j + 1 : i2], error_block[j - i1], weight[j + 1 : i2]) 148 | weight[i2:].addmm_(hessian_inv[i1:i2, i2:].t(), error_block[: i2 - i1], beta=1, alpha=1) 149 | return qweight, weight 150 | 151 | previous_device: torch.device = torch.device(f"cuda:{torch.cuda.current_device()}") 152 | torch.cuda.set_device(weight.device) 153 | if not hasattr(gptq_loop_graph, "graph_info"): 154 | gptq_loop_graph.graph_info = {} 155 | graph_key: tuple = n_columns, n_rows, w_dtype, dtype, gptq_block_size, device 156 | if graph_key not in gptq_loop_graph.graph_info: 157 | graph: torch.cuda.CUDAGraph = torch.cuda.CUDAGraph() 158 | graph_tensors: dict[str, torch.Tensor] = { 159 | "weight": torch.empty_like(weight.contiguous()), 160 | "hessian_inv": torch.empty_like(hessian_inv.contiguous()), 161 | "scale": torch.empty_like(scale.contiguous()), 162 | "qzero": torch.empty_like(qzero.contiguous()), 163 | "maxq": torch.empty_like(maxq.contiguous()), 164 | "qweight": torch.empty_like(weight.contiguous()), 165 | "error_block": torch.empty(gptq_block_size, n_rows, dtype=w_dtype, device=device), 166 | } 167 | n_warmups: int = 5 168 | s: torch.cuda.Stream = torch.cuda.Stream() 169 | s.wait_stream(torch.cuda.current_stream()) 170 | with torch.cuda.stream(s): 171 | for _ in range(n_warmups): 172 | gptq_loop_graph(**graph_tensors, dtype=dtype, gptq_block_size=gptq_block_size, direct=True) 173 | torch.cuda.current_stream().wait_stream(s) 174 | with torch.cuda.graph(graph): 175 | gptq_loop_graph(**graph_tensors, dtype=dtype, gptq_block_size=gptq_block_size, direct=True) 176 | gptq_loop_graph.graph_info[graph_key] = {"graph": graph, "tensors": graph_tensors} 177 | 178 | graph, graph_tensors = ( 179 | gptq_loop_graph.graph_info[graph_key]["graph"], 180 | gptq_loop_graph.graph_info[graph_key]["tensors"], 181 | ) 182 | graph_tensors["weight"].copy_(weight) 183 | graph_tensors["hessian_inv"].copy_(hessian_inv) 184 | graph_tensors["scale"].copy_(scale) 185 | graph_tensors["qzero"].copy_(qzero) 186 | graph_tensors["maxq"].copy_(maxq) 187 | graph.replay() 188 | weight.copy_(graph_tensors["weight"]) 189 | torch.cuda.set_device(previous_device) 190 | return graph_tensors["qweight"], weight 191 | 192 | 193 | def gptq_loop( 194 | weight: torch.Tensor, 195 | hessian_inv: torch.Tensor, 196 | scale: torch.Tensor, 197 | qzero: torch.Tensor, 198 | maxq: torch.Tensor, 199 | dtype: torch.dtype, 200 | gptq_block_size: int = 128, 201 | ) -> tuple[torch.Tensor, torch.Tensor]: 202 | """ 203 | Quantize weight tensor with GPTQ algorithm 204 | weight: (C, R), transposed weight tensor to quantize, modified in-place and returned 205 | hessian_inv: (C, C), inverse of Hessian matrix 206 | scale: (C, R), transposed scale tensor for quantization 207 | qzero: (C, R), transposed zero-point tensor for quantization 208 | maxq: (), maximum quantized value 209 | dtype: target scale dtype, fp16 or bf16 210 | gptq_block_size: block size for GPTQ loop, this is independent of the quantization group size 211 | """ 212 | if gptq_block_size <= 0: 213 | gptq_block_size = weight.size(-2) 214 | 215 | qweight, _ = gptq_loop_graph( 216 | weight=weight, 217 | hessian_inv=hessian_inv, 218 | scale=scale, 219 | qzero=qzero, 220 | maxq=maxq, 221 | dtype=dtype, 222 | gptq_block_size=gptq_block_size, 223 | direct=False, 224 | ) 225 | return qweight # (C, R) 226 | -------------------------------------------------------------------------------- /src/linalg_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | __all__ = ["inv_sym"] 6 | 7 | 8 | def inv_sym(X: Tensor): 9 | """ 10 | More efficient and stable inversion of symmetric matrices. 11 | """ 12 | return torch.cholesky_inverse(torch.linalg.cholesky(X)) -------------------------------------------------------------------------------- /src/loading_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from safetensors import safe_open 5 | 6 | def load_param_shard(weight_dir: str, weight_path: str) -> dict[str, torch.Tensor]: 7 | param_shard = {} 8 | with safe_open(os.path.join(weight_dir, weight_path), framework="pt", device="cpu") as f: 9 | param_shard_keys = f.keys() 10 | for k in param_shard_keys: 11 | param_shard[k] = f.get_tensor(k) 12 | return param_shard 13 | -------------------------------------------------------------------------------- /src/model_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional, Union, Optional, Sequence, Tuple 2 | 3 | import re 4 | import numpy as np 5 | import torch.nn as nn 6 | 7 | 8 | LINEAR_LAYERS = (nn.Linear,) 9 | 10 | 11 | def get_number_of_rows_and_cols(layer): 12 | return layer.weight.shape[0], np.prod(layer.weight.shape[1:]) 13 | 14 | 15 | def select_layers( 16 | model: nn.Module, 17 | layer_prefix: Optional[str] = "", 18 | layer_regex: str = ".*", 19 | layer_classes: Union[nn.Module, List[nn.Module]] = nn.Module, 20 | ) -> Dict[str, nn.Module]: 21 | layers = {} 22 | for layer_name, layer in model.named_modules(): 23 | if ( 24 | isinstance(layer, layer_classes) 25 | and re.search(layer_regex, layer_name) 26 | and layer_name.startswith(layer_prefix) 27 | ): 28 | layers[layer_name] = layer 29 | return layers 30 | -------------------------------------------------------------------------------- /src/quant_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from enum import Enum 3 | from typing import Optional 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import triton 8 | from triton import language as tl 9 | 10 | 11 | torch.backends.cuda.matmul.allow_tf32 = False 12 | torch.backends.cudnn.allow_tf32 = False 13 | torch.set_float32_matmul_precision("highest") 14 | 15 | 16 | FP8_GROUP_SIZE = 128 17 | FP8_DTYPES = (torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz) 18 | 19 | 20 | class QuantizationScale(Enum): 21 | ABSMAX = "absmax" 22 | MSE = "mse" 23 | 24 | 25 | @triton.jit 26 | def tl_pow(x, a): 27 | return (x.abs().log() * a).exp() # TODO: triton does not have x.pow(a) or x ** a? 28 | 29 | 30 | @triton.jit 31 | def tl_round(x): 32 | return ( 33 | x + 0.5 34 | ).floor() # TODO: triton does not have round()? We might want to change to round to even number here. 35 | 36 | 37 | @triton.jit 38 | def tl_round_fp(x, dtype): 39 | return x if dtype is None else x.cast(dtype, fp_downcast_rounding="rtne").cast(x.dtype) 40 | 41 | 42 | @triton.jit 43 | def tl_quantize(x, scale, qzero, maxq): 44 | return tl.clamp(tl_round(x / scale + qzero), 0.0, maxq) 45 | 46 | 47 | @triton.jit 48 | def tl_dequantize(qx, scale, qzero, dtype): 49 | return tl_round_fp((qx - qzero) * scale, dtype) 50 | 51 | 52 | @triton.jit 53 | def tl_dequantize_quantized(x, scale, qzero, maxq, dtype): 54 | return tl_dequantize(tl_quantize(x, scale, qzero, maxq), scale, qzero, dtype) 55 | 56 | 57 | def round_fp(x: torch.Tensor, dtype: torch.dtype = None) -> torch.Tensor: 58 | return x if dtype is None else x.to(dtype=dtype).to(x.dtype) 59 | 60 | 61 | def quantize(x: torch.Tensor, scale: torch.Tensor, qzero: torch.Tensor, maxq: torch.Tensor) -> torch.Tensor: 62 | return (x / scale + qzero).round().clamp(torch.zeros_like(maxq), maxq) 63 | 64 | 65 | def dequantize(qx: torch.Tensor, scale: torch.Tensor, qzero: torch.Tensor, dtype: torch.dtype = None) -> torch.Tensor: 66 | return round_fp((qx - qzero) * scale, dtype) 67 | 68 | 69 | def dequantize_quantized( 70 | x: torch.Tensor, scale: torch.Tensor, qzero: torch.Tensor, maxq: torch.Tensor, dtype: torch.dtype = None 71 | ) -> torch.Tensor: 72 | return dequantize(quantize(x, scale, qzero, maxq), scale, qzero, dtype) 73 | 74 | 75 | def find_quantization_meta( 76 | x: torch.Tensor, 77 | bit_width: int, 78 | symmetric: bool = False, 79 | dtype: torch.dtype = None, 80 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 81 | """ 82 | Find quantization metadata over dim=-1 83 | x: (..., C), weight 84 | bit_width: int 85 | symmetric: bool, whether to set qzero to the middle 86 | dtype: torch.dtype, target scale dtype, fp16 or bf16 87 | """ 88 | epsilon: float = 1e-12 89 | maxq = torch.tensor(2**bit_width - 1, dtype=x.dtype, device=x.device) # () 90 | 91 | x_min = x.amax(dim=-1) 92 | x_max = x.amin(dim=-1) 93 | 94 | if symmetric: 95 | scale = (2.0 / maxq) * torch.maximum(x_min.abs(), x_max.abs()) 96 | scale = round_fp(scale + epsilon, dtype) # (...) 97 | qzero = torch.full_like(scale, ((maxq + 1.0) * 0.5).item()) # (...) 98 | else: 99 | scale = round_fp((x_max - x_min) / maxq + epsilon, dtype) # (...) 100 | qzero = (-x_min / scale).round().clamp(0, maxq) # (...) 101 | return scale, qzero, maxq 102 | 103 | 104 | @triton.jit 105 | def mse_scale_triton_kernel( 106 | x_ptr, 107 | p_ptr, 108 | scale_ptr, 109 | qzero_ptr, 110 | maxq_ptr, 111 | dtype_ptr, 112 | norm: float, 113 | p_size: int, 114 | group_size: int, 115 | batch_size: int, 116 | BLOCK_SIZE_P: tl.constexpr, 117 | BLOCK_SIZE_G: tl.constexpr, 118 | BLOCK_SIZE_B: tl.constexpr, 119 | ): 120 | pid = tl.program_id(axis=0) 121 | b_offsets = pid * BLOCK_SIZE_B + tl.arange(0, BLOCK_SIZE_B) # (R) 122 | b_mask = b_offsets < batch_size # (R) 123 | x_offsets = b_offsets[:, None] * group_size + tl.arange(0, BLOCK_SIZE_G) # (R, C) 124 | x_mask = b_mask[:, None] & (tl.arange(0, BLOCK_SIZE_G) < group_size) # (R, C) 125 | p_offsets = tl.arange(0, BLOCK_SIZE_P) # (P) 126 | p_mask = p_offsets < p_size # (P) 127 | scale_ptrs = scale_ptr + b_offsets # (R) 128 | 129 | x = tl.load(x_ptr + x_offsets, mask=x_mask)[:, None, :] # (R, 1, C) 130 | p = tl.load(p_ptr + p_offsets, mask=p_mask) # (P) 131 | scale = tl.load(scale_ptrs, mask=b_mask) # (R) 132 | qzero = tl.load(qzero_ptr + b_offsets, mask=b_mask)[:, None, None] # (R, 1, 1) 133 | maxq = tl.load(maxq_ptr) # () 134 | dtype = None if dtype_ptr is None else tl.load(dtype_ptr).dtype 135 | 136 | scale_p = tl_round_fp(scale[:, None] * p, dtype)[:, :, None] # (R, P, 1) 137 | q = tl_dequantize_quantized(x, scale_p, qzero, maxq, dtype) # (R, P, C) 138 | best_idx = tl.argmin(tl.sum(tl_pow(q - x, norm), axis=-1), axis=-1, tie_break_left=False) # (R) 139 | 140 | scale = tl_round_fp(scale * tl.load(p_ptr + best_idx), dtype) # (R) # TODO: replace with tl.gather() 141 | tl.store(scale_ptrs, scale, mask=b_mask) # (R) 142 | 143 | 144 | def mse_scale( 145 | x: torch.Tensor, 146 | p: torch.Tensor, 147 | scale: torch.Tensor, 148 | qzero: torch.Tensor, 149 | maxq: torch.Tensor, 150 | dtype: torch.dtype = None, 151 | norm: float = 2.4, 152 | ) -> torch.Tensor: 153 | """ 154 | Find the optimal scale for quantization with respect to the MSE loss 155 | x: (..., C), weight 156 | p: (P), shrinkage factors 157 | scale: (...), initial scale, modified in-place and returned 158 | qzero: (...), zero points 159 | maxq: () 160 | dtype: torch.dtype, target scale dtype, fp16 or bf16 161 | norm: float, norm for the loss 162 | debug_mode: bool, whether to use the baseline implementation without Triton 163 | """ 164 | 165 | assert ( 166 | x.is_contiguous() 167 | and p.is_contiguous() 168 | and scale.is_contiguous() 169 | and qzero.is_contiguous() 170 | and maxq.is_contiguous() 171 | ) 172 | batch_size: int = torch.tensor(x.shape[:-1]).prod().item() 173 | previous_device: torch.device = torch.device(f"cuda:{torch.cuda.current_device()}") 174 | torch.cuda.set_device(x.device) 175 | grid = lambda meta: (triton.cdiv(batch_size, meta["BLOCK_SIZE_B"]),) 176 | mse_scale_triton_kernel[grid]( 177 | x, 178 | p, 179 | scale, 180 | qzero, 181 | maxq, 182 | torch.empty(0, dtype=dtype) if dtype is not None else None, 183 | norm, 184 | p.size(-1), 185 | x.size(-1), 186 | batch_size, 187 | BLOCK_SIZE_P=torch.tensor(p.size(-1)).log2().ceil().exp2().int().item(), 188 | BLOCK_SIZE_G=torch.tensor(x.size(-1)).log2().ceil().exp2().int().item(), 189 | BLOCK_SIZE_B=1, 190 | ) 191 | torch.cuda.set_device(previous_device) 192 | return scale 193 | 194 | 195 | @torch.no_grad() 196 | def get_quantization_grid( 197 | weight: torch.Tensor, 198 | group_size: int, 199 | bits: int, 200 | symmetric: bool = False, 201 | dtype: torch.dtype = None, 202 | quantization_scale: QuantizationScale = QuantizationScale.ABSMAX, 203 | quant_max_shrink: float = 0.2, 204 | quant_n_grid: int = 100, 205 | quant_norm: float = 2.4, 206 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 207 | """ 208 | Get the quantization grid for the weight matrix 209 | weight: (..., (R), C) 210 | scale: (..., (R), C) 211 | qzero: (..., (R), C) 212 | maxq: () 213 | """ 214 | weight = weight.unflatten(dim=-1, sizes=(-1, group_size)) # (..., G, gs) 215 | 216 | scale, qzero, maxq = find_quantization_meta( 217 | x=weight, 218 | bit_width=bits, 219 | symmetric=symmetric, 220 | dtype=dtype, 221 | ) # (..., G), (..., G), () 222 | if quantization_scale == QuantizationScale.MSE: 223 | search_points = torch.linspace(1, quant_max_shrink, quant_n_grid, dtype=weight.dtype, device=weight.device) 224 | mse_scale( 225 | x=weight.contiguous(), # (..., G, gs) 226 | p=search_points, # (..., P) 227 | scale=scale, # (..., G) 228 | qzero=qzero, # (..., G) 229 | maxq=maxq, # () 230 | dtype=dtype, 231 | norm=quant_norm, 232 | ) 233 | 234 | scale = scale.repeat_interleave(group_size, dim=-1) # (..., C) 235 | qzero = qzero.repeat_interleave(group_size, dim=-1) # (..., C) 236 | 237 | weight = weight.flatten(start_dim=-2) # (..., C) 238 | 239 | assert weight.shape == scale.shape == qzero.shape and maxq.shape == () 240 | return scale, qzero, maxq # (..., (R), C), (..., (R), C), () 241 | 242 | 243 | def dequantize_linear_weight( 244 | qweight: torch.Tensor, 245 | scale: torch.Tensor, 246 | zero: torch.Tensor, 247 | perm: Optional[torch.Tensor] = None, 248 | ): 249 | scale = scale.view(qweight.shape[0], -1, 1) 250 | zero = zero.view(qweight.shape[0], -1, 1) 251 | num_groups = scale.shape[1] 252 | weight = dequantize(qweight.view(qweight.shape[0], num_groups, -1), scale, zero).view_as(qweight) 253 | if perm is not None: 254 | invperm = perm.argsort() 255 | weight = weight[:, invperm] 256 | return weight 257 | 258 | 259 | def get_relative_mse_error(q: torch.Tensor, w: torch.Tensor, H: Optional[torch.Tensor] = None): 260 | delta = q - w 261 | if H is None: 262 | return delta.pow(2).mean() / w.pow(2).mean() 263 | else: 264 | return (delta).mm(H).mul(delta).mean() / (w.mm(H).mul(w).mean() + 1e-6) 265 | 266 | 267 | def dequantize_weight_from_fp8(W, s): 268 | g = FP8_GROUP_SIZE 269 | # Dequantize weight 270 | d_out, d_in = W.shape 271 | # Pad weight if needed 272 | pad_out = math.ceil(d_out / g) * g - d_out 273 | pad_in = math.ceil(d_in / g) * g - d_in 274 | W = F.pad(W, (0, pad_in, 0, pad_out)) 275 | d_out_pad, d_in_pad = W.shape 276 | 277 | W = W.view(d_out_pad // g, g, d_in_pad // g, g) 278 | s = s.view(d_out_pad // g, 1, d_in_pad // g, 1) 279 | W = (W * s).view(d_out_pad, d_in_pad) 280 | 281 | # Remove padding 282 | W = W[:d_out, :d_in] 283 | return W 284 | 285 | 286 | def dequantize_state_dict(state_dict: dict[str, torch.Tensor], dtype: torch.dtype = torch.float16) -> None: 287 | state_dict_keys = list(state_dict.keys()) 288 | # Dequantize 289 | for k in state_dict_keys: 290 | if k.endswith("scale_inv"): 291 | layer_name, _ = k.rsplit(".", 1) 292 | 293 | W = state_dict[f"{layer_name}.weight"].to(dtype) 294 | s = state_dict[f"{layer_name}.weight_scale_inv"].to(dtype) 295 | 296 | state_dict[f"{layer_name}.weight"] = dequantize_weight_from_fp8(W, s) 297 | del state_dict[f"{layer_name}.weight_scale_inv"] 298 | 299 | 300 | def can_dequantize_from_fp8(state_dict: dict[str, torch.Tensor]) -> bool: 301 | for k, v in state_dict.items(): 302 | if v.dtype in FP8_DTYPES and f"{k}_scale_inv" not in state_dict: 303 | return False 304 | return True 305 | --------------------------------------------------------------------------------