├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── act_scales └── README.md ├── assets └── SmoothQuant.pdf ├── examples ├── export_int8_model.py ├── generate_act_scales.py ├── smoothquant_opt_demo.ipynb └── smoothquant_opt_real_int8_demo.ipynb ├── figures └── throughput_latency.png ├── setup.py └── smoothquant ├── __init__.py ├── calibration.py ├── fake_quant.py ├── opt.py └── smooth.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pt filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dataset/ 2 | scripts/ 3 | int8_models/ 4 | .DS_Store 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 MIT HAN Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SmoothQuant+: Accurate and Efficient 4-bit Post-Training WeightQuantization for LLM 2 | 3 | [[paper](https://arxiv.org/abs/2312.03788)] 4 | 5 | 4-bit weight-only quantization models for code llama are released in [model zoo](https://github.com/Adlik/model_zoo/tree/main/smooth_quant_plus). You can use [Adlik/vllm](https://github.com/Adlik/vllm/tree/dev) to evaluate. 6 | 7 | We will release the code soon, please stay tuned. 8 | 9 | ```bash 10 | python examples/test_codellama_7b_human_eval.py \ 11 | --model-path /path/codellama-7b-base-search-4-bits/ \ 12 | --output-path codellama-7b-base-smooth-search \ 13 | --quant-mode weight_int4 14 | ``` 15 | 16 | 17 | 18 | ![intuition](figures/throughput_latency.png) 19 | 20 | ## Abstract 21 | 22 | Large language models (LLMs) have shown remarkable capabilities in various tasks. However their huge model size and the consequent demand for computational and memory resources also pose challenges to model deployment. Currently, 4-bit post-training quantization (PTQ) has achieved some success in LLMs, reducing the memory footprint by approximately 75% compared to FP16 models, albeit with some accuracy loss. In this paper, we propose SmoothQuant+, an accurate and efficient 4-bit weight-only PTQ that requires no additional training, which enables lossless in accuracy for LLMs for the first time. Based on the fact that the loss of weight quantization is amplified by the activation outliers, SmoothQuant+ smoothes the activation outliers by channel before quantization, while adjusting the corresponding weights for mathematical equivalence, and then performs group-wise 4-bit weight quantization for linear layers. We have integrated SmoothQuant+ into the vLLM framework, an advanced high-throughput inference engine specially developed for LLMs, and equipped it with an efficient W4A16 CUDA kernels, so that vLLM can seamlessly support SmoothQuant+ 4-bit weight quantization. Our results show that, with SmoothQuant+, the Code Llama-34B model can be quantized and deployed on a A100 40GB GPU, achieving lossless accuracy and a throughput increase of 1.9 to 4.0 times compared to the FP16 model deployed on two A100 40GB GPUs. Moreover, the latency per token is only 68% of the FP16 model deployed on two A100 40GB GPUs. This is the state-of-the-art 4-bit weight quantization for LLMs as we know. 23 | 24 | 25 | -------------------------------------------------------------------------------- /act_scales/README.md: -------------------------------------------------------------------------------- 1 | # Activation Channel Scales 2 | 3 | We provide the activation channel scales for OPT and BLOOM models at [Huggingface](https://huggingface.co/mit-han-lab/smoothquant-scales). We get those scales with 512 random sentences in the Pile validation set. You can use `../examples/smoothquant_opt_demo.ipynb` to test smoothing and quantizing those models. 4 | -------------------------------------------------------------------------------- /assets/SmoothQuant.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Adlik/smoothquantplus/df36110a315c548cea04e83c096ea7128fba79fa/assets/SmoothQuant.pdf -------------------------------------------------------------------------------- /examples/export_int8_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | 5 | from pathlib import Path 6 | 7 | from transformers.models.opt.modeling_opt import OPTForCausalLM 8 | from transformers import AutoTokenizer 9 | 10 | from smoothquant.opt import Int8OPTForCausalLM 11 | from smoothquant.smooth import smooth_lm 12 | 13 | from smoothquant.calibration import get_static_decoder_layer_scales 14 | 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--model-name", type=str, default='facebook/opt-13b') 19 | parser.add_argument("--num-samples", type=int, default=512) 20 | parser.add_argument("--seq-len", type=int, default=512) 21 | parser.add_argument("--act-scales", type=str, 22 | default='act_scales/opt-13b.pt') 23 | parser.add_argument("--output-path", type=str, default='int8_models') 24 | parser.add_argument('--dataset-path', type=str, default='dataset/val.jsonl.zst', 25 | help='location of the calibration dataset, we use the validation set of the Pile dataset') 26 | parser.add_argument('--export-FT', default=False, action="store_true") 27 | args = parser.parse_args() 28 | model = OPTForCausalLM.from_pretrained( 29 | args.model_name, device_map="auto", torch_dtype=torch.float16) 30 | act_scales = torch.load(args.act_scales) 31 | smooth_lm(model, act_scales, 0.5) 32 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 33 | 34 | if not os.path.exists(args.dataset_path): 35 | print(f'Cannot find the dataset at {args.dataset_path}') 36 | print('Please download the Pile dataset and put the validation set at the path') 37 | print('You can download the validation dataset of the Pile at https://mystic.the-eye.eu/public/AI/pile/val.jsonl.zst') 38 | raise FileNotFoundError 39 | 40 | decoder_layer_scales, raw_scales = get_static_decoder_layer_scales(model, 41 | tokenizer, 42 | args.dataset_path, 43 | num_samples=args.num_samples, 44 | seq_len=args.seq_len) 45 | output_path = Path(args.output_path) / (Path(args.model_name).name + "-smoothquant.pt") 46 | if args.export_FT: 47 | model.save_pretrained(output_path) 48 | print(f"Saved smoothed model at {output_path}") 49 | 50 | output_path = Path(args.output_path) / (Path(args.model_name).name + "-smoothquant-scales.pt") 51 | torch.save(raw_scales, output_path) 52 | print(f"Saved scaling factors at {output_path}") 53 | else: 54 | int8_model = Int8OPTForCausalLM.from_float(model, decoder_layer_scales) 55 | int8_model.save_pretrained(output_path) 56 | print(f"Saved int8 model at {output_path}") 57 | -------------------------------------------------------------------------------- /examples/generate_act_scales.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | from transformers import ( 5 | AutoModelForCausalLM, 6 | AutoTokenizer, 7 | ) 8 | import argparse 9 | 10 | from smoothquant.calibration import get_act_scales 11 | 12 | def build_model_and_tokenizer(model_name): 13 | tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512) 14 | kwargs = {"torch_dtype": torch.float16, "device_map": "sequential"} 15 | model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) 16 | return model, tokenizer 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--model-name', type=str, 21 | default='facebook/opt-1.3b', help='model name') 22 | parser.add_argument('--output-path', type=str, default='act_scales/opt-1.3b.pt', 23 | help='where to save the act scales') 24 | parser.add_argument('--dataset-path', type=str, default='dataset/val.jsonl.zst', 25 | help='location of the calibration dataset, we use the validation set of the Pile dataset') 26 | parser.add_argument('--num-samples', type=int, default=512) 27 | parser.add_argument('--seq-len', type=int, default=512) 28 | args = parser.parse_args() 29 | return args 30 | 31 | 32 | @torch.no_grad() 33 | def main(): 34 | args = parse_args() 35 | model, tokenizer = build_model_and_tokenizer(args.model_name) 36 | 37 | if not os.path.exists(args.dataset_path): 38 | print(f'Cannot find the dataset at {args.dataset_path}') 39 | print('Please download the Pile dataset and put the validation set at the path') 40 | print('You can download the validation dataset of the Pile at https://mystic.the-eye.eu/public/AI/pile/val.jsonl.zst') 41 | raise FileNotFoundError 42 | 43 | act_scales = get_act_scales(model, tokenizer, args.dataset_path, 44 | args.num_samples, args.seq_len) 45 | 46 | os.makedirs(os.path.dirname(args.output_path), exist_ok=True) 47 | torch.save(act_scales, args.output_path) 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /examples/smoothquant_opt_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# SmoothQuant on OPT-13B\n", 8 | "\n", 9 | "### Guangxuan Xiao\\*, Ji Lin\\*, Mickael Seznec, Julien Demouth, Song Han\n", 10 | "\n", 11 | "In this notebook, we use OPT-13B model to demonstrate SmoothQuant can use 8-bit for both weights and activations to achieve the same accuracy as FP16 models. Unlike previous method [[Dettmers *et al.*, 2022]](https://arxiv.org/abs/2208.07339), SmoothQuant enables fully INT8 GEMMs for linear layers and does not require high precision numbers to represent outliers. " 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "This notebook demonstrates SmoothQuant on OPT-13B in consideration of the user's resouce constraints. We have tested SmoothQuant on up to 176 billion parameter models (OPT-175B, BLOOM-176B, GLM-130B). You can also adjust the model name to validate SmoothQuant on other models. `../act_scales/` provides the activation channel scales for OPT and BLOOM models." 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "In order to run this notebook, you need to install the following packages:\n", 26 | "\n", 27 | "- smoothquant\n", 28 | "- PyTorch\n", 29 | "- Transformers\n", 30 | "- Accelerate" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 1, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "import torch\n", 40 | "from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoderLayer, OPTForCausalLM\n", 41 | "from transformers import GPT2Tokenizer\n", 42 | "from smoothquant.smooth import smooth_lm\n", 43 | "from smoothquant.fake_quant import W8A8Linear" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "In this notebook, we simulate the 8-bit dynamic per-tensor weight and activation quantization with FP16, i.e., fake quantization. We have implemented the real 8-bit quantization with INT8 CUTLASS GEMM kernels for both PyTorch and FasterTransformer. Please stay tuned for the release." 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 2, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "def quantize_model(model, weight_quant='per_tensor', act_quant='per_tensor', quantize_bmm_input=True):\n", 60 | " for name, m in model.model.named_modules():\n", 61 | " if isinstance(m, OPTDecoderLayer):\n", 62 | " m.fc1 = W8A8Linear.from_float(m.fc1, weight_quant=weight_quant, act_quant=act_quant)\n", 63 | " m.fc2 = W8A8Linear.from_float(m.fc2, weight_quant=weight_quant, act_quant=act_quant)\n", 64 | " elif isinstance(m, OPTAttention):\n", 65 | " # Her we simulate quantizing BMM inputs by quantizing the output of q_proj, k_proj, v_proj\n", 66 | " m.q_proj = W8A8Linear.from_float(\n", 67 | " m.q_proj, weight_quant=weight_quant, act_quant=act_quant, quantize_output=quantize_bmm_input)\n", 68 | " m.k_proj = W8A8Linear.from_float(\n", 69 | " m.k_proj, weight_quant=weight_quant, act_quant=act_quant, quantize_output=quantize_bmm_input)\n", 70 | " m.v_proj = W8A8Linear.from_float(\n", 71 | " m.v_proj, weight_quant=weight_quant, act_quant=act_quant, quantize_output=quantize_bmm_input)\n", 72 | " m.out_proj = W8A8Linear.from_float(m.out_proj, weight_quant=weight_quant, act_quant=act_quant)\n", 73 | " return model\n" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "The following is an evaluator to see the performance of the model. We use a toy dataset (the first 1000 examples in the validation set of the Lambada dataset) to evaluate the model. You can replace it with your own dataset. The conclusion should be the same." 81 | ] 82 | }, 83 | { 84 | "attachments": {}, 85 | "cell_type": "markdown", 86 | "metadata": { 87 | "tags": [] 88 | }, 89 | "source": [ 90 | "**In this demo, we have simplified the evaluation by using the first 1,000 samples from the LAMBADA dataset's validation set. We employ the \"Last Token Prediction Accuracy\" as our evaluation metric. This approximate evaluation is intended for demonstration purposes, providing simple but meaningful comparisons of relative performance between methods. For a more strict assessment, we recommend using the [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) to obtain the \"Last Word Prediction Accuracy\" for the LAMBADA dataset, which is the reported metric in our paper.**" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "class Evaluator:\n", 100 | " def __init__(self, dataset, tokenizer, device):\n", 101 | " self.dataset = dataset\n", 102 | " self.tokenizer = tokenizer\n", 103 | " self.device = device\n", 104 | "\n", 105 | " # tokenize the dataset\n", 106 | " def tokenize_function(examples):\n", 107 | " example = self.tokenizer(examples['text'])\n", 108 | " return example\n", 109 | "\n", 110 | " self.dataset = self.dataset.map(tokenize_function, batched=True)\n", 111 | " self.dataset.set_format(type='torch', columns=['input_ids'])\n", 112 | "\n", 113 | " @torch.no_grad()\n", 114 | " def evaluate(self, model):\n", 115 | " model.eval()\n", 116 | " # The task is to predict the last word of the input.\n", 117 | " total, hit = 0, 0\n", 118 | " for batch in self.dataset:\n", 119 | " input_ids = batch['input_ids'].to(self.device).unsqueeze(0)\n", 120 | " label = input_ids[:, -1]\n", 121 | " outputs = model(input_ids)\n", 122 | " last_token_logits = outputs.logits[:, -2, :]\n", 123 | " pred = last_token_logits.argmax(dim=-1)\n", 124 | " total += label.size(0)\n", 125 | " hit += (pred == label).sum().item()\n", 126 | " acc = hit / total\n", 127 | " return acc\n" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "from datasets import load_dataset\n", 137 | "\n", 138 | "tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-13b')\n", 139 | "dataset = load_dataset('lambada', split='validation[:1000]')\n", 140 | "evaluator = Evaluator(dataset, tokenizer, 'cuda')\n" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "## FP16 Model Accuracy" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "Let's first check the performance of the original FP16 model." 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 5, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "model_fp16 = OPTForCausalLM.from_pretrained('facebook/opt-13b', torch_dtype=torch.float16, device_map='auto')" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 6, 169 | "metadata": {}, 170 | "outputs": [ 171 | { 172 | "name": "stdout", 173 | "output_type": "stream", 174 | "text": [ 175 | "Original model (fp16) accuracy: 0.786\n" 176 | ] 177 | } 178 | ], 179 | "source": [ 180 | "acc_fp16 = evaluator.evaluate(model_fp16)\n", 181 | "print(f'Original model (fp16) accuracy: {acc_fp16}')\n" 182 | ] 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "metadata": {}, 187 | "source": [ 188 | "We then quantize the model to W8A8 and check the performance." 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "## Naive W8A8 Quantized Model Accuracy" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "model_w8a8 = quantize_model(model_fp16)\n", 205 | "print(model_w8a8)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 8, 211 | "metadata": {}, 212 | "outputs": [ 213 | { 214 | "name": "stdout", 215 | "output_type": "stream", 216 | "text": [ 217 | "Naive W8A8 quantized model accuracy: 0.048\n" 218 | ] 219 | } 220 | ], 221 | "source": [ 222 | "acc_w8a8 = evaluator.evaluate(model_w8a8)\n", 223 | "print(f'Naive W8A8 quantized model accuracy: {acc_w8a8}')" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "metadata": {}, 229 | "source": [ 230 | "We can see there is a significant accuracy drop. This is consistent with LLM.int8()'s finding: when the model size increases larger than 6.7B, systematic outliers will emerge in activations, which makes fully INT8 quantization impossible." 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": {}, 236 | "source": [ 237 | "## SmoothQuant W8A8 Quantized Model Accuracy" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "metadata": {}, 243 | "source": [ 244 | "Let's smooth the model, quantize it, and check the performance! In `../act_scales`, we provide the activation scales for OPT and BLOOM models. You can also use this notebook to test quantizing those models." 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "model = OPTForCausalLM.from_pretrained('facebook/opt-13b', torch_dtype=torch.float16, device_map='auto')\n", 254 | "act_scales = torch.load('../act_scales/opt-13b.pt')\n", 255 | "smooth_lm(model, act_scales, 0.5)\n", 256 | "model_smoothquant_w8a8 = quantize_model(model)\n", 257 | "print(model_smoothquant_w8a8)" 258 | ] 259 | }, 260 | { 261 | "cell_type": "markdown", 262 | "metadata": {}, 263 | "source": [ 264 | "We can see the smoothed model has the same accuracy as the FP16 model. This is because SmoothQuant smooths the outliers in activations and moves the quantization difficulty from activations to weights." 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 10, 270 | "metadata": {}, 271 | "outputs": [ 272 | { 273 | "name": "stdout", 274 | "output_type": "stream", 275 | "text": [ 276 | "SmoothQuant W8A8 quantized model accuracy: 0.793\n" 277 | ] 278 | } 279 | ], 280 | "source": [ 281 | "acc_smoothquant_w8a8 = evaluator.evaluate(model_smoothquant_w8a8)\n", 282 | "print(f'SmoothQuant W8A8 quantized model accuracy: {acc_smoothquant_w8a8}')" 283 | ] 284 | } 285 | ], 286 | "metadata": { 287 | "kernelspec": { 288 | "display_name": "Python 3.10.4 (conda)", 289 | "language": "python", 290 | "name": "python3" 291 | }, 292 | "language_info": { 293 | "codemirror_mode": { 294 | "name": "ipython", 295 | "version": 3 296 | }, 297 | "file_extension": ".py", 298 | "mimetype": "text/x-python", 299 | "name": "python", 300 | "nbconvert_exporter": "python", 301 | "pygments_lexer": "ipython3", 302 | "version": "3.10.4" 303 | }, 304 | "orig_nbformat": 4, 305 | "vscode": { 306 | "interpreter": { 307 | "hash": "c458cb81aeeb610631c72e4cc4799f00f630d4dfa7a554b37f8134a7fe160cb8" 308 | } 309 | } 310 | }, 311 | "nbformat": 4, 312 | "nbformat_minor": 2 313 | } 314 | -------------------------------------------------------------------------------- /examples/smoothquant_opt_real_int8_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# SmoothQuant Real-INT8 Inference for PyTorch\n", 8 | "\n", 9 | "### Guangxuan Xiao\\*, Ji Lin\\*, Mickael Seznec, Julien Demouth, Song Han\n", 10 | "\n", 11 | "In this notebook, we use OPT-30B model to demonstrate the latency and memory advantages of SmoothQuant. We implement SmoothQuant real-INT8 inference for PyTorch with [CUTLASS](https://github.com/NVIDIA/cutlass) INT8 GEMM kernels, which are wrapped as PyTorch modules in [torch-int](https://github.com/Guangxuan-Xiao/torch-int).\n", 12 | "\n", 13 | "This notebook demonstrates SmoothQuant on OPT-30B because it is the largest model we can run both FP16 and INT8 inference on a single A100 GPU. For larger models requiring multiple GPUs, we recommend using the [FasterTransformer](https://github.com/NVIDIA/FasterTransformer) implementation of SmoothQuant.\n", 14 | "\n", 15 | "In order to run this notebook, you need to install the following packages:\n", 16 | "\n", 17 | "- [smoothquant](https://github.com/mit-han-lab/smoothquant)\n", 18 | "- [torch-int](https://github.com/Guangxuan-Xiao/torch-int)\n", 19 | "- [PyTorch](https://pytorch.org/)\n", 20 | "- [Transformers](https://github.com/huggingface/transformers)\n", 21 | "- [Accelerate](https://github.com/huggingface/accelerate)" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import torch\n", 31 | "from transformers.models.opt.modeling_opt import OPTForCausalLM\n", 32 | "from transformers import GPT2Tokenizer\n", 33 | "from smoothquant.opt import Int8OPTForCausalLM\n", 34 | "import os\n", 35 | "import gc\n", 36 | "from torch.nn.functional import pad\n", 37 | "\n", 38 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "The following is an evaluator to see the performance of the model. We use a toy dataset (the first 1000 examples in the validation set of the Lambada dataset) to evaluate the model. You can replace it with your dataset. The conclusion should be the same." 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "**In this demo, we have simplified the evaluation by using the first 1,000 samples from the LAMBADA dataset's validation set. We employ the \"Last Token Prediction Accuracy\" as our evaluation metric. This approximate evaluation is intended for demonstration purposes, providing simple but meaningful comparisons of relative performance between methods. For a more strict assessment, we recommend using the [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) to obtain the \"Last Word Prediction Accuracy\" for the LAMBADA dataset, which is the reported metric in our paper.**" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 2, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "class Evaluator:\n", 62 | " def __init__(self, dataset, tokenizer):\n", 63 | " self.dataset = dataset\n", 64 | " self.tokenizer = tokenizer\n", 65 | "\n", 66 | " # tokenize the dataset\n", 67 | " def tokenize_function(examples):\n", 68 | " example = self.tokenizer(examples['text'])\n", 69 | " return example\n", 70 | "\n", 71 | " self.dataset = self.dataset.map(tokenize_function, batched=True)\n", 72 | " self.dataset.set_format(type='torch', columns=['input_ids'])\n", 73 | "\n", 74 | " @torch.no_grad()\n", 75 | " def evaluate(self, model):\n", 76 | " model.eval()\n", 77 | " # The task is to predict the last word of the input.\n", 78 | " total, hit = 0, 0\n", 79 | " start = torch.cuda.Event(enable_timing=True)\n", 80 | " end = torch.cuda.Event(enable_timing=True)\n", 81 | " latency = 0\n", 82 | " for batch in self.dataset:\n", 83 | " input_ids = batch['input_ids'].cuda().unsqueeze(0)\n", 84 | " label = input_ids[:, -1]\n", 85 | " pad_len = 512 - input_ids.shape[1]\n", 86 | " input_ids = pad(input_ids, (0, pad_len), value=1)\n", 87 | " torch.cuda.synchronize()\n", 88 | " start.record()\n", 89 | " outputs = model(input_ids)\n", 90 | " end.record()\n", 91 | " torch.cuda.synchronize()\n", 92 | " latency += start.elapsed_time(end)\n", 93 | " last_token_logits = outputs.logits[:, -2-pad_len, :]\n", 94 | " pred = last_token_logits.argmax(dim=-1)\n", 95 | " total += label.size(0)\n", 96 | " hit += (pred == label).sum().item()\n", 97 | "\n", 98 | " acc = hit / total\n", 99 | " lantecy = latency / len(self.dataset)\n", 100 | " return acc, lantecy\n", 101 | "\n", 102 | "\n", 103 | "def print_model_size(model):\n", 104 | " # https://discuss.pytorch.org/t/finding-model-size/130275\n", 105 | " param_size = 0\n", 106 | " for param in model.parameters():\n", 107 | " param_size += param.nelement() * param.element_size()\n", 108 | " buffer_size = 0\n", 109 | " for buffer in model.buffers():\n", 110 | " buffer_size += buffer.nelement() * buffer.element_size()\n", 111 | "\n", 112 | " size_all_mb = (param_size + buffer_size) / 1024**2\n", 113 | " print('Model size: {:.3f}MB'.format(size_all_mb))\n" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "from datasets import load_dataset\n", 123 | "tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b')\n", 124 | "dataset = load_dataset('lambada', split='validation[:1000]')\n", 125 | "evaluator = Evaluator(dataset, tokenizer)" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": {}, 131 | "source": [ 132 | "## FP16 Model Accuracy and Latency" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 4, 138 | "metadata": {}, 139 | "outputs": [ 140 | { 141 | "name": "stdout", 142 | "output_type": "stream", 143 | "text": [ 144 | "Model size: 57171.898MB\n", 145 | "FP16 accuracy: 0.807, per-sample lantecy: 263.633ms\n" 146 | ] 147 | } 148 | ], 149 | "source": [ 150 | "model_fp16 = OPTForCausalLM.from_pretrained(\n", 151 | " 'facebook/opt-30b', torch_dtype=torch.float16, device_map='auto')\n", 152 | "print_model_size(model_fp16)\n", 153 | "acc_fp16, lantecy_fp16 = evaluator.evaluate(model_fp16)\n", 154 | "print(f'FP16 accuracy: {acc_fp16}, per-sample lantecy: {lantecy_fp16:.3f}ms')" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 5, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "del model_fp16\n", 164 | "gc.collect()\n", 165 | "torch.cuda.empty_cache()" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | "## SmoothQuant W8A8 Quantized Model Accuracy and Latency" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "metadata": {}, 178 | "source": [ 179 | "We provide the already smoothed and quantized OPT model at `https://huggingface.co/mit-han-lab/opt-[MODEL-SIZE]-smoothquant`, where `[MODEL-SIZE]` can be `125m`, `1.3B`, `2.7B`, `6.7B`, `13B`, `30b`, and `66b`. You can load the INT8 model with the following code:\n", 180 | "\n", 181 | "```python\n", 182 | "from smoothquant.opt import Int8OPTForCausalLM\n", 183 | "model = Int8OPTForCausalLM.from_pretrained(\"mit-han-lab/opt-30b-smoothquant\")\n", 184 | "```\n", 185 | "\n", 186 | "We implement the following quantization flow for OPT models, which you can see details in [smoothquant/opt.py](../smoothquant/opt.py).\n", 187 | "\n", 188 | "![quantization flow](../figures/quantization_flow.png)\n", 189 | "\n", 190 | "You can also check [generate_act_scales.py](../examples/generate_act_scales.py) and [export_int8_model.py](../examples/export_int8_model.py) to see how we smooth, quantize and export INT8 models." 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 6, 196 | "metadata": {}, 197 | "outputs": [ 198 | { 199 | "name": "stdout", 200 | "output_type": "stream", 201 | "text": [ 202 | "Model size: 28945.603MB\n", 203 | "SmoothQuant INT8 accuracy: 0.798, per-sample lantecy: 212.361ms\n" 204 | ] 205 | } 206 | ], 207 | "source": [ 208 | "model_smoothquant = Int8OPTForCausalLM.from_pretrained(\n", 209 | " 'mit-han-lab/opt-30b-smoothquant', torch_dtype=torch.float16, device_map='auto')\n", 210 | "print_model_size(model_smoothquant)\n", 211 | "acc_smoothquant, lantecy_smoothquant = evaluator.evaluate(model_smoothquant)\n", 212 | "print(\n", 213 | " f'SmoothQuant INT8 accuracy: {acc_smoothquant}, per-sample lantecy: {lantecy_smoothquant:.3f}ms')" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": {}, 219 | "source": [ 220 | "## Conlusion\n", 221 | "\n", 222 | "We can see that the SmoothQuant model has a similar accuracy as the FP16 model, but it is faster and uses less memory. This is because SmoothQuant reduces the quantization difficulty of activations and enables the use of INT8 GEMM kernels." 223 | ] 224 | } 225 | ], 226 | "metadata": { 227 | "kernelspec": { 228 | "display_name": "Python 3.8.15 (conda)", 229 | "language": "python", 230 | "name": "python3" 231 | }, 232 | "language_info": { 233 | "codemirror_mode": { 234 | "name": "ipython", 235 | "version": 3 236 | }, 237 | "file_extension": ".py", 238 | "mimetype": "text/x-python", 239 | "name": "python", 240 | "nbconvert_exporter": "python", 241 | "pygments_lexer": "ipython3", 242 | "version": "3.8.15" 243 | }, 244 | "orig_nbformat": 4, 245 | "vscode": { 246 | "interpreter": { 247 | "hash": "b18562e22caa2a2bb5e6615862f7e7ce92f781ef7fc2a883871422ecfcd6595c" 248 | } 249 | } 250 | }, 251 | "nbformat": 4, 252 | "nbformat_minor": 2 253 | } 254 | -------------------------------------------------------------------------------- /figures/throughput_latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Adlik/smoothquantplus/df36110a315c548cea04e83c096ea7128fba79fa/figures/throughput_latency.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from struct import pack 2 | from setuptools import setup, find_packages 3 | setup( 4 | name='smoothquant', 5 | packages=find_packages(exclude=['figures', 'act_scales']) 6 | ) 7 | -------------------------------------------------------------------------------- /smoothquant/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /smoothquant/calibration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from datasets import load_dataset 5 | import functools 6 | from collections import defaultdict 7 | 8 | from functools import partial 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | 13 | def get_act_scales(model, tokenizer, dataset_path, num_samples=512, seq_len=512): 14 | model.eval() 15 | device = next(model.parameters()).device 16 | act_scales = {} 17 | 18 | def stat_tensor(name, tensor): 19 | hidden_dim = tensor.shape[-1] 20 | tensor = tensor.view(-1, hidden_dim).abs().detach() 21 | comming_max = torch.max(tensor, dim=0)[0].float().cpu() 22 | if name in act_scales: 23 | act_scales[name] = torch.max(act_scales[name], comming_max) 24 | else: 25 | act_scales[name] = comming_max 26 | 27 | def stat_input_hook(m, x, y, name): 28 | if isinstance(x, tuple): 29 | x = x[0] 30 | stat_tensor(name, x) 31 | 32 | hooks = [] 33 | for name, m in model.named_modules(): 34 | if isinstance(m, nn.Linear): 35 | hooks.append( 36 | m.register_forward_hook( 37 | functools.partial(stat_input_hook, name=name)) 38 | ) 39 | 40 | dataset = load_dataset("json", data_files=dataset_path, split="train") 41 | dataset = dataset.shuffle(seed=42) 42 | 43 | for i in tqdm(range(num_samples)): 44 | input_ids = tokenizer(dataset[i]["text"], return_tensors="pt", 45 | max_length=seq_len, truncation=True).input_ids.to(device) 46 | model(input_ids) 47 | 48 | for h in hooks: 49 | h.remove() 50 | 51 | return act_scales 52 | 53 | 54 | @torch.no_grad() 55 | def get_static_decoder_layer_scales(model, 56 | tokenizer, 57 | dataset_path, 58 | num_samples=512, 59 | seq_len=512, 60 | ): 61 | model.eval() 62 | device = next(model.parameters()).device 63 | 64 | act_dict = defaultdict(dict) 65 | 66 | def stat_io_hook(m, x, y, name): 67 | if isinstance(x, tuple): 68 | x = x[0] 69 | if name not in act_dict or "input" not in act_dict[name]: 70 | act_dict[name]["input"] = x.detach().abs().max().item() 71 | else: 72 | act_dict[name]["input"] = max( 73 | act_dict[name]["input"], x.detach().abs().max().item()) 74 | if isinstance(y, tuple): 75 | y = y[0] 76 | if name not in act_dict or "output" not in act_dict[name]: 77 | act_dict[name]["output"] = y.detach().abs().max().item() 78 | else: 79 | act_dict[name]["output"] = max( 80 | act_dict[name]["output"], y.detach().abs().max().item()) 81 | 82 | hooks = [] 83 | for name, m in model.named_modules(): 84 | if isinstance(m, torch.nn.Linear): 85 | hooks.append(m.register_forward_hook( 86 | partial(stat_io_hook, name=name))) 87 | 88 | print("Collecting activation scales...") 89 | pbar = tqdm(range(num_samples)) 90 | dataset = load_dataset('json', data_files=dataset_path, split="train") 91 | dataset = dataset.shuffle(seed=42) 92 | for i in pbar: 93 | input_ids = tokenizer(dataset[i]["text"], return_tensors="pt", 94 | max_length=seq_len, truncation=True).input_ids.to(device) 95 | model(input_ids) 96 | mean_scale = np.mean([v["input"] for v in act_dict.values()]) 97 | pbar.set_description(f"Mean input scale: {mean_scale:.2f}") 98 | for hook in hooks: 99 | hook.remove() 100 | 101 | decoder_layer_scales = [] 102 | for idx in range(model.config.num_hidden_layers): 103 | scale_dict = {} 104 | scale_dict["attn_input_scale"] = act_dict[ 105 | f"model.decoder.layers.{idx}.self_attn.q_proj"]['input'] / 127 106 | scale_dict["q_output_scale"] = act_dict[ 107 | f"model.decoder.layers.{idx}.self_attn.q_proj"]['output'] / 127 108 | scale_dict["k_output_scale"] = act_dict[ 109 | f"model.decoder.layers.{idx}.self_attn.k_proj"]['output'] / 127 110 | scale_dict["v_output_scale"] = act_dict[ 111 | f"model.decoder.layers.{idx}.self_attn.v_proj"]['output'] / 127 112 | scale_dict["out_input_scale"] = act_dict[ 113 | f"model.decoder.layers.{idx}.self_attn.out_proj"]['input'] / 127 114 | scale_dict["fc1_input_scale"] = act_dict[ 115 | f"model.decoder.layers.{idx}.fc1"]['input'] / 127 116 | scale_dict["fc2_input_scale"] = act_dict[ 117 | f"model.decoder.layers.{idx}.fc2"]["input"] / 127 118 | decoder_layer_scales.append(scale_dict) 119 | 120 | return decoder_layer_scales, act_dict 121 | -------------------------------------------------------------------------------- /smoothquant/fake_quant.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from functools import partial 4 | 5 | 6 | def quantize_weight_per_channel_absmax(w, n_bits=8): 7 | # w: (out_features, in_features) 8 | scales = w.abs().max(dim=-1, keepdim=True)[0] 9 | q_max = 2**(n_bits-1)-1 10 | scales.clamp_(min=1e-5).div_(q_max) 11 | w.div_(scales).round_().mul_(scales) 12 | return w 13 | 14 | 15 | @torch.no_grad() 16 | def quantize_weight_per_tensor_absmax(w, n_bits=8): 17 | # w: (out_features, in_features) 18 | scales = w.abs().max() 19 | q_max = 2**(n_bits-1)-1 20 | scales.clamp_(min=1e-5).div_(q_max) 21 | w.div_(scales).round_().mul_(scales) 22 | return w 23 | 24 | 25 | @torch.no_grad() 26 | def quantize_activation_per_token_absmax(t, n_bits=8): 27 | t_shape = t.shape 28 | t.view(-1, t_shape[-1]) 29 | scales = t.abs().max(dim=-1, keepdim=True)[0] 30 | q_max = 2**(n_bits-1)-1 31 | scales.clamp_(min=1e-5).div_(q_max) 32 | t.div_(scales).round_().mul_(scales) 33 | return t 34 | 35 | 36 | @torch.no_grad() 37 | def quantize_activation_per_tensor_absmax(t, n_bits=8): 38 | t_shape = t.shape 39 | t.view(-1, t_shape[-1]) 40 | scales = t.abs().max() 41 | q_max = 2**(n_bits-1)-1 42 | scales.clamp_(min=1e-5).div_(q_max) 43 | t.div_(scales).round_().mul_(scales) 44 | return t 45 | 46 | 47 | class W8A8Linear(nn.Module): 48 | def __init__(self, in_features, out_features, bias=True, act_quant='per_token', quantize_output=False): 49 | super().__init__() 50 | self.in_features = in_features 51 | self.out_features = out_features 52 | 53 | self.register_buffer('weight', torch.randn(self.out_features, 54 | self.in_features, dtype=torch.float16, requires_grad=False)) 55 | if bias: 56 | self.register_buffer('bias', torch.zeros( 57 | (1, self.out_features), dtype=torch.float16, requires_grad=False)) 58 | else: 59 | self.register_buffer('bias', None) 60 | 61 | if act_quant == 'per_token': 62 | self.act_quant_name = 'per_token' 63 | self.act_quant = partial( 64 | quantize_activation_per_token_absmax, n_bits=8) 65 | elif act_quant == 'per_tensor': 66 | self.act_quant_name = 'per_tensor' 67 | self.act_quant = partial( 68 | quantize_activation_per_tensor_absmax, n_bits=8) 69 | else: 70 | raise ValueError(f'Invalid act_quant: {act_quant}') 71 | 72 | if quantize_output: 73 | self.output_quant_name = self.act_quant_name 74 | self.output_quant = self.act_quant 75 | else: 76 | self.output_quant_name = 'None' 77 | self.output_quant = lambda x: x 78 | 79 | def to(self, *args, **kwargs): 80 | super(W8A8Linear, self).to(*args, **kwargs) 81 | self.weight = self.weight.to(*args, **kwargs) 82 | if self.bias is not None: 83 | self.bias = self.bias.to(*args, **kwargs) 84 | return self 85 | 86 | @torch.no_grad() 87 | def forward(self, x): 88 | q_x = self.act_quant(x) 89 | y = torch.functional.F.linear(q_x, self.weight, self.bias) 90 | q_y = self.output_quant(y) 91 | return q_y 92 | 93 | @staticmethod 94 | def from_float(module, weight_quant='per_channel', act_quant='per_token', quantize_output=False): 95 | assert isinstance(module, torch.nn.Linear) 96 | new_module = W8A8Linear( 97 | module.in_features, module.out_features, module.bias is not None, act_quant=act_quant, quantize_output=quantize_output) 98 | if weight_quant == 'per_channel': 99 | new_module.weight = quantize_weight_per_channel_absmax( 100 | module.weight, n_bits=8) # use 8-bit integer for weight 101 | elif weight_quant == 'per_tensor': 102 | new_module.weight = quantize_weight_per_tensor_absmax( 103 | module.weight, n_bits=8) 104 | else: 105 | raise ValueError(f'Invalid weight_quant: {weight_quant}') 106 | new_module.weight_quant_name = weight_quant 107 | if module.bias is not None: 108 | new_module.bias = module.bias 109 | return new_module 110 | 111 | def __repr__(self): 112 | return f'W8A8Linear({self.in_features}, {self.out_features}, bias={self.bias is not None}, weight_quant={self.weight_quant_name}, act_quant={self.act_quant_name}, output_quant={self.output_quant_name})' 113 | -------------------------------------------------------------------------------- /smoothquant/opt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers.models.opt.modeling_opt import ( 4 | OPTConfig, 5 | OPTForCausalLM, 6 | OPTModel, 7 | OPTPreTrainedModel, 8 | OPTLearnedPositionalEmbedding, 9 | OPTAttention, 10 | OPTDecoderLayer, 11 | OPTDecoder, 12 | BaseModelOutputWithPast 13 | ) 14 | from typing import Optional, Tuple, List 15 | from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearReLU 16 | from torch_int.nn.fused import LayerNormQ 17 | from transformers.utils import logging 18 | from torch_int.nn.bmm import BMM_S8T_S8N_S8T, BMM_S8T_S8N_F32T 19 | logger = logging.get_logger(__name__) 20 | 21 | 22 | class Int8OPTAttention(nn.Module): 23 | """Multi-headed attention from 'Attention Is All You Need' paper""" 24 | 25 | def __init__( 26 | self, 27 | embed_dim: int, 28 | num_heads: int, 29 | ): 30 | super().__init__() 31 | self.embed_dim = embed_dim 32 | self.num_heads = num_heads 33 | self.head_dim = embed_dim // num_heads 34 | 35 | if (self.head_dim * num_heads) != self.embed_dim: 36 | raise ValueError( 37 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" 38 | f" and `num_heads`: {num_heads})." 39 | ) 40 | 41 | self.attention_weight_scale = 1.0 42 | 43 | self.qk_bmm = BMM_S8T_S8N_F32T(1.0) 44 | self.pv_bmm = BMM_S8T_S8N_S8T(1.0) 45 | 46 | self.k_proj = W8A8B8O8Linear(embed_dim, embed_dim) 47 | self.v_proj = W8A8B8O8Linear(embed_dim, embed_dim) 48 | self.q_proj = W8A8B8O8Linear(embed_dim, embed_dim) 49 | self.out_proj = W8A8BFP32OFP32Linear(embed_dim, embed_dim) 50 | 51 | @staticmethod 52 | @torch.no_grad() 53 | def from_float(module: OPTAttention, 54 | input_scale: float, 55 | q_output_scale: float, 56 | k_output_scale: float, 57 | v_output_scale: float, 58 | out_input_scale: float): 59 | int8_module = Int8OPTAttention(module.embed_dim, module.num_heads) 60 | # Fuse the scaling into the q_proj output scale 61 | q_output_scale = q_output_scale * module.scaling 62 | module.q_proj.weight *= module.scaling 63 | module.q_proj.bias *= module.scaling 64 | int8_module.q_proj = W8A8B8O8Linear.from_float( 65 | module.q_proj, input_scale, q_output_scale) 66 | int8_module.k_proj = W8A8B8O8Linear.from_float( 67 | module.k_proj, input_scale, k_output_scale) 68 | int8_module.v_proj = W8A8B8O8Linear.from_float( 69 | module.v_proj, input_scale, v_output_scale) 70 | int8_module.out_proj = W8A8BFP32OFP32Linear.from_float( 71 | module.out_proj, out_input_scale) 72 | int8_module.qk_bmm = BMM_S8T_S8N_F32T.from_scale( 73 | q_output_scale, k_output_scale) 74 | 75 | # alpha = s_prob * s_v / s_out, where s_prob = 1 / 127 76 | int8_module.pv_bmm = BMM_S8T_S8N_S8T.from_scale( 77 | 1.0 / 127, v_output_scale, out_input_scale) 78 | return int8_module 79 | 80 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 81 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 82 | 83 | @torch.no_grad() 84 | def forward( 85 | self, 86 | hidden_states: torch.Tensor, 87 | key_value_states: Optional[torch.Tensor] = None, 88 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 89 | attention_mask: Optional[torch.Tensor] = None, 90 | layer_head_mask: Optional[torch.Tensor] = None, 91 | output_attentions: bool = False, 92 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 93 | """Input shape: Batch x Time x Channel""" 94 | # if key_value_states are provided this layer is used as a cross-attention layer 95 | # for the decoder 96 | is_cross_attention = key_value_states is not None 97 | 98 | bsz, tgt_len, _ = hidden_states.size() 99 | 100 | # get query proj 101 | query_states = self.q_proj(hidden_states) 102 | # get key, value proj 103 | if is_cross_attention and past_key_value is not None: 104 | # reuse k,v, cross_attentions 105 | key_states = past_key_value[0] 106 | value_states = past_key_value[1] 107 | elif is_cross_attention: 108 | # cross_attentions 109 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz) 110 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz) 111 | elif past_key_value is not None: 112 | # reuse k, v, self_attention 113 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 114 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 115 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 116 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 117 | else: 118 | # self_attention 119 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 120 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 121 | 122 | past_key_value = (key_states, value_states) 123 | 124 | proj_shape = (bsz * self.num_heads, -1, self.head_dim) 125 | query_states = self._shape( 126 | query_states, tgt_len, bsz).view(*proj_shape) 127 | key_states = key_states.view(*proj_shape) 128 | value_states = value_states.view(*proj_shape) 129 | 130 | src_len = key_states.size(1) 131 | attn_weights = self.qk_bmm(query_states, key_states) 132 | 133 | if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): 134 | raise ValueError( 135 | f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" 136 | f" {attn_weights.size()}" 137 | ) 138 | 139 | if attention_mask is not None: 140 | if attention_mask.size() != (bsz, 1, tgt_len, src_len): 141 | raise ValueError( 142 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" 143 | ) 144 | attn_weights = attn_weights.view( 145 | bsz, self.num_heads, tgt_len, src_len) + attention_mask 146 | attn_weights = torch.max(attn_weights, torch.tensor( 147 | torch.finfo(attn_weights.dtype).min)) 148 | attn_weights = attn_weights.view( 149 | bsz * self.num_heads, tgt_len, src_len) 150 | 151 | attn_probs = nn.functional.softmax(attn_weights, dim=-1) 152 | 153 | if layer_head_mask is not None: 154 | if layer_head_mask.size() != (self.num_heads,): 155 | raise ValueError( 156 | f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" 157 | f" {layer_head_mask.size()}" 158 | ) 159 | attn_probs = layer_head_mask.view( 160 | 1, -1, 1, 1) * attn_probs.view(bsz, self.num_heads, tgt_len, src_len) 161 | attn_probs = attn_probs.view( 162 | bsz * self.num_heads, tgt_len, src_len) 163 | 164 | if output_attentions: 165 | # this operation is a bit awkward, but it's required to 166 | # make sure that attn_weights keeps its gradient. 167 | # In order to do so, attn_weights have to be reshaped 168 | # twice and have to be reused in the following 169 | attn_probs_reshaped = attn_probs.view( 170 | bsz, self.num_heads, tgt_len, src_len) 171 | attn_probs = attn_probs_reshaped.view( 172 | bsz * self.num_heads, tgt_len, src_len) 173 | else: 174 | attn_probs_reshaped = None 175 | 176 | # (A_row V_row)_row = (A_row V_col ^T)_row 177 | attn_probs.mul_(127).round_() 178 | attn_probs = attn_probs.to(torch.int8) 179 | 180 | value_states = value_states.transpose(1, 2).contiguous() 181 | attn_output = self.pv_bmm(attn_probs, value_states) 182 | 183 | if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): 184 | raise ValueError( 185 | f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" 186 | f" {attn_output.size()}" 187 | ) 188 | 189 | attn_output = attn_output.view( 190 | bsz, self.num_heads, tgt_len, self.head_dim) 191 | attn_output = attn_output.transpose(1, 2) 192 | 193 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be 194 | # partitioned aross GPUs when using tensor-parallelism. 195 | attn_output = attn_output.reshape( 196 | bsz, tgt_len, self.embed_dim).contiguous() 197 | attn_output = self.out_proj(attn_output) 198 | 199 | return attn_output, attn_probs_reshaped, past_key_value 200 | 201 | 202 | class Int8OPTDecoderLayer(nn.Module): 203 | def __init__(self, embed_dim, num_attention_heads, ffn_dim): 204 | super().__init__() 205 | self.embed_dim = embed_dim 206 | self.self_attn = Int8OPTAttention( 207 | embed_dim=self.embed_dim, 208 | num_heads=num_attention_heads 209 | ) 210 | 211 | self.self_attn_layer_norm = LayerNormQ( 212 | self.embed_dim) 213 | self.fc1 = W8A8B8O8LinearReLU(self.embed_dim, ffn_dim) 214 | self.fc2 = W8A8BFP32OFP32Linear( 215 | ffn_dim, self.embed_dim) 216 | self.final_layer_norm = LayerNormQ(self.embed_dim) 217 | 218 | @staticmethod 219 | def from_float(module: OPTDecoderLayer, 220 | attn_input_scale: float, 221 | q_output_scale: float, 222 | k_output_scale: float, 223 | v_output_scale: float, 224 | out_input_scale: float, 225 | fc1_input_scale: float, 226 | fc2_input_scale: float): 227 | int8_module = Int8OPTDecoderLayer( 228 | module.embed_dim, 229 | module.self_attn.num_heads, 230 | module.fc1.out_features 231 | ) 232 | int8_module.self_attn_layer_norm = LayerNormQ.from_float( 233 | module.self_attn_layer_norm, attn_input_scale) 234 | int8_module.self_attn = Int8OPTAttention.from_float( 235 | module.self_attn, attn_input_scale, q_output_scale, k_output_scale, v_output_scale, out_input_scale) 236 | int8_module.final_layer_norm = LayerNormQ.from_float( 237 | module.final_layer_norm, fc1_input_scale) 238 | int8_module.fc1 = W8A8B8O8LinearReLU.from_float( 239 | module.fc1, fc1_input_scale, fc2_input_scale) 240 | int8_module.fc2 = W8A8BFP32OFP32Linear.from_float( 241 | module.fc2, fc2_input_scale) 242 | return int8_module 243 | 244 | def forward( 245 | self, 246 | hidden_states: torch.Tensor, 247 | attention_mask: Optional[torch.Tensor] = None, 248 | layer_head_mask: Optional[torch.Tensor] = None, 249 | output_attentions: Optional[bool] = False, 250 | use_cache: Optional[bool] = False, 251 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 252 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 253 | """ 254 | Args: 255 | hidden_states (`torch.Int8Tensor`): the output of previous layer's layernorm in INT8 256 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 257 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 258 | layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size 259 | `(encoder_attention_heads,)`. 260 | output_attentions (`bool`, *optional*): 261 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 262 | returned tensors for more detail. 263 | use_cache (`bool`, *optional*): 264 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 265 | (see `past_key_values`). 266 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 267 | """ 268 | 269 | # Self Attention 270 | residual = hidden_states 271 | hidden_states = self.self_attn_layer_norm(hidden_states) 272 | 273 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 274 | hidden_states=hidden_states, 275 | past_key_value=past_key_value, 276 | attention_mask=attention_mask, 277 | layer_head_mask=layer_head_mask, 278 | output_attentions=output_attentions, 279 | ) 280 | 281 | residual.add_(hidden_states.to(residual.dtype)) 282 | 283 | hidden_states = self.final_layer_norm(residual) 284 | 285 | hidden_states = self.fc1(hidden_states) 286 | 287 | hidden_states = self.fc2(hidden_states) 288 | 289 | residual.add_(hidden_states.to(residual.dtype)) 290 | 291 | outputs = (residual,) 292 | 293 | if output_attentions: 294 | outputs += (self_attn_weights,) 295 | 296 | if use_cache: 297 | outputs += (present_key_value,) 298 | 299 | return outputs 300 | 301 | 302 | class Int8OPTDecoder(OPTPreTrainedModel): 303 | """ 304 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Int8OPTDecoderLayer`] 305 | 306 | """ 307 | 308 | def __init__(self, config): 309 | super().__init__(config) 310 | self.padding_idx = config.pad_token_id 311 | self.max_target_positions = config.max_position_embeddings 312 | self.vocab_size = config.vocab_size 313 | 314 | self.embed_tokens = nn.Embedding( 315 | config.vocab_size, config.word_embed_proj_dim, self.padding_idx) 316 | self.embed_positions = OPTLearnedPositionalEmbedding( 317 | config.max_position_embeddings, config.hidden_size) 318 | 319 | if config.word_embed_proj_dim != config.hidden_size: 320 | self.project_out = nn.Linear( 321 | config.hidden_size, config.word_embed_proj_dim, bias=False) 322 | else: 323 | self.project_out = None 324 | 325 | if config.word_embed_proj_dim != config.hidden_size: 326 | self.project_in = nn.Linear( 327 | config.word_embed_proj_dim, config.hidden_size, bias=False) 328 | else: 329 | self.project_in = None 330 | 331 | # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility 332 | # with checkpoints that have been fine-tuned before transformers v4.20.1 333 | # see https://github.com/facebookresearch/metaseq/pull/164 334 | if config.do_layer_norm_before and not config._remove_final_layer_norm: 335 | self.final_layer_norm = nn.LayerNorm(config.hidden_size) 336 | else: 337 | self.final_layer_norm = None 338 | 339 | self.layers = nn.ModuleList( 340 | [Int8OPTDecoderLayer(config.hidden_size, config.num_attention_heads, config.ffn_dim) for _ in range(config.num_hidden_layers)]) 341 | 342 | self.gradient_checkpointing = False 343 | # Initialize weights and apply final processing 344 | self.post_init() 345 | 346 | get_input_embeddings = OPTDecoder.get_input_embeddings 347 | set_input_embeddings = OPTDecoder.set_input_embeddings 348 | _prepare_decoder_attention_mask = OPTDecoder._prepare_decoder_attention_mask 349 | old_forward = OPTDecoder.forward 350 | 351 | @staticmethod 352 | def from_float(module, decoder_layer_scales): 353 | int8_module = Int8OPTDecoder(module.config) 354 | int8_module.embed_tokens = module.embed_tokens 355 | int8_module.embed_positions = module.embed_positions 356 | int8_module.project_out = module.project_out 357 | int8_module.final_layer_norm = module.final_layer_norm 358 | for i, layer in enumerate(module.layers): 359 | int8_module.layers[i] = Int8OPTDecoderLayer.from_float( 360 | layer, **decoder_layer_scales[i]) 361 | return int8_module 362 | 363 | def forward( 364 | self, 365 | input_ids: torch.LongTensor, 366 | attention_mask: Optional[torch.Tensor] = None, 367 | head_mask: Optional[torch.Tensor] = None, 368 | past_key_values: Optional[List[torch.FloatTensor]] = None, 369 | inputs_embeds: Optional[torch.FloatTensor] = None, 370 | use_cache: Optional[bool] = None, 371 | output_attentions: Optional[bool] = None, 372 | output_hidden_states: Optional[bool] = None, 373 | return_dict: Optional[bool] = None, 374 | ) -> BaseModelOutputWithPast: 375 | # pad the input to the multiple of 16 376 | input_len = input_ids.shape[1] 377 | from torch.nn.functional import pad 378 | if input_len % 16 != 0: 379 | # is 1 380 | padding_len = 16 - input_len % 16 381 | input_ids = pad(input_ids, (0, padding_len), value=1) 382 | if attention_mask is not None: 383 | attention_mask = pad(attention_mask, (0, padding_len), value=0) 384 | output = self.old_forward( 385 | input_ids=input_ids, 386 | attention_mask=attention_mask, 387 | head_mask=head_mask, 388 | past_key_values=past_key_values, 389 | inputs_embeds=inputs_embeds, 390 | use_cache=use_cache, 391 | output_attentions=output_attentions, 392 | output_hidden_states=output_hidden_states 393 | ) 394 | # slice the output to the original length 395 | if input_len % 16 != 0: 396 | output.last_hidden_state = output.last_hidden_state[:, 397 | :input_len, :] 398 | return output 399 | 400 | 401 | class Int8OPTModel(OPTPreTrainedModel): 402 | def __init__(self, config: OPTConfig): 403 | super().__init__(config) 404 | self.decoder = Int8OPTDecoder(config) 405 | # Initialize weights and apply final processing 406 | self.post_init() 407 | get_input_embeddings = OPTModel.get_input_embeddings 408 | set_input_embeddings = OPTModel.set_input_embeddings 409 | get_decoder = OPTModel.get_decoder 410 | forward = OPTModel.forward 411 | 412 | @staticmethod 413 | def from_float(module, decoder_layer_scales): 414 | int8_module = Int8OPTModel(module.config) 415 | int8_module.decoder = Int8OPTDecoder.from_float( 416 | module.decoder, decoder_layer_scales) 417 | return int8_module 418 | 419 | 420 | class Int8OPTForCausalLM(OPTPreTrainedModel): 421 | _keys_to_ignore_on_load_missing = [r"lm_head.weight"] 422 | 423 | def __init__(self, config): 424 | super().__init__(config) 425 | self.model = Int8OPTModel(config) 426 | 427 | # the lm_head weight is automatically tied to the embed tokens weight 428 | self.lm_head = nn.Linear( 429 | config.word_embed_proj_dim, config.vocab_size, bias=False) 430 | 431 | # Initialize weights and apply final processing 432 | self.post_init() 433 | 434 | @staticmethod 435 | def from_float(module, decoder_layer_scales): 436 | int8_module = Int8OPTForCausalLM(module.config) 437 | int8_module.model = Int8OPTModel.from_float( 438 | module.model, decoder_layer_scales) 439 | int8_module.lm_head = module.lm_head 440 | return int8_module 441 | 442 | get_input_embeddings = OPTForCausalLM.get_input_embeddings 443 | set_input_embeddings = OPTForCausalLM.set_input_embeddings 444 | get_output_embeddings = OPTForCausalLM.get_output_embeddings 445 | set_output_embeddings = OPTForCausalLM.set_output_embeddings 446 | set_decoder = OPTForCausalLM.set_decoder 447 | get_decoder = OPTForCausalLM.get_decoder 448 | forward = OPTForCausalLM.forward 449 | prepare_inputs_for_generation = OPTForCausalLM.prepare_inputs_for_generation 450 | _reorder_cache = OPTForCausalLM._reorder_cache 451 | -------------------------------------------------------------------------------- /smoothquant/smooth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers.models.opt.modeling_opt import OPTDecoderLayer 5 | from transformers.models.bloom.modeling_bloom import BloomBlock 6 | 7 | 8 | @torch.no_grad() 9 | def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5): 10 | if not isinstance(fcs, list): 11 | fcs = [fcs] 12 | assert isinstance(ln, nn.LayerNorm) 13 | for fc in fcs: 14 | assert isinstance(fc, nn.Linear) 15 | assert ln.weight.numel() == fc.in_features == act_scales.numel() 16 | 17 | device, dtype = fcs[0].weight.device, fcs[0].weight.dtype 18 | act_scales = act_scales.to(device=device, dtype=dtype) 19 | weight_scales = torch.cat([fc.weight.abs().max( 20 | dim=0, keepdim=True)[0] for fc in fcs], dim=0) 21 | weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5) 22 | 23 | scales = (act_scales.pow(alpha) / weight_scales.pow(1-alpha) 24 | ).clamp(min=1e-5).to(device).to(dtype) 25 | 26 | ln.weight.div_(scales) 27 | ln.bias.div_(scales) 28 | 29 | for fc in fcs: 30 | fc.weight.mul_(scales.view(1, -1)) 31 | 32 | 33 | @torch.no_grad() 34 | def smooth_lm(model, scales, alpha=0.5): 35 | for name, module in model.named_modules(): 36 | if isinstance(module, OPTDecoderLayer): 37 | attn_ln = module.self_attn_layer_norm 38 | qkv = [module.self_attn.q_proj, 39 | module.self_attn.k_proj, module.self_attn.v_proj] 40 | qkv_input_scales = scales[name + '.self_attn.q_proj'] 41 | smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha) 42 | 43 | ffn_ln = module.final_layer_norm 44 | fc1 = module.fc1 45 | fc1_input_scales = scales[name + '.fc1'] 46 | smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha) 47 | elif isinstance(module, BloomBlock): 48 | attn_ln = module.input_layernorm 49 | qkv = module.self_attention.query_key_value 50 | qkv_input_scales = scales[name + '.self_attention.query_key_value'] 51 | smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha) 52 | 53 | ffn_ln = module.post_attention_layernorm 54 | fc1 = module.mlp.dense_h_to_4h 55 | fc1_input_scales = scales[name + '.mlp.dense_h_to_4h'] 56 | smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha) 57 | --------------------------------------------------------------------------------