├── .gitignore ├── .gitignore.git ├── README.md ├── bench ├── README.md ├── bench.jpg ├── tflops_int4_overall.jpg └── tflops_int8_overall.jpg ├── benchalltokens.py ├── benchflops.py ├── benchlatency.py ├── evalppl.py ├── examples ├── .gitignore ├── basic_generate.py ├── basic_quant.py ├── basic_quant_mix.py ├── basic_quant_opt.py ├── basic_quant_quik.py ├── basic_safetensors_generate.py ├── benchW8A8.ipynb ├── benchbitsand.py ├── benchkernels.sh ├── benchlayers.sh ├── benchmark.py ├── breakdown.sh ├── eval.py ├── mmlu.py ├── mmlu.sh ├── overhead.sh ├── smooth_quant_get_act.py ├── testgenerate.py └── utils.py ├── figures ├── awq32.gif ├── awq512.gif ├── mixq32.gif ├── mixq512.gif └── textmixq.jpg ├── generate.py ├── generate.sh ├── get_act.sh ├── mixquant ├── Cache.py ├── __init__.py ├── int8fusedkernel.py ├── kernel.py ├── models │ ├── __init__.py │ ├── aquila.py │ ├── auto.py │ ├── baichuan.py │ ├── base.py │ ├── basefuser.py │ ├── bloom.py │ ├── falcon.py │ ├── gpt_bigcode.py │ ├── gpt_neox.py │ ├── gptj.py │ ├── llama.py │ ├── mistral.py │ ├── mpt.py │ ├── opt.py │ └── sample.py ├── modules │ ├── __init__.py │ ├── act.py │ ├── fused │ │ ├── __init__.py │ │ ├── attn.py │ │ ├── block.py │ │ ├── cache.py │ │ ├── gptj_attn.py │ │ ├── mistral_attn.py │ │ ├── mlp.py │ │ ├── model.py │ │ └── norm.py │ ├── linear.py │ └── qlinear.py ├── quantize │ ├── __init__.py │ └── mixquant.py └── utils │ ├── __init__.py │ ├── calib_data.py │ ├── fused_utils.py │ ├── lm_eval_adaptor.py │ ├── module down_weight_only.py │ ├── module.py │ ├── module_.py │ ├── parallel.py │ └── utils.py ├── mmlu.sh ├── quant.sh ├── requirements.txt ├── runalltokens.sh ├── runlatency.sh ├── runppl.sh ├── runthroughput.sh └── utils └── utils ├── __init__.py ├── data_utils.py ├── exllama_utils.py ├── import_utils.py ├── peft_utils.py └── perplexity_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | third-party 3 | data/ 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | *.pyc 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbolsd 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | 166 | *.pt 167 | **/*.pt 168 | **/*.pyc 169 | *.json 170 | __pycache__ 171 | 172 | 173 | 174 | 175 | build/ 176 | output/ 177 | outputcprof/ 178 | outputcproffp16/ 179 | outputweight/ 180 | outputncu/ 181 | roofline/ 182 | weight/ 183 | *.onnx 184 | 185 | -------------------------------------------------------------------------------- /.gitignore.git: -------------------------------------------------------------------------------- 1 | *.onnx 2 | 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MixQ 2 | 3 | 4 | 5 | MixQ: Taming Dynamic Outliers in Mixed-Precision Quantization by Online Prediction 6 | 7 | We use mixed-precision GEMM for enhancing throughput. 8 | 9 | Please refer to https://github.com/Qcompiler/vllm-mixed-precision for end-to-end text generation. 10 | 11 | ## Comparision with AWQ 12 | 13 | Assuming we have a task that is to compute the PPL(perplexity) of Wikitext2. 14 | The dataset wikitext contains 333088 validation data. 15 | 16 | For ```batch size = 32```, the task is devided into 10409 parts. 17 | 18 | AWQ finished the task in 10 minutes with 16.71 it/s. 19 | 20 | 21 | 22 | MixQ (W8A8O16) finished the task in 4.50 minutes with 35.02 it/s. 23 | 24 | 25 | 26 | For ```batch size = 512```, the task is devided into 655 parts. 27 | 28 | AWQ finished the task in 127 seconds with 5.2 it/s. 29 | 30 | 31 | 32 | MixQ (W8A8O16) finished the task in 30 seconds with 21.34 it/s. 33 | 34 | 35 | 36 | 37 | ## Setup 38 | 39 | Please download the mixlib kernel from https://github.com/Qcompiler/QComplier: 40 | 41 | ``` 42 | git clone git@github.com:Qcompiler/QComplier.git 43 | cd EETQ 44 | python setup.py install 45 | ``` 46 | ``` 47 | cd quantkernel 48 | python setup.py install 49 | ``` 50 | 51 | ## Benchmarking the throughput 52 | 53 | 54 | 55 | It is very easy to quantize a LLM and run by MIXQ 4bit or 8bit kernel 56 | 57 | Running the following CMD to quantize the LLM with W8A8O16 kernel: 58 | 59 | ``` 60 | python examples/basic_quant_mix.py --model_path /mnt/data/checkpoint/Llama-2-7b --quant_file /home/dataset/quant/quant8/Llama-2-7b --w_bit 8 61 | ``` 62 | 63 | Benchmark the throughput of MIXQ by: 64 | 65 | ``` 66 | python benchflops.py --model_type mix --model_path /home/dataset/quant/quant8/Llama-2-7b --quant_file /home/dataset/quant/quant8/Llama-2-7b --batch_size 512 --bit 8 67 | ``` 68 | 69 | In NVIDIA A100-PCIE-40GB, the output is 70 | 71 | ``` 72 | Version: mix 8bit 73 | | Batch Size | Decode Length | Decode tokens/s | Memory (VRAM) | 74 | |-------------:|----------------:|------------------:|:-----------------| 75 | | 512 | 1024 | 10609.8 | 7.86 GB (19.97%) | 76 | ``` 77 | 78 | 79 | 80 | # News !! 81 | 82 | We have integrate the MixedQLinear designed by QUIK into our framework! The QUIK now is able to support a wide range of LLMs including: 83 | 84 | 85 | - Llama-2 7B/13B/70B 86 | - Llama-3 8B 87 | - Falcon 7B/40B 88 | - ChatGLM 7B 89 | - QWen2 7B 90 | 91 | 92 | ## How to Run 93 | 94 | It is very easy to quantize a LLM and run by QUIK 4bit kernel 95 | 96 | Running the following CMD to quantize the LLM 97 | 98 | ``` 99 | python examples/basic_quant_quik.py --model_path /mnt/data/checkpoint/Llama-2-7b --quant_file /home/dataset/quant/quantquik4/Llama-2-7b --w_bit 4 100 | ``` 101 | 102 | Benchmark the throughput of QUIK by: 103 | 104 | ``` 105 | python benchflops.py --model_type quik --model_path /home/dataset/quant/quantquik4/Llama-2-7b \ 106 | --quant_file /home/dataset/quant/quantquik4/quik4/Llama-2-7b \ 107 | --batch_size 512 --bit 4 108 | ``` 109 | 110 | In NVIDIA A100-PCIE-40GB, the output is 111 | 112 | ``` 113 | Version: quik 4bit 114 | | Batch Size | Decode Length | Decode tokens/s | Memory (VRAM) | 115 | |-------------:|----------------:|------------------:|:-----------------| 116 | | 512 | 1024 | 8981.17 | 4.88 GB (12.40%) | 117 | ``` 118 | 119 | 120 | 121 | # Tensorrt-LLM implementation of QUIK and MIXQ 122 | 123 | We have supported the end-to-end text generation in TRT-LLM and VLLM! 124 | 125 | For TRT-LLM, please download the NVIDIA TensorRT docker. [TensorRT docker](https://github.com/NVIDIA/TensorRT-LLM). DO NOT USE your local environment! 126 | 127 | Please enter the e2eTRTLLM folder https://github.com/Qcompiler/MixQ_Tensorrt_LLM 128 | 129 | ``` 130 | git clone https://github.com/Qcompiler/MixQ_Tensorrt_LLM.git 131 | docker pull registry.cn-hangzhou.aliyuncs.com/dongdongchen/dongdong:v1 132 | ``` 133 | 134 | 135 | 136 | Please Running the docker: 137 | 138 | ``` 139 | export name=myname 140 | bash -c " nvidia-smi; docker run --rm -it --ipc=host -p 6789:22 \ 141 | -v /home/${name}/lianxiang/lianxiangTRT/:/code/tensorrt_llm \ 142 | -v /mnt/octave/data/${name}/checkpoint:/dataset \ 143 | -v /home/${name}/checkpoint:/code/checkpoint \ 144 | -v /mnt/octave/data/${name}/lianxiang/checkpoint:/octave/checkpoint \ 145 | --ulimit memlock=-1 --ulimit stack=67108864 \ 146 | --gpus=all \ 147 | --env 'CCACHE_DIR=/code/tensorrt_llm/cpp/.ccache' \ 148 | --env 'CCACHE_BASEDIR=/code/tensorrt_llm' \ 149 | --workdir /app/tensorrt_llm \ 150 | --hostname hpc-release \ 151 | --name tensorrt_llm-release-zhanghy \ 152 | --tmpfs /tmp:exec \ 153 | registry.cn-hangzhou.aliyuncs.com/dongdongchen/dongdong:v1 " 154 | 155 | ``` 156 | 157 | 158 | After starting the docker, set the env : 159 | 160 | ``` 161 | model=Llama-2-7b 162 | ngpu=1 163 | export model_dir=/code/tensorrt_llm/checkpoint/${model} 164 | export quant_dir=/code/tensorrt_llm/checkpoint/checkpoinmix/tllm_checkpoint_${ngpu}gpu_fp16${model} 165 | export out_dir=/code/tensorrt_llm/checkpoint/trt_enginesmix/tllm_checkpoint_${ngpu}gpu_fp16${model} 166 | ``` 167 | 168 | Please quantize the model by: 169 | 170 | ``` 171 | CUDA_VISIBLE_DEVICES=0 python quantize.py --model_dir ${model_dir} \ 172 | --output_dir ${quant_dir} --dtype float16 --device cpu \ 173 | --qformat int8_mix --calib_size 32 174 | ``` 175 | 176 | Please build the MIXQ model by: 177 | 178 | ``` 179 | CUDA_VISIBLE_DEVICES=0 trtllm-build --checkpoint_dir ${quant_dir} \ 180 | --output_dir ${out_dir} \ 181 | --gemm_plugin float16 --mix_precision int8 182 | ``` 183 | 184 | 185 | Generating the text with MIXQ by: 186 | 187 | ``` 188 | CUDA_VISIBLE_DEVICES=0 python summarize.py --test_trt_llm \ 189 | --hf_model_dir ${model_dir} \ 190 | --data_type fp16 \ 191 | --engine_dir ${out_dir} 192 | ``` 193 | 194 | 195 | ## Building the TRT-LLM MIXQ plugging with 4 stage pipline for Llama-2-70B 196 | 197 | 198 | 199 | ``` 200 | model=Llama-2-70b 201 | ngpu=4 202 | export model_dir=/code/tensorrt_llm/checkpoint/${model} 203 | export quant_dir=/code/tensorrt_llm/checkpoint/checkpoinmix/tllm_checkpoint_${ngpu}gpu_fp16${model} 204 | export out_dir=/code/tensorrt_llm/checkpoint/trt_enginesmix/tllm_checkpoint_${ngpu}gpu_fp16${model} 205 | ``` 206 | 207 | Please quantize the model by: 208 | 209 | ``` 210 | CUDA_VISIBLE_DEVICES=0,1,2,3 python quantize.py --model_dir ${model_dir} \ 211 | --output_dir ${quant_dir} --dtype float16 --device cpu \ 212 | --qformat int8_mix --calib_size 32 --pp_size ${gpu} 213 | ``` 214 | 215 | Please build the MIXQ model by: 216 | 217 | ``` 218 | CUDA_VISIBLE_DEVICES=0,1,2,3 trtllm-build --checkpoint_dir ${quant_dir} \ 219 | --output_dir ${out_dir} \ 220 | --gemm_plugin float16 --mix_precision int8 221 | ``` 222 | 223 | 224 | Generating the text with MIXQ by: 225 | 226 | ``` 227 | CUDA_VISIBLE_DEVICES=0,1,2,3 mpirun -np 4 --allow-run-as-root python summarize.py --test_trt_llm \ 228 | --hf_model_dir ${model_dir} \ 229 | --data_type fp16 \ 230 | --engine_dir ${out_dir} 231 | ``` 232 | 233 | 234 | ## Text generation result 235 | 236 | # Llama-2-7B FP16 baseline 237 | 238 | 239 | 240 | When running the ```summarize.py``` of MIXQ (Llama-2-7B in A100, 40GB, PCIE), we get: 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | ## Mixed-precision Inference In VLLM 249 | 250 | Please follow the https://github.com/Qcompiler/vllm-mixed-precision for mixed-precision inference. 251 | 252 | Please install the vllm by 253 | ``` 254 | pip install vllm==0.6.2 255 | ``` 256 | 257 | 258 | Please install the mixed-precision source code by 259 | ``` 260 | git clone git@github.com:Qcompiler/vllm-mixed-precision.git 261 | ``` 262 | 263 | And copy the ".so" from the vllm project 264 | 265 | ``` 266 | cp -r $PYTHON_PATH/lib/python3.11/site-packages/vllm/*.so vllm-mixed-precision/vllm/ 267 | ``` 268 | 269 | Delete the vllm==0.6.2 270 | ``` 271 | pip uninstall vllm 272 | ``` 273 | 274 | 275 | 276 | ## Runing 8-bit mixed-preiciosn infernce in vllm 277 | 278 | ``` 279 | export PYTHONPATH=$( pwd ) 280 | python test8bit.py --quant 8 281 | ``` 282 | -------------------------------------------------------------------------------- /bench/README.md: -------------------------------------------------------------------------------- 1 | # Benchmarking for kernels 2 | 3 | 4 | # Benchmarking for MIXQ in A100 5 | 6 | For the 8-bit kernel evaluation in A100: 7 | 8 | 9 | 10 | For the 4-bit kernel evaluation in A100: 11 | 12 | 13 | 14 | 15 | 16 | # Benchmarking for FP8 and INT8 in H100 17 | 18 | In Hopper arch (H100), we bench the kernel performance of FP8, INT8, FP16. We found that the FP8 kernel is slightly slower than INT8 kernel: y-axis is the TFLOPs of kernel, x-axis is the shape of GEMM. M=N=K; 19 | 20 | -------------------------------------------------------------------------------- /bench/bench.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/bench/bench.jpg -------------------------------------------------------------------------------- /bench/tflops_int4_overall.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/bench/tflops_int4_overall.jpg -------------------------------------------------------------------------------- /bench/tflops_int8_overall.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/bench/tflops_int8_overall.jpg -------------------------------------------------------------------------------- /benchalltokens.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | 5 | 6 | os.environ["WORLD_SIZE"] = "1" 7 | import time 8 | import torch 9 | import argparse 10 | import numpy as np 11 | import pandas as pd 12 | 13 | 14 | from transformers import AutoTokenizer 15 | from torch.cuda import OutOfMemoryError 16 | torch.manual_seed(0) 17 | 18 | from mixquant.Cache import MixLibCache 19 | def warmup(model): 20 | warm_up = torch.randn((4096,4096)).to(next(model.parameters()).device) 21 | torch.mm(warm_up,warm_up) 22 | 23 | 24 | 25 | 26 | 27 | 28 | def prepare_data(_dataset_path = 'wikitext', _split='test', _text_column='text'): 29 | from datasets import load_dataset 30 | """ 31 | Prepares the dataset by loading and formatting. 32 | 33 | Returns 34 | ------- 35 | str 36 | The formatted dataset as a single string. 37 | """ 38 | if _dataset_path == 'wikitext': 39 | _dataset_name = 'wikitext-2-raw-v1' 40 | data = load_dataset(_dataset_path, _dataset_name, split=_split) 41 | 42 | elif _dataset_path == 'c4': 43 | _dataset_name = 'realnewslike' 44 | data = load_dataset(_dataset_path, _dataset_name, split=_split) 45 | else: 46 | _dataset_name = 'wikitext-2-raw-v1' 47 | data = load_dataset(os.path.join(_dataset_path,'wikitext'), 48 | _dataset_name, split=_split, cache_dir="/home/chenyidong/tmp") 49 | # Format the text column of the dataset 50 | text_list = [' \n' if s == '' else s for s in data[_text_column]] 51 | return ''.join(text_list) 52 | 53 | def decode_token(model, _tokenizer, _text, n_batch, repeat = 10): 54 | 55 | 56 | tokens = _tokenizer(_text, truncation=False, return_tensors='pt').input_ids.to('cuda') 57 | start = 0 58 | end = n_batch 59 | for j in range(repeat): 60 | 61 | batch_start = start + j * n_batch 62 | batch_size = min(end - batch_start, n_batch) 63 | 64 | token_org = tokens[0][batch_start].item() 65 | 66 | if j == 0: 67 | # Replace the first token with the BOS token 68 | tokens[0][batch_start] = _tokenizer.bos_token_id 69 | 70 | # Compute the logits for the current batch of tokens 71 | _compute_batch_logits(tokens, batch_start, batch_size) 72 | 73 | tokens[0][batch_start] = token_org 74 | 75 | def _compute_batch_logits(_model,tokens, batch_start, batch_size): 76 | # Compute the logits without keeping track of gradients 77 | 78 | outputs = _model(tokens[:, batch_start:batch_start+batch_size]) 79 | return outputs 80 | 81 | 82 | def generate(model, tokens, n_generate, batch_size, cache): 83 | context_time = 0 84 | generate_time = [] 85 | 86 | 87 | with torch.inference_mode(): 88 | 89 | 90 | # prefill context 91 | cache.is_prefill = False 92 | 93 | 94 | 95 | 96 | for i in range(10): 97 | batch_start = i * batch_size 98 | inputs = torch.as_tensor(tokens[:, batch_start:batch_start+batch_size], device=next(model.parameters()).device) 99 | inputs = inputs.reshape((batch_size,1,)) 100 | out = model(inputs,use_cache=True) 101 | 102 | 103 | 104 | 105 | 106 | with torch.inference_mode(): 107 | # cache.is_prefill = True 108 | # inputs = torch.as_tensor(input_ids, device=next(model.parameters()).device) 109 | # out = model(inputs,use_cache=True) 110 | # token = out[0][:, -1].max(1)[1].unsqueeze(1) 111 | 112 | for i in range(n_generate): 113 | batch_start = i * batch_size 114 | torch.cuda.synchronize() 115 | 116 | 117 | 118 | inputs = torch.as_tensor(tokens[:, batch_start:batch_start+batch_size], device=next(model.parameters()).device) 119 | inputs = inputs.reshape((batch_size,1,)) 120 | start = time.time() 121 | 122 | 123 | out = model(inputs,use_cache=True) 124 | torch.cuda.synchronize() 125 | 126 | 127 | generate_time.append(time.time() - start) 128 | 129 | 130 | print("--- generate time ---") 131 | #print(generate_time) 132 | return generate_time 133 | 134 | def run_round(model_path, quant_file, n_generate, token, batch_size, safetensors, model_type='fp16',mixlibcache=None): 135 | 136 | from mixquant import AutoForCausalLM 137 | model = AutoForCausalLM.from_quantized( 138 | model_path, quant_file, fuse_layers=True, 139 | max_new_tokens=n_generate, batch_size=batch_size, 140 | safetensors=safetensors, 141 | mix = True, 142 | cache = mixlibcache 143 | ) 144 | 145 | 146 | 147 | 148 | 149 | 150 | def main(args): 151 | rounds = [ 152 | 153 | {"context": args.seq_length, "n_generate": args.seq_length}, 154 | 155 | ] 156 | 157 | all_stats = [] 158 | 159 | cache = MixLibCache(bit=args.bit) 160 | 161 | print("downloading data......") 162 | text = prepare_data(args.dataset_path) 163 | print("done......") 164 | 165 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=args.use_fast_tokenizer, trust_remote_code=True) 166 | if not tokenizer.pad_token_id: 167 | tokenizer.pad_token_id = tokenizer.eos_token_id 168 | tokenizer.model_max_length = sys.maxsize 169 | tokens = tokenizer(text, truncation=False, return_tensors='pt').input_ids 170 | 171 | print( type(tokens[0])) 172 | exit(0) 173 | 174 | 175 | 176 | 177 | for settings in rounds: 178 | 179 | 180 | 181 | stats, model_version = run_round( 182 | args.model_path, 183 | args.quant_file, 184 | settings["n_generate"], 185 | tokens, 186 | args.batch_size, 187 | args.safetensors, 188 | args.model_type, 189 | cache 190 | ) 191 | 192 | 193 | 194 | if __name__ == "__main__": 195 | 196 | """ 197 | 198 | """ 199 | parser = argparse.ArgumentParser() 200 | parser.add_argument("--model_path", type=str, default="", help="path to the model") 201 | parser.add_argument("--quant_file", type=str, default="", help="weights filename") 202 | parser.add_argument("--batch_size", type=int, default=1, help="Batch size for cache and generation") 203 | parser.add_argument("--model_type", type=str, default="fp16") 204 | parser.add_argument("--safetensors", default=False, action="store_true", help="Use for enabling safetensors") 205 | parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer") 206 | parser.add_argument("--seq_length", type=int, default=128) 207 | parser.add_argument("--dataset_path", type=str, default='wikitext', help="Path to the dataset.") 208 | parser.add_argument("--bit", type=int, default=8) 209 | args = parser.parse_args() 210 | 211 | main(args) -------------------------------------------------------------------------------- /benchlatency.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | 5 | 6 | os.environ["WORLD_SIZE"] = "1" 7 | import time 8 | import torch 9 | import argparse 10 | import numpy as np 11 | import pandas as pd 12 | 13 | 14 | from transformers import AutoTokenizer 15 | from torch.cuda import OutOfMemoryError 16 | torch.manual_seed(0) 17 | 18 | from mixquant.Cache import MixLibCache 19 | def warmup(model): 20 | warm_up = torch.randn((4096,4096)).to(next(model.parameters()).device) 21 | torch.mm(warm_up,warm_up) 22 | 23 | 24 | 25 | 26 | 27 | 28 | def prepare_data(_dataset_path = 'wikitext', _split='test', _text_column='text'): 29 | from datasets import load_dataset 30 | """ 31 | Prepares the dataset by loading and formatting. 32 | 33 | Returns 34 | ------- 35 | str 36 | The formatted dataset as a single string. 37 | """ 38 | if _dataset_path == 'wikitext': 39 | _dataset_name = 'wikitext-2-raw-v1' 40 | data = load_dataset(_dataset_path, _dataset_name, split=_split) 41 | 42 | elif _dataset_path == 'c4': 43 | _dataset_name = 'realnewslike' 44 | data = load_dataset(_dataset_path, _dataset_name, split=_split) 45 | else: 46 | _dataset_name = 'wikitext-2-raw-v1' 47 | data = load_dataset(os.path.join(_dataset_path,'wikitext'), 48 | _dataset_name, split=_split, cache_dir="/home/chenyidong/tmp") 49 | # Format the text column of the dataset 50 | text_list = [' \n' if s == '' else s for s in data[_text_column]] 51 | return ''.join(text_list) 52 | 53 | def decode_token(model, _tokenizer, _text, n_batch, repeat = 10): 54 | 55 | 56 | tokens = _tokenizer(_text, truncation=False, return_tensors='pt').input_ids.to('cuda') 57 | start = 0 58 | end = n_batch 59 | for j in range(repeat): 60 | 61 | batch_start = start + j * n_batch 62 | batch_size = min(end - batch_start, n_batch) 63 | 64 | token_org = tokens[0][batch_start].item() 65 | 66 | if j == 0: 67 | # Replace the first token with the BOS token 68 | tokens[0][batch_start] = _tokenizer.bos_token_id 69 | 70 | # Compute the logits for the current batch of tokens 71 | _compute_batch_logits(tokens, batch_start, batch_size) 72 | 73 | tokens[0][batch_start] = token_org 74 | 75 | def _compute_batch_logits(_model,tokens, batch_start, batch_size): 76 | # Compute the logits without keeping track of gradients 77 | 78 | outputs = _model(tokens[:, batch_start:batch_start+batch_size]) 79 | return outputs 80 | 81 | 82 | def generate(model, tokens, n_generate, batch_size, cache): 83 | context_time = 0 84 | generate_time = [] 85 | 86 | 87 | with torch.inference_mode(): 88 | 89 | 90 | # prefill context 91 | cache.is_prefill = False 92 | 93 | 94 | 95 | 96 | for i in range(10): 97 | batch_start = i * batch_size 98 | inputs = torch.as_tensor(tokens[:, batch_start:batch_start+batch_size], device=next(model.parameters()).device) 99 | inputs = inputs.reshape((batch_size,1,)) 100 | out = model(inputs,use_cache=True) 101 | 102 | 103 | 104 | 105 | 106 | with torch.inference_mode(): 107 | # cache.is_prefill = True 108 | # inputs = torch.as_tensor(input_ids, device=next(model.parameters()).device) 109 | # out = model(inputs,use_cache=True) 110 | # token = out[0][:, -1].max(1)[1].unsqueeze(1) 111 | 112 | for i in range(n_generate): 113 | batch_start = i * batch_size 114 | torch.cuda.synchronize() 115 | 116 | 117 | # decode tokens 118 | cache.is_prefill = False 119 | inputs = torch.as_tensor(tokens[:, batch_start:batch_start+batch_size], device=next(model.parameters()).device) 120 | inputs = inputs.reshape((batch_size,1,)) 121 | start = time.time() 122 | 123 | 124 | out = model(inputs,use_cache=True) 125 | torch.cuda.synchronize() 126 | 127 | 128 | generate_time.append(time.time() - start) 129 | 130 | 131 | print("--- generate time ---") 132 | #print(generate_time) 133 | return generate_time 134 | 135 | def run_round(model_path, quant_file, n_generate, token, batch_size, safetensors, model_type='fp16',mixlibcache=None): 136 | if model_type == 'mix': 137 | from mixquant import AutoForCausalLM 138 | model = AutoForCausalLM.from_quantized( 139 | model_path, quant_file, fuse_layers=True, 140 | max_new_tokens=n_generate, batch_size=batch_size, 141 | safetensors=safetensors, 142 | mix = True, 143 | cache = mixlibcache 144 | ) 145 | 146 | 147 | 148 | if model_type == 'awq': 149 | 150 | import awq 151 | from awq import AutoAWQForCausalLM 152 | print(f" -- Loading model awq...") 153 | model = AutoAWQForCausalLM.from_quantized( 154 | model_path, quant_file, fuse_layers=True, 155 | max_new_tokens=n_generate, batch_size=batch_size, 156 | safetensors=safetensors 157 | ) 158 | if model_type == 'fp16': 159 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 160 | model = AutoModelForCausalLM.from_pretrained( 161 | model_path, torch_dtype=torch.float16, 162 | device_map='auto', trust_remote_code=True 163 | ) 164 | 165 | 166 | 167 | 168 | if model_type == 'bitsandbytes': 169 | from transformers import AutoModelForCausalLM 170 | model = AutoModelForCausalLM.from_pretrained( 171 | model_path, 172 | torch_dtype=torch.float16, 173 | load_in_8bit=True, 174 | trust_remote_code=True, 175 | max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB') 176 | 177 | 178 | 179 | if model_type == 'quik': 180 | from mixquant import AutoForCausalLM 181 | model = AutoForCausalLM.from_quantized( 182 | model_path, quant_file, fuse_layers=True, 183 | max_new_tokens=n_generate, batch_size=batch_size, 184 | safetensors=safetensors, 185 | mix = True, 186 | cache = mixlibcache 187 | ) 188 | 189 | 190 | 191 | 192 | print(model) 193 | print(f" -- Warming up...") 194 | warmup(model) 195 | 196 | print(f" -- Generating {n_generate} tokens, in context...") 197 | 198 | try: 199 | generate_time = generate(model, token, n_generate, batch_size, mixlibcache) 200 | successful_generate = True 201 | except RuntimeError as ex: 202 | if 'cuda out of memory' in str(ex).lower(): 203 | successful_generate = False 204 | else: 205 | raise RuntimeError(ex) 206 | 207 | device = next(model.parameters()).device 208 | memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3) 209 | memory_pct = memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)) * 100 210 | 211 | if successful_generate: 212 | # number of tokens in context / time for processing context * batch size 213 | # 1 second / median time per token in seconds * batch size 214 | decode_tokens_per_second = 1 / np.median(generate_time) * batch_size 215 | 216 | print(f" ** Speed (Decode): {decode_tokens_per_second:.2f} tokens/second") 217 | print(f" ** Max Memory (VRAM): {memory_used:.2f} GB ({memory_pct:.2f}%)") 218 | else: 219 | 220 | decode_tokens_per_second = 'OOM' 221 | 222 | return { 223 | "Batch Size": batch_size, 224 | "Decode Length": n_generate, 225 | "Decode tokens/s": decode_tokens_per_second, 226 | "Memory (VRAM)": f"{memory_used:.2f} GB ({memory_pct:.2f}%)", 227 | "latency" : float(np.median(generate_time)) 228 | }, args.model_type 229 | 230 | def main(args): 231 | rounds = [ 232 | 233 | {"context": args.seq_length, "n_generate": args.seq_length}, 234 | 235 | ] 236 | 237 | all_stats = [] 238 | 239 | cache = MixLibCache(bit=args.bit) 240 | 241 | print("downloading data......") 242 | text = prepare_data(args.dataset_path) 243 | print("done......") 244 | 245 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=args.use_fast_tokenizer, trust_remote_code=True) 246 | if not tokenizer.pad_token_id: 247 | tokenizer.pad_token_id = tokenizer.eos_token_id 248 | tokenizer.model_max_length = sys.maxsize 249 | tokens = tokenizer(text, truncation=False, return_tensors='pt').input_ids.to('cuda') 250 | 251 | 252 | 253 | 254 | 255 | 256 | for settings in rounds: 257 | 258 | 259 | 260 | stats, model_version = run_round( 261 | args.model_path, 262 | args.quant_file, 263 | settings["n_generate"], 264 | tokens, 265 | args.batch_size, 266 | args.safetensors, 267 | args.model_type, 268 | cache 269 | ) 270 | 271 | all_stats.append(stats) 272 | 273 | if stats["Decode tokens/s"] == 'OOM': 274 | break 275 | 276 | df = pd.DataFrame(all_stats) 277 | print('GPU:', torch.cuda.get_device_name()) 278 | print('Model:', args.model_path) 279 | print('Version:', model_version) 280 | print(df.to_markdown(index=False)) 281 | try: 282 | os.mkdir('output/throughput/'+args.model_type) 283 | except: 284 | pass 285 | df.to_csv('output/throughput/'+args.model_type + '/' + args.quant_file.split("/")[-1] \ 286 | + str(args.batch_size) + '_' + str(args.bit) + ".csv") 287 | 288 | if __name__ == "__main__": 289 | 290 | 291 | parser = argparse.ArgumentParser() 292 | parser.add_argument("--model_path", type=str, default="", help="path to the model") 293 | parser.add_argument("--quant_file", type=str, default="", help="weights filename") 294 | parser.add_argument("--batch_size", type=int, default=1, help="Batch size for cache and generation") 295 | parser.add_argument("--model_type", type=str, default="fp16") 296 | parser.add_argument("--safetensors", default=False, action="store_true", help="Use for enabling safetensors") 297 | parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer") 298 | parser.add_argument("--seq_length", type=int, default=128) 299 | parser.add_argument("--dataset_path", type=str, default='wikitext', help="Path to the dataset.") 300 | parser.add_argument("--bit", type=int, default=8) 301 | args = parser.parse_args() 302 | 303 | main(args) -------------------------------------------------------------------------------- /evalppl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | os.environ["WORLD_SIZE"] = "1" 5 | import argparse 6 | import pandas as pd 7 | import torch 8 | from utils.utils import Perplexity 9 | from transformers import AutoTokenizer 10 | 11 | 12 | 13 | 14 | 15 | 16 | def get_fp_features_num(module: torch.nn.Linear, args): 17 | fp_features_num = args.fp_features_num 18 | if args.fp_features_frac is not None: 19 | fp_features_num = max(int(module.in_features * args.fp_features_frac), fp_features_num) 20 | return fp_features_num 21 | def llama_replace_with_kernels(model, args): 22 | import modelutils 23 | layers = model.model.layers 24 | shared_inputs = {} 25 | 26 | assert not args.w_asym, 'Benchmarking only supports symmetric weight quantization!' 27 | print("Replace with INT4 kernels.") 28 | for i in range(len(layers)): 29 | opt_block = layers[i] 30 | sequential = [ 31 | ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'], 32 | ['self_attn.o_proj'], 33 | ['mlp.up_proj', 'mlp.gate_proj'], 34 | ['mlp.down_proj'] 35 | ] 36 | full = modelutils.find_layers(opt_block) 37 | for j, layer_group in enumerate(sequential): 38 | subset = {n: full[n] for n in layer_group} 39 | shared_inputs[f"{i}.{j}"] = qlinear.SharedQuantizedInput(len(layer_group)) 40 | for name in subset: 41 | layer = subset[name] 42 | if 'lm_head' in name or 'rotary_emb' in name: 43 | continue 44 | is_quantized = False 45 | bits = 16 46 | fp_features = 0 47 | import quant_sim 48 | import qlinear 49 | if isinstance(layer, quant_sim.ActQuantWrapper): 50 | if layer.quantizer.configured: 51 | is_quantized = True 52 | bits = layer.quantizer.bits 53 | fp_features = layer.fp_features_num 54 | layer = layer.module 55 | layer_weight = layer.weight.data 56 | 57 | layer_scale = save_dict['model.layers.{}.{}.scale'.format(i, name)] 58 | if fp_features == 0: 59 | fp_feature_idx = None 60 | else: 61 | print('---------------save act_scales----------------') 62 | layer_act_scales = act_scales['model.layers.{}.{}'.format(i, name)] 63 | fp_feature_idx = torch.sort(layer_act_scales)[1][-fp_features:] 64 | 65 | if is_quantized: 66 | int_mod = qlinear.MixedQLinear.from_float(layer, layer_weight, layer_scale, 67 | shared_inputs[f"{i}.{j}"], fp_feature_idx, 68 | bits=bits) 69 | else: 70 | int_mod = layer 71 | modelutils.replace_single_mod_opt(opt_block, name, int_mod) 72 | 73 | 74 | 75 | 76 | 77 | 78 | if __name__ == "__main__": 79 | """ 80 | Example usage. 81 | 82 | Default usage with GPT2 model: 83 | python examples/benchmark/perplexity.py 84 | 85 | Specify GPTQ quantized model: 86 | http_proxy=127.0.0.1:7890 https_proxy=127.0.0.1:7890 CUDA_VISIBLE_DEVICES=0 WORLD_SIZE=1 python examples/benchmark/perplexity.py \ 87 | --model_name /mnt/data/zhongrx/Llama-2-7b \ 88 | --model_basename gptq_model-4bit-128g \ 89 | --is_quantized 90 | 91 | Change your dataset: 92 | python examples/benchmark/perplexity.py --dataset_path tiny_shakespeare 93 | 94 | """ 95 | parser = argparse.ArgumentParser(description="Calculate Perplexity for a model.") 96 | parser.add_argument("--model_path", type=str, help="Model path") 97 | parser.add_argument("--quant_file", type=str, help="quant_file Model path") 98 | 99 | parser.add_argument("--model_type", type=str, default='bitsandbytesfp16') 100 | 101 | 102 | parser.add_argument("--n_ctx", type=int, default=256, help="Context size.") 103 | parser.add_argument("--n_batch", type=int, default=256, help="Batch size.") 104 | parser.add_argument("--dataset_path", type=str, default='wikitext', help="Path to the dataset.") 105 | parser.add_argument("--dataset_name", type=str, default=None, help="Name of the dataset.") 106 | parser.add_argument("--split", type=str, default='test', help="Dataset split to use.") 107 | parser.add_argument("--text_column", type=str, default='text', help="Column in the dataset containing the text.") 108 | parser.add_argument("--per_gpu_max_memory", type=int, default=None, help="Max memory used in each GPU.") 109 | parser.add_argument("--cpu_max_memory", type=int, default=None, help="Mx memory used in CPU.") 110 | 111 | parser.add_argument("--use_safetensors", action="store_true", help="Whether to use safetensors model file") 112 | parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer") 113 | parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code") 114 | parser.add_argument("--disable_exllama", action="store_true", help="Whether to use disable exllama kernel") 115 | 116 | 117 | # Weight Quantization Params: 118 | parser.add_argument('--w_bits', type=int, default=16, choices=[4, 8, 16]) 119 | 120 | 121 | parser.add_argument('--int8_down_proj', action='store_true', help='Use INT8 for Down Projection') 122 | parser.add_argument('--fp_features_frac', type=float, default=None, help='Fraction of features to keep in FP16.') 123 | parser.add_argument("--fp_features_num", type=int, default=1, help="outliers") 124 | 125 | parser.add_argument('--eval_accuracy', type=bool, default=True) 126 | parser.add_argument('--eval_throughput', type=bool, default=False) 127 | 128 | 129 | args = parser.parse_args() 130 | 131 | if args.eval_throughput is True: 132 | args.eval_accuracy = False 133 | 134 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 135 | 136 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=args.use_fast_tokenizer, trust_remote_code=True) 137 | if not tokenizer.pad_token_id: 138 | tokenizer.pad_token_id = tokenizer.eos_token_id 139 | ppl = Perplexity(None, tokenizer, args.dataset_path, args.dataset_name, args.split, args.text_column, args.eval_accuracy) 140 | 141 | 142 | model_path = args.model_path 143 | quant_file = args.quant_file 144 | 145 | if args.model_type == 'bitsandbytesfp16': 146 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 147 | print(f" -- Loading model fp16...") 148 | # model = transformers.LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, 149 | # device_map='auto') 150 | model = AutoModelForCausalLM.from_pretrained( 151 | model_path, torch_dtype=torch.bfloat16, 152 | device_map='auto', trust_remote_code=True 153 | ) 154 | 155 | model = model.to('cuda') 156 | print(model) 157 | 158 | if args.model_type == 'bitsandbytesmix4': 159 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 160 | print(f" -- Loading model mix4...") 161 | 162 | n_gpus = torch.cuda.device_count() 163 | max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' 164 | max_memory = {i: max_memory for i in range(n_gpus)} 165 | quantization_config = BitsAndBytesConfig( 166 | load_in_4bit=True, 167 | llm_int4_threshold=6.0, 168 | llm_int4_has_fp16_weight=False, 169 | ) 170 | model = AutoModelForCausalLM.from_pretrained( 171 | model_path, 172 | device_map='auto', 173 | max_memory=max_memory, 174 | quantization_config=quantization_config 175 | ) 176 | if args.model_type == 'bitsandbytes': 177 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 178 | print(f" -- Loading model mix bit8...") 179 | 180 | n_gpus = torch.cuda.device_count() 181 | max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' 182 | max_memory = {i: max_memory for i in range(n_gpus)} 183 | quantization_config = BitsAndBytesConfig( 184 | load_in_8bit=True, 185 | llm_int8_threshold=6.0, 186 | llm_int8_has_fp16_weight=False, 187 | ) 188 | model = AutoModelForCausalLM.from_pretrained( 189 | model_path, 190 | device_map='auto', 191 | max_memory=max_memory, 192 | quantization_config=quantization_config 193 | ) 194 | 195 | 196 | if args.model_type == 'awq': 197 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 198 | print(f" -- Loading model awq...") 199 | 200 | 201 | from awq import AutoAWQForCausalLM 202 | model = AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=True, mix = False) 203 | 204 | 205 | if 'mix' in args.model_type : 206 | from mixquant.Cache import MixLibCache 207 | from mixquant import AutoForCausalLM 208 | cache = MixLibCache(args.n_batch) 209 | 210 | 211 | model = AutoForCausalLM.from_quantized( 212 | model_path, quant_file, fuse_layers=True, 213 | mix = True, cache = cache 214 | ) 215 | 216 | if args.model_type == 'fp16': 217 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 218 | model = AutoModelForCausalLM.from_pretrained( 219 | model_path, torch_dtype=torch.float16, 220 | device_map='auto', trust_remote_code=True 221 | ) 222 | 223 | #model = model.to('cuda') 224 | 225 | if args.model_type == 'quik': 226 | from mixquant import AutoForCausalLM 227 | model = AutoForCausalLM.from_quantized( 228 | model_path, quant_file, fuse_layers=True, 229 | max_new_tokens=args.n_generate, batch_size=args.batch_size, 230 | safetensors=args.safetensors, 231 | mix = True, 232 | cache = cache 233 | ) 234 | print(model) 235 | ppl = Perplexity(model, tokenizer, args.dataset_path, args.dataset_name, 236 | args.split, args.text_column, args.eval_accuracy) 237 | allppl = ppl.calculate_perplexity(args.n_ctx, args.n_batch) 238 | 239 | data = pd.DataFrame(allppl) 240 | try: 241 | os.mkdir("output") 242 | except: 243 | pass 244 | data.to_csv("output/ppl_batchsize"+str(args.n_ctx)+"_"+args.model_type+"_"+model_path.split('/')[-1]+".csv" + str(args.fp_features_num)) 245 | 246 | 247 | 248 | 249 | -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | data/ 3 | -------------------------------------------------------------------------------- /examples/basic_generate.py: -------------------------------------------------------------------------------- 1 | from awq import AutoAWQForCausalLM 2 | from transformers import AutoTokenizer, TextStreamer 3 | 4 | quant_path = "TheBloke/Mistral-7B-OpenOrca-AWQ" 5 | 6 | # Load model 7 | model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True, safetensors=True) 8 | tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True) 9 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 10 | 11 | # Convert prompt to tokens 12 | prompt_template = """\ 13 | <|im_start|>system 14 | You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!<|im_end|> 15 | <|im_start|>user 16 | {prompt}<|im_end|> 17 | <|im_start|>assistant""" 18 | 19 | prompt = "You're standing on the surface of the Earth. "\ 20 | "You walk one mile south, one mile west and one mile north. "\ 21 | "You end up exactly where you started. Where are you?" 22 | 23 | tokens = tokenizer( 24 | prompt_template.format(prompt=prompt), 25 | return_tensors='pt' 26 | ).input_ids.cuda() 27 | 28 | # Generate output 29 | generation_output = model.generate( 30 | tokens, 31 | streamer=streamer, 32 | max_new_tokens=512 33 | ) -------------------------------------------------------------------------------- /examples/basic_quant.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "7" 4 | os.environ["WORLD_SIZE"] = "1" 5 | from awq import AutoAWQForCausalLM 6 | from transformers import AutoTokenizer 7 | 8 | # model_path = '/mnt/data/zhongrx/Llama-2-13b-hf' 9 | # quant_path = '/mnt/data/chenyd/Llama-2-13b-awq' 10 | 11 | 12 | model_path = '/mnt/data/huangkz/huggingface/hub/models--facebook--opt-30b/snapshots/ceea0a90ac0f6fae7c2c34bcb40477438c152546' 13 | quant_path = '/mnt/data/chenyd/models--facebook--opt-30b-awq' 14 | 15 | quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } 16 | print(quant_path) 17 | # Load model 18 | # NOTE: pass safetensors=True to load safetensors 19 | model = AutoAWQForCausalLM.from_pretrained(model_path, **{"low_cpu_mem_usage": True}) 20 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 21 | 22 | # Quantize 23 | model.quantize(tokenizer, quant_config=quant_config) 24 | 25 | # Save quantized model 26 | # NOTE: pass safetensors=True to save quantized model weights as safetensors 27 | model.save_quantized(quant_path) 28 | tokenizer.save_pretrained(quant_path) 29 | 30 | print(f'Model is quantized and saved at "{quant_path}"') -------------------------------------------------------------------------------- /examples/basic_quant_mix.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # 本测试使用 mixquant 的量化工具集合实现混和精度量化 4 | import os 5 | os.environ["WORLD_SIZE"] = "1" 6 | 7 | 8 | import sys 9 | from mixquant import AutoForCausalLM 10 | from transformers import AutoTokenizer 11 | 12 | import argparse 13 | parser = argparse.ArgumentParser(description="Calculate Perplexity for a model.") 14 | parser.add_argument("--model_path", type=str, help="Model path") 15 | parser.add_argument("--quant_file", type=str, help="quant_file Model path") 16 | parser.add_argument("--w_bit", type=int, default=8, help="weight bit") 17 | args = parser.parse_args() 18 | 19 | model_path = args.model_path 20 | quant_path = args.quant_file 21 | quant_config = { "w_bit": args.w_bit, "version": "MIX" } 22 | print(quant_path) 23 | # Load model 24 | # NOTE: pass safetensors=True to load safetensors 25 | model = AutoForCausalLM.from_pretrained(model_path, mix = True, **{"low_cpu_mem_usage": True},device_map='cpu') 26 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 27 | 28 | print(model) 29 | # Quantize 30 | model.quantize_mix(tokenizer, quant_config=quant_config) 31 | 32 | # Save quantized model 33 | # NOTE: pass safetensors=True to save quantized model weights as safetensors 34 | model.save_quantized(quant_path) 35 | tokenizer.save_pretrained(quant_path) 36 | 37 | print(f'Mix Model is quantized and saved at "{quant_path}"') -------------------------------------------------------------------------------- /examples/basic_quant_opt.py: -------------------------------------------------------------------------------- 1 | from awq import OptAWQForCausalLM 2 | from transformers import AutoTokenizer 3 | import os 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "7" 5 | os.environ["WORLD_SIZE"] = "1" 6 | model_path = '/mnt/data/huangkz/huggingface/hub/models--facebook--opt-30b' 7 | quant_path = '/mnt/data/chenyd/models--facebook--opt-30b-awq' 8 | quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } 9 | print(quant_path) 10 | # Load model 11 | # NOTE: pass safetensors=True to load safetensors 12 | model = OptAWQForCausalLM.from_pretrained(model_path, **{"low_cpu_mem_usage": True}) 13 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 14 | 15 | # Quantize 16 | model.quantize(tokenizer, quant_config=quant_config) 17 | 18 | # Save quantized model 19 | # NOTE: pass safetensors=True to save quantized model weights as safetensors 20 | model.save_quantized(quant_path) 21 | tokenizer.save_pretrained(quant_path) 22 | 23 | print(f'Model is quantized and saved at "{quant_path}"') -------------------------------------------------------------------------------- /examples/basic_quant_quik.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | os.environ["WORLD_SIZE"] = "1" 5 | import sys 6 | from mixquant import AutoForCausalLM 7 | from transformers import AutoTokenizer 8 | 9 | import argparse 10 | parser = argparse.ArgumentParser(description="Calculate Perplexity for a model.") 11 | parser.add_argument("--model_path", type=str, help="Model path") 12 | parser.add_argument("--quant_file", type=str, help="quant_file Model path") 13 | parser.add_argument("--w_bit", type=int, default=4, help="weight bit") 14 | args = parser.parse_args() 15 | 16 | model_path = args.model_path 17 | quant_path = args.quant_file 18 | quant_config = { "w_bit": args.w_bit, "version": "QUIK" } 19 | print(quant_path) 20 | # Load model 21 | # NOTE: pass safetensors=True to load safetensors 22 | model = AutoForCausalLM.from_pretrained(model_path, mix = True, **{"low_cpu_mem_usage": True},device_map='cpu') 23 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 24 | 25 | print(model) 26 | # Quantize 27 | model.quantize_quik(tokenizer, quant_config=quant_config) 28 | 29 | # Save quantized model 30 | # NOTE: pass safetensors=True to save quantized model weights as safetensors 31 | model.save_quantized(quant_path) 32 | tokenizer.save_pretrained(quant_path) 33 | 34 | print(f'QUIK Model is quantized and saved at "{quant_path}"') -------------------------------------------------------------------------------- /examples/basic_safetensors_generate.py: -------------------------------------------------------------------------------- 1 | from awq import AutoAWQForCausalLM 2 | from transformers import AutoTokenizer, TextStreamer 3 | 4 | quant_path = "casperhansen/opt-125m-awq" 5 | 6 | # Load model 7 | model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True, safetensors=True) 8 | tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True) 9 | streamer = TextStreamer(tokenizer, skip_special_tokens=True) 10 | 11 | # Convert prompt to tokens 12 | prompt_template = """\ 13 | A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. 14 | 15 | USER: {prompt} 16 | ASSISTANT:""" 17 | 18 | tokens = tokenizer( 19 | prompt_template.format(prompt="How are you today?"), 20 | return_tensors='pt' 21 | ).input_ids.cuda() 22 | 23 | # Generate output 24 | generation_output = model.generate( 25 | tokens, 26 | streamer=streamer, 27 | max_new_tokens=512 28 | ) 29 | -------------------------------------------------------------------------------- /examples/benchkernels.sh: -------------------------------------------------------------------------------- 1 | for i in {0..39}; do 2 | # echo "evaluate ${file}" >> eval_out 3 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py $i result/llama13b-1.csv LLama-2-13b q >>temp13 4 | done 5 | 6 | for i in {0..31}; do 7 | # echo "evaluate ${file}" >> eval_out 8 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py $i result/llama7b-1.csv LLama-2-7b q >>temp7 9 | done 10 | 11 | for i in {0..31}; do 12 | # echo "evaluate ${file}" >> eval_out 13 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py $i result/llama70b-1.csv LLama-2-70b q >>temp70 14 | done 15 | 16 | for i in {0..31}; do 17 | # echo "evaluate ${file}" >> eval_out 18 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py $i result/llama70b-down.csv LLama-2-70b down 19 | done 20 | 21 | for i in {0..31}; do 22 | # echo "evaluate ${file}" >> eval_out 23 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py $i result/llama70b-up.csv LLama-2-70b up 24 | done 25 | -------------------------------------------------------------------------------- /examples/benchlayers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 0 result/llama13b-bd-unschedule.csv LLama-2-13b q 4 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 20 result/llama13b-bd-unschedule.csv LLama-2-13b q 5 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 39 result/llama13b-bd-unschedule.csv LLama-2-13b q 6 | 7 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 0 result/llama7b-bd-unschedule.csv LLama-2-7b q 8 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 15 result/llama7b-bd-unschedule.csv LLama-2-7b q 9 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 31 result/llama7b-bd-unschedule.csv LLama-2-7b q 10 | 11 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 0 result/llama70b-bd-unschedule.csv LLama-2-70b q 12 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 15 result/llama70b-bd-unschedule.csv LLama-2-70b q 13 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 31 result/llama70b-bd-unschedule.csv LLama-2-70b q 14 | 15 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 0 result/llama7b-up-1.csv LLama-2-7b up 16 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 15 result/llama7b-up-1.csv LLama-2-7b up 17 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 31 result/llama7b-up-1.csv LLama-2-7b up 18 | 19 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 0 result/llama7b-down-1.csv LLama-2-7b down 20 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 15 result/llama7b-down-1.csv LLama-2-7b down 21 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 31 result/llama7b-down-1.csv LLama-2-7b down 22 | 23 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 0 result/llama70b-down-1.csv LLama-2-70b down 24 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 15 result/llama70b-down-1.csv LLama-2-70b down 25 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 31 result/llama70b-down-1.csv LLama-2-70b down 26 | 27 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 0 result/llama13b-down-1.csv LLama-2-13b down 28 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 20 result/llama13b-down-1.csv LLama-2-13b down 29 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 39 result/llama13b-down-1.csv LLama-2-13b down 30 | 31 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 0 result/llama70b-up-1.csv LLama-2-70b up 32 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 15 result/llama70b-up-1.csv LLama-2-70b up 33 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 31 result/llama70b-up-1.csv LLama-2-70b up 34 | 35 | -------------------------------------------------------------------------------- /examples/benchmark.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "6" 4 | os.environ["WORLD_SIZE"] = "1" 5 | import time 6 | import torch 7 | import argparse 8 | import numpy as np 9 | import pandas as pd 10 | from awq import AutoAWQForCausalLM 11 | from transformers import AutoTokenizer 12 | from torch.cuda import OutOfMemoryError 13 | 14 | def warmup(model): 15 | warm_up = torch.randn((4096,4096)).to(next(model.parameters()).device) 16 | torch.mm(warm_up,warm_up) 17 | 18 | 19 | def load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True): 20 | from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils 21 | config = LlamaConfig.from_pretrained(model) 22 | 23 | def noop(*args, **kwargs): 24 | pass 25 | 26 | torch.nn.init.kaiming_uniform_ = noop 27 | torch.nn.init.uniform_ = noop 28 | torch.nn.init.normal_ = noop 29 | 30 | torch.set_default_dtype(torch.half) 31 | modeling_utils._init_weights = False 32 | torch.set_default_dtype(torch.half) 33 | model = LlamaForCausalLM(config) 34 | torch.set_default_dtype(torch.float) 35 | if eval: 36 | model = model.eval() 37 | layers = find_layers(model) 38 | for name in ['lm_head']: 39 | if name in layers: 40 | del layers[name] 41 | quant.make_quant_linear(model, layers, wbits, groupsize) 42 | 43 | del layers 44 | 45 | print('Loading model ...') 46 | if checkpoint.endswith('.safetensors'): 47 | from safetensors.torch import load_file as safe_load 48 | model.load_state_dict(safe_load(checkpoint)) 49 | else: 50 | model.load_state_dict(torch.load(checkpoint)) 51 | 52 | if eval: 53 | quant.make_quant_attn(model) 54 | quant.make_quant_norm(model) 55 | if fused_mlp: 56 | quant.make_fused_mlp(model) 57 | 58 | if warmup_autotune: 59 | quant.autotune_warmup_linear(model, transpose=not (eval)) 60 | if eval and fused_mlp: 61 | quant.autotune_warmup_fused(model) 62 | model.seqlen = 2048 63 | print('Done.') 64 | 65 | return model 66 | 67 | def generate(model, input_ids, n_generate): 68 | context_time = 0 69 | generate_time = [] 70 | 71 | with torch.inference_mode(): 72 | for i in range(n_generate): 73 | torch.cuda.synchronize() 74 | start = time.time() 75 | 76 | if i == 0: 77 | # prefill context 78 | inputs = torch.as_tensor(input_ids, device=next(model.parameters()).device) 79 | else: 80 | # decode tokens 81 | inputs = torch.as_tensor(token, device=next(model.parameters()).device) 82 | 83 | out = model(inputs, use_cache=True) 84 | 85 | torch.cuda.synchronize() 86 | token = out[0][:, -1].max(1)[1].unsqueeze(1) 87 | 88 | if i == 0: 89 | context_time += time.time() - start 90 | else: 91 | generate_time.append(time.time() - start) 92 | 93 | return context_time, generate_time 94 | 95 | def run_round(model_path, quant_file, n_generate, input_ids, batch_size, safetensors): 96 | print(f" -- Loading model...") 97 | model = AutoAWQForCausalLM.from_quantized( 98 | model_path, quant_file, fuse_layers=True, 99 | max_new_tokens=n_generate, batch_size=batch_size, 100 | safetensors=safetensors 101 | ) 102 | 103 | print(f" -- Warming up...") 104 | warmup(model) 105 | 106 | print(f" -- Generating {n_generate} tokens, {input_ids.shape[1]} in context...") 107 | 108 | try: 109 | context_time, generate_time = generate(model, input_ids, n_generate) 110 | successful_generate = True 111 | except RuntimeError as ex: 112 | if 'cuda out of memory' in str(ex).lower(): 113 | successful_generate = False 114 | else: 115 | raise RuntimeError(ex) 116 | 117 | device = next(model.parameters()).device 118 | memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3) 119 | memory_pct = memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)) * 100 120 | 121 | if successful_generate: 122 | # number of tokens in context / time for processing context * batch size 123 | prefill_tokens_per_second = input_ids.shape[1] / context_time * batch_size 124 | # 1 second / median time per token in seconds * batch size 125 | decode_tokens_per_second = 1 / np.median(generate_time) * batch_size 126 | 127 | print(f" ** Speed (Prefill): {prefill_tokens_per_second:.2f} tokens/second") 128 | print(f" ** Speed (Decode): {decode_tokens_per_second:.2f} tokens/second") 129 | print(f" ** Max Memory (VRAM): {memory_used:.2f} GB ({memory_pct:.2f}%)") 130 | else: 131 | prefill_tokens_per_second = 'OOM' 132 | decode_tokens_per_second = 'OOM' 133 | 134 | return { 135 | "Batch Size": batch_size, 136 | "Prefill Length": input_ids.shape[1], 137 | "Decode Length": n_generate, 138 | "Prefill tokens/s": prefill_tokens_per_second, 139 | "Decode tokens/s": decode_tokens_per_second, 140 | "Memory (VRAM)": f"{memory_used:.2f} GB ({memory_pct:.2f}%)" 141 | }, model.quant_config["version"] 142 | 143 | def main(args): 144 | rounds = [ 145 | {"context": 32, "n_generate": 32}, 146 | {"context": 64, "n_generate": 64}, 147 | {"context": 128, "n_generate": 128}, 148 | {"context": 256, "n_generate": 256}, 149 | {"context": 512, "n_generate": 512}, 150 | {"context": 1024, "n_generate": 1024}, 151 | {"context": 2048, "n_generate": 2048}, 152 | ] 153 | 154 | all_stats = [] 155 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) 156 | 157 | for settings in rounds: 158 | input_ids = torch.randint(0, tokenizer.vocab_size, (args.batch_size, settings["context"])).cuda() 159 | 160 | stats, model_version = run_round( 161 | args.model_path, 162 | args.quant_file, 163 | settings["n_generate"], 164 | input_ids, 165 | args.batch_size, 166 | args.safetensors 167 | ) 168 | 169 | all_stats.append(stats) 170 | 171 | if stats["Prefill tokens/s"] == 'OOM': 172 | break 173 | 174 | df = pd.DataFrame(all_stats) 175 | print('GPU:', torch.cuda.get_device_name()) 176 | print('Model:', args.model_path) 177 | print('Version:', model_version) 178 | print(df.to_markdown(index=False)) 179 | df.to_csv(args.quant_file.split("/")[-1]+".csv"+str(args.batch_size)) 180 | 181 | if __name__ == "__main__": 182 | 183 | """ 184 | python examples/benchmark.py --model_path /mnt/data/zhongrx/Llama-2-7b-hf --quant_file /mnt/data/chenyd/Llama-2-7b-awq 185 | 186 | """ 187 | parser = argparse.ArgumentParser() 188 | parser.add_argument("--model_path", type=str, default="casperhansen/vicuna-7b-v1.5-awq", help="path to the model") 189 | parser.add_argument("--quant_file", type=str, default="awq_model_w4_g128.pt", help="weights filename") 190 | parser.add_argument("--batch_size", type=int, default=1, help="Batch size for cache and generation") 191 | parser.add_argument("--safetensors", default=False, action="store_true", help="Use for enabling safetensors") 192 | args = parser.parse_args() 193 | 194 | main(args) -------------------------------------------------------------------------------- /examples/breakdown.sh: -------------------------------------------------------------------------------- 1 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 0 result/llama70b-up-1.csv LLama-2-70b up 2 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 15 result/llama70b-up-1.csv LLama-2-70b up 3 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 31 result/llama70b-up-1.csv LLama-2-70b up 4 | -------------------------------------------------------------------------------- /examples/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from lm_eval import evaluator 3 | from awq import AutoAWQForCausalLM 4 | from transformers import AutoTokenizer 5 | from awq.utils.lm_eval_adaptor import LMEvalAdaptor 6 | 7 | def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot, task_use_pretrained): 8 | """ 9 | Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness 10 | """ 11 | # Load model 12 | if task_use_pretrained: 13 | model = AutoAWQForCausalLM.from_pretrained(model_path) 14 | else: 15 | 16 | model = AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=False) 17 | 18 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 19 | 20 | # Load adapter 21 | lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=task_batch_size) 22 | 23 | # Evaluate perplexity of quantized model 24 | results = evaluator.simple_evaluate( 25 | model=lm_eval_model, 26 | tasks=tasks.split(','), 27 | batch_size=task_batch_size, 28 | no_cache=True, 29 | num_fewshot=task_n_shot, 30 | ) 31 | 32 | print(evaluator.make_table(results)) 33 | 34 | if __name__ == '__main__': 35 | """ 36 | - Run perplexity of quantized model: 37 | python examples/eval.py --model_path /mnt/data/zhongrx/Llama-2-7b-hf --quant_file /mnt/data/chenyd/Llama-2-7b-awq 38 | 39 | - Run perplexity unquantized FP16 model: 40 | python examples/eval.py --use_pretrained --model_path lmsys/vicuna-7b-v1.5 41 | """ 42 | 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument('--model_path', type=str, help='Path to hf model') 45 | parser.add_argument('--quant_file', default='', type=str, help='Path to quantized AWQ model file') 46 | parser.add_argument('--device', type=str, default='cuda:0', help='Device to load model to') 47 | parser.add_argument("--use_pretrained", default=False, action='store_true', 48 | help="Pass '--use_pretrained' to use a pretrained model running FP16") 49 | parser.add_argument('--tasks', type=str, default='wikitext', help='Tasks to evaluate. ' 50 | 'Separate tasks by comma for multiple tasks.' 51 | 'https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md') 52 | parser.add_argument('--batch_size', type=int, default=1) 53 | parser.add_argument('--n_shot', type=int, default=0) 54 | args = parser.parse_args() 55 | 56 | run_eval(args.model_path, args.quant_file, args.device, 57 | args.tasks, args.batch_size, args.n_shot, args.use_pretrained) -------------------------------------------------------------------------------- /examples/mmlu.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | if [ $2 == a100 ] 4 | then 5 | CMD=" srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python " 6 | else 7 | CMD="srun -p twills -A h100 --gres=gpu:h100:1 --export=ALL python" 8 | fi 9 | set -ex 10 | quantpath=/home/dataset/quant/quant 11 | modelpath=/home/dataset 12 | 13 | 14 | models=( "Llama-2-7b" ) 15 | ngpu=1 16 | 17 | 18 | data_type=$3 19 | if [ ${data_type} == mix8 ] 20 | then 21 | bit=${data_type:3:3} 22 | for model in "${models[@]}" 23 | do 24 | echo ${model} 25 | 26 | CUDA_VISIBLE_DEVICES=$1 ${CMD} mmlu.py \ 27 | --model_type ${data_type} --hf_model_dir ${quantpath}${bit}/${model} 28 | 29 | done 30 | fi 31 | 32 | if [ ${data_type} == fp16 ] 33 | then 34 | 35 | for model in "${models[@]}" 36 | do 37 | echo ${model} 38 | 39 | CUDA_VISIBLE_DEVICES=$1 ${CMD} mmlu.py \ 40 | --model_type ${data_type} --hf_model_dir ${modelpath}/${model} 41 | 42 | done 43 | fi 44 | 45 | -------------------------------------------------------------------------------- /examples/overhead.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python testoverhead.py -------------------------------------------------------------------------------- /examples/smooth_quant_get_act.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from datasets import load_dataset 5 | import functools 6 | from collections import defaultdict 7 | 8 | from functools import partial 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | 13 | def get_act_scales(model, tokenizer, dataset_path, num_samples=512, seq_len=512): 14 | model.eval() 15 | device = next(model.parameters()).device 16 | act_scales = {} 17 | 18 | def stat_tensor(name, tensor): 19 | hidden_dim = tensor.shape[-1] 20 | tensor = tensor.view(-1, hidden_dim).abs().detach() 21 | comming_max = torch.max(tensor, dim=0)[0].float().cpu() 22 | if name in act_scales: 23 | act_scales[name] = torch.max(act_scales[name], comming_max) 24 | else: 25 | act_scales[name] = comming_max 26 | 27 | def stat_input_hook(m, x, y, name): 28 | if isinstance(x, tuple): 29 | x = x[0] 30 | stat_tensor(name, x) 31 | 32 | hooks = [] 33 | for name, m in model.named_modules(): 34 | if isinstance(m, nn.Linear): 35 | hooks.append( 36 | m.register_forward_hook(functools.partial(stat_input_hook, name=name)) 37 | ) 38 | 39 | dataset = load_dataset("json", data_files=dataset_path, split="train") 40 | dataset = dataset.shuffle(seed=42) 41 | 42 | for i in tqdm(range(num_samples)): 43 | input_ids = tokenizer( 44 | dataset[i]["text"], return_tensors="pt", max_length=seq_len, truncation=True 45 | ).input_ids.to(device) 46 | model(input_ids) 47 | 48 | for h in hooks: 49 | h.remove() 50 | 51 | return act_scales 52 | 53 | 54 | @torch.no_grad() 55 | def get_static_decoder_layer_scales( 56 | model, 57 | tokenizer, 58 | dataset_path, 59 | num_samples=512, 60 | seq_len=512, 61 | ): 62 | model.eval() 63 | device = next(model.parameters()).device 64 | 65 | act_dict = defaultdict(dict) 66 | 67 | def stat_io_hook(m, x, y, name): 68 | if isinstance(x, tuple): 69 | x = x[0] 70 | if name not in act_dict or "input" not in act_dict[name]: 71 | act_dict[name]["input"] = x.detach().abs().max().item() 72 | else: 73 | act_dict[name]["input"] = max( 74 | act_dict[name]["input"], x.detach().abs().max().item() 75 | ) 76 | if isinstance(y, tuple): 77 | y = y[0] 78 | if name not in act_dict or "output" not in act_dict[name]: 79 | act_dict[name]["output"] = y.detach().abs().max().item() 80 | else: 81 | act_dict[name]["output"] = max( 82 | act_dict[name]["output"], y.detach().abs().max().item() 83 | ) 84 | 85 | hooks = [] 86 | for name, m in model.named_modules(): 87 | if isinstance(m, torch.nn.Linear): 88 | hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) 89 | 90 | print("Collecting activation scales...") 91 | pbar = tqdm(range(num_samples)) 92 | dataset = load_dataset("json", data_files=dataset_path, split="train") 93 | dataset = dataset.shuffle(seed=42) 94 | for i in pbar: 95 | input_ids = tokenizer( 96 | dataset[i]["text"], return_tensors="pt", max_length=seq_len, truncation=True 97 | ).input_ids.to(device) 98 | model(input_ids) 99 | mean_scale = np.mean([v["input"] for v in act_dict.values()]) 100 | pbar.set_description(f"Mean input scale: {mean_scale:.2f}") 101 | for hook in hooks: 102 | hook.remove() 103 | 104 | decoder_layer_scales = [] 105 | for idx in range(model.config.num_hidden_layers): 106 | scale_dict = {} 107 | scale_dict["attn_input_scale"] = ( 108 | act_dict[f"model.decoder.layers.{idx}.self_attn.q_proj"]["input"] / 127 109 | ) 110 | scale_dict["q_output_scale"] = ( 111 | act_dict[f"model.decoder.layers.{idx}.self_attn.q_proj"]["output"] / 127 112 | ) 113 | scale_dict["k_output_scale"] = ( 114 | act_dict[f"model.decoder.layers.{idx}.self_attn.k_proj"]["output"] / 127 115 | ) 116 | scale_dict["v_output_scale"] = ( 117 | act_dict[f"model.decoder.layers.{idx}.self_attn.v_proj"]["output"] / 127 118 | ) 119 | scale_dict["out_input_scale"] = ( 120 | act_dict[f"model.decoder.layers.{idx}.self_attn.out_proj"]["input"] / 127 121 | ) 122 | scale_dict["fc1_input_scale"] = ( 123 | act_dict[f"model.decoder.layers.{idx}.fc1"]["input"] / 127 124 | ) 125 | scale_dict["fc2_input_scale"] = ( 126 | act_dict[f"model.decoder.layers.{idx}.fc2"]["input"] / 127 127 | ) 128 | decoder_layer_scales.append(scale_dict) 129 | 130 | return decoder_layer_scales, act_dict 131 | 132 | 133 | import torch 134 | import os 135 | 136 | from transformers import ( 137 | AutoModelForCausalLM, 138 | AutoTokenizer, 139 | ) 140 | import argparse 141 | 142 | 143 | def build_model_and_tokenizer(model_name): 144 | 145 | print(model_name) 146 | tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512, trust_remote_code=True) 147 | kwargs = {"torch_dtype": torch.float16, "device_map": "sequential"} 148 | model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) 149 | return model, tokenizer 150 | 151 | 152 | def parse_args(): 153 | parser = argparse.ArgumentParser() 154 | parser.add_argument('--model-name', type=str, 155 | default='facebook/opt-1.3b', help='model name') 156 | parser.add_argument('--output-path', type=str, default='act_scales/opt-1.3b.pt', 157 | help='where to save the act scales') 158 | parser.add_argument('--dataset-path', type=str, default='dataset/val.jsonl.zst', 159 | help='location of the calibration dataset, we use the validation set of the Pile dataset') 160 | parser.add_argument('--num-samples', type=int, default=512) 161 | parser.add_argument('--seq-len', type=int, default=512) 162 | args = parser.parse_args() 163 | return args 164 | 165 | if __name__ == '__main__': 166 | 167 | 168 | 169 | 170 | args = parse_args() 171 | model, tokenizer = build_model_and_tokenizer(args.model_name) 172 | 173 | act_scales = get_act_scales(model, tokenizer, args.dataset_path, 174 | args.num_samples, args.seq_len) 175 | 176 | os.makedirs(os.path.dirname(args.output_path), exist_ok=True) 177 | torch.save(act_scales, args.output_path) 178 | 179 | 180 | 181 | -------------------------------------------------------------------------------- /examples/testgenerate.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import os 5 | import time 6 | import psutil 7 | import random 8 | import torch 9 | import numpy as np 10 | from torch.nn.parameter import Parameter 11 | 12 | from transformers import AutoTokenizer, LlamaModel, LlamaForCausalLM, LlamaTokenizer, AutoConfig, AutoModelForCausalLM 13 | 14 | 15 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 16 | 17 | def set_random_seed(seed): 18 | random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.backends.cudnn.deterministic = True 21 | torch.cuda.manual_seed_all(seed) 22 | np.random.seed(seed) 23 | 24 | 25 | def test_from_fp16(): 26 | torch.set_printoptions(precision=6, sci_mode=False) 27 | torch.set_grad_enabled(False) 28 | set_random_seed(1) 29 | 30 | # model_name = '/root/data/models/2023/llama-13B-v1/' 31 | model_name = '/mnt/octave/data/chenyidong/checkpoint/Llama-2-7b' 32 | MAX_NEW_TOKENS = 32 33 | 34 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) 35 | config = AutoConfig.from_pretrained(model_name) 36 | # config.num_hidden_layers = 1 37 | 38 | free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3) 39 | max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-1}GB' 40 | 41 | n_gpus = torch.cuda.device_count() 42 | max_memory = {i: max_memory for i in range(n_gpus)} 43 | 44 | model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.float16) 45 | model.eval() 46 | 47 | # from eetq.utils import eet_accelerator 48 | # eet_accelerator(model, quantize=True, fused_attn=True, dev="cuda:0") 49 | model.to("cuda:0") 50 | # for k, v in model.state_dict().items(): 51 | # print(k, v.shape, v.dtype, v.device) 52 | # torch.save(model, "eetq_llama13B_model_fused_attn_v2.pt") 53 | 54 | prompt_template = "[INST] {prompt} [/INST]" 55 | 56 | prompt = "You're standing on the surface of the Earth. "\ 57 | "You walk one mile south, one mile west and one mile north. "\ 58 | "You end up exactly where you started. Where are you?" 59 | 60 | 61 | messages = [] 62 | messages.append({"role": "user", "content": prompt}) 63 | response = model.chat(tokenizer, messages) 64 | print(response) 65 | test_from_fp16() -------------------------------------------------------------------------------- /figures/awq32.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/figures/awq32.gif -------------------------------------------------------------------------------- /figures/awq512.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/figures/awq512.gif -------------------------------------------------------------------------------- /figures/mixq32.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/figures/mixq32.gif -------------------------------------------------------------------------------- /figures/mixq512.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/figures/mixq512.gif -------------------------------------------------------------------------------- /figures/textmixq.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/figures/textmixq.jpg -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | os.environ["WORLD_SIZE"] = "1" 5 | import argparse 6 | import pandas as pd 7 | import torch 8 | from utils.utils import Perplexity 9 | from transformers import AutoTokenizer 10 | 11 | 12 | 13 | 14 | 15 | 16 | def get_fp_features_num(module: torch.nn.Linear, args): 17 | fp_features_num = args.fp_features_num 18 | if args.fp_features_frac is not None: 19 | fp_features_num = max(int(module.in_features * args.fp_features_frac), fp_features_num) 20 | return fp_features_num 21 | def llama_replace_with_kernels(model, args): 22 | import modelutils 23 | layers = model.model.layers 24 | shared_inputs = {} 25 | 26 | assert not args.w_asym, 'Benchmarking only supports symmetric weight quantization!' 27 | print("Replace with INT4 kernels.") 28 | for i in range(len(layers)): 29 | opt_block = layers[i] 30 | sequential = [ 31 | ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'], 32 | ['self_attn.o_proj'], 33 | ['mlp.up_proj', 'mlp.gate_proj'], 34 | ['mlp.down_proj'] 35 | ] 36 | full = modelutils.find_layers(opt_block) 37 | for j, layer_group in enumerate(sequential): 38 | subset = {n: full[n] for n in layer_group} 39 | shared_inputs[f"{i}.{j}"] = qlinear.SharedQuantizedInput(len(layer_group)) 40 | for name in subset: 41 | layer = subset[name] 42 | if 'lm_head' in name or 'rotary_emb' in name: 43 | continue 44 | is_quantized = False 45 | bits = 16 46 | fp_features = 0 47 | import quant_sim 48 | import qlinear 49 | if isinstance(layer, quant_sim.ActQuantWrapper): 50 | if layer.quantizer.configured: 51 | is_quantized = True 52 | bits = layer.quantizer.bits 53 | fp_features = layer.fp_features_num 54 | layer = layer.module 55 | layer_weight = layer.weight.data 56 | 57 | layer_scale = save_dict['model.layers.{}.{}.scale'.format(i, name)] 58 | if fp_features == 0: 59 | fp_feature_idx = None 60 | else: 61 | print('---------------save act_scales----------------') 62 | layer_act_scales = act_scales['model.layers.{}.{}'.format(i, name)] 63 | fp_feature_idx = torch.sort(layer_act_scales)[1][-fp_features:] 64 | 65 | if is_quantized: 66 | int_mod = qlinear.MixedQLinear.from_float(layer, layer_weight, layer_scale, 67 | shared_inputs[f"{i}.{j}"], fp_feature_idx, 68 | bits=bits) 69 | else: 70 | int_mod = layer 71 | modelutils.replace_single_mod_opt(opt_block, name, int_mod) 72 | 73 | 74 | 75 | 76 | 77 | 78 | if __name__ == "__main__": 79 | """ 80 | Example usage. 81 | 82 | Default usage with GPT2 model: 83 | python examples/benchmark/perplexity.py 84 | 85 | Specify GPTQ quantized model: 86 | http_proxy=127.0.0.1:7890 https_proxy=127.0.0.1:7890 CUDA_VISIBLE_DEVICES=0 WORLD_SIZE=1 python examples/benchmark/perplexity.py \ 87 | --model_name /mnt/data/zhongrx/Llama-2-7b \ 88 | --model_basename gptq_model-4bit-128g \ 89 | --is_quantized 90 | 91 | Change your dataset: 92 | python examples/benchmark/perplexity.py --dataset_path tiny_shakespeare 93 | 94 | """ 95 | parser = argparse.ArgumentParser(description="Calculate Perplexity for a model.") 96 | parser.add_argument("--model_path", type=str, help="Model path") 97 | parser.add_argument("--quant_file", type=str, help="quant_file Model path") 98 | 99 | parser.add_argument("--model_type", type=str, default='bitsandbytesfp16') 100 | 101 | 102 | parser.add_argument("--n_ctx", type=int, default=256, help="Context size.") 103 | parser.add_argument("--n_batch", type=int, default=256, help="Batch size.") 104 | parser.add_argument("--dataset_path", type=str, default='wikitext', help="Path to the dataset.") 105 | parser.add_argument("--dataset_name", type=str, default=None, help="Name of the dataset.") 106 | parser.add_argument("--split", type=str, default='test', help="Dataset split to use.") 107 | parser.add_argument("--text_column", type=str, default='text', help="Column in the dataset containing the text.") 108 | parser.add_argument("--per_gpu_max_memory", type=int, default=None, help="Max memory used in each GPU.") 109 | parser.add_argument("--cpu_max_memory", type=int, default=None, help="Mx memory used in CPU.") 110 | 111 | parser.add_argument("--use_safetensors", action="store_true", help="Whether to use safetensors model file") 112 | parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer") 113 | parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code") 114 | parser.add_argument("--disable_exllama", action="store_true", help="Whether to use disable exllama kernel") 115 | 116 | 117 | # Weight Quantization Params: 118 | parser.add_argument('--w_bits', type=int, default=16, choices=[4, 8, 16]) 119 | 120 | 121 | parser.add_argument('--int8_down_proj', action='store_true', help='Use INT8 for Down Projection') 122 | parser.add_argument('--fp_features_frac', type=float, default=None, help='Fraction of features to keep in FP16.') 123 | parser.add_argument("--fp_features_num", type=int, default=1, help="outliers") 124 | 125 | parser.add_argument('--eval_accuracy', type=bool, default=True) 126 | parser.add_argument('--eval_throughput', type=bool, default=False) 127 | 128 | 129 | args = parser.parse_args() 130 | 131 | if args.eval_throughput is True: 132 | args.eval_accuracy = False 133 | 134 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 135 | 136 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=args.use_fast_tokenizer, trust_remote_code=True) 137 | if not tokenizer.pad_token_id: 138 | tokenizer.pad_token_id = tokenizer.eos_token_id 139 | ppl = Perplexity(None, tokenizer, args.dataset_path, args.dataset_name, args.split, args.text_column, args.eval_accuracy) 140 | 141 | 142 | model_path = args.model_path 143 | quant_file = args.quant_file 144 | 145 | if args.model_type == 'bitsandbytesfp16': 146 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 147 | print(f" -- Loading model fp16...") 148 | # model = transformers.LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, 149 | # device_map='auto') 150 | model = AutoModelForCausalLM.from_pretrained( 151 | model_path, torch_dtype=torch.bfloat16, 152 | device_map='auto', trust_remote_code=True 153 | ) 154 | 155 | model = model.to('cuda') 156 | print(model) 157 | 158 | if args.model_type == 'bitsandbytesmix4': 159 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 160 | print(f" -- Loading model mix4...") 161 | 162 | n_gpus = torch.cuda.device_count() 163 | max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' 164 | max_memory = {i: max_memory for i in range(n_gpus)} 165 | quantization_config = BitsAndBytesConfig( 166 | load_in_4bit=True, 167 | llm_int4_threshold=6.0, 168 | llm_int4_has_fp16_weight=False, 169 | ) 170 | model = AutoModelForCausalLM.from_pretrained( 171 | model_path, 172 | device_map='auto', 173 | max_memory=max_memory, 174 | quantization_config=quantization_config 175 | ) 176 | if args.model_type == 'bitsandbytes': 177 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 178 | print(f" -- Loading model mix bit8...") 179 | 180 | n_gpus = torch.cuda.device_count() 181 | max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' 182 | max_memory = {i: max_memory for i in range(n_gpus)} 183 | quantization_config = BitsAndBytesConfig( 184 | load_in_8bit=True, 185 | llm_int8_threshold=6.0, 186 | llm_int8_has_fp16_weight=False, 187 | ) 188 | model = AutoModelForCausalLM.from_pretrained( 189 | model_path, 190 | device_map='auto', 191 | max_memory=max_memory, 192 | quantization_config=quantization_config 193 | ) 194 | 195 | 196 | if args.model_type == 'awq': 197 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 198 | print(f" -- Loading model awq...") 199 | 200 | 201 | from awq import AutoAWQForCausalLM 202 | model = AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=True, mix = False) 203 | 204 | 205 | if args.model_type == 'mix': 206 | from mixquant.Cache import MixLibCache 207 | from mixquant import AutoForCausalLM 208 | cache = MixLibCache(args.n_batch) 209 | 210 | 211 | model = AutoForCausalLM.from_quantized( 212 | model_path, quant_file, fuse_layers=True, 213 | mix = True, cache = cache 214 | ) 215 | 216 | if args.model_type == 'fp16': 217 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 218 | model = AutoModelForCausalLM.from_pretrained( 219 | model_path, torch_dtype=torch.float16, 220 | device_map='auto', trust_remote_code=True 221 | ) 222 | 223 | #model = model.to('cuda') 224 | 225 | if args.model_type == 'quik': 226 | from mixquant import AutoForCausalLM 227 | model = AutoForCausalLM.from_quantized( 228 | model_path, quant_file, fuse_layers=True, 229 | max_new_tokens=args.n_generate, batch_size=args.batch_size, 230 | safetensors=args.safetensors, 231 | mix = True, 232 | cache = cache 233 | ) 234 | print(model) 235 | ppl = Perplexity(model, tokenizer, args.dataset_path, args.dataset_name, 236 | args.split, args.text_column, args.eval_accuracy) 237 | allppl = ppl.calculate_perplexity(args.n_ctx, args.n_batch) 238 | 239 | 240 | 241 | 242 | -------------------------------------------------------------------------------- /generate.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | if [ $2 == a100 ] 4 | then 5 | CMD=" srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python " 6 | fi 7 | 8 | if [ $2 == h100 ] 9 | then 10 | CMD="srun -p twills -A h100 --gres=gpu:h100:1 --export=ALL python" 11 | fi 12 | 13 | export http_proxy=127.0.0.1:7890 14 | export https_proxy=127.0.0.1:7890 15 | set -x 16 | 17 | quantpath=/home/dataset/quant/quant 18 | modelpath=/home/dataset 19 | 20 | for batch in 32 21 | #for batch in 1 22 | 23 | do 24 | for seq in 1024 25 | do 26 | ##model_type=Aquila2 27 | #model_type=opt 28 | #model_type=Mistral 29 | #model_type=gpt-j 30 | #model_type=falcon 31 | model_type=$3 32 | 33 | 34 | 35 | data_types=( "mix" ) 36 | for bit in 8 37 | do 38 | for data_type in "${data_types[@]}" 39 | do 40 | model=${model_type} 41 | echo ${model} 42 | rm -r ${quantpath}${bit}/${model}/model.safetensors 43 | CUDA_VISIBLE_DEVICES=$1 ${CMD} generate.py --model_type ${data_type} --model_path \ 44 | ${quantpath}${bit}/${model} \ 45 | --quant_file ${quantpath}${bit}/${model} \ 46 | --n_batch ${batch} --n_ctx ${batch} --dataset_path /home/chenyidong/checkpoint/dataset 47 | done 48 | done 49 | 50 | 51 | # data_types=( "quik" ) 52 | # for bit in 4 53 | # do 54 | # for data_type in "${data_types[@]}" 55 | # do 56 | # model=${model_type} 57 | # echo ${model} 58 | # CUDA_VISIBLE_DEVICES=$1 ${CMD} generate.py --model_type ${data_type} --model_path \ 59 | # ${quantpath}quik${bit}/${model} \ 60 | # --quant_file ${quantpath}quik${bit}/${model} \ 61 | # --batch_size ${batch} --bit ${bit} --dataset_path /home/chenyidong/checkpoint/dataset 62 | # done 63 | # done 64 | # data_types=( "bitsandbytes" ) 65 | # for data_type in "${data_types[@]}" 66 | # do 67 | # model=${model_type} 68 | 69 | # echo ${model} 70 | # CUDA_VISIBLE_DEVICES=$1 ${CMD} benchflops.py --model_type ${data_type} --model_path \ 71 | # ${modelpath}/${model} \ 72 | # --quant_file ${modelpath}/${model} --batch_size ${batch} --dataset_path /home/chenyidong/checkpoint/dataset 73 | 74 | # done 75 | # data_types=( "awq" ) 76 | # for data_type in "${data_types[@]}" 77 | # do 78 | # for model in "${models[@]}" 79 | # do 80 | # echo ${model} 81 | # CUDA_VISIBLE_DEVICES=$1 ${CMD} benchflops.py --model_type ${data_type} --model_path \ 82 | # ${quantpath}/awqquant/${model} \ 83 | # --quant_file ${quantpath}t/awqquant/${model} --batch_size ${batch} 84 | # done 85 | # done 86 | 87 | 88 | 89 | done 90 | done 91 | -------------------------------------------------------------------------------- /get_act.sh: -------------------------------------------------------------------------------- 1 | 2 | CMD=" srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python " 3 | set -x 4 | 5 | model=( Llama-2-7b ) 6 | model=( Aquila2-7b ) 7 | model=( Baichuan2-7b ) 8 | $CMD examples/smooth_quant_get_act.py --model-name /mnt/octave/data/chenyidong/checkpoint/${model} \ 9 | --output-path /home/chenyidong/SC3/MixQ/src/act_scales/${model}.pt --dataset-path /home/chenyidong/val.jsonl.zst 10 | 11 | -------------------------------------------------------------------------------- /mixquant/Cache.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | 5 | class MixLibCache: 6 | def __init__(self, inputdim = 1024, sigma = 6, bit = 8, eval_ppl = False, locality = False): 7 | self.device = 'cuda' 8 | self.x_scale = torch.zeros((inputdim,1),dtype=torch.float16).to('cuda') 9 | 10 | self.sigma = torch.zeros((1,1),dtype=torch.float16).to('cuda') 11 | self.zeros = torch.zeros((inputdim,12288*3),dtype=torch.float16).to('cuda') 12 | self.sigma[0] = sigma 13 | 14 | 15 | self.ind = None 16 | self.shape = None 17 | self.activation_outliers = None 18 | self.is_prefill = False 19 | self.bit = bit 20 | 21 | self.max_outliers = 256 22 | self.stop = 2 23 | 24 | self.eval_ppl = eval_ppl 25 | self.locality = locality 26 | def do_bench_cudagraph(self, fn): 27 | if torch.cuda.current_stream() == torch.cuda.default_stream(): 28 | raise RuntimeError("Cannot capture graph in default stream. Please use side stream in benchmark code.") 29 | # warmup 30 | for i in range(10): 31 | fn() 32 | g = torch.cuda.CUDAGraph() 33 | with torch.cuda.graph(g): 34 | fn() 35 | torch.cuda.synchronize() 36 | 37 | 38 | return g 39 | 40 | 41 | 42 | class MLPCache: 43 | def __init__(self, max_batch_size = 4096): 44 | self.device = 'cuda' 45 | self.x_scale = torch.zeros((max_batch_size,1),dtype=torch.float16).to('cuda') 46 | self.ind = None 47 | self.shape = None 48 | self.activation_outliers = None 49 | 50 | -------------------------------------------------------------------------------- /mixquant/__init__.py: -------------------------------------------------------------------------------- 1 | from mixquant.models.auto import AutoForCausalLM -------------------------------------------------------------------------------- /mixquant/int8fusedkernel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import triton 4 | import triton.language as tl 5 | 6 | from triton import Config, autotune, cdiv, heuristics, jit 7 | from triton import language as tl 8 | from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time 9 | # `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: 10 | def upcast_if_fp8(a): 11 | if "fp8" in str(a): 12 | return torch.float16 13 | return a 14 | 15 | _ordered_datatypes = [torch.int8, torch.float16, torch.bfloat16, torch.float32] 16 | 17 | def get_higher_dtype(a, b): 18 | a = upcast_if_fp8(a) 19 | b = upcast_if_fp8(b) 20 | if a is b: 21 | return a 22 | 23 | assert a in _ordered_datatypes 24 | assert b in _ordered_datatypes 25 | 26 | for d in _ordered_datatypes: 27 | if a is d: 28 | return b 29 | if b is d: 30 | return a 31 | 32 | 33 | def init_to_zero(name): 34 | return lambda nargs: nargs[name].zero_() 35 | 36 | 37 | def get_configs_io_bound(): 38 | configs = [] 39 | for num_stages in [2, 3, 4, 5, 6]: 40 | for block_m in [16, 32]: 41 | for block_k in [32, 64]: 42 | for block_kfp in [32, 64]: 43 | for block_n in [32, 64, 128, 256]: 44 | num_warps = 2 if block_n <= 64 else 4 45 | configs.append( 46 | Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,'BLOCK_Kfp': block_kfp, 'SPLIT_K': 1}, 47 | num_stages=num_stages, num_warps=num_warps)) 48 | # split_k 49 | for split_k in [2, 4, 8, 16]: 50 | configs.append( 51 | Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'BLOCK_Kfp': block_kfp, 'SPLIT_K': split_k}, 52 | num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) 53 | return configs 54 | 55 | 56 | 57 | 58 | 59 | @autotune( 60 | configs=[ 61 | # basic configs for compute-bound matmuls 62 | Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8), 63 | Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8), 64 | Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4), 65 | Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4), 66 | Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4), 67 | Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4), 68 | Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4), 69 | Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4), 70 | Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=5, num_warps=2), 71 | # good for int8 72 | Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8), 73 | Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8), 74 | Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4), 75 | Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4), 76 | Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4), 77 | Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4), 78 | Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4), 79 | Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4), 80 | Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=5, num_warps=2), 81 | ] + get_configs_io_bound(), 82 | key=['M', 'N'], 83 | prune_configs_by={ 84 | 'early_config_prune': early_config_prune, 85 | 'perf_model': estimate_matmul_time, 86 | 'top_k': 10, 87 | }, 88 | ) 89 | @heuristics({ 90 | 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, 91 | }) 92 | 93 | @triton.jit 94 | def matmul_kernelint8(x,w, A, B, C, M, N, K, # 95 | stride_am, stride_ak, # 96 | stride_bk, stride_bn, # 97 | stride_cm, stride_cn, # 98 | 99 | Afp, Bfp, Cfp, 100 | Kfp, 101 | stride_amfp, stride_akfp, # 102 | stride_bkfp, stride_bnfp, # 103 | stride_cmfp, stride_cnfp, 104 | 105 | acc_dtype: tl.constexpr, # 106 | allow_tf32: tl.constexpr, # 107 | fp8_fast_accum: tl.constexpr, # 108 | BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # 109 | BLOCK_Kfp: tl.constexpr, 110 | GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr # 111 | ): 112 | # matrix multiplication 113 | pid = tl.program_id(0) 114 | pid_z = tl.program_id(1) 115 | grid_m = tl.cdiv(M, BLOCK_M) 116 | grid_n = tl.cdiv(N, BLOCK_N) 117 | # re-order program ID for better L2 performance 118 | width = GROUP_M * grid_n 119 | group_id = pid // width 120 | group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 121 | pid_m = group_id * GROUP_M + (pid % group_size) 122 | pid_n = (pid % width) // (group_size) 123 | # do matrix multiplication 124 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 125 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 126 | ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) 127 | rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) 128 | rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) 129 | # pointers 130 | A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) 131 | B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) 132 | 133 | 134 | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) 135 | 136 | 137 | for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): 138 | if EVEN_K: 139 | a = tl.load(A) 140 | b = tl.load(B) 141 | else: 142 | k_remaining = K - k * (BLOCK_K * SPLIT_K) 143 | _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) 144 | a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) 145 | b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) 146 | acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=False) 147 | 148 | A += BLOCK_K * SPLIT_K * stride_ak 149 | B += BLOCK_K * SPLIT_K * stride_bk 150 | 151 | 152 | 153 | 154 | # rematerialize rm and rn to save registers 155 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 156 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 157 | 158 | 159 | #Cfp = Cfp + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 160 | C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 161 | mask = (rm < M)[:, None] & (rn < N)[None, :] 162 | 163 | # handles write-back with reduction-splitting 164 | 165 | # x_ = tl.load(x + 0) 166 | 167 | # accfp = acc.to(tl.float32) 168 | # accfp *= x_ 169 | # accfp = accfp.to(tl.float16) 170 | if SPLIT_K == 1: 171 | tl.store(C, acc, mask=mask) 172 | 173 | else: 174 | tl.atomic_add(C, acc, mask=mask) 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | def matmulint8_fused_dequant(x,w, a, b, afp, bfp, c, cfp16, M, N, K, Kfp): 184 | 185 | grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K']) 186 | 187 | 188 | allow_tf32=True 189 | fp8_fast_accum=True 190 | matmul_kernelint8[grid]( 191 | x, w, 192 | a, b, c, # 193 | M, N, K, # 194 | K, 1, # 195 | 1, K, # 196 | N, 1, # 197 | afp, bfp, cfp16, Kfp[0], 198 | Kfp, 1, # 199 | 1, Kfp, # 200 | N, 1, # 201 | allow_tf32=allow_tf32, # 202 | fp8_fast_accum=fp8_fast_accum, # 203 | GROUP_M=8, acc_dtype=tl.int32,AB_DTYPE=None 204 | ) 205 | return c, cfp16 -------------------------------------------------------------------------------- /mixquant/kernel.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | import triton 5 | import triton.language as tl 6 | 7 | from triton import Config, autotune, cdiv, heuristics, jit 8 | from triton import language as tl 9 | from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time 10 | 11 | 12 | 13 | def get_configs_fp_io_bound(): 14 | configs = [] 15 | for num_stages in [2, 3, 4, 5, 6]: 16 | for block_m in [16, 32]: 17 | for block_kfp in [32, 64]: 18 | for block_n in [32, 64, 128, 256]: 19 | num_warps = 2 if block_n <= 64 else 4 20 | configs.append( 21 | Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_kfp, 'SPLIT_K': 1}, 22 | num_stages=num_stages, num_warps=num_warps)) 23 | 24 | return configs 25 | 26 | @autotune( 27 | configs=[ 28 | # basic configs for compute-bound matmuls 29 | Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=3, num_warps=8), 30 | Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=3, num_warps=8), 31 | Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 32 | Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 33 | Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 34 | Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 35 | Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 36 | Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 37 | Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=5, num_warps=2), 38 | # good for int8 39 | Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=3, num_warps=8), 40 | Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=3, num_warps=8), 41 | Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 42 | Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 43 | Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 44 | Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 45 | Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 46 | Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 47 | Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=5, num_warps=2), 48 | ] + get_configs_fp_io_bound(), 49 | key=['M', 'N'], 50 | prune_configs_by={ 51 | 'early_config_prune': early_config_prune, 52 | 'perf_model': estimate_matmul_time, 53 | 'top_k': 10, 54 | }, 55 | ) 56 | 57 | @triton.jit 58 | def matmul_kernelfp16(A, B, C, M, N, K, 59 | stride_amfp, stride_akfp, # 60 | stride_bkfp, stride_bnfp, # 61 | stride_cmfp, stride_cnfp, 62 | BLOCK_M: tl.constexpr, 63 | BLOCK_N: tl.constexpr, 64 | BLOCK_K: tl.constexpr, 65 | SPLIT_K: tl.constexpr, 66 | GROUP_M: tl.constexpr 67 | ): 68 | # matrix multiplication 69 | pid = tl.program_id(0) 70 | pid_z = tl.program_id(1) 71 | grid_m = tl.cdiv(M, BLOCK_M) 72 | grid_n = tl.cdiv(N, BLOCK_N) 73 | # re-order program ID for better L2 performance 74 | width = GROUP_M * grid_n 75 | group_id = pid // width 76 | group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 77 | pid_m = group_id * GROUP_M + (pid % group_size) 78 | pid_n = (pid % width) // (group_size) 79 | # do matrix multiplication 80 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 81 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 82 | ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) 83 | rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) 84 | 85 | # pointers 86 | 87 | rkfp = tl.arange(0, BLOCK_K) 88 | A = A + (ram[:, None] * stride_amfp + rkfp[None, :] * stride_akfp) 89 | B = B + (rkfp[:, None] * stride_bkfp + rbn[None, :] * stride_bnfp) 90 | 91 | accfp = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 92 | 93 | 94 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 95 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 96 | 97 | rmfp = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 98 | rnfp = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 99 | 100 | 101 | afp = tl.zeros((BLOCK_M, BLOCK_K), dtype=C.dtype.element_ty) 102 | bfp = tl.zeros((BLOCK_K, BLOCK_N), dtype=C.dtype.element_ty) 103 | C = C + (rmfp[:, None] * stride_cmfp + rnfp[None, :] * stride_cnfp) 104 | mask = (rm < M)[:, None] & (rn < N)[None, :] 105 | 106 | 107 | 108 | K_ = tl.load(K + 0) 109 | if K_ == 0: 110 | 111 | return 112 | 113 | maxK = tl.cdiv(K_, BLOCK_K ) 114 | for k in range(0, maxK - 1): 115 | 116 | afp = tl.load(A) 117 | bfp = tl.load(B) 118 | 119 | A += BLOCK_K * stride_akfp 120 | B += BLOCK_K * stride_bkfp 121 | 122 | accfp = tl.dot(afp, bfp, accfp, out_dtype=tl.float32, allow_tf32=False) 123 | 124 | k = maxK - 1 125 | if K_ % ( BLOCK_K ) == 0: 126 | afp = tl.load(A) 127 | bfp = tl.load(B) 128 | else: 129 | k_remainingfp = K_ - k * (BLOCK_K ) 130 | afp = tl.load(A, mask=rkfp[None, :] < k_remainingfp, other=0.0) 131 | bfp = tl.load(B, mask=rkfp[:, None] < k_remainingfp, other=0.0) 132 | 133 | accfp = tl.dot(afp, bfp, accfp, out_dtype=tl.float32, allow_tf32=False) 134 | 135 | accfp = accfp.to(tl.float16) 136 | 137 | # rematerialize rm and rn to save registers 138 | 139 | 140 | tl.store(C, accfp, mask=mask) 141 | 142 | 143 | 144 | 145 | 146 | def matmulfp16( afp, bfp, cfp16, M, N, K): 147 | 148 | grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K']) 149 | 150 | 151 | 152 | matmul_kernelfp16[grid]( 153 | afp, bfp, cfp16,M, N, K, 154 | 1, M, # 155 | N, 1, # 156 | N, 1, # 157 | GROUP_M=8 158 | ) 159 | 160 | return -------------------------------------------------------------------------------- /mixquant/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .llama import LlamaMixQForCausalLM 2 | from .opt import OptMixForCausalLM 3 | #from .mistral import MistralMixForCausalLM 4 | from .gptj import GPTJMixForCausalLM 5 | from .falcon import FalconMixForCausalLM 6 | from .baichuan import BaichuanMixQForCausalLM -------------------------------------------------------------------------------- /mixquant/models/aquila.py: -------------------------------------------------------------------------------- 1 | ## Reference from llama.py 2 | from .base import BaseAWQForCausalLM 3 | from typing import Dict 4 | from transformers.models.llama.modeling_llama import ( 5 | LlamaDecoderLayer as AquilaDecoderLayer, 6 | LlamaForCausalLM as AquilaForCausalLM, 7 | LlamaAttention as AquilaAttention, 8 | LlamaRMSNorm as AquilaRMSNorm, 9 | LlamaMLP as AquilaMLP 10 | ) 11 | 12 | class AquilaAWQForCausalLM(BaseAWQForCausalLM): 13 | layer_type = "AquilaDecoderLayer" 14 | max_new_tokens_key = "max_position_embeddings" 15 | 16 | @staticmethod 17 | def fuse_layers(model: AquilaForCausalLM, quant_config: Dict): 18 | fuser = AquilaFuser(model, quant_config) 19 | fuser.fuse_attention() 20 | fuser.fuse_rmsnorm() 21 | fuser.fuse_mlp() 22 | 23 | @staticmethod 24 | def get_model_layers(model: AquilaForCausalLM): 25 | return model.model.layers 26 | 27 | @staticmethod 28 | def get_act_for_scaling(module: AquilaDecoderLayer): 29 | return dict( 30 | is_scalable=False 31 | ) 32 | 33 | @staticmethod 34 | def move_embed(model: AquilaForCausalLM, device: str): 35 | model.model.embed_tokens = model.model.embed_tokens.to(device) 36 | 37 | @staticmethod 38 | def get_layers_for_scaling(module: AquilaDecoderLayer, input_feat, module_kwargs): 39 | layers = [] 40 | 41 | # attention input 42 | layers.append(dict( 43 | prev_op=module.input_layernorm, 44 | layers=[module.self_attn.q_proj, 45 | module.self_attn.k_proj, module.self_attn.v_proj], 46 | inp=input_feat['self_attn.q_proj'], 47 | module2inspect=module.self_attn, kwargs=module_kwargs, 48 | )) 49 | 50 | if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: 51 | layers.append(dict( 52 | prev_op=module.self_attn.v_proj, 53 | layers=[module.self_attn.o_proj], 54 | inp=input_feat['self_attn.o_proj'], 55 | )) 56 | 57 | # linear 1 58 | layers.append(dict( 59 | prev_op=module.post_attention_layernorm, 60 | layers=[module.mlp.gate_proj, module.mlp.up_proj], 61 | inp=input_feat['mlp.gate_proj'], 62 | module2inspect=module.mlp, 63 | )) 64 | 65 | # linear 2 66 | layers.append(dict( 67 | prev_op=module.mlp.up_proj, 68 | layers=[module.mlp.down_proj], 69 | inp=input_feat['mlp.down_proj'], 70 | )) 71 | 72 | return layers 73 | 74 | -------------------------------------------------------------------------------- /mixquant/models/auto.py: -------------------------------------------------------------------------------- 1 | import os 2 | from transformers import AutoConfig 3 | from mixquant.models import * 4 | from mixquant.models.base import BaseForCausalLM 5 | 6 | CAUSAL_LM_MODEL_MAP = { 7 | 8 | "llama": LlamaMixQForCausalLM, 9 | "baichuan": BaichuanMixQForCausalLM, 10 | "aquila": LlamaMixQForCausalLM, 11 | #"mistral": MistralMixForCausalLM, 12 | "gptj" : GPTJMixForCausalLM, 13 | "falcon": FalconMixForCausalLM, 14 | "opt": OptMixForCausalLM 15 | } 16 | 17 | def check_and_get_model_type(model_dir, trust_remote_code=True): 18 | config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code) 19 | if config.model_type not in CAUSAL_LM_MODEL_MAP.keys(): 20 | raise TypeError(f"{config.model_type} isn't supported yet.") 21 | model_type = config.model_type 22 | if config.architectures[0]=="BaichuanForCausalLM": 23 | model_type="baichuan" 24 | return model_type 25 | 26 | class AutoForCausalLM: 27 | def __init__(self): 28 | raise EnvironmentError('You must instantiate AutoAWQForCausalLM with\n' 29 | 'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained') 30 | 31 | @classmethod 32 | def from_pretrained(self, model_path, trust_remote_code=True, safetensors=False, 33 | device_map=None, mix = False, **model_init_kwargs) -> BaseForCausalLM: 34 | model_type = check_and_get_model_type(model_path, trust_remote_code) 35 | 36 | return CAUSAL_LM_MODEL_MAP[model_type].from_pretrained( 37 | model_path, model_type, trust_remote_code=trust_remote_code, safetensors=safetensors, 38 | device_map=device_map, mix = mix, **model_init_kwargs 39 | ) 40 | 41 | @classmethod 42 | def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None, 43 | trust_remote_code=True, fuse_layers=True, 44 | batch_size=1, safetensors=False, 45 | max_memory=None, offload_folder=None, mix = False, cache = None) -> BaseForCausalLM: 46 | 47 | model_type = check_and_get_model_type(quant_path, trust_remote_code) 48 | os.environ["BATCH_SIZE"] = str(batch_size) 49 | return CAUSAL_LM_MODEL_MAP[model_type].from_quantized( 50 | quant_path, model_type, quant_filename, max_new_tokens, trust_remote_code=trust_remote_code, 51 | fuse_layers=fuse_layers, safetensors=safetensors, 52 | max_memory=max_memory, offload_folder=offload_folder, mix = mix, cache = cache 53 | ) 54 | -------------------------------------------------------------------------------- /mixquant/models/baichuan.py: -------------------------------------------------------------------------------- 1 | from .base import BaseForCausalLM 2 | from typing import Dict 3 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM 4 | 5 | class BaichuanMixQForCausalLM(BaseForCausalLM): 6 | layer_type = "LlamaDecoderLayer" 7 | max_new_tokens_key = "max_position_embeddings" 8 | 9 | @staticmethod 10 | def fuse_layers(model: LlamaForCausalLM, quant_config: Dict, mix = False, cache = None): 11 | 12 | fuser = LlamaFuser(model, quant_config) 13 | 14 | fuser.fuse_attention(MixGemmCache = cache) 15 | 16 | fuser.fuse_mlp(mix, MixGemmCache = cache) 17 | fuser.fuse_rmsnorm(MixGemmCache = cache) 18 | 19 | 20 | for layer in model.model.layers: 21 | layer.input_layernorm.next_layer = layer.self_attn.W_pack 22 | layer.post_attention_layernorm.next_layer = layer.mlp.up_proj_ 23 | 24 | @staticmethod 25 | def get_model_layers(model: LlamaForCausalLM): 26 | return model.model.layers 27 | 28 | 29 | 30 | @staticmethod 31 | def move_embed(model: LlamaForCausalLM, device: str): 32 | model.model.embed_tokens = model.model.embed_tokens.to(device) 33 | 34 | 35 | import torch 36 | from typing import List, Tuple, Union 37 | from mixquant.utils.utils import set_module_name 38 | from mixquant.modules.fused.mlp import MixLlamaMLP 39 | from mixquant.modules.fused.attn import QuantAttentionFused, QuantAttentionFusedBaichuan13B 40 | from mixquant.modules.fused.norm import FasterTransformerRMSNorm 41 | from mixquant.modules.linear import MixLinear_GEMM 42 | 43 | from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP 44 | import sys 45 | 46 | 47 | #from modeling_baichuan import Attention 48 | class LlamaFuser: 49 | def __init__(self, model, quant_config): 50 | self.model = model 51 | self.quant_config = quant_config 52 | 53 | #print(model.model.layers[0].self_attn.o_proj) # 确认一下模型的权重的格式 54 | 55 | #需要加入百川的 Attention 56 | self.attention_modules: List[Tuple[str, LlamaAttention]] = [ 57 | (name, module) for name, module in self.model.named_modules() 58 | if isinstance(module, LlamaAttention) or "Attention" in str(module.__class__) 59 | ] 60 | #print(self.attention_modules) 61 | 62 | self.rmsnorm_modules: List[Tuple[str, LlamaRMSNorm]] = [ 63 | (name, module) for name, module in self.model.named_modules() 64 | if isinstance(module, LlamaRMSNorm) or "RMSNorm" in str(module.__class__) 65 | ] 66 | 67 | self.mlp_modules: List[Tuple[str, LlamaMLP]] = [ 68 | (name, module) for name, module in self.model.named_modules() 69 | if isinstance(module, LlamaMLP) or "MLP" in str(module.__class__) 70 | ] 71 | 72 | def fuse_attention(self, MixGemmCache): 73 | for name, module in self.attention_modules: 74 | qkv_layer = self._fuse_qkv(module, MixGemmCache) 75 | try: 76 | num_key_value_heads = module.num_key_value_heads 77 | except: 78 | # 为了处理百川的模型 79 | num_key_value_heads = 32 80 | 81 | if self.model.config.num_hidden_layers == 40: 82 | attn = QuantAttentionFusedBaichuan13B( 83 | module.hidden_size, 84 | module.num_heads, 85 | num_key_value_heads, 86 | qkv_layer, 87 | module.o_proj, 88 | next(iter(qkv_layer.state_dict().values())).device, 89 | self.model.config.max_new_tokens, 90 | MixGemmCache = MixGemmCache 91 | ) 92 | 93 | else: 94 | attn = QuantAttentionFused( 95 | module.hidden_size, 96 | module.num_heads, 97 | num_key_value_heads, 98 | qkv_layer, 99 | module.o_proj, 100 | next(iter(qkv_layer.state_dict().values())).device, 101 | self.model.config.max_new_tokens, 102 | MixGemmCache = MixGemmCache 103 | ) 104 | set_module_name(self.model, name, attn) 105 | 106 | def fuse_attention(self, MixGemmCache): 107 | 108 | for name, module in self.attention_modules: 109 | 110 | layer_idx = int(name.split('.')[2]) 111 | qkv_layer = self._fuse_qkv(module, MixGemmCache) 112 | try: 113 | num_key_value_heads = module.num_key_value_heads 114 | except: 115 | # 为了处理百川的模型 116 | print("do not find the attr module.num_key_value_heads") 117 | num_key_value_heads = 32 118 | attn = QuantAttentionFused( 119 | module.hidden_size, 120 | module.num_heads, 121 | num_key_value_heads, 122 | qkv_layer, 123 | module.o_proj, 124 | next(iter(qkv_layer.state_dict().values())).device, 125 | self.model.config.max_new_tokens, 126 | MixGemmCache = MixGemmCache, 127 | layer_idx = layer_idx 128 | ) 129 | set_module_name(self.model, name, attn) 130 | 131 | def _fuse_qkv(self, module: LlamaAttention,cache): 132 | try: 133 | q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj 134 | except: 135 | qkv_layer = module.W_pack 136 | return qkv_layer 137 | 138 | 139 | 140 | if not isinstance(q_proj, MixLinear_GEMM) : 141 | raise "no implement error" 142 | 143 | if isinstance(q_proj, MixLinear_GEMM): 144 | qkv_layer = MixLinear_GEMM(q_proj.in_features,q_proj.out_features + k_proj.out_features + v_proj.out_features, 145 | q_proj.bias is not None, 146 | next(iter(module.state_dict().values())).device, 147 | bit = self.quant_config['w_bit'], 148 | weight_only=False, 149 | cache=cache) 150 | 151 | 152 | 153 | if isinstance(qkv_layer, MixLinear_GEMM): 154 | shapew = qkv_layer.q_weight.shape 155 | 156 | if qkv_layer.weight_only: 157 | qkv_layer.q_weight = torch.cat([q_proj.q_weight, k_proj.q_weight, v_proj.q_weight], dim=1) 158 | qkv_layer.scale_col = torch.cat([q_proj.scale_col, k_proj.scale_col, v_proj.scale_col], dim=0) 159 | 160 | else: 161 | qkv_layer.q_weight = torch.cat([q_proj.q_weight, k_proj.q_weight, v_proj.q_weight], dim=0) 162 | qkv_layer.scale_col = torch.cat([q_proj.scale_col, k_proj.scale_col, v_proj.scale_col], dim=1) 163 | assert shapew[0] == qkv_layer.q_weight.shape[0] 164 | assert shapew[1] == qkv_layer.q_weight.shape[1] 165 | assert shapew[0] == qkv_layer.scale_col.shape[1] 166 | assert 1 == qkv_layer.scale_col.shape[0] 167 | if self.quant_config['w_bit'] == 4: 168 | 169 | 170 | 171 | qkv_layer.weight_cache.copy_(torch.cat([q_proj.weight_cache, 172 | k_proj.weight_cache, 173 | v_proj.weight_cache], dim=0)) 174 | 175 | 176 | 177 | qkv_layer.ind.copy_(q_proj.ind) 178 | 179 | 180 | 181 | 182 | 183 | 184 | if q_proj.bias is not None: 185 | raise NotImplementedError 186 | else: 187 | qkv_layer.bias = None 188 | 189 | else: 190 | raise "no implement" 191 | 192 | q_proj.q_weight = q_proj.q_weight.to('cpu') 193 | k_proj.q_weight = k_proj.q_weight.to('cpu') 194 | v_proj.q_weight = v_proj.q_weight.to('cpu') 195 | q_proj.scale_col = q_proj.scale_col.to('cpu') 196 | k_proj.scale_col = k_proj.scale_col.to('cpu') 197 | v_proj.scale_col = v_proj.scale_col.to('cpu') 198 | torch.cuda.empty_cache() 199 | return qkv_layer 200 | 201 | def fuse_rmsnorm(self, MixGemmCache): 202 | for name, module in self.rmsnorm_modules: 203 | norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon, MixGemmCache) 204 | set_module_name(self.model, name, norm) 205 | 206 | def fuse_mlp(self,mix, MixGemmCache = None): 207 | for name, module in self.mlp_modules: 208 | if mix: 209 | assert MixGemmCache is not None 210 | mlp = MixLlamaMLP(module.gate_proj, module.down_proj, module.up_proj , MixGemmCache) 211 | set_module_name(self.model, name, mlp) -------------------------------------------------------------------------------- /mixquant/models/basefuser.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/mixquant/models/basefuser.py -------------------------------------------------------------------------------- /mixquant/models/bloom.py: -------------------------------------------------------------------------------- 1 | from .base import BaseAWQForCausalLM 2 | from transformers.models.bloom.modeling_bloom import BloomForCausalLM, BloomBlock 3 | 4 | class BloomMixForCausalLM(BaseAWQForCausalLM): 5 | layer_type = "BloomBlock" 6 | 7 | @staticmethod 8 | def get_model_layers(model: BloomForCausalLM): 9 | return model.transformer.h 10 | 11 | 12 | @staticmethod 13 | def move_embed(model: BloomForCausalLM, device: str): 14 | model.transformer.word_embeddings = model.transformer.word_embeddings.to(device) 15 | model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.to(device) 16 | 17 | -------------------------------------------------------------------------------- /mixquant/models/falcon.py: -------------------------------------------------------------------------------- 1 | from .base import BaseForCausalLM 2 | from typing import Dict 3 | from transformers.models.falcon.modeling_falcon import FalconDecoderLayer as OldFalconDecoderLayer, FalconForCausalLM, FalconAttention 4 | from transformers.models.falcon.configuration_falcon import FalconConfig 5 | from transformers.models.falcon.modeling_falcon import FalconMLP 6 | 7 | 8 | from mixquant.modules.fused.mlp import MixFalconMLP 9 | from mixquant.utils.utils import set_module_name 10 | import torch 11 | from typing import Optional, Tuple, Union, List 12 | from torch import nn 13 | from torch.nn import functional as F 14 | 15 | class FalconMixForCausalLM(BaseForCausalLM): 16 | layer_type = "FalconDecoderLayer" 17 | 18 | @staticmethod 19 | def fuse_layers(model: FalconForCausalLM, quant_config: Dict, mix, cache): 20 | fuser = FalconFuser(model) 21 | 22 | 23 | fuser.fuse_mlp(mix, cache) 24 | 25 | @staticmethod 26 | def get_model_layers(model: FalconForCausalLM): 27 | return model.transformer.h 28 | 29 | @staticmethod 30 | def move_embed(model: FalconForCausalLM, device): 31 | model.transformer.word_embeddings = model.transformer.word_embeddings.to(device) 32 | 33 | 34 | 35 | 36 | class FalconFuser: 37 | def __init__(self, model: FalconForCausalLM): 38 | self.model = model 39 | 40 | self.attention_modules = [ 41 | (name, module) for name, module in self.model.named_modules() 42 | if "Attention" in str(module.__class__) 43 | ] 44 | self.mlp_modules: List[Tuple[str, FalconMLP]] = [ 45 | (name, module) for name, module in self.model.named_modules() 46 | if isinstance(module, FalconMLP) or "MLP" in str(module.__class__) 47 | ] 48 | 49 | def fuse_mlp(self,mix, MixGemmCache = None): 50 | for name, module in self.mlp_modules: 51 | if mix: 52 | assert MixGemmCache is not None 53 | mlp = MixFalconMLP(module.dense_h_to_4h, module.dense_4h_to_h, MixGemmCache) 54 | set_module_name(self.model, name, mlp) -------------------------------------------------------------------------------- /mixquant/models/gpt_bigcode.py: -------------------------------------------------------------------------------- 1 | from .base import BaseAWQForCausalLM 2 | from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM, GPTBigCodeBlock as OldGptBigCodeBlock 3 | 4 | class GptBigCodeAWQForCausalLM(BaseAWQForCausalLM): 5 | layer_type = "GPTBigCodeBlock" 6 | max_new_tokens_key = "n_positions" 7 | 8 | @staticmethod 9 | def get_model_layers(model: GPTBigCodeForCausalLM): 10 | return model.transformer.h 11 | 12 | @staticmethod 13 | def get_act_for_scaling(module: OldGptBigCodeBlock): 14 | return dict( 15 | is_scalable=True, 16 | scale_name="mlp.act", 17 | scale_layer=module.mlp.act, 18 | scale_shape=module.mlp.c_fc.out_features 19 | ) 20 | 21 | @staticmethod 22 | def move_embed(model: GPTBigCodeForCausalLM, device): 23 | model.transformer.wte = model.transformer.wte.to(device) 24 | model.transformer.wpe = model.transformer.wpe.to(device) 25 | model.transformer.drop = model.transformer.drop.to(device) 26 | 27 | @staticmethod 28 | def get_layers_for_scaling(module:OldGptBigCodeBlock, input_feat, module_kwargs): 29 | layers = [] 30 | 31 | # attention input 32 | layers.append(dict( 33 | prev_op=module.ln_1, 34 | layers=[module.attn.c_attn], 35 | inp=input_feat['attn.c_attn'], 36 | module2inspect=module.attn, 37 | kwargs=module_kwargs 38 | )) 39 | 40 | # linear 1 41 | layers.append(dict( 42 | prev_op=module.ln_2, 43 | layers=[module.mlp.c_fc], 44 | inp=input_feat['mlp.c_fc'], 45 | module2inspect=module.mlp 46 | )) 47 | 48 | # linear 2 49 | layers.append(dict( 50 | prev_op=module.mlp.act, 51 | layers=[module.mlp.c_proj], 52 | inp=input_feat['mlp.c_proj'] 53 | )) 54 | 55 | return layers 56 | -------------------------------------------------------------------------------- /mixquant/models/gpt_neox.py: -------------------------------------------------------------------------------- 1 | from .base import BaseAWQForCausalLM 2 | from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXLayer, GPTNeoXForCausalLM 3 | 4 | class GPTNeoXAWQForCausalLM(BaseAWQForCausalLM): 5 | layer_type = "GPTNeoXDecoderLayer" 6 | max_new_tokens_key = "max_position_embeddings" 7 | 8 | @staticmethod 9 | def get_model_layers(model: GPTNeoXForCausalLM): 10 | return model.gpt_neox.layers 11 | 12 | @staticmethod 13 | def get_act_for_scaling(module: GPTNeoXLayer): 14 | return dict( 15 | is_scalable=True, 16 | scale_name="mlp.act", 17 | scale_layer=module.mlp.act, 18 | scale_shape=module.mlp.dense_h_to_4h.out_features, 19 | ) 20 | 21 | @staticmethod 22 | def move_embed(model: GPTNeoXForCausalLM, device: str): 23 | model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(device) 24 | 25 | @staticmethod 26 | def get_layers_for_scaling(module: GPTNeoXLayer, input_feat, module_kwargs): 27 | layers = [] 28 | 29 | # attention input 30 | layers.append(dict( 31 | prev_op=module.input_layernorm, 32 | layers=[module.attention.query_key_value], 33 | inp=input_feat['attention.query_key_value'], 34 | )) 35 | 36 | # attention out 37 | # Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469 38 | """ 39 | layers.append(dict( 40 | prev_op=module.attention.query_key_value, 41 | layers=[module.attention.dense], 42 | inp=input_feat['attention.dense'], 43 | )) 44 | """ 45 | 46 | # linear 1 47 | layers.append(dict( 48 | prev_op=module.post_attention_layernorm, 49 | layers=[module.mlp.dense_h_to_4h], 50 | inp=input_feat['mlp.dense_h_to_4h'], 51 | )) 52 | 53 | # linear 2 54 | layers.append(dict( 55 | prev_op=module.mlp.act, 56 | layers=[module.mlp.dense_4h_to_h], 57 | inp=input_feat['mlp.dense_4h_to_h'], 58 | )) 59 | 60 | return layers 61 | -------------------------------------------------------------------------------- /mixquant/models/gptj.py: -------------------------------------------------------------------------------- 1 | from .base import BaseForCausalLM 2 | from transformers.models.gptj.modeling_gptj import GPTJForCausalLM, GPTJBlock, GPTJAttention, GPTJMLP 3 | from typing import List, Tuple, Union 4 | from mixquant.modules.linear import MixLinear_GEMM 5 | import torch 6 | from mixquant.modules.fused.gptj_attn import QuantGPTJAttentionFused 7 | from mixquant.modules.fused.mlp import MixGPTJMLP 8 | from mixquant.utils.utils import set_module_name 9 | 10 | 11 | class GPTJMixForCausalLM(BaseForCausalLM): 12 | layer_type = "GPTJBlock" 13 | max_new_tokens_key = "n_positions" 14 | 15 | @staticmethod 16 | def get_model_layers(model: GPTJForCausalLM): 17 | return model.transformer.h 18 | 19 | 20 | @staticmethod 21 | def move_embed(model: GPTJForCausalLM, device: str): 22 | model.transformer.wte = model.transformer.wte.to(device) 23 | 24 | 25 | 26 | 27 | 28 | @staticmethod 29 | def fuse_layers(model: GPTJForCausalLM, quant_config, mix, cache): 30 | fuser = GPTJFuser(model) 31 | 32 | 33 | fuser.fuse_mlp(mix, cache) 34 | fuser.fuse_attention(MixGemmCache = cache) 35 | 36 | 37 | 38 | 39 | 40 | class GPTJFuser: 41 | def __init__(self, model: GPTJForCausalLM): 42 | self.model = model 43 | 44 | self.attention_modules: List[Tuple[str, GPTJAttention]] = [ 45 | (name, module) for name, module in self.model.named_modules() 46 | if isinstance(module, GPTJAttention) or "Attention" in str(module.__class__) 47 | ] 48 | self.mlp_modules: List[Tuple[str, GPTJMLP]] = [ 49 | (name, module) for name, module in self.model.named_modules() 50 | if isinstance(module, GPTJMLP) or "MLP" in str(module.__class__) 51 | ] 52 | def fuse_mlp(self,mix, MixGemmCache = None): 53 | for name, module in self.mlp_modules: 54 | if mix: 55 | assert MixGemmCache is not None 56 | mlp = MixGPTJMLP(module, self.model.config, MixGemmCache) 57 | set_module_name(self.model, name, mlp) 58 | 59 | 60 | def _fuse_qkv(self, module,cache): 61 | try: 62 | q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj 63 | except: 64 | qkv_layer = module.W_pack 65 | return qkv_layer 66 | 67 | 68 | 69 | if not isinstance(q_proj, MixLinear_GEMM) : 70 | raise "no implement error" 71 | 72 | if isinstance(q_proj, MixLinear_GEMM): 73 | qkv_layer = MixLinear_GEMM(q_proj.in_features,q_proj.out_features + k_proj.out_features + v_proj.out_features, 74 | q_proj.bias is not None, 75 | next(iter(module.state_dict().values())).device, 76 | False, 77 | cache) 78 | 79 | 80 | 81 | if isinstance(qkv_layer, MixLinear_GEMM): 82 | shapew = qkv_layer.q_weight.shape 83 | qkv_layer.q_weight = torch.cat([q_proj.q_weight, k_proj.q_weight, v_proj.q_weight], dim=0) 84 | qkv_layer.scale_col = torch.cat([q_proj.scale_col, k_proj.scale_col, v_proj.scale_col], dim=1) 85 | 86 | assert shapew[0] == qkv_layer.q_weight.shape[0] 87 | assert shapew[1] == qkv_layer.q_weight.shape[1] 88 | assert shapew[0] == qkv_layer.scale_col.shape[1] 89 | assert 1 == qkv_layer.scale_col.shape[0] 90 | 91 | if q_proj.bias is not None: 92 | raise NotImplementedError 93 | else: 94 | qkv_layer.bias = None 95 | 96 | else: 97 | raise "no implement" 98 | 99 | q_proj.q_weight = q_proj.q_weight.to('cpu') 100 | k_proj.q_weight = k_proj.q_weight.to('cpu') 101 | v_proj.q_weight = v_proj.q_weight.to('cpu') 102 | q_proj.scale_col = q_proj.scale_col.to('cpu') 103 | k_proj.scale_col = k_proj.scale_col.to('cpu') 104 | v_proj.scale_col = v_proj.scale_col.to('cpu') 105 | torch.cuda.empty_cache() 106 | return qkv_layer 107 | 108 | 109 | def fuse_attention(self, MixGemmCache): 110 | for name, module in self.attention_modules: 111 | qkv_layer = self._fuse_qkv(module) 112 | 113 | attn = QuantGPTJAttentionFused( 114 | self.model.config, 115 | module, 116 | qkv_layer, 117 | module.out_proj, 118 | MixGemmCache = MixGemmCache 119 | ) 120 | set_module_name(self.model, name, attn) -------------------------------------------------------------------------------- /mixquant/models/llama.py: -------------------------------------------------------------------------------- 1 | from .base import BaseForCausalLM 2 | from typing import Dict 3 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM 4 | 5 | class LlamaMixQForCausalLM(BaseForCausalLM): 6 | layer_type = "LlamaDecoderLayer" 7 | max_new_tokens_key = "max_position_embeddings" 8 | 9 | @staticmethod 10 | def fuse_layers(model: LlamaForCausalLM, quant_config: Dict, mix = False, cache = None): 11 | 12 | fuser = LlamaFuser(model, quant_config) 13 | 14 | fuser.fuse_attention(MixGemmCache = cache) 15 | 16 | fuser.fuse_mlp(mix, MixGemmCache = cache) 17 | fuser.fuse_rmsnorm(MixGemmCache = cache) 18 | 19 | 20 | for layer in model.model.layers: 21 | layer.input_layernorm.next_layer = layer.self_attn.W_pack 22 | layer.post_attention_layernorm.next_layer = layer.mlp.up_proj_ 23 | 24 | 25 | @staticmethod 26 | def get_model_layers(model: LlamaForCausalLM): 27 | return model.model.layers 28 | 29 | 30 | 31 | @staticmethod 32 | def move_embed(model: LlamaForCausalLM, device: str): 33 | model.model.embed_tokens = model.model.embed_tokens.to(device) 34 | 35 | 36 | import torch 37 | from typing import List, Tuple, Union 38 | from mixquant.utils.utils import set_module_name 39 | from mixquant.modules.fused.mlp import MixLlamaMLP 40 | from mixquant.modules.fused.attn import QuantAttentionFused 41 | from mixquant.modules.fused.norm import FasterTransformerRMSNorm 42 | from mixquant.modules.linear import MixLinear_GEMM 43 | 44 | from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP 45 | import sys 46 | 47 | 48 | #from modeling_baichuan import Attention 49 | class LlamaFuser: 50 | def __init__(self, model, quant_config): 51 | self.model = model 52 | self.quant_config = quant_config 53 | 54 | #print(model.model.layers[0].self_attn.o_proj) # 确认一下模型的权重的格式 55 | 56 | #需要加入百川的 Attention 57 | self.attention_modules: List[Tuple[str, LlamaAttention]] = [ 58 | (name, module) for name, module in self.model.named_modules() 59 | if isinstance(module, LlamaAttention) or "Attention" in str(module.__class__) 60 | ] 61 | #print(self.attention_modules) 62 | 63 | self.rmsnorm_modules: List[Tuple[str, LlamaRMSNorm]] = [ 64 | (name, module) for name, module in self.model.named_modules() 65 | if isinstance(module, LlamaRMSNorm) or "RMSNorm" in str(module.__class__) 66 | ] 67 | 68 | self.mlp_modules: List[Tuple[str, LlamaMLP]] = [ 69 | (name, module) for name, module in self.model.named_modules() 70 | if isinstance(module, LlamaMLP) or "MLP" in str(module.__class__) 71 | ] 72 | 73 | def fuse_attention(self, MixGemmCache): 74 | 75 | for name, module in self.attention_modules: 76 | 77 | layer_idx = int(name.split('.')[2]) 78 | qkv_layer = self._fuse_qkv(module, MixGemmCache) 79 | try: 80 | num_key_value_heads = module.num_key_value_heads 81 | except: 82 | # 为了处理百川的模型 83 | print("do not find the attr module.num_key_value_heads") 84 | num_key_value_heads = 32 85 | attn = QuantAttentionFused( 86 | module.hidden_size, 87 | module.num_heads, 88 | num_key_value_heads, 89 | qkv_layer, 90 | module.o_proj, 91 | next(iter(qkv_layer.state_dict().values())).device, 92 | self.model.config.max_new_tokens, 93 | MixGemmCache = MixGemmCache, 94 | layer_idx = layer_idx 95 | ) 96 | set_module_name(self.model, name, attn) 97 | 98 | def _fuse_qkv(self, module: LlamaAttention,cache): 99 | try: 100 | q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj 101 | except: 102 | qkv_layer = module.W_pack 103 | return qkv_layer 104 | 105 | 106 | 107 | if not isinstance(q_proj, MixLinear_GEMM) : 108 | raise "no implement error" 109 | 110 | if isinstance(q_proj, MixLinear_GEMM): 111 | qkv_layer = MixLinear_GEMM(q_proj.in_features,q_proj.out_features + k_proj.out_features + v_proj.out_features, 112 | q_proj.bias is not None, 113 | next(iter(module.state_dict().values())).device, 114 | bit = self.quant_config['w_bit'], 115 | weight_only=False, 116 | cache=cache) 117 | 118 | 119 | 120 | if isinstance(qkv_layer, MixLinear_GEMM): 121 | shapew = qkv_layer.q_weight.shape 122 | 123 | if qkv_layer.weight_only: 124 | qkv_layer.q_weight = torch.cat([q_proj.q_weight, k_proj.q_weight, v_proj.q_weight], dim=1) 125 | qkv_layer.scale_col = torch.cat([q_proj.scale_col, k_proj.scale_col, v_proj.scale_col], dim=0) 126 | 127 | else: 128 | qkv_layer.q_weight = torch.cat([q_proj.q_weight, k_proj.q_weight, v_proj.q_weight], dim=0) 129 | qkv_layer.scale_col = torch.cat([q_proj.scale_col, k_proj.scale_col, v_proj.scale_col], dim=1) 130 | assert shapew[0] == qkv_layer.q_weight.shape[0] 131 | assert shapew[1] == qkv_layer.q_weight.shape[1] 132 | assert shapew[0] == qkv_layer.scale_col.shape[1] 133 | assert 1 == qkv_layer.scale_col.shape[0] 134 | if self.quant_config['w_bit'] == 4: 135 | 136 | 137 | 138 | qkv_layer.weight_cache.copy_(torch.cat([q_proj.weight_cache, 139 | k_proj.weight_cache, 140 | v_proj.weight_cache], dim=0)) 141 | 142 | 143 | 144 | qkv_layer.ind.copy_(q_proj.ind) 145 | 146 | 147 | 148 | 149 | 150 | 151 | if q_proj.bias is not None: 152 | raise NotImplementedError 153 | else: 154 | qkv_layer.bias = None 155 | 156 | else: 157 | raise "no implement" 158 | 159 | q_proj.q_weight = q_proj.q_weight.to('cpu') 160 | k_proj.q_weight = k_proj.q_weight.to('cpu') 161 | v_proj.q_weight = v_proj.q_weight.to('cpu') 162 | q_proj.scale_col = q_proj.scale_col.to('cpu') 163 | k_proj.scale_col = k_proj.scale_col.to('cpu') 164 | v_proj.scale_col = v_proj.scale_col.to('cpu') 165 | torch.cuda.empty_cache() 166 | return qkv_layer 167 | 168 | def fuse_rmsnorm(self, MixGemmCache): 169 | for name, module in self.rmsnorm_modules: 170 | norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon, MixGemmCache) 171 | set_module_name(self.model, name, norm) 172 | 173 | def fuse_mlp(self,mix, MixGemmCache = None): 174 | for name, module in self.mlp_modules: 175 | if mix: 176 | assert MixGemmCache is not None 177 | mlp = MixLlamaMLP(module.gate_proj, module.down_proj, module.up_proj , MixGemmCache) 178 | set_module_name(self.model, name, mlp) -------------------------------------------------------------------------------- /mixquant/models/mistral.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from .base import BaseForCausalLM 3 | from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralForCausalLM 4 | 5 | class MistralMixForCausalLM(BaseForCausalLM): 6 | layer_type = "MistralDecoderLayer" 7 | max_new_tokens_key = "max_position_embeddings" 8 | 9 | @staticmethod 10 | def fuse_layers(model: MistralForCausalLM, quant_config: Dict, mix, cache): 11 | fuser = MistralFuser(model, quant_config) 12 | fuser.fuse_attention(cache) 13 | fuser.fuse_rmsnorm() 14 | fuser.fuse_mlp() 15 | 16 | @staticmethod 17 | def get_model_layers(model: MistralForCausalLM): 18 | return model.model.layers 19 | 20 | @staticmethod 21 | def get_act_for_scaling(module: MistralDecoderLayer): 22 | return dict( 23 | is_scalable=False 24 | ) 25 | 26 | @staticmethod 27 | def move_embed(model: MistralForCausalLM, device: str): 28 | model.model.embed_tokens = model.model.embed_tokens.to(device) 29 | 30 | 31 | 32 | import torch 33 | from typing import List, Tuple, Union 34 | from mixquant.utils.utils import set_module_name 35 | from mixquant.modules.fused.mlp import MixLlamaMLP 36 | from mixquant.modules.fused.mistral_attn import MistralQuantAttentionFused 37 | from mixquant.modules.fused.norm import FasterTransformerRMSNorm 38 | from mixquant.modules.linear import MixLinear_GEMM 39 | 40 | from transformers.models.mistral.modeling_mistral import MistralAttention, MistralRMSNorm, MistralMLP 41 | 42 | class MistralFuser: 43 | def __init__(self, model, quant_config): 44 | self.model = model 45 | self.quant_config = quant_config 46 | 47 | self.attention_modules: List[Tuple[str, MistralAttention]] = [ 48 | (name, module) for name, module in self.model.named_modules() 49 | if isinstance(module, MistralAttention) 50 | ] 51 | 52 | self.rmsnorm_modules: List[Tuple[str, MistralRMSNorm]] = [ 53 | (name, module) for name, module in self.model.named_modules() 54 | if isinstance(module, MistralRMSNorm) 55 | ] 56 | 57 | self.mlp_modules: List[Tuple[str, MistralMLP]] = [ 58 | (name, module) for name, module in self.model.named_modules() 59 | if isinstance(module, MistralMLP) 60 | ] 61 | 62 | def fuse_attention(self,cache): 63 | for name, module in self.attention_modules: 64 | qkv_layer = self._fuse_qkv(module, cache) 65 | attn = MistralQuantAttentionFused( 66 | module.hidden_size, 67 | module.num_heads, 68 | module.num_key_value_heads, 69 | qkv_layer, 70 | module.o_proj, 71 | next(iter(qkv_layer.state_dict().values())).device, 72 | self.model.config.max_new_tokens, 73 | MixGemmCache = cache, 74 | config = self.model.config 75 | ) 76 | set_module_name(self.model, name, attn) 77 | 78 | def _fuse_qkv(self, module: MistralAttention, cache): 79 | try: 80 | q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj 81 | except: 82 | qkv_layer = module.W_pack 83 | return qkv_layer 84 | 85 | 86 | 87 | if not isinstance(q_proj, MixLinear_GEMM) : 88 | raise NotImplementedError 89 | 90 | if isinstance(q_proj, MixLinear_GEMM): 91 | qkv_layer = MixLinear_GEMM(q_proj.in_features,q_proj.out_features + k_proj.out_features + v_proj.out_features, 92 | q_proj.bias is not None, 93 | next(iter(module.state_dict().values())).device, 94 | False, 95 | cache) 96 | 97 | 98 | 99 | if isinstance(qkv_layer, MixLinear_GEMM): 100 | shapew = qkv_layer.q_weight.shape 101 | qkv_layer.q_weight = torch.cat([q_proj.q_weight, k_proj.q_weight, v_proj.q_weight], dim=0) 102 | qkv_layer.scale_col = torch.cat([q_proj.scale_col, k_proj.scale_col, v_proj.scale_col], dim=1) 103 | 104 | assert shapew[0] == qkv_layer.q_weight.shape[0] 105 | assert shapew[1] == qkv_layer.q_weight.shape[1] 106 | assert shapew[0] == qkv_layer.scale_col.shape[1] 107 | assert 1 == qkv_layer.scale_col.shape[0] 108 | 109 | if q_proj.bias is not None: 110 | raise NotImplementedError 111 | else: 112 | qkv_layer.bias = None 113 | 114 | else: 115 | raise NotImplementedError 116 | 117 | q_proj.q_weight = q_proj.q_weight.to('cpu') 118 | k_proj.q_weight = k_proj.q_weight.to('cpu') 119 | v_proj.q_weight = v_proj.q_weight.to('cpu') 120 | q_proj.scale_col = q_proj.scale_col.to('cpu') 121 | k_proj.scale_col = k_proj.scale_col.to('cpu') 122 | v_proj.scale_col = v_proj.scale_col.to('cpu') 123 | torch.cuda.empty_cache() 124 | return qkv_layer 125 | 126 | def fuse_rmsnorm(self): 127 | for name, module in self.rmsnorm_modules: 128 | norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon) 129 | set_module_name(self.model, name, norm) 130 | 131 | def fuse_mlp(self): 132 | for name, module in self.mlp_modules: 133 | mlp = MixLlamaMLP(module.gate_proj, module.down_proj, module.up_proj) 134 | set_module_name(self.model, name, mlp) 135 | -------------------------------------------------------------------------------- /mixquant/models/mpt.py: -------------------------------------------------------------------------------- 1 | from .base import BaseAWQForCausalLM 2 | from typing import Dict 3 | from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM 4 | 5 | class MptAWQForCausalLM(BaseAWQForCausalLM): 6 | layer_type = "MPTBlock" 7 | max_new_tokens_key = "max_seq_len" 8 | 9 | @staticmethod 10 | def fuse_layers(model: MptForCausalLM, quant_config: Dict): 11 | fuser = MptFuser(model) 12 | fuser.fuse_transformer() 13 | 14 | @staticmethod 15 | def get_model_layers(model: MptForCausalLM): 16 | return model.transformer.blocks 17 | 18 | @staticmethod 19 | def get_act_for_scaling(module: OldMptBlock): 20 | return dict( 21 | is_scalable=True, 22 | scale_name="ffn.act", 23 | scale_layer=module.ffn.act, 24 | scale_shape=module.ffn.up_proj.out_features 25 | ) 26 | 27 | @staticmethod 28 | def move_embed(model: MptForCausalLM, device: str): 29 | model.transformer.wte = model.transformer.wte.to(device) 30 | model.transformer.emb_drop = model.transformer.emb_drop.to(device) 31 | 32 | @staticmethod 33 | def get_layers_for_scaling(module: OldMptBlock, input_feat, module_kwargs): 34 | layers = [] 35 | 36 | # attention input 37 | layers.append(dict( 38 | prev_op=module.norm_1, 39 | layers=[module.attn.Wqkv], 40 | inp=input_feat['attn.Wqkv'], 41 | module2inspect=module.attn, 42 | kwargs=module_kwargs 43 | )) 44 | 45 | # attention output 46 | layers.append(dict( 47 | prev_op=module.attn.Wqkv, 48 | layers=[module.attn.out_proj], 49 | inp=input_feat['attn.out_proj'] 50 | )) 51 | 52 | # linear 1 53 | layers.append(dict( 54 | prev_op=module.norm_2, 55 | layers=[module.ffn.up_proj], 56 | inp=input_feat['ffn.up_proj'], 57 | module2inspect=module.ffn 58 | )) 59 | 60 | # linear 2 61 | layers.append(dict( 62 | prev_op=module.ffn.act, 63 | layers=[module.ffn.down_proj], 64 | inp=input_feat['ffn.down_proj'] 65 | )) 66 | 67 | return layers 68 | 69 | from typing import List, Tuple 70 | from awq.utils.utils import set_module_name 71 | from awq.modules.fused.block import MPTBlock 72 | from awq.modules.fused.model import MPTModel 73 | 74 | class MptFuser: 75 | def __init__(self, model: MptForCausalLM): 76 | self.model = model 77 | 78 | self.mpt_blocks: List[Tuple[str, OldMptBlock]] = [ 79 | (name, module) for name, module in self.model.named_modules() 80 | if 'mptblock' in module.__class__.__name__.lower() 81 | ] 82 | 83 | def fuse_transformer(self): 84 | blocks = [] 85 | 86 | module: OldMptBlock 87 | for module in self.model.transformer.blocks: 88 | blocks.append(MPTBlock( 89 | self.model.config.d_model, 90 | self.model.config.n_heads, 91 | module.attn.Wqkv, 92 | module.attn.out_proj, 93 | module.ffn, 94 | module.norm_1, 95 | module.norm_2, 96 | next(iter(module.state_dict().values())).device, 97 | self.model.config.max_new_tokens 98 | )) 99 | 100 | self.model.transformer = MPTModel( 101 | self.model.config.vocab_size, 102 | blocks, 103 | self.model.transformer.wte, 104 | self.model.transformer.norm_f, 105 | ) -------------------------------------------------------------------------------- /mixquant/models/opt.py: -------------------------------------------------------------------------------- 1 | from .base import BaseForCausalLM 2 | from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTDecoderLayer 3 | 4 | class OptMixForCausalLM(BaseForCausalLM): 5 | layer_type = "OPTDecoderLayer" 6 | max_new_tokens_key = "max_position_embeddings" 7 | 8 | @staticmethod 9 | def get_model_layers(model: OPTForCausalLM): 10 | return model.model.decoder.layers 11 | 12 | 13 | @staticmethod 14 | def move_embed(model: OPTForCausalLM, device: str): 15 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device) 16 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(device) 17 | 18 | -------------------------------------------------------------------------------- /mixquant/models/sample.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | @codegen(dtype = [torch.float16, torch.int8 , torch.int32], codegen = "cutlass") 5 | def mixgemm(a, bint8, ind): 6 | afp = a[:,ind] 7 | bfp = bint8[:,ind].to(torch.float16).scale() 8 | a[:,ind] = 0 9 | aint8 = a.to(torch.int8) 10 | c = (aint8 * bint8).to(torch.float16).scale() + afp * bfp 11 | out = torch.relu(c) 12 | return out 13 | 14 | 15 | 16 | class Attention: 17 | def __init__(self,weight,oproj): 18 | self.qweight = weight 19 | self.oproj = oproj 20 | 21 | 22 | @quant([function=self.qweight, type="Mix",activation="A8",weight="W8",codegen = "cutlass"], 23 | [function=self.oproj, type="Weight",activation="A16",weight="W8",codegen = "cutlass"]) 24 | def attention(self,hidden_state): 25 | 26 | qx = self.qweight(hidden_state) 27 | o = self.oproj(qx) 28 | 29 | return o -------------------------------------------------------------------------------- /mixquant/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/mixquant/modules/__init__.py -------------------------------------------------------------------------------- /mixquant/modules/act.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class ScaledActivation(nn.Module): 4 | def __init__(self, module, scales): 5 | super().__init__() 6 | self.act = module 7 | self.scales = nn.Parameter(scales.data) 8 | 9 | def forward(self, x): 10 | return self.act(x) / self.scales.view(1, 1, -1).to(x.device) 11 | -------------------------------------------------------------------------------- /mixquant/modules/fused/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/mixquant/modules/fused/__init__.py -------------------------------------------------------------------------------- /mixquant/modules/fused/block.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | from awq.modules.fused.attn import QuantAttentionFused 4 | 5 | class MPTBlock(nn.Module): 6 | def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len): 7 | super().__init__() 8 | self.n_heads = n_heads 9 | self.n_kv_heads = 0 10 | self.hidden_size = hidden_size 11 | self.norm_1 = norm_1 12 | self.attn = QuantAttentionFused( 13 | hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, 14 | dev=dev, max_seq_len=max_seq_len, use_alibi=True 15 | ).to(dev) 16 | self.norm_2 = norm_2 17 | self.ffn = mpt_mlp.to(dev) 18 | 19 | def forward( 20 | self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None 21 | ): 22 | norm_out = self.norm_1(hidden_states) 23 | attn_output, _, past_key_value = self.attn.forward( 24 | hidden_states=norm_out, 25 | past_key_value=past_key_value, 26 | attention_mask=attention_mask, 27 | position_ids=None, 28 | output_attentions=False, 29 | use_cache=True 30 | ) 31 | 32 | h = hidden_states + attn_output 33 | out = h + self.ffn.forward(self.norm_2(h)) 34 | return out, None, past_key_value 35 | 36 | class FalconDecoderLayer(nn.Module): 37 | def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_len, 38 | input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True): 39 | super().__init__() 40 | self.n_heads = n_heads 41 | self.n_kv_heads = 8 if new_decoder_arch else 0 42 | self.hidden_size = hidden_size 43 | self.new_decoder_arch = new_decoder_arch 44 | 45 | if new_decoder_arch: 46 | attention_shapes = None 47 | else: 48 | attention_shapes = self._get_attention_shapes(n_heads, max_seq_len, self.hidden_size // n_heads) 49 | 50 | # TODO: Falcon has ALiBi implemented but which model uses it? 51 | self.attn = QuantAttentionFused( 52 | hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, 53 | dev=dev, max_seq_len=max_seq_len, use_alibi=False, 54 | attention_shapes=attention_shapes 55 | ).to(dev) 56 | 57 | if new_decoder_arch: 58 | self.ln_attn = ln_attn # before attention 59 | self.ln_mlp = ln_mlp # before mlp 60 | else: 61 | self.input_layernorm = input_layernorm # before attention 62 | 63 | self.mlp = mlp 64 | 65 | def _get_attention_shapes(self, n_heads, max_seq_len, head_dim): 66 | batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1")) 67 | 68 | self.attention_shapes = { 69 | # following fastertransformer definition 70 | "cache_v": (batch_size, 1, max_seq_len, head_dim,), 71 | # 8: pack 8 fp16 in FT, if fp32 then use 4 72 | "cache_k": (batch_size, 1, head_dim // 8, max_seq_len, 8,), 73 | "xqkv_view": (n_heads+2, head_dim), 74 | "xq_slice": lambda xqkv: xqkv[:, :, :-2], 75 | "xk_slice": lambda xqkv: xqkv[:, :, [-2]], 76 | "xv_slice": lambda xqkv: xqkv[:, :, [-1]], 77 | "xq_view": (n_heads, head_dim), 78 | "xk_view": (1, head_dim), 79 | "xv_view": (1, head_dim), 80 | "xk_reshape": (1, head_dim // 8, 8), 81 | "single_xq_view": (n_heads, head_dim), 82 | "single_xk_view": (1, head_dim), 83 | "single_xv_view": (1, head_dim) 84 | } 85 | 86 | return self.attention_shapes 87 | 88 | def forward( 89 | self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None 90 | ): 91 | if self.new_decoder_arch: 92 | layernorm_out = self.ln_attn(hidden_states) 93 | mlp_layernorm_out = self.ln_mlp(hidden_states) 94 | else: 95 | layernorm_out = self.input_layernorm(hidden_states) 96 | 97 | attn_output, _, past_key_value = self.attn.forward( 98 | hidden_states=layernorm_out, 99 | past_key_value=past_key_value, 100 | attention_mask=attention_mask, 101 | position_ids=None, 102 | output_attentions=False, 103 | use_cache=True 104 | ) 105 | 106 | h_attn = hidden_states + attn_output 107 | 108 | if self.new_decoder_arch: 109 | h_mlp = self.mlp.forward(mlp_layernorm_out) 110 | else: 111 | h_mlp = self.mlp.forward(layernorm_out) 112 | 113 | out = h_attn + h_mlp 114 | 115 | return out, None, past_key_value 116 | -------------------------------------------------------------------------------- /mixquant/modules/fused/cache.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class WindowedCache: 4 | def __init__(self, cache_v_shape, cache_k_shape, device): 5 | """ 6 | The window size is the same as the max_new_tokens. The window will 7 | automatically roll once max_new_tokens is exceeded. 8 | """ 9 | # [batch_size, n_kv_heads, max_seq_len, head_dim] 10 | self.v = torch.zeros(cache_v_shape).to(device).half() 11 | # [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor] 12 | self.k = torch.zeros(cache_k_shape).to(device).half() 13 | 14 | def get_kv(self, batch_size, start_pos, seqlen, head_dim): 15 | xv = self.v[:batch_size, :, : start_pos + seqlen, :].transpose(1, 2).contiguous() 16 | xk = self.k[:batch_size, :, :, : start_pos + seqlen, :].transpose(2, 3).contiguous() 17 | xk = xk.reshape(xk.shape[:-2] + (head_dim,)).transpose(1, 2).contiguous() 18 | 19 | return xv, xk 20 | 21 | def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen): 22 | self.v[:batch_size, :, start_pos : start_pos + seqlen, :] = values_store 23 | self.k[:batch_size, :, :, start_pos : start_pos + seqlen, :] = keys_store 24 | 25 | def roll_kv(self, roll_len, start_pos): 26 | # Roll only the necessary part of the cache to the left 27 | self.v[:, :, :-roll_len, :] = self.v[:, :, roll_len:, :] 28 | self.k[:, :, :, :-roll_len, :] = self.k[:, :, :, roll_len:, :] 29 | 30 | # Zero out the new part 31 | self.v[:, :, -roll_len:, :] = 0 32 | self.k[:, :, :, -roll_len:, :] = 0 33 | 34 | return start_pos - roll_len 35 | 36 | def to(self, device): 37 | self.k = self.k.to(device) 38 | self.v = self.v.to(device) 39 | -------------------------------------------------------------------------------- /mixquant/modules/fused/gptj_attn.py: -------------------------------------------------------------------------------- 1 | from transformers.utils.import_utils import ( 2 | is_torch_fx_proxy, 3 | ) 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | from typing import Optional, Tuple, Union 9 | 10 | 11 | 12 | 13 | def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor: 14 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) 15 | sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float() 16 | return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1) 17 | @torch.fx.wrap 18 | def get_embed_positions(embed_positions, position_ids): 19 | return embed_positions.to(position_ids.device).repeat(position_ids.shape[0], 1, 1) 20 | 21 | 22 | def GPTJrotate_every_two(x: torch.Tensor) -> torch.Tensor: 23 | x1 = x[:, :, :, ::2] 24 | x2 = x[:, :, :, 1::2] 25 | x = torch.stack((-x2, x1), dim=-1) 26 | return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') 27 | 28 | 29 | def GPTJapply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor: 30 | sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3) 31 | cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3) 32 | return (tensor * cos) + (GPTJrotate_every_two(tensor) * sin) 33 | 34 | class QuantGPTJAttentionFused(nn.Module): 35 | 36 | 37 | 38 | 39 | def __init__(self, config,module,qkv_layer,out_proj,MixGemmCache): 40 | super().__init__() 41 | 42 | max_positions = config.max_position_embeddings 43 | self.register_buffer( 44 | "bias", 45 | torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( 46 | 1, 1, max_positions, max_positions 47 | ), 48 | persistent=False, 49 | ) 50 | self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False) 51 | 52 | self.attn_dropout = nn.Dropout(config.attn_pdrop) 53 | self.resid_dropout = nn.Dropout(config.resid_pdrop) 54 | 55 | self.embed_dim = config.hidden_size 56 | self.hidden_size = config.hidden_size 57 | self.num_attention_heads = config.num_attention_heads 58 | self.head_dim = self.embed_dim // self.num_attention_heads 59 | if self.head_dim * self.num_attention_heads != self.embed_dim: 60 | raise ValueError( 61 | f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and" 62 | f" `num_attention_heads`: {self.num_attention_heads})." 63 | ) 64 | self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) 65 | 66 | self.qkv_proj = qkv_layer 67 | self.out_proj = out_proj 68 | self.MixGemmCache = MixGemmCache 69 | self.rotary_dim = config.rotary_dim 70 | pos_embd_dim = self.rotary_dim or self.embed_dim 71 | self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim) 72 | 73 | def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary): 74 | """ 75 | Splits hidden dim into attn_head_size and num_attention_heads 76 | """ 77 | new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) 78 | tensor = tensor.view(new_shape) 79 | if rotary: 80 | return tensor 81 | if len(tensor.shape) == 5: 82 | return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features) 83 | elif len(tensor.shape) == 4: 84 | return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) 85 | else: 86 | raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") 87 | 88 | def _merge_heads(self, tensor, num_attention_heads, attn_head_size): 89 | """ 90 | Merges attn_head_size dim and num_attn_heads dim into hidden dim 91 | """ 92 | if len(tensor.shape) == 5: 93 | tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() 94 | elif len(tensor.shape) == 4: 95 | tensor = tensor.permute(0, 2, 1, 3).contiguous() 96 | else: 97 | raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") 98 | new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,) 99 | return tensor.view(new_shape) 100 | 101 | def _attn( 102 | self, 103 | query, 104 | key, 105 | value, 106 | attention_mask=None, 107 | head_mask=None, 108 | ): 109 | # compute causal mask from causal mask buffer 110 | self.bias = self.bias.to(query.device) 111 | query_length, key_length = query.size(-2), key.size(-2) 112 | causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] 113 | 114 | # Keep the attention weights computation in fp32 to avoid overflow issues 115 | query = query.to(torch.float32) 116 | key = key.to(torch.float32) 117 | 118 | attn_weights = torch.matmul(query, key.transpose(-1, -2)) 119 | 120 | mask_value = torch.finfo(attn_weights.dtype).min 121 | # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. 122 | # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` 123 | mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) 124 | attn_weights = torch.where(causal_mask, attn_weights, mask_value) 125 | 126 | attn_weights = attn_weights / self.scale_attn 127 | 128 | if attention_mask is not None: 129 | # Apply the attention mask 130 | attn_weights = attn_weights + attention_mask 131 | 132 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 133 | attn_weights = attn_weights.to(value.dtype) 134 | 135 | 136 | attn_weights = self.attn_dropout(attn_weights) 137 | 138 | # Mask heads if we want to 139 | if head_mask is not None: 140 | attn_weights = attn_weights * head_mask 141 | 142 | attn_output = torch.matmul(attn_weights, value) 143 | 144 | return attn_output, attn_weights 145 | 146 | def _get_embed_positions(self, position_ids): 147 | embed_positions = self.embed_positions 148 | if embed_positions.device != position_ids.device: 149 | embed_positions = embed_positions.to(position_ids.device) 150 | self.embed_positions = embed_positions 151 | return embed_positions.repeat(position_ids.shape[0], 1, 1) 152 | 153 | def forward( 154 | self, 155 | hidden_states: torch.FloatTensor, 156 | layer_past: Optional[Tuple[torch.Tensor]] = None, 157 | attention_mask: Optional[torch.FloatTensor] = None, 158 | position_ids: Optional[torch.LongTensor] = None, 159 | head_mask: Optional[torch.FloatTensor] = None, 160 | use_cache: Optional[bool] = False, 161 | output_attentions: Optional[bool] = False, 162 | ) -> Union[ 163 | Tuple[torch.Tensor, Tuple[torch.Tensor]], 164 | Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], 165 | ]: 166 | 167 | 168 | 169 | proj = self.qkv_proj(hidden_states, self.MixGemmCache) 170 | 171 | 172 | proj = ( 173 | proj.unflatten(-1, (3, self.hidden_size)) 174 | .unsqueeze(0) 175 | .transpose(0, -2) 176 | .squeeze(-2) 177 | ) 178 | query = proj[0] 179 | key = proj[1] 180 | value = proj[2] 181 | 182 | 183 | 184 | query = self._split_heads(query, self.num_attention_heads, self.head_dim, True) 185 | key = self._split_heads(key, self.num_attention_heads, self.head_dim, True) 186 | value = self._split_heads(value, self.num_attention_heads, self.head_dim, False) 187 | 188 | if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing(): 189 | # The logic to conditionally copy to GPU could not be traced, so we do this 190 | # every time in the torch.fx case 191 | embed_positions = get_embed_positions(self.embed_positions, position_ids) 192 | else: 193 | embed_positions = self._get_embed_positions(position_ids) 194 | 195 | repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) 196 | sincos = torch.gather(embed_positions, 1, repeated_position_ids) 197 | sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) 198 | 199 | if self.rotary_dim is not None: 200 | k_rot = key[:, :, :, : self.rotary_dim] 201 | k_pass = key[:, :, :, self.rotary_dim :] 202 | 203 | q_rot = query[:, :, :, : self.rotary_dim] 204 | q_pass = query[:, :, :, self.rotary_dim :] 205 | 206 | k_rot = GPTJapply_rotary_pos_emb(k_rot, sin, cos) 207 | q_rot = GPTJapply_rotary_pos_emb(q_rot, sin, cos) 208 | 209 | key = torch.cat([k_rot, k_pass], dim=-1) 210 | query = torch.cat([q_rot, q_pass], dim=-1) 211 | else: 212 | key = GPTJapply_rotary_pos_emb(key, sin, cos) 213 | query = GPTJapply_rotary_pos_emb(query, sin, cos) 214 | 215 | key = key.permute(0, 2, 1, 3) 216 | query = query.permute(0, 2, 1, 3) 217 | 218 | if layer_past is not None: 219 | past_key = layer_past[0] 220 | past_value = layer_past[1] 221 | key = torch.cat((past_key, key), dim=-2) 222 | value = torch.cat((past_value, value), dim=-2) 223 | 224 | if use_cache is True: 225 | present = (key, value) 226 | else: 227 | present = None 228 | 229 | # compute self-attention: V x Softmax(QK^T) 230 | attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) 231 | 232 | attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) 233 | attn_output = self.out_proj(attn_output) 234 | attn_output = self.resid_dropout(attn_output) 235 | 236 | outputs = (attn_output, present) 237 | if output_attentions: 238 | outputs += (attn_weights,) 239 | 240 | return outputs # a, present, (attentions) -------------------------------------------------------------------------------- /mixquant/modules/fused/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | import torch 4 | from mixquant.Cache import MixLibCache, MLPCache 5 | import mixlib 6 | 7 | 8 | class MixFalconMLP(nn.Module): 9 | 10 | def __init__( 11 | self, 12 | dense_h_to_4h, 13 | dense_4h_to_h, 14 | MixGemmCache = None 15 | ): 16 | super().__init__() 17 | 18 | 19 | 20 | self.dense_h_to_4h = dense_h_to_4h 21 | self.dense_4h_to_h = dense_4h_to_h 22 | self.act = nn.GELU() 23 | self.MixGemmCache = MixGemmCache 24 | 25 | 26 | def forward(self, x): 27 | 28 | x = self.act(self.dense_h_to_4h(x, self.MixGemmCache)) 29 | 30 | 31 | x = self.dense_4h_to_h(x, self.MixGemmCache,True) 32 | 33 | return x 34 | 35 | 36 | import time 37 | class MixLlamaMLP(nn.Module): 38 | 39 | def __init__( 40 | self, 41 | gate_proj, 42 | down_proj, 43 | up_proj, 44 | MixGemmCache = None 45 | ): 46 | super().__init__() 47 | 48 | 49 | 50 | self.down_proj_ = down_proj 51 | self.gate_proj_ = gate_proj 52 | self.up_proj_ = up_proj 53 | self.out_features = down_proj.out_features 54 | self.MLPCache = MixGemmCache 55 | 56 | 57 | def forward(self, x): 58 | 59 | 60 | 61 | up_output = self.up_proj_(x, self.MLPCache) 62 | gate_output = self.gate_proj_.forward_without_preconditionFusedSilu(x, self.MLPCache) 63 | 64 | gate_output *= up_output 65 | 66 | 67 | y = self.down_proj_(gate_output,None,True) 68 | 69 | 70 | return y 71 | 72 | 73 | from transformers.activations import ACT2FN 74 | class MixGPTJMLP(nn.Module): 75 | def __init__(self, module, config, MixGemmCache = None): # in MLP: intermediate_size= 4 * embed_dim 76 | super().__init__() 77 | 78 | 79 | self.fc_in = module.fc_in 80 | self.fc_out = module.fc_out 81 | 82 | self.act = ACT2FN[config.activation_function] 83 | self.dropout = nn.Dropout(config.resid_pdrop) 84 | 85 | 86 | self.MLPCache = MLPCache() 87 | 88 | def forward(self, hidden_states) -> torch.FloatTensor: 89 | 90 | hidden_states = self.fc_in(hidden_states, self.MLPCache) 91 | hidden_states = self.act(hidden_states) 92 | hidden_states = self.fc_out(hidden_states, self.MLPCache) 93 | hidden_states = self.dropout(hidden_states) 94 | return hidden_states -------------------------------------------------------------------------------- /mixquant/modules/fused/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import List 4 | from mixquant.modules.fused.block import MPTBlock, FalconDecoderLayer 5 | from transformers.modeling_outputs import BaseModelOutputWithPast 6 | 7 | class MPTModel(nn.Module): 8 | def __init__(self, vocab_size, blocks, wte, norm_f): 9 | super().__init__() 10 | self.vocab_size = vocab_size 11 | self.wte = wte 12 | self.blocks: List[MPTBlock] = nn.ModuleList(blocks) 13 | self.norm_f = norm_f 14 | self.attn_uses_sequence_id = False 15 | self.prefix_lm = False 16 | 17 | @torch.inference_mode() 18 | def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs): 19 | _bsz, seqlen = input_ids.shape 20 | h = self.wte(input_ids) 21 | 22 | mask = None 23 | if seqlen > 1: 24 | mask = torch.full( 25 | (1, 1, seqlen, seqlen), float("-inf"), device=input_ids.device 26 | ) 27 | mask = torch.triu(mask, diagonal=self.blocks[0].attn.start_pos + 1).type_as(h) 28 | 29 | for layer in self.blocks: 30 | h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal) 31 | h = self.norm_f(h) 32 | 33 | return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=()) 34 | 35 | class FalconModel(nn.Module): 36 | def __init__(self, vocab_size, blocks, word_embeddings, ln_f): 37 | super().__init__() 38 | self.vocab_size = vocab_size 39 | self.word_embeddings = word_embeddings 40 | self.blocks: List[FalconDecoderLayer] = nn.ModuleList(blocks) 41 | self.ln_f = ln_f 42 | self.attn_uses_sequence_id = False 43 | self.prefix_lm = False 44 | 45 | @torch.inference_mode() 46 | def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs): 47 | # NOTE: falcon input ids contain full context 48 | # after context is processed, slice to latest token 49 | if self.blocks[0].attn.start_pos != 0 and input_ids.shape[-1] != 1: 50 | input_ids = input_ids[:, self.blocks[0].attn.start_pos:] 51 | 52 | _bsz, seqlen = input_ids.shape 53 | h = self.word_embeddings(input_ids) 54 | 55 | mask = None 56 | if seqlen > 1: 57 | mask = torch.full( 58 | (1, 1, seqlen, seqlen), float("-inf"), device=input_ids.device 59 | ) 60 | mask = torch.triu(mask, diagonal=self.blocks[0].attn.start_pos + 1).type_as(h) 61 | 62 | for layer in self.blocks: 63 | h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal) 64 | h = self.ln_f(h) 65 | 66 | return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=()) 67 | -------------------------------------------------------------------------------- /mixquant/modules/fused/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import mixlib 4 | 5 | 6 | class FasterTransformerRMSNorm(nn.Module): 7 | def __init__(self, weight, eps=1e-6, cache = None): 8 | super().__init__() 9 | self.weight = weight.cuda().to(torch.float16) 10 | self.variance_epsilon = eps 11 | self.cache = cache 12 | self.next_layer = None 13 | 14 | @torch.no_grad() 15 | def forward(self, x): 16 | 17 | 18 | output = torch.empty_like(x) 19 | 20 | if self.next_layer is None: 21 | mixlib.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon) 22 | 23 | else: 24 | if self.next_layer.bit == 8: 25 | self.cache.activation_outliers, self.cache.q_xcache = mixlib.layernorm_forward_cuda_extract_outliers(x, 26 | self.weight, 27 | output, self.variance_epsilon, 28 | self.next_layer.ind, self.cache.x_scale) 29 | elif self.next_layer.bit == 4: 30 | self.cache.activation_outliers, self.cache.q_xcache = mixlib.layernorm_forward_cuda_extract_outliers_int4(x, 31 | self.weight, 32 | output, self.variance_epsilon, 33 | self.next_layer.ind, self.cache.x_scale) 34 | 35 | 36 | else: 37 | raise NotImplementedError 38 | 39 | return output 40 | -------------------------------------------------------------------------------- /mixquant/quantize/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/mixquant/quantize/__init__.py -------------------------------------------------------------------------------- /mixquant/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/mixquant/utils/__init__.py -------------------------------------------------------------------------------- /mixquant/utils/calib_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | from typing import List, Union 4 | from datasets import load_dataset 5 | 6 | def get_calib_dataset(data: Union[str, List[str]] = "pileval", 7 | tokenizer=None, n_samples=512, block_size=512, 8 | split="train", text_column="text"): 9 | if isinstance(data, str): 10 | if data == "pileval": 11 | dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") 12 | else: 13 | dataset = load_dataset(data, split=split) 14 | text_column = 'question' 15 | 16 | dataset = dataset.shuffle(seed=42) 17 | 18 | elif isinstance(data, list): 19 | dataset = [{text_column: text} for text in data] 20 | else: 21 | raise NotImplementedError( 22 | "Either pass a string to a huggingface dataset or a list" 23 | "that is preprocessed with one sample of text per element.") 24 | 25 | samples = [] 26 | n_run = 0 27 | 28 | for data in dataset: 29 | line = data[text_column] 30 | line = line.strip() 31 | line_encoded = tokenizer.encode(line) 32 | if len(line_encoded) > 512: 33 | continue 34 | sample = torch.tensor([line_encoded]) 35 | if sample.numel() == 0: 36 | continue 37 | samples.append(sample) 38 | n_run += 1 39 | if n_run == n_samples: 40 | break 41 | # now concatenate all samples and split according to block size 42 | cat_samples = torch.cat(samples, dim=1) 43 | n_split = cat_samples.shape[1] // block_size 44 | logging.debug(f" * Split into {n_split} blocks") 45 | return [cat_samples[:, i*block_size:(i+1)*block_size] for i in range(n_split)] 46 | -------------------------------------------------------------------------------- /mixquant/utils/fused_utils.py: -------------------------------------------------------------------------------- 1 | 2 | def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim): 3 | if attention_shapes is not None: 4 | attention_shapes = attention_shapes 5 | 6 | elif n_kv_heads == 0: 7 | attention_shapes = { 8 | # following fastertransformer definition 9 | "cache_v": (cache_batch_size, n_heads, max_seq_len, head_dim,), 10 | # 8: pack 8 fp16 in FT, if fp32 then use 4 11 | "cache_k": (cache_batch_size, n_heads, head_dim // 8, max_seq_len, 8,), 12 | "xqkv_view": (-1, n_heads, head_dim), 13 | "xq_slice": lambda xqkv: xqkv[:, :, 0], 14 | "xk_slice": lambda xqkv: xqkv[:, :, 1], 15 | "xv_slice": lambda xqkv: xqkv[:, :, 2], 16 | "xq_view": (n_heads, head_dim), 17 | "xk_view": (n_heads, head_dim), 18 | "xv_view": (n_heads, head_dim), 19 | "xk_reshape": (n_heads, head_dim // 8, 8), 20 | "single_xq_view": (n_heads, head_dim), 21 | "single_xk_view": (n_heads, head_dim), 22 | "single_xv_view": (n_heads, head_dim) 23 | } 24 | 25 | else: 26 | attention_shapes = { 27 | # following fastertransformer definition 28 | "cache_v": (cache_batch_size, n_kv_heads, max_seq_len, head_dim,), 29 | # 8: pack 8 fp16 in FT, if fp32 then use 4 30 | "cache_k": (cache_batch_size, n_kv_heads, head_dim // 8, max_seq_len, 8,), 31 | "xqkv_view": (n_heads + n_kv_heads * 2, head_dim), 32 | "xq_slice": lambda xqkv: xqkv[:, :, 0 : n_heads], 33 | "xk_slice": lambda xqkv: xqkv[:, :, n_heads : (n_heads + n_kv_heads)], 34 | "xv_slice": lambda xqkv: xqkv[:, :, -n_kv_heads :], 35 | "xq_view": (n_heads, head_dim), 36 | "xk_view": (n_kv_heads, head_dim), 37 | "xv_view": (n_kv_heads, head_dim), 38 | "xk_reshape": (n_kv_heads, head_dim // 8, 8), 39 | "single_xq_view": (n_heads, head_dim), 40 | "single_xk_view": (n_kv_heads, head_dim), 41 | "single_xv_view": (n_kv_heads, head_dim) 42 | } 43 | 44 | return attention_shapes -------------------------------------------------------------------------------- /mixquant/utils/lm_eval_adaptor.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | from lm_eval.base import BaseLM 4 | import fnmatch 5 | import logging 6 | 7 | class LMEvalAdaptor(BaseLM): 8 | 9 | def __init__(self, model_name, model, tokenizer, device, batch_size=1, max_length=-1): 10 | super().__init__() 11 | 12 | assert isinstance(batch_size, int) 13 | 14 | self.model_name = model_name 15 | self.model = model.to(device) 16 | self.model.eval() 17 | 18 | self.tokenizer = tokenizer 19 | 20 | # assert isinstance(self.tokenizer, ( 21 | # transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast, 22 | # transformers.T5Tokenizer, transformers.T5TokenizerFast, 23 | # )), "this tokenizer has not been checked for compatibility yet!" 24 | 25 | self.vocab_size = self.tokenizer.vocab_size 26 | 27 | self._batch_size = batch_size 28 | 29 | self._max_length = max_length 30 | 31 | @property 32 | def eot_token_id(self): 33 | # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* 34 | return self.tokenizer.eos_token_id 35 | 36 | @property 37 | def max_length(self): 38 | if self._max_length != -1: 39 | return self._max_length 40 | if hasattr(self.model.config, 'n_ctx'): 41 | return self.model.config.n_ctx 42 | elif hasattr(self.model.config, 'max_position_embeddings'): 43 | return self.model.config.max_position_embeddings 44 | elif hasattr(self.model.config, 'n_positions'): 45 | return self.model.config.n_positions 46 | elif 'bloom' in self.model_name: 47 | return 2048 48 | elif 'llama' in self.model_name: 49 | return 2048 # TODO: did not check this 50 | elif 'mpt' in self.model_name: 51 | return 2048 52 | elif 'falcon' in self.model_name: 53 | return 2048 54 | else: 55 | logging.debug(self.model.config) 56 | raise NotImplementedError 57 | 58 | @property 59 | def max_gen_toks(self): 60 | return 256 61 | 62 | @property 63 | def batch_size(self): 64 | return self._batch_size 65 | 66 | @property 67 | def device(self): 68 | return "cuda" 69 | 70 | def tok_encode(self, string: str): 71 | return self.tokenizer.encode(string, add_special_tokens=False) 72 | 73 | def tok_decode(self, tokens): 74 | return self.tokenizer.decode(tokens) 75 | 76 | def _model_call(self, inps): 77 | """ 78 | inps: a torch tensor of shape [batch, sequence] 79 | the size of sequence may vary from call to call 80 | 81 | returns: a torch tensor of shape [batch, sequence, vocab] with the 82 | logits returned from the model 83 | """ 84 | with torch.no_grad(): 85 | if isinstance(self.model, transformers.models.t5.modeling_t5.T5ForConditionalGeneration): 86 | dec_inps = torch.cat( 87 | [ 88 | torch.tensor( 89 | self.model.generation_config.decoder_start_token_id, 90 | ) 91 | .tile(len(inps), 1) 92 | .to(inps), 93 | inps, 94 | ], 95 | dim=1, 96 | ) 97 | 98 | kwargs = {"decoder_input_ids": dec_inps,} 99 | else: 100 | kwargs = {} 101 | out = self.model(inps, **kwargs)[0] 102 | if "opt" in self.model_name: # there are a few extra tokens in opt, which we should omit 103 | return out[:, :, :50257] 104 | else: 105 | return out # [:, :, :self.tokenizer.vocab_size] 106 | 107 | def _model_generate(self, context, max_length, eos_token_id): 108 | return self.model.generate( 109 | context, 110 | max_length=max_length, 111 | eos_token_id=eos_token_id, 112 | do_sample=False 113 | ) 114 | 115 | -------------------------------------------------------------------------------- /mixquant/utils/module down_weight_only.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | eightbit_only_name = ["down_proj", "fc_out"] 3 | 4 | weight_only_map = { 5 | "GPTJForCausalLM": ["fc_out"], 6 | "LlamaForCausalLM": ["down_proj"], 7 | "AquilaForCausalLM": ["down_proj"], 8 | "BaichuanForCausalLM": ["down_proj"], 9 | "MistralForCausalLM": ["down_proj"], 10 | "FalconForCausalLM" : [], 11 | 12 | } 13 | 14 | def get_named_linears(module): 15 | return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)} 16 | 17 | def get_op_by_name(module, op_name): 18 | # get the op by its name relative to the module 19 | for name, m in module.named_modules(): 20 | if name == op_name: 21 | return m 22 | raise ValueError(f"Cannot find op {op_name} in module {module}") 23 | 24 | 25 | def set_op_by_name(layer, name, new_module): 26 | levels = name.split('.') 27 | if len(levels) > 1: 28 | mod_ = layer 29 | for l_idx in range(len(levels)-1): 30 | if levels[l_idx].isdigit(): 31 | mod_ = mod_[int(levels[l_idx])] 32 | else: 33 | mod_ = getattr(mod_, levels[l_idx]) 34 | setattr(mod_, levels[-1], new_module) 35 | else: 36 | setattr(layer, name, new_module) 37 | 38 | 39 | def get_op_name(module, op): 40 | # get the name of the op relative to the module 41 | for name, m in module.named_modules(): 42 | if m is op: 43 | return name 44 | raise ValueError(f"Cannot find op {op} in module {module}") 45 | 46 | 47 | def append_str_prefix(x, prefix): 48 | if isinstance(x, str): 49 | return prefix + x 50 | elif isinstance(x, tuple): 51 | return tuple([append_str_prefix(y, prefix) for y in x]) 52 | elif isinstance(x, list): 53 | return [append_str_prefix(y, prefix) for y in x] 54 | else: 55 | return x -------------------------------------------------------------------------------- /mixquant/utils/module.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | eightbit_only_name = ["down_proj", "o_proj", "fc_out"] 3 | 4 | weight_only_map = { 5 | "GPTJForCausalLM": ["fc_out"], 6 | "LlamaForCausalLM": [], 7 | "AquilaForCausalLM": [], 8 | "BaichuanForCausalLM": [], 9 | "MistralForCausalLM": [], 10 | "FalconForCausalLM" : [], 11 | 12 | } 13 | 14 | def get_named_linears(module): 15 | return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)} 16 | 17 | def get_op_by_name(module, op_name): 18 | # get the op by its name relative to the module 19 | for name, m in module.named_modules(): 20 | if name == op_name: 21 | return m 22 | raise ValueError(f"Cannot find op {op_name} in module {module}") 23 | 24 | 25 | def set_op_by_name(layer, name, new_module): 26 | levels = name.split('.') 27 | if len(levels) > 1: 28 | mod_ = layer 29 | for l_idx in range(len(levels)-1): 30 | if levels[l_idx].isdigit(): 31 | mod_ = mod_[int(levels[l_idx])] 32 | else: 33 | mod_ = getattr(mod_, levels[l_idx]) 34 | setattr(mod_, levels[-1], new_module) 35 | else: 36 | setattr(layer, name, new_module) 37 | 38 | 39 | def get_op_name(module, op): 40 | # get the name of the op relative to the module 41 | for name, m in module.named_modules(): 42 | if m is op: 43 | return name 44 | raise ValueError(f"Cannot find op {op} in module {module}") 45 | 46 | 47 | def append_str_prefix(x, prefix): 48 | if isinstance(x, str): 49 | return prefix + x 50 | elif isinstance(x, tuple): 51 | return tuple([append_str_prefix(y, prefix) for y in x]) 52 | elif isinstance(x, list): 53 | return [append_str_prefix(y, prefix) for y in x] 54 | else: 55 | return x -------------------------------------------------------------------------------- /mixquant/utils/module_.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | eightbit_only_name = ["down_proj", "fc_out"] 3 | 4 | weight_only_map = { 5 | "GPTJForCausalLM": ["fc_out"], 6 | "LlamaForCausalLM": [], 7 | "AquilaForCausalLM": ["down_proj"], 8 | "BaichuanForCausalLM": ["down_proj"], 9 | "MistralForCausalLM": ["down_proj"], 10 | "FalconForCausalLM" : [], 11 | 12 | } 13 | 14 | def get_named_linears(module): 15 | return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)} 16 | 17 | def get_op_by_name(module, op_name): 18 | # get the op by its name relative to the module 19 | for name, m in module.named_modules(): 20 | if name == op_name: 21 | return m 22 | raise ValueError(f"Cannot find op {op_name} in module {module}") 23 | 24 | 25 | def set_op_by_name(layer, name, new_module): 26 | levels = name.split('.') 27 | if len(levels) > 1: 28 | mod_ = layer 29 | for l_idx in range(len(levels)-1): 30 | if levels[l_idx].isdigit(): 31 | mod_ = mod_[int(levels[l_idx])] 32 | else: 33 | mod_ = getattr(mod_, levels[l_idx]) 34 | setattr(mod_, levels[-1], new_module) 35 | else: 36 | setattr(layer, name, new_module) 37 | 38 | 39 | def get_op_name(module, op): 40 | # get the name of the op relative to the module 41 | for name, m in module.named_modules(): 42 | if m is op: 43 | return name 44 | raise ValueError(f"Cannot find op {op} in module {module}") 45 | 46 | 47 | def append_str_prefix(x, prefix): 48 | if isinstance(x, str): 49 | return prefix + x 50 | elif isinstance(x, tuple): 51 | return tuple([append_str_prefix(y, prefix) for y in x]) 52 | elif isinstance(x, list): 53 | return [append_str_prefix(y, prefix) for y in x] 54 | else: 55 | return x -------------------------------------------------------------------------------- /mixquant/utils/parallel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import gc 4 | import logging 5 | 6 | 7 | def auto_parallel(args): 8 | model_size = args.model_path.split("-")[-1] 9 | if model_size.endswith("m"): 10 | model_gb = 1 11 | else: 12 | model_gb = float(model_size[:-1]) 13 | if model_gb < 20: 14 | n_gpu = 1 15 | elif model_gb < 50: 16 | n_gpu = 4 17 | else: 18 | n_gpu = 8 19 | args.parallel = n_gpu > 1 20 | cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) 21 | if isinstance(cuda_visible_devices, str): 22 | cuda_visible_devices = cuda_visible_devices.split(",") 23 | else: 24 | cuda_visible_devices = list(range(8)) 25 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( 26 | [str(dev) for dev in cuda_visible_devices[:n_gpu]]) 27 | logging.debug("CUDA_VISIBLE_DEVICES: ", os.environ["CUDA_VISIBLE_DEVICES"]) 28 | return cuda_visible_devices 29 | -------------------------------------------------------------------------------- /mixquant/utils/utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import torch 3 | import accelerate 4 | 5 | 6 | def get_module_by_name_suffix(model, module_name: str): 7 | for name, module in model.named_modules(): 8 | if name.endswith(module_name): 9 | return module 10 | 11 | def simple_dispatch_model(model, device_map): 12 | from accelerate.hooks import add_hook_to_module, AlignDevicesHook 13 | 14 | if "" in device_map: 15 | d = device_map[""] 16 | model = model.to(torch.device(d)) 17 | model.hf_device_map = device_map 18 | return model 19 | 20 | tied_params = accelerate.utils.modeling.find_tied_parameters(model) 21 | if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}: 22 | main_device = "cpu" 23 | else: 24 | main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0] 25 | 26 | cpu_offload_group = [(n, d) for n, d in device_map.items() if d == "cpu"] 27 | prev_hook = None 28 | for idx, (n, d) in enumerate(cpu_offload_group): 29 | m = get_module_by_name_suffix(model, n) 30 | _, prev_hook = accelerate.cpu_offload_with_hook(m, execution_device=main_device, prev_module_hook=prev_hook) 31 | # set first cpu offload module's prev_module_hook to the last cpu offload module's hook 32 | if len(cpu_offload_group) > 1: 33 | get_module_by_name_suffix(model, cpu_offload_group[0][0])._hf_hook.prev_module_hook = prev_hook 34 | 35 | for n, d in device_map.items(): 36 | m = get_module_by_name_suffix(model, n) 37 | if d != "cpu": 38 | d = torch.device(d) 39 | hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True) 40 | add_hook_to_module(m, hook) 41 | accelerate.utils.modeling.retie_parameters(model, tied_params) 42 | model.hf_device_map = device_map 43 | 44 | return model 45 | 46 | def set_module_name(model, name, value): 47 | if '.' in name: 48 | parent_name = name.rsplit('.', 1)[0] 49 | child_name = name[len(parent_name) + 1:] 50 | parent = model.get_submodule(parent_name) 51 | else: 52 | parent_name = '' 53 | parent = model 54 | child_name = name 55 | 56 | setattr(parent, child_name, value) 57 | 58 | def clear_memory(weight=None): 59 | if weight is not None: 60 | del weight 61 | gc.collect() 62 | torch.cuda.empty_cache() 63 | 64 | def compute_memory_used_pct(device): 65 | memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3) 66 | memory_pct = memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)) * 100 67 | return memory_pct -------------------------------------------------------------------------------- /mmlu.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | CMD="srun -N 1 --gres=gpu:4090:1 --pty python " 4 | set -ex 5 | basepath=/home/chenyidong/data/mixqdata 6 | _dataset_path=/code/checkpoint/dataset 7 | 8 | 9 | data_type=$1 10 | 11 | 12 | 13 | models=( "falcon-7b" "vicuna-7b" "chatglm2-6b" ) 14 | ngpu=1 15 | if [ ${data_type} == mix8 ] 16 | then 17 | for model in "${models[@]}" 18 | do 19 | 20 | echo ${model} 21 | bit=${data_type:3:3} 22 | CUDA_VISIBLE_DEVICES=0 ${CMD} examples/mmlu.py --model_type ${data_type} \ 23 | --hf_model_dir ${basepath}/quant8/${model} \ 24 | --data_dir ${basepath}/data/data 25 | 26 | done 27 | fi 28 | 29 | 30 | if [ ${data_type} == fp16 ] || [ ${data_type} == bitsandbytes ] 31 | then 32 | for model in "${models[@]}" 33 | do 34 | echo ${model} 35 | export TRANSFORMERS_VERBOSITY=error 36 | CUDA_VISIBLE_DEVICES=0 ${CMD} examples/mmlu.py --model_type ${data_type} --hf_model_dir ${basepath}/${model} \ 37 | --hf_model_dir ${basepath}/${model} \ 38 | --data_dir ${basepath}/data/data 39 | done 40 | fi 41 | 42 | 43 | if [ ${data_type} == awq ] 44 | then 45 | 46 | for model in "${models[@]}" 47 | do 48 | echo ${model} 49 | 50 | CUDA_VISIBLE_DEVICES=0 ${CMD} examples/mmlu.py --model_type ${data_type} \ 51 | --hf_model_dir ${basepath}/${model}-AWQ --data_dir ${basepath}/data/data 52 | 53 | done 54 | 55 | fi 56 | 57 | 58 | -------------------------------------------------------------------------------- /quant.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=$PYTHONPATH:/home/chenyidong/MIXQ 2 | 3 | if [ $2 == a100 ] 4 | then 5 | CMD=" srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python " 6 | #CMD=" python " 7 | fi 8 | if [ $2 == direct ] 9 | then 10 | CMD=" python " 11 | #CMD=" python " 12 | fi 13 | 14 | if [ $2 == h100 ] 15 | then 16 | CMD="srun -p twills -A h100 --gres=gpu:h100:1 --export=ALL python" 17 | fi 18 | if [ $2 == 4090 ] 19 | then 20 | CMD=" srun -N 1 --gres=gpu:4090:1 --pty python" 21 | fi 22 | #CMD=" python" 23 | 24 | set -x 25 | 26 | # model=65 27 | # CUDA_VISIBLE_DEVICES=$1 http_proxy=127.0.0.1:7890 https_proxy=127.0.0.1:7890 ${CMD} \ 28 | # python examples/basic_quant_mix.py \ 29 | # --model_path /home/dataset/llama-2/checkpoint/Llama-${model}b \ 30 | # --quant_file /home/dataset/llama-2/checkpoint/quant/Llama-${model}b 31 | 32 | 33 | models=( "Baichuan2-7b" "Baichuan2-13b" "Aquila2-7b" "Llama-2-7b" "Mistral-7b" ) 34 | models=( "llama-2-hf" ) 35 | models=( "falcon-40b" ) 36 | models=( "Llama-2-7b" "falcon-7b" "vicuna-7b" "chatglm2-6b" ) 37 | quantpath=/home/chenyidong/data/mixqdata/quant 38 | modelpath=/home/chenyidong/data/mixqdata 39 | 40 | for bit in 8 41 | do 42 | for model in "${models[@]}" 43 | do 44 | echo ${model} 45 | ${CMD} examples/basic_quant_mix.py \ 46 | --model_path ${modelpath}/${model} \ 47 | --quant_file ${quantpath}${bit}/${model} --w_bit ${bit} 48 | done 49 | done 50 | 51 | 52 | # for bit in 4 53 | # do 54 | # for model in "${models[@]}" 55 | # do 56 | # echo ${model} 57 | # ${CMD} \ 58 | # examples/basic_quant_quik.py \ 59 | # --model_path ${modelpath}/${model} \ 60 | # --quant_file ${quantpath}quik${bit}/${model} --w_bit ${bit} 61 | # done 62 | # done 63 | 64 | 65 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==2.14.7 2 | torch==2.4.0 3 | flash_attn==2.5.8 4 | -------------------------------------------------------------------------------- /runalltokens.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchalltokens.py --model_type mix --model_path /home/dataset/quant/quant8/Llama-2-7b --quant_file /home/dataset/quant/quant8/Llama-2-7b --batch_size 512 --bit 8 --dataset_path /home/chenyidong/checkpoint/dataset 3 | -------------------------------------------------------------------------------- /runlatency.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | CMD=" srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python " 4 | export http_proxy=127.0.0.1:7890 5 | export https_proxy=127.0.0.1:7890 6 | set -x 7 | 8 | 9 | bit=4 10 | for batch in 512 11 | #for batch in 1 12 | 13 | do 14 | for seq in 64 15 | do 16 | ##model_type=Aquila2 17 | #model_type=opt 18 | #model_type=Mistral 19 | #model_type=gpt-j 20 | #model_type=falcon 21 | model_type=Llama-2 22 | 23 | 24 | models=( "Llama-2-7b" ) 25 | data_types=( "mix" ) 26 | 27 | for data_type in "${data_types[@]}" 28 | do 29 | for model in "${models[@]}" 30 | do 31 | echo ${model} 32 | CUDA_VISIBLE_DEVICES=$1 http_proxy=127.0.0.1:7890 https_proxy=127.0.0.1:7890 \ 33 | ${CMD} benchlatency.py --model_type ${data_type} --model_path \ 34 | /home/dataset/quant${bit}/${model} \ 35 | --quant_file /home/dataset/quant${bit}/${model} \ 36 | --batch_size ${batch} --bit ${bit} 37 | 38 | done 39 | done 40 | 41 | data_types=( "fp16" , "bitsandbytes" ) 42 | for data_type in "${data_types[@]}" 43 | do 44 | for model in "${models[@]}" 45 | do 46 | echo ${model} 47 | CUDA_VISIBLE_DEVICES=$1 http_proxy=127.0.0.1:7890 https_proxy=127.0.0.1:7890 \ 48 | ${CMD} benchflops.py --model_type ${data_type} --model_path \ 49 | /mnt/octave/data/chenyidong/checkpoint/${model} \ 50 | --quant_file /mnt/octave/data/chenyidong/checkpoint/${model} --batch_size ${batch} 51 | 52 | 53 | done 54 | done 55 | data_types=( "awq" ) 56 | for data_type in "${data_types[@]}" 57 | do 58 | for model in "${models[@]}" 59 | do 60 | echo ${model} 61 | CUDA_VISIBLE_DEVICES=$1 http_proxy=127.0.0.1:7890 https_proxy=127.0.0.1:7890 \ 62 | ${CMD} benchflops.py --model_type ${data_type} --model_path \ 63 | /mnt/octave/data/chenyidong/checkpoint/awqquant/${model} \ 64 | --quant_file /mnt/octave/data/chenyidong/checkpoint/awqquant/${model} --batch_size ${batch} 65 | done 66 | done 67 | 68 | 69 | 70 | done 71 | done 72 | -------------------------------------------------------------------------------- /runppl.sh: -------------------------------------------------------------------------------- 1 | if [ $2 == a100 ] 2 | then 3 | CMD=" srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python " 4 | #CMD=" python " 5 | fi 6 | if [ $2 == direct ] 7 | then 8 | CMD=" python " 9 | #CMD=" python " 10 | fi 11 | 12 | if [ $2 == h100 ] 13 | then 14 | CMD="srun -p twills -A h100 --gres=gpu:h100:1 --export=ALL python" 15 | fi 16 | if [ $2 == 4090 ] 17 | then 18 | CMD=" srun -N 1 --gres=gpu:4090:1 --pty python" 19 | fi 20 | set -x 21 | 22 | quantpath=/home/dataset/quant 23 | modelpath=/home/dataset 24 | dataset_path=/home/dataset/quant/checkpoint/dataset 25 | 26 | model=$3 27 | data_type=$4 28 | down_weight_only=0 29 | 30 | for batch in 512 31 | do 32 | for seq in 64 33 | do 34 | 35 | 36 | if [ ${data_type} == fp16 ] 37 | then 38 | 39 | 40 | echo ${model} 41 | CUDA_VISIBLE_DEVICES=$1 ${CMD} evalppl.py --fp_features_num 128 --model_type ${data_type} --model_path \ 42 | ${modelpath}/${model} \ 43 | --quant_file ${modelpath}/${model} \ 44 | --n_ctx $batch --n_batch $batch --eval_accuracy True --dataset_path ${dataset_path} 45 | 46 | 47 | 48 | fi 49 | 50 | if [ ${data_type} == bitsandbytes ] 51 | then 52 | 53 | 54 | echo ${model} 55 | CUDA_VISIBLE_DEVICES=$1 ${CMD} evalppl.py --fp_features_num 128 --model_type ${data_type} --model_path \ 56 | ${modelpath}/${model} \ 57 | --quant_file ${modelpath}/${model} \ 58 | --n_ctx $batch --n_batch $batch --eval_accuracy True --dataset_path ${dataset_path} 59 | 60 | 61 | 62 | fi 63 | 64 | if [ ${data_type} == awq ] 65 | then 66 | pip install transformers==4.35 67 | 68 | echo ${model} 69 | CUDA_VISIBLE_DEVICES=$1 \ 70 | ${CMD} evalppl.py --model_type ${data_type} --model_path \ 71 | ${quantpath}/awqquant/${model} \ 72 | --quant_file ${quantpath}/awqquant/${model} \ 73 | --n_ctx $batch --n_batch $batch --dataset_path ${dataset_path} --eval_accuracy True 74 | 75 | pip install transformers==4.41.2 76 | 77 | fi 78 | 79 | if [ ${data_type} == mix8 ] 80 | then 81 | bit=8 82 | echo "---------run mix 8--------" 83 | 84 | echo ${model} 85 | if [ ${down_weight_only} == 1 ] 86 | then 87 | rm -r ${quantpath}/quant${bit}/down_weight_only/${model}/model.safetensors 88 | CUDA_VISIBLE_DEVICES=$1 \ 89 | ${CMD} evalppl.py --model_type ${data_type} --model_path \ 90 | ${quantpath}/quant${bit}/down_weight_only/${model} \ 91 | --quant_file ${quantpath}/quant${bit}/down_weight_only/${model} \ 92 | --n_ctx ${batch} --n_batch $batch --dataset_path ${dataset_path} --eval_accuracy True 93 | fi 94 | if [ ${down_weight_only} == 0 ] 95 | then 96 | rm -r ${quantpath}/quant${bit}/${model}/model.safetensors 97 | CUDA_VISIBLE_DEVICES=$1 \ 98 | ${CMD} evalppl.py --model_type ${data_type} --model_path \ 99 | ${quantpath}/quant${bit}/${model} \ 100 | --quant_file ${quantpath}/quant${bit}/${model} \ 101 | --n_ctx ${batch} --n_batch $batch --dataset_path ${dataset_path} --eval_accuracy True 102 | fi 103 | 104 | fi 105 | if [ ${data_type} == mix4 ] 106 | then 107 | bit=4 108 | echo "---------run mix 4--------" 109 | 110 | rm -r ${quantpath}/quant${bit}/down_weight_only/${model}/model.safetensors 111 | echo ${model} 112 | CUDA_VISIBLE_DEVICES=$1 \ 113 | ${CMD} evalppl.py --model_type ${data_type} --model_path \ 114 | ${quantpath}/quant${bit}/down_weight_only/${model} \ 115 | --quant_file ${quantpath}/quant${bit}/down_weight_only/${model} \ 116 | --n_ctx ${batch} --n_batch $batch --dataset_path ${dataset_path} --eval_accuracy True 117 | 118 | 119 | 120 | fi 121 | 122 | if [ ${data_type} == quik ] 123 | then 124 | bit=4 125 | echo "---------run quik 4--------" 126 | 127 | 128 | echo ${model} 129 | CUDA_VISIBLE_DEVICES=$1 \ 130 | ${CMD} evalppl.py --model_type ${data_type} --model_path \ 131 | ${quantpath}/quantquik${bit}/${model} \ 132 | --quant_file ${quantpath}/quantquik${bit}/${model} \ 133 | --n_ctx ${batch} --n_batch $batch --dataset_path ${dataset_path} --eval_accuracy True 134 | 135 | 136 | 137 | fi 138 | done 139 | done -------------------------------------------------------------------------------- /runthroughput.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | if [ $2 == a100 ] 4 | then 5 | CMD=" srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python " 6 | fi 7 | 8 | if [ $2 == h100 ] 9 | then 10 | CMD="srun -p twills -A h100 --gres=gpu:h100:1 --export=ALL python" 11 | fi 12 | 13 | export http_proxy=127.0.0.1:7890 14 | export https_proxy=127.0.0.1:7890 15 | set -x 16 | 17 | quantpath=/home/dataset/quant/quant 18 | modelpath=/home/dataset 19 | 20 | for batch in 32 21 | #for batch in 1 22 | 23 | do 24 | for seq in 1024 25 | do 26 | ##model_type=Aquila2 27 | #model_type=opt 28 | #model_type=Mistral 29 | #model_type=gpt-j 30 | #model_type=falcon 31 | model_type=$3 32 | 33 | 34 | 35 | # data_types=( "mix" ) 36 | # for bit in 8 37 | # do 38 | # for data_type in "${data_types[@]}" 39 | # do 40 | # model=${model_type} 41 | # echo ${model} 42 | # rm -r ${quantpath}${bit}/${model}/model.safetensors 43 | # CUDA_VISIBLE_DEVICES=$1 ${CMD} benchflops.py --model_type ${data_type} --model_path \ 44 | # ${quantpath}${bit}/${model} \ 45 | # --quant_file ${quantpath}${bit}/${model} \ 46 | # --batch_size ${batch} --bit ${bit} --dataset_path /home/chenyidong/checkpoint/dataset 47 | # done 48 | # done 49 | 50 | 51 | # data_types=( "quik" ) 52 | # for bit in 4 53 | # do 54 | # for data_type in "${data_types[@]}" 55 | # do 56 | # model=${model_type} 57 | # echo ${model} 58 | # CUDA_VISIBLE_DEVICES=$1 ${CMD} benchflops.py --model_type ${data_type} --model_path \ 59 | # ${quantpath}quik${bit}/${model} \ 60 | # --quant_file ${quantpath}quik${bit}/${model} \ 61 | # --batch_size ${batch} --bit ${bit} --dataset_path /home/chenyidong/checkpoint/dataset 62 | # done 63 | # done 64 | data_types=( "fp16" "bitsandbytes" ) 65 | for data_type in "${data_types[@]}" 66 | do 67 | model=${model_type} 68 | 69 | echo ${model} 70 | CUDA_VISIBLE_DEVICES=$1 ${CMD} benchflops.py --model_type ${data_type} --model_path \ 71 | ${modelpath}/${model} \ 72 | --quant_file ${modelpath}/${model} --batch_size ${batch} --dataset_path /home/chenyidong/checkpoint/dataset 73 | 74 | done 75 | # data_types=( "awq" ) 76 | # for data_type in "${data_types[@]}" 77 | # do 78 | # for model in "${models[@]}" 79 | # do 80 | # echo ${model} 81 | # CUDA_VISIBLE_DEVICES=$1 ${CMD} benchflops.py --model_type ${data_type} --model_path \ 82 | # ${quantpath}/awqquant/${model} \ 83 | # --quant_file ${quantpath}t/awqquant/${model} --batch_size ${batch} 84 | # done 85 | # done 86 | 87 | 88 | 89 | done 90 | done 91 | -------------------------------------------------------------------------------- /utils/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .perplexity_utils import Perplexity -------------------------------------------------------------------------------- /utils/utils/exllama_utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import torch 3 | 4 | def exllama_set_max_input_length(model, max_input_length: int): 5 | """ 6 | This method does not necessarily require `model` to inherit from BaseGPTQForCausalLM. 7 | 8 | When using the exllama backend with act-order, it is necessary to initialize a buffer that depends on the maximum expected input length. In case the 9 | default used (EXLLAMA_DEFAULT_MAX_INPUT_LENGTH) is too short, this method can be called to extend the buffer size without reloading the whole model. 10 | """ 11 | 12 | # The import is set here to avoid a global import. Arguably this is quite ugly, it would be better to have lazy loading. 13 | from exllama_kernels import prepare_buffers, cleanup_buffers_cuda 14 | 15 | if not model.quantize_config.desc_act: 16 | raise ValueError("The method exllama_set_max_input_length should be called only when using the exllama backend **with act-order**.") 17 | 18 | device_to_buffers_size = {} 19 | for device, buffers in model.device_to_buffers.items(): 20 | device_to_buffers_size[device] = {"max_dq_buffer_size": buffers["max_dq_buffer_size"], "max_inner_outer_dim": buffers["max_inner_outer_dim"]} 21 | 22 | # For an unknown reason calling just `del model.device_to_buffers` raises an AttributeError. 23 | for key in list(model.device_to_buffers.keys()): 24 | del model.device_to_buffers[key] 25 | model.device_to_buffers = None 26 | del model.device_to_buffers 27 | 28 | gc.collect() 29 | torch.cuda.empty_cache() 30 | cleanup_buffers_cuda() 31 | 32 | device_to_buffers = {} 33 | for device, buffers_size in device_to_buffers_size.items(): 34 | # The temp_state buffer is required to reorder X in the act-order case. 35 | # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. 36 | device_to_buffers[device] = { 37 | "temp_state": torch.zeros((max_input_length, buffers_size["max_inner_outer_dim"]), dtype=torch.float16, device=device), 38 | "temp_dq": torch.zeros((1, buffers_size["max_dq_buffer_size"]), dtype=torch.float16, device=device), 39 | "max_dq_buffer_size": buffers_size["max_dq_buffer_size"], 40 | "max_inner_outer_dim": buffers_size["max_inner_outer_dim"], 41 | } 42 | 43 | prepare_buffers(device, device_to_buffers[device]["temp_state"], device_to_buffers[device]["temp_dq"]) 44 | 45 | # Buffers need to be persistent to avoid any bug. 46 | model.device_to_buffers = device_to_buffers 47 | 48 | return model 49 | -------------------------------------------------------------------------------- /utils/utils/import_utils.py: -------------------------------------------------------------------------------- 1 | from packaging.version import parse as parse_version 2 | from logging import getLogger 3 | import torch 4 | 5 | try: 6 | import triton 7 | 8 | TRITON_AVAILABLE = True 9 | except ImportError: 10 | TRITON_AVAILABLE = False 11 | 12 | try: 13 | import autogptq_cuda_256 14 | import autogptq_cuda_64 15 | 16 | AUTOGPTQ_CUDA_AVAILABLE = True 17 | except: 18 | AUTOGPTQ_CUDA_AVAILABLE = False 19 | 20 | 21 | try: 22 | import exllama_kernels 23 | 24 | EXLLAMA_KERNELS_AVAILABLE = True 25 | except: 26 | EXLLAMA_KERNELS_AVAILABLE = False 27 | 28 | try: 29 | import exllamav2_kernels 30 | 31 | EXLLAMAV2_KERNELS_AVAILABLE = True 32 | except: 33 | EXLLAMAV2_KERNELS_AVAILABLE = False 34 | 35 | try: 36 | import cQIGen as qinfer 37 | 38 | QIGEN_AVAILABLE = True 39 | except: 40 | QIGEN_AVAILABLE = False 41 | 42 | logger = getLogger(__name__) 43 | 44 | 45 | def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool = True, disable_exllamav2:bool = False, use_qigen: bool = False): 46 | if use_qigen: 47 | from ..nn_modules.qlinear.qlinear_qigen import QuantLinear 48 | else: 49 | if use_triton: 50 | if torch.version.hip: 51 | logger.warning("Running GPTQ triton version on AMD GPUs is untested and may result in errors or wrong predictions. Please use use_triton=False.") 52 | 53 | from ..nn_modules.qlinear.qlinear_triton import QuantLinear 54 | else: 55 | if bits == 4 and not disable_exllamav2 and EXLLAMAV2_KERNELS_AVAILABLE: 56 | from ..nn_modules.qlinear.qlinear_exllamav2 import QuantLinear 57 | elif bits == 4 and not disable_exllama and EXLLAMA_KERNELS_AVAILABLE: 58 | from ..nn_modules.qlinear.qlinear_exllama import QuantLinear 59 | elif not desc_act or group_size == -1: 60 | from ..nn_modules.qlinear.qlinear_cuda_old import QuantLinear 61 | else: 62 | from ..nn_modules.qlinear.qlinear_cuda import QuantLinear 63 | 64 | return QuantLinear 65 | 66 | 67 | def compare_transformers_version( 68 | version: str = "v4.28.0", 69 | op: str = "eq" 70 | ): 71 | assert op in ["eq", "lt", "le", "gt", "ge"] 72 | 73 | from transformers import __version__ 74 | 75 | return getattr(parse_version(__version__), f"__{op}__")(parse_version(version)) 76 | 77 | 78 | def compare_pytorch_version( 79 | version: str = "v2.0.0", 80 | op: str = "eq" 81 | ): 82 | assert op in ["eq", "lt", "le", "gt", "ge"] 83 | 84 | from torch import __version__ 85 | 86 | return getattr(parse_version(__version__), f"__{op}__")(parse_version(version)) 87 | -------------------------------------------------------------------------------- /utils/utils/perplexity_utils.py: -------------------------------------------------------------------------------- 1 | # this code is taken from GPTQ: 2 | import sys 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from datasets import load_dataset 7 | from transformers import AutoTokenizer, AutoModelForCausalLM 8 | import os 9 | 10 | class Perplexity: 11 | """ 12 | A class for calculating the perplexity of a language model. 13 | """ 14 | 15 | def __init__(self, model, tokenizer, dataset_path='wikitext', dataset_name=None, 16 | split='test', text_column='text', 17 | eval_accuracy = True): 18 | """ 19 | Calculate perplexity using the same method as seen in llama.cpp. 20 | 21 | Parameters 22 | ---------- 23 | model : AutoModelForCausalLM 24 | The language model for which the perplexity is calculated. 25 | tokenizer : AutoTokenizer 26 | The tokenizer corresponding to the model. 27 | device : str, optional 28 | The device to run the calculations on. If auto, the device that your model uses 29 | will be the device used for these calculations. Default is 'auto'. 30 | dataset_path : str, optional 31 | The path to the dataset on the Hugging Face dataset hub. Default is 'wikitext'. 32 | dataset_name : str, optional 33 | The name of the dataset. Default is None. 34 | split : str, optional 35 | The split of the dataset to use. Default is 'test'. 36 | text_column : str, optional 37 | The name of the column in the dataset that contains the text data. Default is 'text'. 38 | """ 39 | self._model = model 40 | self._tokenizer = tokenizer 41 | self._dataset_path = dataset_path 42 | self._dataset_name = dataset_name 43 | self._split = split 44 | self._text_column = text_column 45 | self._text = self._prepare_data() 46 | self.eval_accuracy = eval_accuracy 47 | 48 | 49 | 50 | def _get_device(self): 51 | if torch.backends.mps.is_available(): 52 | return 'mps' 53 | elif torch.cuda.is_available(): 54 | return 'cuda:0' 55 | else: 56 | return 'cpu' 57 | 58 | def _prepare_data(self): 59 | """ 60 | Prepares the dataset by loading and formatting. 61 | 62 | Returns 63 | ------- 64 | str 65 | The formatted dataset as a single string. 66 | """ 67 | _dataset_name = 'wikitext-2-raw-v1' 68 | _dataset_path = self._dataset_path 69 | if _dataset_path == 'wikitext': 70 | _dataset_name = 'wikitext-2-raw-v1' 71 | if _dataset_path == 'c4': 72 | _dataset_name = 'realnewslike' 73 | # Load the dataset 74 | print(_dataset_path) 75 | 76 | #data = load_dataset(data_dir=_dataset_path, name='wikitext-2-raw-v1', split=_split) 77 | data = load_dataset(os.path.join(_dataset_path,"wikitext"), name='wikitext-2-raw-v1', 78 | cache_dir=".cache", split=self._split) 79 | # Format the text column of the dataset 80 | text_list = [' \n' if s == '' else s for s in data[self._text_column]] 81 | return ''.join(text_list) 82 | 83 | @staticmethod 84 | def softmax(logits): 85 | """ 86 | Static method for applying the softmax function. 87 | 88 | Parameters 89 | ---------- 90 | logits : np.ndarray 91 | The input to the softmax function. 92 | 93 | Returns 94 | ------- 95 | np.ndarray 96 | The output of the softmax function. 97 | """ 98 | e_x = np.exp(logits - np.max(logits)) 99 | return e_x / e_x.sum(axis=0) 100 | 101 | def calculate_perplexity(self, n_ctx=512, n_batch=512): 102 | """ 103 | Calculates the perplexity of the language model. 104 | 105 | Parameters 106 | ---------- 107 | n_ctx : int 108 | The context size. 109 | n_batch : int 110 | The batch size. 111 | 112 | Returns 113 | ------- 114 | list 115 | The list of perplexity scores calculated. 116 | """ 117 | # Tokenize the text 118 | self._tokenizer.model_max_length = sys.maxsize 119 | tokens = self._tokenizer(self._text, truncation=False, return_tensors='pt').input_ids.to('cuda') 120 | 121 | nll = 0.0 # Negative log likelihood 122 | count = 0 # Counter for processed tokens 123 | curr_ppl = 0 124 | all_perplexity = [] 125 | 126 | with tqdm(range(len(tokens[0]) // n_ctx), desc="Perplexity: - ") as progress: 127 | for i in progress: 128 | # Process each batch of tokens 129 | nll, count = self._process_batch(i, n_ctx, n_batch, tokens, nll, count) 130 | 131 | 132 | # Calculate and display the current perplexity 133 | if self.eval_accuracy is True: 134 | curr_ppl = np.exp(nll / count) 135 | else: 136 | curr_ppl = 0 137 | all_perplexity.append(curr_ppl) 138 | progress.set_description(f"Perplexity: {curr_ppl:.4f}") 139 | #print("-----------------------") 140 | 141 | 142 | return all_perplexity 143 | 144 | def _process_batch(self, i, n_ctx, n_batch, tokens, nll, count): 145 | """ 146 | Processes each batch of tokens. 147 | 148 | Parameters 149 | ---------- 150 | i : int 151 | The batch index. 152 | n_ctx : int 153 | The context size. 154 | n_batch : int 155 | The batch size. 156 | tokens : torch.Tensor 157 | The tokenized text. 158 | nll : float 159 | The current negative log likelihood. 160 | count : int 161 | The current count of processed tokens. 162 | 163 | Returns 164 | ------- 165 | float 166 | The updated negative log likelihood. 167 | int 168 | The updated count of processed tokens. 169 | """ 170 | start = i * n_ctx 171 | end = start + n_ctx 172 | 173 | num_batches = (n_ctx + n_batch - 1) // n_batch 174 | 175 | logits = [] 176 | if self._tokenizer.bos_token_id is None: 177 | self._tokenizer.bos_token_id = 11 178 | for j in range(num_batches): 179 | batch_start = start + j * n_batch 180 | batch_size = min(end - batch_start, n_batch) 181 | 182 | token_org = tokens[0][batch_start].item() 183 | 184 | if j == 0: 185 | # Replace the first token with the BOS token 186 | tokens[0][batch_start] = self._tokenizer.bos_token_id 187 | 188 | # Compute the logits for the current batch of tokens 189 | batch_logits = self._compute_batch_logits(tokens, batch_start, batch_size) 190 | 191 | tokens[0][batch_start] = token_org 192 | 193 | logits.append(batch_logits) 194 | 195 | # We rely on the fact that attention in the forward pass only looks at previous 196 | # tokens here, so the logits returned for each token are an accurate representation 197 | # of what the model would have predicted at that point. 198 | # 199 | # Example, we have a context window of 512, we will compute perplexity for each of the 200 | # last 256 tokens. Then, we split the input up into context window size chunks to 201 | # process the entire prompt. 202 | if self.eval_accuracy is True: 203 | for j in range(min(512, n_ctx // 2), n_ctx - 1): 204 | tok_logits = logits[0][0][j] 205 | # Compute the probability of the next token 206 | prob = torch.softmax(tok_logits,dim=-1)[tokens[0][start + j + 1]] 207 | prob = prob.to(torch.float16).cpu().numpy() 208 | # Update the negative log likelihood and the count of processed tokens 209 | nll += -np.log(prob, where=prob>0) 210 | count += 1 211 | 212 | return nll, count 213 | 214 | def _compute_batch_logits(self, tokens, batch_start, batch_size): 215 | """ 216 | Computes the logits for a batch of tokens. 217 | 218 | Parameters 219 | ---------- 220 | tokens : torch.Tensor 221 | The tokenized text. 222 | batch_start : int 223 | The start index of the batch. 224 | batch_size : int 225 | The size of the batch. 226 | 227 | Returns 228 | ------- 229 | torch.Tensor 230 | The logits for the batch of tokens. 231 | """ 232 | # Compute the logits without keeping track of gradients 233 | with torch.no_grad(): 234 | outputs = self._model(tokens[:, batch_start:batch_start+batch_size]) 235 | return outputs.logits.detach() 236 | --------------------------------------------------------------------------------