├── .gitignore
├── README.md
├── pack_quantized_model.py
├── quant.py
└── src
├── data_utils.py
├── dist_utils.py
├── gptq.py
├── gptq_loop.py
├── linalg_utils.py
├── loading_utils.py
├── model_utils.py
└── quant_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # Ruff stuff:
171 | .ruff_cache/
172 |
173 | # PyPI configuration file
174 | .pypirc
175 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## MoE-Quant
2 | ---
3 |
4 | This repository provides code for [GPTQ](https://arxiv.org/abs/2210.17323) quantization of [DeepSeekV3](https://huggingface.co/deepseek-ai/DeepSeek-V3)/[DeepSeekR1](https://huggingface.co/deepseek-ai/DeepSeek-R1) model family.
5 |
6 | ### News 🔥
7 |
8 | - [2025/06] Quantized DeepSeek-R1-0528 model is on 🤗 hub.
9 |
10 | ### Features
11 |
12 | In order to quantize large model (671B parameters) with the `GPTQ` algorithm in reasonable time we introduce several optimizations:
13 |
14 | 1) **Fast `triton` kernel for `GPTQ`**:
15 | Since one has to quantize a lot (really a lot - ~45k) of linear layers, a faster `GPTQ` procedure is critical optimization. The provided `triton` implementation allows one to achieve ~10x relative to default `torch` implementation.
16 | 2) **Expert parallelism**: We shard MLP experts across all devices to fit Hessians into VRAM, required for `GPTQ` calibration. Each process stores only a fraction of expert layers and corresponding Hessians.
17 | 3) **Data parallelism**: To accelerate forward propagation we split calibration data uniformly across processes.
18 |
19 | **The total runtime of the algorithm to quantize DeepSeek-V3/R1 is 2 hours on a server with `8xH100` (for 512 calibration sequences of length 4096).**
20 |
21 | Currently we support conversion of `GPTQ`-quantized model into the [compressed_tensors](https://github.com/neuralmagic/compressed-tensors) format supported in HuggingFace transformers and vLLM.
22 |
23 | At the moment only 4-bit symmetric quantization with different quantization group sizes is supported.
24 | We plan to implement other bit widths and quantization formats (`AWQ`, `AutoGPQ`) in the future.
25 |
26 |
27 | ### GPTQ-quantized models on 🤗
28 |
29 | ---
30 | #### DeepSeek-R1
31 |
32 | | Models | Experts Quantized | Attention blocks quantized | Size (Gb) |
33 | | ------ | --------- | --------- | --------- |
34 | | [ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g](https://huggingface.co/ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g) | ✅ | ✅ | 325 GB |
35 | | [ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g-experts](https://huggingface.co/ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g-experts)| ✅ | ❌ | 346 GB |
36 |
37 | These models easily fit onto single 8x `A100/H100` node with context long enough for most of the applications of interest, including reasoning chains.
38 |
39 | **Evaluation results on OpenLLM Leaderboard V1 tasks**
40 |
41 | | | Recovery (%) | Average Score | ARC-Challenge
acc_norm, 25-shot | GSM8k
exact_match, 5-shot | HellaSwag
acc_norm, 10-shot | MMLU
acc, 5-shot | TruthfulQA
mc2, 0-shot | WinoGrande
acc, 5-shot |
42 | | :------------------------------------------: | :----------: | :-----------: | :--------------------------------: | :--------------------------: | :----------------------------: | :-----------------: | :-----------------------: | :-----------------------: |
43 | | deepseek/DeepSeek-R1 | 100.00 | 81.04 | 72.53 | 95.91 | 89.30 | 87.22 | 59.28 | 82.00 |
44 | | cognitivecomputations/DeepSeek-R1-AWQ | 100.07 | 81.10 | 73.12 | 95.15 | 89.07 | 86.86 | 60.09 | 82.32 |
45 | | ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g | 99.86 | 80.93 | 72.70 | 95.68 | 89.25 | 86.83 | 58.77 | 82.32 |
46 | | ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g-experts | 100.30 | 81.28 | 72.53 | 95.68 | 89.36 | 86.99 | 59.77 | 83.35 |
47 |
48 | **Evaluation results on reasoning tasks (AIME-24, GPQA-Diamond, MATH-500)**
49 |
50 | | | Recovery (%) | Average Score | AIME 2024
pass@1 | MATH-500
pass@1 | GPQA Diamond
pass@1 |
51 | | -------------------------------------------- | :----------: | :-----------: | :-----------------: | :----------------: | :--------------------: |
52 | | deepseek/DeepSeek-R1 | 100.00 | 82.99 | 78.33 | 97.24 | 73.38 |
53 | | cognitivecomputations/DeepSeek-R1-AWQ | 94.29 | 78.25 | 70.67 | 93.64 | 70.46 |
54 | | ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g | 96.52 | 80.10 | 72.96 | 97.09 | 70.26 |
55 | | ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g-experts | **98.81** | 82.00 | 77.00 | 97.08 | 71.92 |
56 |
57 | ---
58 | #### DeepSeek-R1-0528
59 |
60 | | Models | Experts Quantized | Attention blocks quantized | Size (Gb) |
61 | | ------ | --------- | --------- | --------- |
62 | | [ISTA-DASLab/DeepSeek-R1-0528-GPTQ-4b-128g-experts](https://huggingface.co/ISTA-DASLab/DeepSeek-R1-0528-GPTQ-4b-128g-experts)| ✅ | ❌ | 346 GB |
63 |
64 | **Evaluation results on reasoning tasks (AIME-24, GPQA-Diamond, MATH-500)**
65 |
66 | | | Recovery (%) | Average Score | AIME 2024
pass@1 | MATH-500
pass@1 | GPQA Diamond
pass@1 |
67 | | ------------------------------------------- | :----------: | :-----------: | :-----------------: | :----------------: | :--------------------: |
68 | | deepseek/DeepSeek-R1-0528 | 100.00 | 88.61 | 88.66 | 97.52 | 79.65 |
69 | | ISTA-DASLab/DeepSeek-R1-0528-GPTQ-4b-128g-experts | 99.82 | 88.45 | 87.33 | 97.40 | 80.61 |
70 |
71 | ### Usage
72 |
73 | **Model quantization**
74 |
75 | ```shell
76 | torchrun --nnodes=1 --nproc-per-node=$NUM_GPUS --master_port 29501 quant.py \
77 | --model_name_or_path $MODEL_PATH \
78 | --dataset_name_or_path $DATASET \
79 | --num_calibration_samples 512 \
80 | --max_sequence_length 4096 \
81 | --bits 4 \
82 | --group_size 128 \
83 | --rel_damp 0.1 \
84 | --sym \
85 | --offload_activations \
86 | --quantization_order $QUANTIZATION_ORDER \
87 | --quantization_scale $QUANTIZATION_SCALE \
88 | --quantize_only_experts \
89 | --tie_gptq_handles \
90 | --dtype bfloat16 \
91 | --save_dir
92 | ```
93 |
94 | Above:
95 | * `--model_name_or_path` - **exact path** to model weights, say (`$HF_HOME/hub/models/models--deepseek-ai--DeepSeek-V3-0324/snapshots/commit_hash/`)
96 | * `--dataset_name_or_path` - dataset used for calibration. We provide 3 choices `open-thoughts`, `open-platypus`, `fineweb-edu`
97 | * `--num_calibration_samples` - number of calibration samples
98 | * `--max_sequence_length` - maximal length of calibration samples (samples longer are capped to this value)
99 | * `--quantization_order` - `default` or `activation`, we recommend using the latter for best results
100 | * `--quantization_scale` - `absmax` or `mse`, we recommend using the latter for best results
101 | * `--quantize_only_experts` - quantize only *non-shared* experts. Yields potentially better accuracy at the cost of slightly higher memory overhead.
102 | * `--tie_gptq_handles` - reuse the same Hessian for `up` and `gate` projections to reduce memory overhead on quantization
103 | * `--save_dir` - directory to save the model
104 |
105 | The scripts above produces a directory with quantization metadata for each quantized layer, i.e `quantized_weight`, `scale`, and `zero`.
106 |
107 | **Model packing**
108 |
109 | To convert the model into `compressed_tensors` format run `pack_quantized_model.py` script
110 |
111 | ```shell
112 | python pack_quantized_model.py \
113 | --model_name_or_path $MODEL_PATH \
114 | --quantized_model_path $QUANTIZED_MODEL_PATH \
115 | --packed_model_path $QUANTIZED_MODEL_PATH-packed \
116 | --dtype bfloat16
117 | ```
118 |
119 | Above:
120 | * `--model_name_or_path` - **exact path** to model weights
121 | * `--quantized_model_path` - path to quantized weights (output of `quant.py`)
122 | * `--packed_model_path` - path to model in `compressed_tensors` format ready for inference in HF and vLLM.
123 |
124 | ### Environment
125 |
126 | This code was tested with the following versions of libraries:
127 | * `torch 2.5.1`
128 | * `transformers 4.50.0`
129 | * `vllm 0.8.2`
130 |
131 | ### Performance benchmarking
132 | We follow the standard vLLM performance benchmarking with ShareGPT dataset and observe the following metrics (lower is better):
133 |
134 | | | Time to First Token
Median TTFT (ms) ↓ | Time per Output Token
Median TPOT (ms) ↓ | Inter-token Latency
Median ITL (ms) ↓ |
135 | | -------------------------------------------- | :-------------------------------------: | :---------------------------------------: | :------------------------------------: |
136 | | cognitivecomputations/DeepSeek-R1-AWQ | 1585.45 | 55.41 | 43.06 |
137 | | ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g-experts | 1344.68 | 41.49 | 36.33 |
138 | | ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g | 815.19 | 44.65 | 37.88 |
139 |
140 | GPTQ models are faster across all metrics than AWQ models because GPTQ uses less bits-per-parameter than AWQ. More specifically, AWQ has to use smaller group-size of 64 (vs 128 in GPTQ) to preserve accuracy, and zero-points due to asymmetric quantization.
141 |
142 | ### Contributors
143 |
144 | Denis Kuznedelev (Yandex), Eldar Kurtić (Red Hat AI & ISTA), Jiale Chen (ISTA), Michael Goin (Red Hat AI), Elias Frantar (ISTA), Dan Alistarh (Red Hat AI & ISTA).
145 |
146 | ### Citation
147 |
148 | ```
149 | @article{gptq,
150 | title={{GPTQ}: Accurate Post-training Compression for Generative Pretrained Transformers},
151 | author={Elias Frantar and Saleh Ashkboos and Torsten Hoefler and Dan Alistarh},
152 | year={2022},
153 | journal={arXiv preprint arXiv:2210.17323}
154 | }
155 | ```
156 |
--------------------------------------------------------------------------------
/pack_quantized_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gc
3 | import json
4 | import shutil
5 | import argparse
6 | from collections import defaultdict
7 | from typing import Optional, Any
8 |
9 | from tqdm import tqdm
10 | import torch
11 | from safetensors.torch import save_file
12 | from accelerate import init_empty_weights
13 | from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
14 | from compressed_tensors.compressors import pack_to_int32
15 |
16 | from src import quant_utils
17 | from src import loading_utils
18 |
19 |
20 | def parse_args():
21 | parser = argparse.ArgumentParser()
22 | # Model params
23 | parser.add_argument(
24 | "--model_name_or_path",
25 | type=str,
26 | required=True,
27 | help="The name or path to the DeepSeek model",
28 | )
29 | parser.add_argument(
30 | "--quantized_model_path",
31 | type=str,
32 | required=True,
33 | help="Path to quantized model."
34 | )
35 | parser.add_argument(
36 | "--packed_model_path",
37 | type=str,
38 | required=True,
39 | help="Whether to save packed model."
40 | )
41 | # Misc params
42 | parser.add_argument(
43 | "--dtype",
44 | default="float16",
45 | type=str,
46 | choices=["float16", "bfloat16"],
47 | help="Torch dtype used."
48 | )
49 | args = parser.parse_args()
50 | return args
51 |
52 |
53 | def is_subset(set1: set, set2: set):
54 | return set1 <= set2
55 |
56 |
57 | def pack_weight(
58 | weight: dict[torch.Tensor],
59 | bits: int,
60 | sym: bool,
61 | group_size: Optional[int] = None,
62 | ) -> dict[torch.Tensor]:
63 | compressed_data = {}
64 | qweight, scale, zero = weight['qweight'], weight['scale'], weight['zero']
65 | group_size = group_size or qweight.shape[-1]
66 | qweight_shifted = qweight.to(torch.int8) - zero.repeat_interleave(group_size, dim=-1).to(torch.int8)
67 | qweight_packed = pack_to_int32(qweight_shifted, bits)
68 | compressed_data = {
69 | "weight_packed": qweight_packed,
70 | "weight_shape": torch.tensor(qweight.shape),
71 | "weight_scale": scale
72 | }
73 | if not sym:
74 | compressed_data["weight_zero_point"] = weight['zero']
75 | return compressed_data
76 |
77 |
78 | def prepare_quantization_config(args: argparse.Namespace) -> dict[str, Any]:
79 | ignored_modules = ["lm_head"]
80 | if args.quantize_only_experts:
81 | ignored_modules += ["re:.*self_attn.*", "re:.*shared_experts.*", "re:.*mlp\.(gate|up|gate_up|down)_proj.*"]
82 | return {
83 | "config_groups": {
84 | "group_0": {
85 | "input_activations": None,
86 | "output_activations": None,
87 | "targets": [
88 | "Linear"
89 | ],
90 | "weights": {
91 | "actorder": None,
92 | "block_structure": None,
93 | "dynamic": False,
94 | "group_size": args.group_size,
95 | "num_bits": args.bits,
96 | "observer": "minmax",
97 | "observer_kwargs": {},
98 | "strategy": "group",
99 | "symmetric": True,
100 | "type": "int"
101 | }
102 | }
103 | },
104 | "format": "pack-quantized",
105 | "ignore": ignored_modules,
106 | "kv_cache_scheme": None,
107 | "quant_method": "compressed-tensors",
108 | "quantization_status": "compressed"
109 | }
110 |
111 |
112 | def main():
113 | args = parse_args()
114 |
115 | dtype = getattr(torch, args.dtype)
116 |
117 | # Load DeepSeek model
118 | config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True)
119 | if hasattr(config, "quantization_config"):
120 | delattr(config, "quantization_config")
121 |
122 | with init_empty_weights():
123 | model = AutoModelForCausalLM.from_config(
124 | config=config,
125 | trust_remote_code=True,
126 | torch_dtype=torch.bfloat16
127 | ).eval()
128 | model.config.use_cache = False
129 |
130 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
131 |
132 | # Load quantization metadata
133 | metadata = torch.load(os.path.join(args.quantized_model_path, "metadata.pt"))
134 | args.bits = metadata["bits"]
135 | args.group_size = metadata["group_size"]
136 | args.quantize_only_experts = metadata["quantize_only_experts"]
137 | # Currently we do not support asymmetric quantization
138 | args.sym = True
139 |
140 | num_output_shards = len(model.model.layers) + 2
141 | current_output_shard_id = 1
142 | quantized_layer_names = defaultdict(list)
143 | for layer_name in sorted(os.listdir(args.quantized_model_path)):
144 | if os.path.isdir(os.path.join(args.quantized_model_path, layer_name)):
145 | block_idx = int(layer_name.split(".")[2])
146 | quantized_layer_names[block_idx].append(layer_name)
147 | safetensors_index = {}
148 | # Prepare directory to save packed weights
149 | os.makedirs(args.packed_model_path, exist_ok=True)
150 |
151 | # Load initial weight shard
152 | weight_dir = args.model_name_or_path
153 | current_input_shard_id = 1
154 | weight_path = f"model-{current_input_shard_id:05}-of-000163.safetensors"
155 |
156 | param_buffer = loading_utils.load_param_shard(weight_dir, weight_path)
157 |
158 | # Save embeddings
159 | current_output_shard_path = f"model-{current_output_shard_id:05}-of-{num_output_shards:05}.safetensors"
160 | save_file(
161 | {"model.embed_tokens.weight": param_buffer["model.embed_tokens.weight"]},
162 | os.path.join(args.packed_model_path, current_output_shard_path)
163 | )
164 | safetensors_index["model.embed_tokens.weight"] = current_output_shard_path
165 | del param_buffer["model.embed_tokens.weight"]
166 |
167 | # Process blocks
168 | for block_idx, block in tqdm(
169 | enumerate(model.model.layers),
170 | desc="Processing transformer blocks",
171 | total=len(model.model.layers)
172 | ):
173 | current_output_shard_id += 1
174 | prefix = f"model.layers.{block_idx}."
175 | block_keys_with_prefix = set(f"{prefix}{k}" for k in block.state_dict())
176 |
177 | while not is_subset(block_keys_with_prefix, set(param_buffer.keys())):
178 | current_input_shard_id += 1
179 | weight_path = f"model-{current_input_shard_id:05}-of-000163.safetensors"
180 | param_buffer.update(loading_utils.load_param_shard(weight_dir, weight_path))
181 |
182 | block_state_dict = {k: param_buffer[k] for k in param_buffer if k.startswith(prefix)}
183 | quant_utils.dequantize_state_dict(block_state_dict, dtype)
184 |
185 | for layer_name in quantized_layer_names[block_idx]:
186 | weight_state_dict = torch.load(
187 | os.path.join(args.quantized_model_path, layer_name, "quantized_weight.pt"),
188 | weights_only=True,
189 | map_location="cpu"
190 | )
191 | packed_weight_state_dict = pack_weight(weight_state_dict, args.bits, args.sym, args.group_size)
192 | block_state_dict.pop(f"{layer_name}.weight")
193 | block_state_dict.pop(f"{layer_name}.weight_scale_inv", None)
194 | block_state_dict.update({f"{layer_name}.{k}": v for k, v in packed_weight_state_dict.items()})
195 |
196 | # Save block
197 | current_output_shard_path = f"model-{current_output_shard_id:05}-of-{num_output_shards:05}.safetensors"
198 | save_file(
199 | block_state_dict,
200 | os.path.join(args.packed_model_path, current_output_shard_path)
201 | )
202 | for k in block_state_dict:
203 | safetensors_index[k] = current_output_shard_path
204 |
205 | for k in block_keys_with_prefix:
206 | param_buffer.pop(k, None)
207 |
208 | del block_state_dict
209 | gc.collect()
210 |
211 | # Load final shard
212 | if current_input_shard_id < 163:
213 | current_input_shard_id = 163
214 | weight_path = f"model-{current_input_shard_id:05}-of-000163.safetensors"
215 | param_buffer.update(loading_utils.load_param_shard(weight_dir, weight_path))
216 |
217 | # Save lm head
218 | current_output_shard_id += 1
219 | current_output_shard_path = f"model-{current_output_shard_id:05}-of-{num_output_shards:05}.safetensors"
220 | save_file(
221 | {
222 | "lm_head.weight": param_buffer["lm_head.weight"],
223 | "model.norm.weight": param_buffer["model.norm.weight"]
224 | },
225 | os.path.join(args.packed_model_path, current_output_shard_path)
226 | )
227 | safetensors_index["lm_head.weight"] = current_output_shard_path
228 | safetensors_index["model.norm.weight"] = current_output_shard_path
229 | # Save safetensors index
230 | with open(os.path.join(args.packed_model_path, "model.safetensors.index.json"), "w") as f:
231 | json.dump({"metadata": {}, "weight_map": safetensors_index}, f)
232 | # Add quantization metadata
233 | config.quantization_config = prepare_quantization_config(args)
234 | # Save configs
235 | config.save_pretrained(args.packed_model_path)
236 | model.generation_config.save_pretrained(args.packed_model_path)
237 | # Save tokenizer
238 | tokenizer.save_pretrained(args.packed_model_path)
239 | # Copy modeling script
240 | shutil.copy(
241 | os.path.join(args.model_name_or_path, "modeling_deepseek.py"),
242 | args.packed_model_path
243 | )
244 |
245 |
246 | if __name__ == "__main__":
247 | main()
--------------------------------------------------------------------------------
/quant.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gc
3 | import re
4 | import argparse
5 |
6 | from tqdm import tqdm
7 | import torch
8 | import torch.distributed as dist
9 | from accelerate import init_empty_weights
10 | from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
11 |
12 | try:
13 | import wandb
14 | wandb_enabled = True
15 | except:
16 | wandb_enabled = False
17 |
18 |
19 | from src import dist_utils, data_utils, model_utils, quant_utils, loading_utils, gptq
20 |
21 |
22 | ROUTED_EXPERTS_REGEX = ".*mlp.experts.\d+.(down|gate|up)_proj$"
23 | TIED_FFN_GROUPS = ("gate_proj", "up_proj")
24 |
25 |
26 | def parse_args():
27 | parser = argparse.ArgumentParser()
28 | # Model params
29 | parser.add_argument(
30 | "--model_name_or_path",
31 | type=str,
32 | required=True,
33 | help="The name or path to the DeepSeek model",
34 | )
35 | # Data params
36 | parser.add_argument(
37 | "--dataset_name_or_path",
38 | type=str,
39 | required=True,
40 | help="The name or path to calibration dataset",
41 | )
42 | parser.add_argument("--num_calibration_samples", default=128, type=int, help="Number of samples for calibration.")
43 | parser.add_argument("--max_sequence_length", default=8192, type=int, help="Calibration sequence length.")
44 | # Quantization params
45 | parser.add_argument(
46 | "--bits",
47 | type=int,
48 | default=4,
49 | choices=[4],
50 | help="Quantization bitwidth.",
51 | )
52 | parser.add_argument(
53 | "--group_size",
54 | type=int,
55 | default=None,
56 | help="How many weight columns (input features) are quantized with the same statistics, default = all of them",
57 | )
58 | parser.add_argument("--sym", action="store_true", help="Whether to use symmetric quantization")
59 | parser.add_argument("--rel_damp", type=float, default=1e-2)
60 | parser.add_argument("--block_size", type=int, default=128)
61 | parser.add_argument("--quantization_scale", type=str, default="absmax", choices=["absmax", "mse"])
62 | parser.add_argument("--quantization_order", type=str, default="default", choices=["default", "activation"])
63 | parser.add_argument(
64 | "--quantize_only_experts",
65 | default=False,
66 | action="store_true",
67 | help="Whether to quantize only routed (non-shared) experts.",
68 | )
69 | # Save params
70 | parser.add_argument("--save_dir", type=str, default=None, help="where to save quantized model.")
71 | # Logging params
72 | parser.add_argument("--log_wandb", default=False, action="store_true", help="Log to W&B")
73 | parser.add_argument("--log_error", default=False, action="store_true", help="Whether to log relative L2 error")
74 | # Misc params
75 | parser.add_argument("--offload_activations", action="store_true", help="whether to offload activations to CPU.")
76 | parser.add_argument("--tie_gptq_handles", action="store_true", help="whether to reuse hessian between gate and up projections.")
77 | parser.add_argument("--resume", action="store_true", help="whether to resume quantization from latest checkpoint.")
78 | parser.add_argument("--seed", default=0, type=int, help="Random seed.")
79 | parser.add_argument(
80 | "--dtype", default="float16", type=str, choices=["float16s", "bfloat16"], help="Torch dtype used."
81 | )
82 | args = parser.parse_args()
83 |
84 | return args
85 |
86 |
87 | def is_subset(set1: set, set2: set):
88 | return set1 <= set2
89 |
90 |
91 | def get_resume_block_idx(save_dir: os.PathLike) -> int:
92 | resume_block_idx = 0
93 | if os.path.exists(save_dir):
94 | for layer_name in os.listdir(save_dir):
95 | block_idx = int(layer_name.split(".")[2])
96 | resume_block_idx = max(resume_block_idx, block_idx)
97 | return resume_block_idx
98 |
99 |
100 | def main():
101 | args = parse_args()
102 | # Distributed init
103 | if dist.is_available():
104 | dist.init_process_group(backend="nccl", init_method="env://")
105 | world_size = dist_utils.get_world_size()
106 | rank = dist_utils.get_rank()
107 | # init device
108 | device = f"cuda:{rank}"
109 | torch.set_grad_enabled(False)
110 | torch.cuda.set_device(device)
111 | offload_device = "cpu" if args.offload_activations else None
112 | dtype = getattr(torch, args.dtype)
113 | # Init W&B logger
114 | if args.log_wandb and dist_utils.is_main():
115 | assert wandb_enabled, "wandb not installed. try `pip install wandb`"
116 | wandb.init(config=args)
117 |
118 | # Load DeepSeek model
119 | config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True)
120 | # Sanity check
121 | assert config.architectures == ["DeepseekV3ForCausalLM"], "Only DeepseekV3 is supported!"
122 | if hasattr(config, "quantization_config"):
123 | delattr(config, "quantization_config")
124 | config.ep_size = world_size
125 |
126 | with init_empty_weights():
127 | model = AutoModelForCausalLM.from_config(
128 | config=config, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=dtype
129 | ).eval()
130 | model.config.use_cache = False
131 |
132 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
133 |
134 | # Prepare calibration dataset
135 | calibration_dataset = data_utils.prepare_calibration_dataset(
136 | args.dataset_name_or_path, tokenizer, args.max_sequence_length, args.num_calibration_samples, args.seed
137 | )
138 |
139 | # Take slices (if running on multiple workers)
140 | num_seq_per_rank = len(calibration_dataset) // world_size
141 | calibration_dataset = calibration_dataset[rank * num_seq_per_rank : (rank + 1) * num_seq_per_rank]
142 | dist_utils.barrier(device_ids=[rank])
143 |
144 | # Load initial weight shard
145 | weight_dir = args.model_name_or_path
146 | current_shard_id = 1
147 | weight_path = f"model-{current_shard_id:05}-of-000163.safetensors"
148 |
149 | param_buffer = {}
150 | if dist_utils.is_main():
151 | param_buffer = loading_utils.load_param_shard(weight_dir, weight_path)
152 | dist_utils.barrier(device_ids=[rank])
153 |
154 | # Get resume block id
155 | resume_block_idx = 0
156 | if args.resume:
157 | resume_block_idx = get_resume_block_idx(args.save_dir)
158 |
159 | # Prepare input embeddings and position ids
160 | inputs = []
161 | position_ids = []
162 | model.model.embed_tokens.to_empty(device=device)
163 | if dist_utils.is_main():
164 | model.model.embed_tokens.weight.data = param_buffer["model.embed_tokens.weight"].to(device=device, dtype=dtype)
165 | if dist_utils.is_dist_available_and_initialized():
166 | dist_utils.broadcast_parameters(model.model.embed_tokens)
167 | for i in range(num_seq_per_rank):
168 | seq_length = calibration_dataset[i].shape[1]
169 | inputs.append(model.model.embed_tokens(calibration_dataset[i].to(device)).to(offload_device))
170 | position_ids.append(torch.arange(0, seq_length, dtype=torch.long, device=device).unsqueeze(0))
171 | # Offload embeddings back to meta
172 | model.model.embed_tokens.to(device="meta")
173 | param_buffer.pop("model.embed_tokens.weight", None)
174 |
175 | for block_idx, block in tqdm(
176 | enumerate(model.model.layers), desc="Processing transformer blocks", total=len(model.model.layers)
177 | ):
178 | prefix = f"model.layers.{block_idx}."
179 |
180 | # Collect state dict keys from all processes
181 | rank_block_keys = [k for k in block.state_dict()]
182 | if dist_utils.is_main():
183 | block_keys_with_prefix = [f"{prefix}{k}" for k in rank_block_keys]
184 | other_ranks_keys = []
185 | for i in range(1, world_size):
186 | other_rank_keys = [None for _ in rank_block_keys]
187 | dist.recv_object_list(other_rank_keys, src=i)
188 | block_keys_with_prefix.extend([f"{prefix}{k}" for k in other_rank_keys])
189 | other_ranks_keys.append(other_rank_keys)
190 | # Make it a set
191 | block_keys_with_prefix = set(block_keys_with_prefix)
192 | else:
193 | block_keys_with_prefix = []
194 | other_ranks_keys = []
195 | dist.send_object_list(rank_block_keys, dst=0)
196 |
197 | if dist_utils.is_main():
198 | can_dequantize = True
199 | # Select weights corresponding to current block
200 | block_state_dict = {k[len(prefix) :]: v for k, v in param_buffer.items() if k.startswith(prefix)}
201 | while not (is_subset(block_keys_with_prefix, set(param_buffer.keys())) and can_dequantize):
202 | current_shard_id += 1
203 | weight_path = f"model-{current_shard_id:05}-of-000163.safetensors"
204 | param_buffer.update(loading_utils.load_param_shard(weight_dir, weight_path))
205 | # Update weights corresponding to current block
206 | block_state_dict = {k[len(prefix) :]: v for k, v in param_buffer.items() if k.startswith(prefix)}
207 | can_dequantize = quant_utils.can_dequantize_from_fp8(block_state_dict)
208 | # Dequantize weights corresponding to current block
209 | quant_utils.dequantize_state_dict(block_state_dict, dtype)
210 |
211 | # Put block onto GPU
212 | block.to_empty(device=device)
213 |
214 | # Simply load block state dict on master and broadcast
215 | if block_idx < model.config.first_k_dense_replace:
216 | if dist_utils.is_main():
217 | block.load_state_dict(block_state_dict)
218 | if dist_utils.is_dist_available_and_initialized():
219 | dist_utils.broadcast_parameters(block)
220 | # Send dict with part of expets to target device
221 | else:
222 | if dist_utils.is_main():
223 | # Load state dict on master
224 | rank_state_dict = {k: block_state_dict[k] for k in rank_block_keys}
225 | block.load_state_dict(rank_state_dict)
226 | # Send to other processes
227 | for i in range(1, world_size):
228 | rank_state_dict = {k: block_state_dict[k] for k in other_ranks_keys[i - 1]}
229 | for k in rank_state_dict:
230 | dist.send(rank_state_dict[k].to(device), dst=i)
231 | else:
232 | rank_state_dict = block.state_dict()
233 | for k in rank_state_dict:
234 | dist.recv(rank_state_dict[k], src=0)
235 | block.load_state_dict(rank_state_dict)
236 | del rank_state_dict
237 | # Clear memory before calibration
238 | torch.cuda.empty_cache()
239 | gc.collect()
240 |
241 | if block_idx >= resume_block_idx:
242 | # Hessian estimate
243 | layers = model_utils.select_layers(model, prefix, ".*", model_utils.LINEAR_LAYERS)
244 | handles = {}
245 | hooks = {}
246 |
247 | for layer_name, layer in layers.items():
248 |
249 | def update_handle_hook(name):
250 | def _hook(_, inp, out):
251 | handles[name].update(inp[0])
252 |
253 | return _hook
254 |
255 | if args.quantize_only_experts and re.search(ROUTED_EXPERTS_REGEX, layer_name) is None:
256 | continue
257 |
258 | tied_gptq_handle = None
259 | if args.tie_gptq_handles and layer_name.endswith("up_proj"):
260 | parent_name, _ = layer_name.rsplit(".", 1)
261 | tied_layer_name = f"{parent_name}.gate_proj"
262 | tied_gptq_handle = handles[tied_layer_name]
263 |
264 | handles[layer_name] = gptq.GPTQ(
265 | layer,
266 | args.group_size,
267 | args.sym,
268 | args.rel_damp,
269 | args.block_size,
270 | args.quantization_order,
271 | args.quantization_scale,
272 | is_distributed=re.search(ROUTED_EXPERTS_REGEX, layer_name) is None,
273 | tied_gptq_handle=tied_gptq_handle
274 | )
275 |
276 | if tied_gptq_handle is None:
277 | hooks[layer_name] = layer.register_forward_hook(update_handle_hook(layer_name))
278 |
279 | # Collect Hessians
280 | for i in range(num_seq_per_rank):
281 | block(inputs[i].to(device), position_ids=position_ids[i])
282 |
283 | for _, h in hooks.items():
284 | h.remove()
285 |
286 | dist_utils.barrier(device_ids=[rank])
287 |
288 | shared_handles = {k: v for k, v in handles.items() if re.search(ROUTED_EXPERTS_REGEX, k) is None}
289 | expert_handles = {k: v for k, v in handles.items() if k not in shared_handles}
290 |
291 | # Quantized shared handles first
292 | num_issue_zero_samples = 0
293 | num_issue_nan_hessian = 0
294 | num_issue_non_invertible = 0
295 | for handle_name, handle in shared_handles.items():
296 | dist_utils.print_on_main(f"Quantizing layer {handle_name}")
297 | qweight, scale, zero = handle.quantize(args.bits)
298 | # Construct dequantized weight
299 | dequantized_weight = quant_utils.dequantize_linear_weight(qweight, scale, zero)
300 | assert (
301 | torch.isfinite(dequantized_weight).all().item()
302 | ), f"[rank{rank}] {handle_name} weight is broken after quantization."
303 | # Update issue tracker
304 | num_issue_zero_samples += handle.issue_zero_samples
305 | num_issue_nan_hessian += handle.issue_nan_hessian
306 | num_issue_non_invertible += handle.issue_non_invertible
307 |
308 | if args.log_error:
309 | if handle.has_hessian_issues():
310 | dist_utils.print_on_main(
311 | "An issue occured on Hessian computation. Output error cannot be estimated."
312 | )
313 | else:
314 | relative_mse = quant_utils.get_relative_mse_error(
315 | dequantized_weight.float(), handle.layer.weight.float(), handle.H
316 | )
317 | dist_utils.print_on_main(f"Relative error: {relative_mse.item():.2e}")
318 | if args.log_wandb and dist_utils.is_main():
319 | wandb.log({f"relative_error/{handle_name}": relative_mse.item()}, step=0)
320 |
321 | if args.save_dir and dist_utils.is_main():
322 | os.makedirs(os.path.join(args.save_dir, handle_name), exist_ok=True)
323 | torch.save(
324 | {"qweight": qweight, "scale": scale, "zero": zero},
325 | os.path.join(args.save_dir, handle_name, f"quantized_weight.pt"),
326 | )
327 | # Replace original weight by quantized one
328 | handle.layer.weight.data = dequantized_weight
329 | # Destroy handle
330 | handle.reset()
331 |
332 | dist_utils.print_on_main("-" * 10)
333 | dist_utils.print_on_main(f"GPTQ calibration issues for shared modules:")
334 | dist_utils.print_on_main(f"Zero Hessian: {num_issue_zero_samples}")
335 | dist_utils.print_on_main(f"Non-invertible: {num_issue_non_invertible}")
336 | dist_utils.print_on_main(f"NaN Hessian: {num_issue_nan_hessian}")
337 | dist_utils.print_on_main("-" * 10)
338 |
339 | # Quantize experts
340 | num_issue_zero_samples = 0
341 | num_issue_nan_hessian = 0
342 | num_issue_non_invertible = 0
343 | if len(expert_handles) > 0:
344 | dist_utils.print_on_main(f"Processing experts")
345 |
346 | expert_messages = None
347 | if dist_utils.is_main():
348 | expert_messages = [None for _ in range(world_size)]
349 | rank_expert_message = ""
350 |
351 | for handle_name, handle in expert_handles.items():
352 | rank_expert_message += f"Quantizing layer {handle_name}\n"
353 | qweight, scale, zero = handle.quantize(args.bits)
354 | # Construct dequantized weight
355 | dequantized_weight = quant_utils.dequantize_linear_weight(qweight, scale, zero)
356 | assert (
357 | torch.isfinite(dequantized_weight).all().item()
358 | ), f"[rank{rank}] {handle_name} weight is broken after quantization."
359 | # Update issue tracker
360 | num_issue_zero_samples += handle.issue_zero_samples
361 | num_issue_nan_hessian += handle.issue_nan_hessian
362 | num_issue_non_invertible += handle.issue_non_invertible
363 |
364 | rank_expert_message += f"Tokens collected: {handle.tokens_collected}.\n"
365 |
366 | if args.log_error:
367 | if handle.has_hessian_issues():
368 | rank_expert_message += "Hessian issue. Output error cannot be estimated.\n"
369 | else:
370 | relative_mse = quant_utils.get_relative_mse_error(
371 | dequantized_weight.float(), handle.layer.weight.float(), handle.H
372 | )
373 | rank_expert_message += f"Relative error: {relative_mse.item():.2e}\n"
374 | # TODO send to main process
375 | if args.log_wandb and dist_utils.is_main():
376 | wandb.log({f"relative_error/{handle_name}": relative_mse.item()}, step=0)
377 |
378 | if args.save_dir:
379 | os.makedirs(os.path.join(args.save_dir, handle_name), exist_ok=True)
380 | torch.save(
381 | {"qweight": qweight, "scale": scale, "zero": zero},
382 | os.path.join(args.save_dir, handle_name, f"quantized_weight.pt"),
383 | )
384 | # Replace original weight by quantized one
385 | handle.layer.weight.data = dequantized_weight
386 | # Destroy handle
387 | handle.reset()
388 |
389 | dist_utils.barrier(device_ids=[rank])
390 |
391 | dist.gather_object(rank_expert_message, expert_messages)
392 | if dist_utils.is_main():
393 | for expert_message in expert_messages:
394 | dist_utils.print_on_main(expert_message)
395 |
396 | # TODO sync data from other processes
397 | dist_utils.print_on_main("-" * 10)
398 | dist_utils.print_on_main(f"GPTQ calibration issues for expert modules:")
399 | dist_utils.print_on_main(f"Zero Hessian: {num_issue_zero_samples}")
400 | dist_utils.print_on_main(f"Non-invertible: {num_issue_non_invertible}")
401 | dist_utils.print_on_main(f"NaN Hessian: {num_issue_nan_hessian}")
402 | dist_utils.print_on_main("-" * 10)
403 |
404 | del handles
405 | del shared_handles
406 | del expert_handles
407 | del hooks
408 | torch.cuda.empty_cache()
409 | gc.collect()
410 | else:
411 | dist_utils.print_on_main(f"Block {block_idx} is already quantized. Skipping quantization.")
412 |
413 | # Update activations
414 | for i in range(num_seq_per_rank):
415 | inputs[i] = block(inputs[i].to(device), position_ids=position_ids[i])[0].to(offload_device)
416 | assert torch.isfinite(inputs[i]).all().item(), "NaN of inf encountered."
417 |
418 | # Offload block
419 | block.to(device="meta")
420 | for k in block_keys_with_prefix:
421 | param_buffer.pop(k, None)
422 |
423 | torch.cuda.empty_cache()
424 | gc.collect()
425 |
426 | # Save quantization metadata
427 | if args.save_dir:
428 | torch.save(
429 | {
430 | "bits": args.bits,
431 | "group_size": args.group_size,
432 | "quantize_only_experts": args.quantize_only_experts
433 | },
434 | os.path.join(args.save_dir, "metadata.pt")
435 | )
436 |
437 | dist.destroy_process_group()
438 |
439 |
440 | if __name__ == "__main__":
441 | main()
442 |
--------------------------------------------------------------------------------
/src/data_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List
2 |
3 | import torch
4 | from datasets import load_dataset
5 | from transformers import AutoTokenizer
6 |
7 |
8 | def prepare_open_thoughts(
9 | tokenizer: AutoTokenizer,
10 | max_sequence_length: int,
11 | num_calibration_samples: Optional[int] = None,
12 | seed: int = 42
13 | ) -> List[torch.Tensor]:
14 | train_dataset_raw = load_dataset("open-thoughts/OpenThoughts-114k", split="train")
15 | if num_calibration_samples:
16 | train_dataset_raw = train_dataset_raw.shuffle(seed=seed).select(range(num_calibration_samples))
17 | # Preprocess the data into the format the model is trained with.
18 | def preprocess(example):
19 | messages = []
20 | # add system prompt
21 | messages.append({"role": "system", "content": example['system']})
22 | # add dialogue
23 | for message in example['conversations']:
24 | messages.append({"role": message["from"], "content": message["value"]})
25 | return {"text": tokenizer.apply_chat_template(messages, tokenize=False)}
26 | train_dataset_raw = train_dataset_raw.map(preprocess)
27 | # Tokenize the data
28 | def tokenize(sample):
29 | return tokenizer(
30 | sample["text"],
31 | padding=False,
32 | max_length=max_sequence_length,
33 | truncation=True,
34 | add_special_tokens=False,
35 | )
36 | train_dataset = train_dataset_raw.map(tokenize, remove_columns=train_dataset_raw.column_names)
37 | train_dataset = [torch.tensor(sample['input_ids']).unsqueeze(0) for sample in train_dataset]
38 | return train_dataset
39 |
40 |
41 | def prepare_open_platypus(
42 | tokenizer: AutoTokenizer,
43 | max_sequence_length: int,
44 | num_calibration_samples: Optional[int] = None,
45 | seed: int = 42
46 | ) -> List[torch.Tensor]:
47 | train_dataset_raw = load_dataset("garage-bAInd/Open-Platypus", split="train")
48 | if num_calibration_samples:
49 | train_dataset_raw = train_dataset_raw.shuffle(seed=seed).select(range(num_calibration_samples))
50 | # Preprocess the data into the format the model is trained with.
51 | def preprocess(example):
52 | messages = [
53 | {"role": "user", "content": example["instruction"]},
54 | {"role": "assistant", "content": example["output"]},
55 | ]
56 | return {"text": tokenizer.apply_chat_template(messages, tokenize=False)}
57 | train_dataset_raw = train_dataset_raw.map(preprocess)
58 | # Tokenize the data
59 | def tokenize(sample):
60 | return tokenizer(
61 | sample["text"],
62 | padding=False,
63 | max_length=max_sequence_length,
64 | truncation=True,
65 | add_special_tokens=False,
66 | )
67 | train_dataset = train_dataset_raw.map(tokenize, remove_columns=train_dataset_raw.column_names)
68 | train_dataset = [torch.tensor(sample['input_ids']).unsqueeze(0) for sample in train_dataset]
69 | return train_dataset
70 |
71 |
72 | def prepare_fineweb_edu(
73 | tokenizer: AutoTokenizer,
74 | max_sequence_length: int,
75 | num_calibration_samples: Optional[int] = None,
76 | seed: int = 42
77 | ) -> List[torch.Tensor]:
78 | train_dataset_raw = load_dataset("HuggingFaceFW/fineweb-edu", "sample-10BT", split="train", streaming=True)
79 | train_dataset_raw = train_dataset_raw.shuffle(seed=seed, buffer_size=1_000)
80 | train_dataset = []
81 | for i, sample in enumerate(train_dataset_raw):
82 | if i == num_calibration_samples:
83 | break
84 | tokenized_sample = tokenizer(
85 | sample["text"],
86 | max_length=max_sequence_length,
87 | truncation=True,
88 | return_tensors="pt"
89 | )
90 | train_dataset.append(tokenized_sample['input_ids'])
91 | return train_dataset
92 |
93 |
94 | def prepare_calibration_dataset(
95 | dataset_name: str,
96 | tokenizer: AutoTokenizer,
97 | max_sequence_length: int,
98 | num_calibration_samples: Optional[int] = None,
99 | seed: int = 42
100 | ) -> List[torch.Tensor]:
101 | if dataset_name == "open-thoughts":
102 | return prepare_open_thoughts(tokenizer, max_sequence_length, num_calibration_samples, seed)
103 | if dataset_name == "open-platypus":
104 | return prepare_open_platypus(tokenizer, max_sequence_length, num_calibration_samples, seed)
105 | if dataset_name == "fineweb-edu":
106 | return prepare_fineweb_edu(tokenizer, max_sequence_length, num_calibration_samples, seed)
107 | else:
108 | raise ValueError("Unknown dataset")
109 |
--------------------------------------------------------------------------------
/src/dist_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Any
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.distributed as dist
6 |
7 |
8 | __all__ = [
9 | "is_dist_available_and_initialized",
10 | "get_world_size",
11 | "get_rank",
12 | "is_main",
13 | "broadcast_parameters",
14 | "gather_into_tensor",
15 | "print_on_main",
16 | "barrier"
17 | ]
18 |
19 |
20 | def is_dist_available_and_initialized():
21 | return dist.is_available() and dist.is_initialized()
22 |
23 |
24 | def get_world_size():
25 | if is_dist_available_and_initialized():
26 | return dist.get_world_size()
27 | return 1
28 |
29 |
30 | def get_rank():
31 | if is_dist_available_and_initialized():
32 | return dist.get_rank()
33 | return 0
34 |
35 |
36 | def is_main():
37 | return get_rank() == 0
38 |
39 |
40 | def barrier(device_ids=None):
41 | if is_dist_available_and_initialized():
42 | dist.barrier(device_ids=device_ids)
43 |
44 |
45 | def broadcast_parameters(module: nn.Module, src: Any = 0, group: Optional[Any] = None):
46 | for param in module.parameters():
47 | dist.broadcast(param.data, src=src, group=group)
48 |
49 |
50 | def gather_into_tensor(tensor, dim: int = 0):
51 | world_size = get_world_size()
52 | if is_main():
53 | gathered_shape = (*tensor.shape[:dim], world_size * tensor.shape[dim], *tensor.shape[dim + 1 :])
54 | gathered_tensor = torch.empty(gathered_shape, device=tensor.device, dtype=tensor.dtype)
55 | gathered_tensor_chunks = list(gathered_tensor.chunk(world_size, dim=dim))
56 | else:
57 | gathered_tensor = None
58 | gathered_tensor_chunks = None
59 | dist.gather(tensor, gathered_tensor_chunks)
60 | return gathered_tensor
61 |
62 |
63 | def print_on_main(*args, **kwargs):
64 | if is_main():
65 | print(*args, **kwargs)
66 |
--------------------------------------------------------------------------------
/src/gptq.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from typing import Optional, Tuple
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.distributed as dist
7 | from torch import Tensor
8 | from torch.nn.modules.conv import _ConvNd
9 |
10 | from src import dist_utils, model_utils, linalg_utils, quant_utils, gptq_loop
11 |
12 |
13 | torch.backends.cuda.matmul.allow_tf32 = False
14 | torch.backends.cudnn.allow_tf32 = False
15 |
16 |
17 | class QuantizationOrder(Enum):
18 | DEFAULT = "default"
19 | ACTIVATION = "activation"
20 |
21 |
22 | class GPTQ:
23 |
24 | def __init__(
25 | self,
26 | layer: nn.Module,
27 | group_size: Optional[int] = None,
28 | sym: bool = False,
29 | rel_damp: float = 1e-2,
30 | block_size: int = None,
31 | quantization_order: str = "default",
32 | quantization_scale: str = "absmax",
33 | is_distributed: bool = False,
34 | tied_gptq_handle: Optional["GPTQ"] = None
35 | ):
36 | self._validate_layer(layer)
37 | self.layer = layer
38 | self.W = self.layer.weight
39 | self.d_row, self.d_col = model_utils.get_number_of_rows_and_cols(layer)
40 | # Quantization hyperparameters
41 | self.sym = sym
42 | self.group_size = group_size
43 | # GPTQ hyperparameters
44 | self.rel_damp = rel_damp
45 | self.block_size = block_size or self.d_col
46 | self.quantization_order = QuantizationOrder(quantization_order)
47 | self.quantization_scale = quantization_scale
48 | # backup layer properties
49 | self.W_device = self.W.device
50 | self.W_dtype = self.W.dtype
51 | self.W_shape = self.W.shape
52 | # init hessian
53 | self.H = None
54 | self.num_samples = 0
55 | self.is_distributed = is_distributed
56 | self.tied_gptq_handle = tied_gptq_handle
57 | self.num_tied_handles = 0
58 | if tied_gptq_handle is not None:
59 | tied_gptq_handle.num_tied_handles += 1
60 | # Flags indicating issues
61 | self.issue_zero_samples = False
62 | self.issue_nan_hessian = False
63 | self.issue_non_invertible = False
64 |
65 | @staticmethod
66 | def _validate_layer(layer):
67 | assert isinstance(layer, (nn.Linear, _ConvNd)), "OBC supports only linear and convolutional layers."
68 |
69 | def has_hessian_issues(self) -> bool:
70 | return any([self.issue_zero_samples, self.issue_nan_hessian, self.issue_non_invertible])
71 |
72 | # preparatory methods
73 | @torch.no_grad()
74 | def update(self, input: Tensor) -> None:
75 | """
76 | Update the estimate of Hessian matrix from a batch of data.
77 |
78 | Args:
79 | input: batch of layer inputs
80 | """
81 | # init hessian
82 | if self.H is None:
83 | self.H = torch.zeros((self.d_col, self.d_col), device=input.device, dtype=torch.float32)
84 | # input reshaping
85 | if isinstance(self.layer, nn.Linear):
86 | input = input.reshape(-1, input.shape[-1])
87 | else:
88 | unfold = nn.Unfold(
89 | self.layer.kernel_size,
90 | dilation=self.layer.dilation,
91 | padding=self.layer.padding,
92 | stride=self.layer.stride,
93 | )
94 | # output size (batch_size, channels * \prod kernel_size, num_patches)
95 | input = unfold(input)
96 | input = input.transpose(1, 2).flatten(0, 1)
97 | input = input.float()
98 | # get number of samples (tokens) in batch
99 | num_new_samples = input.shape[0]
100 | # hessian update
101 | beta = self.num_samples / (self.num_samples + num_new_samples)
102 | alpha = 2.0 / (self.num_samples + num_new_samples)
103 | self.H.addmm_(input.T, input, beta=beta, alpha=alpha)
104 | # update number of collected samples
105 | self.num_samples += num_new_samples
106 |
107 | @property
108 | def tokens_collected(self) -> int:
109 | return self.num_samples
110 |
111 | def reset(self) -> None:
112 | self.W = self.layer.weight
113 | if self.num_tied_handles == 0:
114 | self.H = None
115 | elif self.tied_gptq_handle:
116 | self.tied_gptq_handle.num_tied_handles -= 1
117 | if self.tied_gptq_handle.num_tied_handles == 0:
118 | self.tied_gptq_handle.H = None
119 | self.num_samples = 0
120 | torch.cuda.empty_cache()
121 |
122 | @torch.no_grad()
123 | def quantization_pre_step(self) -> None:
124 | """
125 | Preparatory step with hessian regularization and weight reshaping.
126 | """
127 | # 1) Hessian preparation
128 | reduce_if_needed = True
129 | if self.H is None:
130 | if self.tied_gptq_handle:
131 | self.H = self.tied_gptq_handle.H
132 | else:
133 | self.H = torch.eye(self.d_col, device=self.W_device, dtype=torch.float32)
134 | self.issue_zero_samples = True
135 | # No need to reduce
136 | reduce_if_needed = False
137 | # synchronize Hessians
138 | if self.is_distributed and reduce_if_needed and dist_utils.is_dist_available_and_initialized():
139 | dist.all_reduce(self.H, op=dist.ReduceOp.AVG)
140 | # Replace matrix by identity in case of NaNs
141 | if torch.isnan(self.H).any().item():
142 | self.H = torch.eye(self.d_col, device=self.W_device, dtype=torch.float32)
143 | self.issue_nan_hessian = True
144 | # get ids of pruned channels
145 | pruned_ids = torch.diag(self.H) == 0
146 | self.H[pruned_ids, pruned_ids] = 1
147 | # 2) Weight preparation
148 | # copy weight, flatten
149 | self.W = self.W.clone().float()
150 | if isinstance(self.layer, _ConvNd):
151 | self.W = self.W.flatten(1, -1)
152 | self.W[:, pruned_ids] = 0
153 | # flag pre step as completed
154 | self.pre_step_completed = True
155 |
156 | @torch.no_grad()
157 | def _quantize(self, bits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
158 | """
159 | Quantize the weight matrix using GPTQ
160 | """
161 | # 1) Define constants and chunk
162 | d_row, d_col, block_size, device, dtype = self.d_row, self.d_col, self.block_size, self.W_device, self.W_dtype
163 | # 2) Get quantization group size
164 | group_size = self.group_size or d_col
165 | num_groups = d_col // group_size
166 |
167 | is_main_gptq_process = dist_utils.is_main() or not self.is_distributed
168 |
169 | if is_main_gptq_process:
170 | # Get scale, qzero
171 | scale, zero, maxq = quant_utils.get_quantization_grid(
172 | weight=self.W,
173 | group_size=self.group_size,
174 | bits=bits,
175 | symmetric=self.sym,
176 | dtype=dtype,
177 | quantization_scale=self.quantization_scale,
178 | )
179 | # Get permutation
180 | if self.quantization_order == QuantizationOrder.ACTIVATION:
181 | perm = torch.argsort(self.H.diag(), descending=True)
182 | else:
183 | perm = torch.arange(d_col, device=device)
184 | perm_inv = torch.argsort(perm)
185 | # Permute Hessian prior to inversion (if reusing hessian from other handle, Hessian is already permuted)
186 | if not self.tied_gptq_handle:
187 | self.H = self.H[perm][:, perm]
188 | # Get hessian inverse
189 | hessian_inv = self._get_hessian_inverse()
190 | # Quantize
191 | qweight = gptq_loop.gptq_loop(
192 | weight=self.W.transpose(-2, -1)[perm],
193 | hessian_inv=hessian_inv,
194 | scale=scale.transpose(-2, -1)[perm],
195 | qzero=zero.transpose(-2, -1)[perm],
196 | maxq=maxq,
197 | dtype=dtype,
198 | gptq_block_size=block_size,
199 | )[perm_inv].transpose(-2, -1).contiguous().to(torch.uint8)
200 | # Remove scale and zero replication
201 | scale = scale[:, ::group_size].to(dtype)
202 | zero = zero[:, ::group_size].to(dtype)
203 | else:
204 | qweight = torch.empty(d_row, d_col, device=device, dtype=torch.uint8)
205 | scale = torch.empty(d_row, num_groups, device=device, dtype=dtype)
206 | zero = torch.empty(d_row, num_groups, device=device, dtype=dtype)
207 |
208 | if self.is_distributed and dist_utils.is_dist_available_and_initialized():
209 | dist.barrier()
210 | dist.broadcast(qweight, src=0)
211 | dist.broadcast(scale, src=0)
212 | dist.broadcast(zero, src=0)
213 |
214 | return qweight, scale, zero
215 |
216 | def quantize(self, bits: int | float) -> Tensor:
217 | self.quantization_pre_step()
218 | return self._quantize(bits)
219 |
220 | @torch.no_grad()
221 | def _get_hessian_inverse(self):
222 | w = self.W
223 | # Get columns with all zeros
224 | zero_cols = torch.nonzero(w.eq(0).all(dim=0))
225 | H = self.H
226 | # Regularize Hessian before quantization
227 | if not self.tied_gptq_handle:
228 | # Mask rows with zero input channels
229 | H[zero_cols, :] = 0
230 | H[:, zero_cols] = 0
231 | H[zero_cols, zero_cols] = 1
232 | # Hessian regularization
233 | damp = self.rel_damp * torch.diag(self.H).mean()
234 | self.H[range(self.d_col), range(self.d_col)] += damp
235 | # Invert
236 | try:
237 | H = linalg_utils.inv_sym(H)
238 | H_inv_cho = torch.linalg.cholesky(H, upper=True)
239 | except:
240 | H_inv_cho = torch.eye(self.d_col, device=H.device, dtype=torch.float32)
241 | # Divide Hessian inverse by diagonal (in order to not divide on it later)
242 | H_inv_cho.div_(H_inv_cho.diag()[:, None])
243 | return H_inv_cho
244 |
--------------------------------------------------------------------------------
/src/gptq_loop.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import triton
3 | from triton import language as tl
4 |
5 | from src.quant_utils import tl_quantize, tl_dequantize
6 |
7 | torch.backends.cuda.matmul.allow_tf32 = False
8 | torch.backends.cudnn.allow_tf32 = False
9 | torch.set_float32_matmul_precision("highest")
10 |
11 |
12 | @triton.jit
13 | def quantize_error_triton_kernel(
14 | x_ptr,
15 | qx_ptr,
16 | error_ptr,
17 | scale_ptr,
18 | qzero_ptr,
19 | maxq_ptr,
20 | dtype_ptr,
21 | n_elements: int,
22 | BLOCK_SIZE: tl.constexpr,
23 | ):
24 | pid = tl.program_id(axis=0)
25 | offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
26 | mask = offsets < n_elements
27 |
28 | x = tl.load(x_ptr + offsets, mask=mask)
29 | scale = tl.load(scale_ptr + offsets, mask=mask)
30 | qzero = tl.load(qzero_ptr + offsets, mask=mask)
31 | maxq = tl.load(maxq_ptr)
32 | dtype = None if dtype_ptr is None else tl.load(dtype_ptr).dtype
33 |
34 | qx = tl_quantize(x, scale, qzero, maxq)
35 | y = tl_dequantize(qx, scale, qzero, dtype)
36 | error = y - x
37 |
38 | tl.store(x_ptr + offsets, y, mask=mask)
39 | tl.store(qx_ptr + offsets, qx, mask=mask)
40 | tl.store(error_ptr + offsets, error, mask=mask)
41 |
42 |
43 | def quantize_error_triton(
44 | x: torch.Tensor,
45 | qx: torch.Tensor,
46 | error: torch.Tensor,
47 | scale: torch.Tensor,
48 | qzero: torch.Tensor,
49 | maxq: torch.Tensor,
50 | dtype: torch.dtype = None,
51 | ) -> None:
52 |
53 | n_elements: int = x.numel()
54 | grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
55 | quantize_error_triton_kernel[grid](
56 | x,
57 | qx,
58 | error,
59 | scale,
60 | qzero,
61 | maxq,
62 | torch.empty(0, dtype=dtype) if dtype is not None else None,
63 | n_elements,
64 | BLOCK_SIZE=128,
65 | )
66 |
67 |
68 | @triton.jit
69 | def addvv_triton_kernel(
70 | vec_a_ptr,
71 | vec_b_ptr,
72 | mat_c_ptr,
73 | size_a: int,
74 | size_b: int,
75 | BLOCK_SIZE_B: tl.constexpr,
76 | ):
77 | pid = tl.program_id(axis=0)
78 | offset_a = pid % size_a
79 | offsets_b = pid // size_a * BLOCK_SIZE_B + tl.arange(0, BLOCK_SIZE_B)
80 | mask = offsets_b < size_b
81 | c_ptrs = mat_c_ptr + offset_a * size_b + offsets_b
82 |
83 | a = tl.load(vec_a_ptr + offset_a)
84 | b = tl.load(vec_b_ptr + offsets_b, mask=mask)
85 | c = tl.load(c_ptrs, mask=mask)
86 | c = tl.fma(a, b, c)
87 |
88 | tl.store(c_ptrs, c, mask=mask)
89 |
90 |
91 | def addvv_triton(
92 | vec_a: torch.Tensor,
93 | vec_b: torch.Tensor,
94 | mat_c: torch.Tensor,
95 | ) -> None:
96 | size_a, size_b = mat_c.shape
97 | grid = lambda meta: (size_a * triton.cdiv(size_b, meta["BLOCK_SIZE_B"]),)
98 | addvv_triton_kernel[grid](
99 | vec_a,
100 | vec_b,
101 | mat_c,
102 | size_a,
103 | size_b,
104 | BLOCK_SIZE_B=256,
105 | )
106 |
107 |
108 | def gptq_loop_graph(
109 | weight: torch.Tensor,
110 | hessian_inv: torch.Tensor,
111 | scale: torch.Tensor,
112 | qzero: torch.Tensor,
113 | maxq: torch.Tensor,
114 | qweight: torch.Tensor = None,
115 | error_block: torch.Tensor = None,
116 | dtype: torch.dtype = None,
117 | gptq_block_size: int = 128,
118 | direct: bool = True,
119 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
120 | """
121 | CUDA Graph wrapper for GPTQ loops
122 | """
123 | n_columns, n_rows = weight.shape
124 | w_dtype: torch.dtype = weight.dtype
125 | device: torch.device = weight.device
126 |
127 | if direct:
128 | if qweight is None:
129 | qweight: torch.Tensor = torch.empty_like(weight)
130 | if error_block is None:
131 | error_block: torch.Tensor = torch.empty(gptq_block_size, n_rows, dtype=w_dtype, device=device)
132 | assert (
133 | weight.is_contiguous()
134 | and hessian_inv.is_contiguous()
135 | and scale.is_contiguous()
136 | and qzero.is_contiguous()
137 | and maxq.is_contiguous()
138 | and qweight.is_contiguous()
139 | and error_block.is_contiguous()
140 | )
141 | for i1 in range(0, n_columns, gptq_block_size):
142 | i2: int = min(i1 + gptq_block_size, n_columns)
143 | for j in range(i1, i2):
144 | quantize_error_triton(
145 | weight[j], qweight[j], error_block[j - i1], scale[j], qzero[j], maxq, dtype,
146 | )
147 | addvv_triton(hessian_inv[j, j + 1 : i2], error_block[j - i1], weight[j + 1 : i2])
148 | weight[i2:].addmm_(hessian_inv[i1:i2, i2:].t(), error_block[: i2 - i1], beta=1, alpha=1)
149 | return qweight, weight
150 |
151 | previous_device: torch.device = torch.device(f"cuda:{torch.cuda.current_device()}")
152 | torch.cuda.set_device(weight.device)
153 | if not hasattr(gptq_loop_graph, "graph_info"):
154 | gptq_loop_graph.graph_info = {}
155 | graph_key: tuple = n_columns, n_rows, w_dtype, dtype, gptq_block_size, device
156 | if graph_key not in gptq_loop_graph.graph_info:
157 | graph: torch.cuda.CUDAGraph = torch.cuda.CUDAGraph()
158 | graph_tensors: dict[str, torch.Tensor] = {
159 | "weight": torch.empty_like(weight.contiguous()),
160 | "hessian_inv": torch.empty_like(hessian_inv.contiguous()),
161 | "scale": torch.empty_like(scale.contiguous()),
162 | "qzero": torch.empty_like(qzero.contiguous()),
163 | "maxq": torch.empty_like(maxq.contiguous()),
164 | "qweight": torch.empty_like(weight.contiguous()),
165 | "error_block": torch.empty(gptq_block_size, n_rows, dtype=w_dtype, device=device),
166 | }
167 | n_warmups: int = 5
168 | s: torch.cuda.Stream = torch.cuda.Stream()
169 | s.wait_stream(torch.cuda.current_stream())
170 | with torch.cuda.stream(s):
171 | for _ in range(n_warmups):
172 | gptq_loop_graph(**graph_tensors, dtype=dtype, gptq_block_size=gptq_block_size, direct=True)
173 | torch.cuda.current_stream().wait_stream(s)
174 | with torch.cuda.graph(graph):
175 | gptq_loop_graph(**graph_tensors, dtype=dtype, gptq_block_size=gptq_block_size, direct=True)
176 | gptq_loop_graph.graph_info[graph_key] = {"graph": graph, "tensors": graph_tensors}
177 |
178 | graph, graph_tensors = (
179 | gptq_loop_graph.graph_info[graph_key]["graph"],
180 | gptq_loop_graph.graph_info[graph_key]["tensors"],
181 | )
182 | graph_tensors["weight"].copy_(weight)
183 | graph_tensors["hessian_inv"].copy_(hessian_inv)
184 | graph_tensors["scale"].copy_(scale)
185 | graph_tensors["qzero"].copy_(qzero)
186 | graph_tensors["maxq"].copy_(maxq)
187 | graph.replay()
188 | weight.copy_(graph_tensors["weight"])
189 | torch.cuda.set_device(previous_device)
190 | return graph_tensors["qweight"], weight
191 |
192 |
193 | def gptq_loop(
194 | weight: torch.Tensor,
195 | hessian_inv: torch.Tensor,
196 | scale: torch.Tensor,
197 | qzero: torch.Tensor,
198 | maxq: torch.Tensor,
199 | dtype: torch.dtype,
200 | gptq_block_size: int = 128,
201 | ) -> tuple[torch.Tensor, torch.Tensor]:
202 | """
203 | Quantize weight tensor with GPTQ algorithm
204 | weight: (C, R), transposed weight tensor to quantize, modified in-place and returned
205 | hessian_inv: (C, C), inverse of Hessian matrix
206 | scale: (C, R), transposed scale tensor for quantization
207 | qzero: (C, R), transposed zero-point tensor for quantization
208 | maxq: (), maximum quantized value
209 | dtype: target scale dtype, fp16 or bf16
210 | gptq_block_size: block size for GPTQ loop, this is independent of the quantization group size
211 | """
212 | if gptq_block_size <= 0:
213 | gptq_block_size = weight.size(-2)
214 |
215 | qweight, _ = gptq_loop_graph(
216 | weight=weight,
217 | hessian_inv=hessian_inv,
218 | scale=scale,
219 | qzero=qzero,
220 | maxq=maxq,
221 | dtype=dtype,
222 | gptq_block_size=gptq_block_size,
223 | direct=False,
224 | )
225 | return qweight # (C, R)
226 |
--------------------------------------------------------------------------------
/src/linalg_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 |
4 |
5 | __all__ = ["inv_sym"]
6 |
7 |
8 | def inv_sym(X: Tensor):
9 | """
10 | More efficient and stable inversion of symmetric matrices.
11 | """
12 | return torch.cholesky_inverse(torch.linalg.cholesky(X))
--------------------------------------------------------------------------------
/src/loading_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from safetensors import safe_open
5 |
6 | def load_param_shard(weight_dir: str, weight_path: str) -> dict[str, torch.Tensor]:
7 | param_shard = {}
8 | with safe_open(os.path.join(weight_dir, weight_path), framework="pt", device="cpu") as f:
9 | param_shard_keys = f.keys()
10 | for k in param_shard_keys:
11 | param_shard[k] = f.get_tensor(k)
12 | return param_shard
13 |
--------------------------------------------------------------------------------
/src/model_utils.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Optional, Union, Optional, Sequence, Tuple
2 |
3 | import re
4 | import numpy as np
5 | import torch.nn as nn
6 |
7 |
8 | LINEAR_LAYERS = (nn.Linear,)
9 |
10 |
11 | def get_number_of_rows_and_cols(layer):
12 | return layer.weight.shape[0], np.prod(layer.weight.shape[1:])
13 |
14 |
15 | def select_layers(
16 | model: nn.Module,
17 | layer_prefix: Optional[str] = "",
18 | layer_regex: str = ".*",
19 | layer_classes: Union[nn.Module, List[nn.Module]] = nn.Module,
20 | ) -> Dict[str, nn.Module]:
21 | layers = {}
22 | for layer_name, layer in model.named_modules():
23 | if (
24 | isinstance(layer, layer_classes)
25 | and re.search(layer_regex, layer_name)
26 | and layer_name.startswith(layer_prefix)
27 | ):
28 | layers[layer_name] = layer
29 | return layers
30 |
--------------------------------------------------------------------------------
/src/quant_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | from enum import Enum
3 | from typing import Optional
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | import triton
8 | from triton import language as tl
9 |
10 |
11 | torch.backends.cuda.matmul.allow_tf32 = False
12 | torch.backends.cudnn.allow_tf32 = False
13 | torch.set_float32_matmul_precision("highest")
14 |
15 |
16 | FP8_GROUP_SIZE = 128
17 | FP8_DTYPES = (torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)
18 |
19 |
20 | class QuantizationScale(Enum):
21 | ABSMAX = "absmax"
22 | MSE = "mse"
23 |
24 |
25 | @triton.jit
26 | def tl_pow(x, a):
27 | return (x.abs().log() * a).exp() # TODO: triton does not have x.pow(a) or x ** a?
28 |
29 |
30 | @triton.jit
31 | def tl_round(x):
32 | return (
33 | x + 0.5
34 | ).floor() # TODO: triton does not have round()? We might want to change to round to even number here.
35 |
36 |
37 | @triton.jit
38 | def tl_round_fp(x, dtype):
39 | return x if dtype is None else x.cast(dtype, fp_downcast_rounding="rtne").cast(x.dtype)
40 |
41 |
42 | @triton.jit
43 | def tl_quantize(x, scale, qzero, maxq):
44 | return tl.clamp(tl_round(x / scale + qzero), 0.0, maxq)
45 |
46 |
47 | @triton.jit
48 | def tl_dequantize(qx, scale, qzero, dtype):
49 | return tl_round_fp((qx - qzero) * scale, dtype)
50 |
51 |
52 | @triton.jit
53 | def tl_dequantize_quantized(x, scale, qzero, maxq, dtype):
54 | return tl_dequantize(tl_quantize(x, scale, qzero, maxq), scale, qzero, dtype)
55 |
56 |
57 | def round_fp(x: torch.Tensor, dtype: torch.dtype = None) -> torch.Tensor:
58 | return x if dtype is None else x.to(dtype=dtype).to(x.dtype)
59 |
60 |
61 | def quantize(x: torch.Tensor, scale: torch.Tensor, qzero: torch.Tensor, maxq: torch.Tensor) -> torch.Tensor:
62 | return (x / scale + qzero).round().clamp(torch.zeros_like(maxq), maxq)
63 |
64 |
65 | def dequantize(qx: torch.Tensor, scale: torch.Tensor, qzero: torch.Tensor, dtype: torch.dtype = None) -> torch.Tensor:
66 | return round_fp((qx - qzero) * scale, dtype)
67 |
68 |
69 | def dequantize_quantized(
70 | x: torch.Tensor, scale: torch.Tensor, qzero: torch.Tensor, maxq: torch.Tensor, dtype: torch.dtype = None
71 | ) -> torch.Tensor:
72 | return dequantize(quantize(x, scale, qzero, maxq), scale, qzero, dtype)
73 |
74 |
75 | def find_quantization_meta(
76 | x: torch.Tensor,
77 | bit_width: int,
78 | symmetric: bool = False,
79 | dtype: torch.dtype = None,
80 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
81 | """
82 | Find quantization metadata over dim=-1
83 | x: (..., C), weight
84 | bit_width: int
85 | symmetric: bool, whether to set qzero to the middle
86 | dtype: torch.dtype, target scale dtype, fp16 or bf16
87 | """
88 | epsilon: float = 1e-12
89 | maxq = torch.tensor(2**bit_width - 1, dtype=x.dtype, device=x.device) # ()
90 |
91 | x_min = x.amax(dim=-1)
92 | x_max = x.amin(dim=-1)
93 |
94 | if symmetric:
95 | scale = (2.0 / maxq) * torch.maximum(x_min.abs(), x_max.abs())
96 | scale = round_fp(scale + epsilon, dtype) # (...)
97 | qzero = torch.full_like(scale, ((maxq + 1.0) * 0.5).item()) # (...)
98 | else:
99 | scale = round_fp((x_max - x_min) / maxq + epsilon, dtype) # (...)
100 | qzero = (-x_min / scale).round().clamp(0, maxq) # (...)
101 | return scale, qzero, maxq
102 |
103 |
104 | @triton.jit
105 | def mse_scale_triton_kernel(
106 | x_ptr,
107 | p_ptr,
108 | scale_ptr,
109 | qzero_ptr,
110 | maxq_ptr,
111 | dtype_ptr,
112 | norm: float,
113 | p_size: int,
114 | group_size: int,
115 | batch_size: int,
116 | BLOCK_SIZE_P: tl.constexpr,
117 | BLOCK_SIZE_G: tl.constexpr,
118 | BLOCK_SIZE_B: tl.constexpr,
119 | ):
120 | pid = tl.program_id(axis=0)
121 | b_offsets = pid * BLOCK_SIZE_B + tl.arange(0, BLOCK_SIZE_B) # (R)
122 | b_mask = b_offsets < batch_size # (R)
123 | x_offsets = b_offsets[:, None] * group_size + tl.arange(0, BLOCK_SIZE_G) # (R, C)
124 | x_mask = b_mask[:, None] & (tl.arange(0, BLOCK_SIZE_G) < group_size) # (R, C)
125 | p_offsets = tl.arange(0, BLOCK_SIZE_P) # (P)
126 | p_mask = p_offsets < p_size # (P)
127 | scale_ptrs = scale_ptr + b_offsets # (R)
128 |
129 | x = tl.load(x_ptr + x_offsets, mask=x_mask)[:, None, :] # (R, 1, C)
130 | p = tl.load(p_ptr + p_offsets, mask=p_mask) # (P)
131 | scale = tl.load(scale_ptrs, mask=b_mask) # (R)
132 | qzero = tl.load(qzero_ptr + b_offsets, mask=b_mask)[:, None, None] # (R, 1, 1)
133 | maxq = tl.load(maxq_ptr) # ()
134 | dtype = None if dtype_ptr is None else tl.load(dtype_ptr).dtype
135 |
136 | scale_p = tl_round_fp(scale[:, None] * p, dtype)[:, :, None] # (R, P, 1)
137 | q = tl_dequantize_quantized(x, scale_p, qzero, maxq, dtype) # (R, P, C)
138 | best_idx = tl.argmin(tl.sum(tl_pow(q - x, norm), axis=-1), axis=-1, tie_break_left=False) # (R)
139 |
140 | scale = tl_round_fp(scale * tl.load(p_ptr + best_idx), dtype) # (R) # TODO: replace with tl.gather()
141 | tl.store(scale_ptrs, scale, mask=b_mask) # (R)
142 |
143 |
144 | def mse_scale(
145 | x: torch.Tensor,
146 | p: torch.Tensor,
147 | scale: torch.Tensor,
148 | qzero: torch.Tensor,
149 | maxq: torch.Tensor,
150 | dtype: torch.dtype = None,
151 | norm: float = 2.4,
152 | ) -> torch.Tensor:
153 | """
154 | Find the optimal scale for quantization with respect to the MSE loss
155 | x: (..., C), weight
156 | p: (P), shrinkage factors
157 | scale: (...), initial scale, modified in-place and returned
158 | qzero: (...), zero points
159 | maxq: ()
160 | dtype: torch.dtype, target scale dtype, fp16 or bf16
161 | norm: float, norm for the loss
162 | debug_mode: bool, whether to use the baseline implementation without Triton
163 | """
164 |
165 | assert (
166 | x.is_contiguous()
167 | and p.is_contiguous()
168 | and scale.is_contiguous()
169 | and qzero.is_contiguous()
170 | and maxq.is_contiguous()
171 | )
172 | batch_size: int = torch.tensor(x.shape[:-1]).prod().item()
173 | previous_device: torch.device = torch.device(f"cuda:{torch.cuda.current_device()}")
174 | torch.cuda.set_device(x.device)
175 | grid = lambda meta: (triton.cdiv(batch_size, meta["BLOCK_SIZE_B"]),)
176 | mse_scale_triton_kernel[grid](
177 | x,
178 | p,
179 | scale,
180 | qzero,
181 | maxq,
182 | torch.empty(0, dtype=dtype) if dtype is not None else None,
183 | norm,
184 | p.size(-1),
185 | x.size(-1),
186 | batch_size,
187 | BLOCK_SIZE_P=torch.tensor(p.size(-1)).log2().ceil().exp2().int().item(),
188 | BLOCK_SIZE_G=torch.tensor(x.size(-1)).log2().ceil().exp2().int().item(),
189 | BLOCK_SIZE_B=1,
190 | )
191 | torch.cuda.set_device(previous_device)
192 | return scale
193 |
194 |
195 | @torch.no_grad()
196 | def get_quantization_grid(
197 | weight: torch.Tensor,
198 | group_size: int,
199 | bits: int,
200 | symmetric: bool = False,
201 | dtype: torch.dtype = None,
202 | quantization_scale: QuantizationScale = QuantizationScale.ABSMAX,
203 | quant_max_shrink: float = 0.2,
204 | quant_n_grid: int = 100,
205 | quant_norm: float = 2.4,
206 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
207 | """
208 | Get the quantization grid for the weight matrix
209 | weight: (..., (R), C)
210 | scale: (..., (R), C)
211 | qzero: (..., (R), C)
212 | maxq: ()
213 | """
214 | weight = weight.unflatten(dim=-1, sizes=(-1, group_size)) # (..., G, gs)
215 |
216 | scale, qzero, maxq = find_quantization_meta(
217 | x=weight,
218 | bit_width=bits,
219 | symmetric=symmetric,
220 | dtype=dtype,
221 | ) # (..., G), (..., G), ()
222 | if quantization_scale == QuantizationScale.MSE:
223 | search_points = torch.linspace(1, quant_max_shrink, quant_n_grid, dtype=weight.dtype, device=weight.device)
224 | mse_scale(
225 | x=weight.contiguous(), # (..., G, gs)
226 | p=search_points, # (..., P)
227 | scale=scale, # (..., G)
228 | qzero=qzero, # (..., G)
229 | maxq=maxq, # ()
230 | dtype=dtype,
231 | norm=quant_norm,
232 | )
233 |
234 | scale = scale.repeat_interleave(group_size, dim=-1) # (..., C)
235 | qzero = qzero.repeat_interleave(group_size, dim=-1) # (..., C)
236 |
237 | weight = weight.flatten(start_dim=-2) # (..., C)
238 |
239 | assert weight.shape == scale.shape == qzero.shape and maxq.shape == ()
240 | return scale, qzero, maxq # (..., (R), C), (..., (R), C), ()
241 |
242 |
243 | def dequantize_linear_weight(
244 | qweight: torch.Tensor,
245 | scale: torch.Tensor,
246 | zero: torch.Tensor,
247 | perm: Optional[torch.Tensor] = None,
248 | ):
249 | scale = scale.view(qweight.shape[0], -1, 1)
250 | zero = zero.view(qweight.shape[0], -1, 1)
251 | num_groups = scale.shape[1]
252 | weight = dequantize(qweight.view(qweight.shape[0], num_groups, -1), scale, zero).view_as(qweight)
253 | if perm is not None:
254 | invperm = perm.argsort()
255 | weight = weight[:, invperm]
256 | return weight
257 |
258 |
259 | def get_relative_mse_error(q: torch.Tensor, w: torch.Tensor, H: Optional[torch.Tensor] = None):
260 | delta = q - w
261 | if H is None:
262 | return delta.pow(2).mean() / w.pow(2).mean()
263 | else:
264 | return (delta).mm(H).mul(delta).mean() / (w.mm(H).mul(w).mean() + 1e-6)
265 |
266 |
267 | def dequantize_weight_from_fp8(W, s):
268 | g = FP8_GROUP_SIZE
269 | # Dequantize weight
270 | d_out, d_in = W.shape
271 | # Pad weight if needed
272 | pad_out = math.ceil(d_out / g) * g - d_out
273 | pad_in = math.ceil(d_in / g) * g - d_in
274 | W = F.pad(W, (0, pad_in, 0, pad_out))
275 | d_out_pad, d_in_pad = W.shape
276 |
277 | W = W.view(d_out_pad // g, g, d_in_pad // g, g)
278 | s = s.view(d_out_pad // g, 1, d_in_pad // g, 1)
279 | W = (W * s).view(d_out_pad, d_in_pad)
280 |
281 | # Remove padding
282 | W = W[:d_out, :d_in]
283 | return W
284 |
285 |
286 | def dequantize_state_dict(state_dict: dict[str, torch.Tensor], dtype: torch.dtype = torch.float16) -> None:
287 | state_dict_keys = list(state_dict.keys())
288 | # Dequantize
289 | for k in state_dict_keys:
290 | if k.endswith("scale_inv"):
291 | layer_name, _ = k.rsplit(".", 1)
292 |
293 | W = state_dict[f"{layer_name}.weight"].to(dtype)
294 | s = state_dict[f"{layer_name}.weight_scale_inv"].to(dtype)
295 |
296 | state_dict[f"{layer_name}.weight"] = dequantize_weight_from_fp8(W, s)
297 | del state_dict[f"{layer_name}.weight_scale_inv"]
298 |
299 |
300 | def can_dequantize_from_fp8(state_dict: dict[str, torch.Tensor]) -> bool:
301 | for k, v in state_dict.items():
302 | if v.dtype in FP8_DTYPES and f"{k}_scale_inv" not in state_dict:
303 | return False
304 | return True
305 |
--------------------------------------------------------------------------------