├── .gitignore
├── .gitignore.git
├── README.md
├── bench
├── README.md
├── bench.jpg
├── tflops_int4_overall.jpg
└── tflops_int8_overall.jpg
├── benchalltokens.py
├── benchflops.py
├── benchlatency.py
├── evalppl.py
├── examples
├── .gitignore
├── basic_generate.py
├── basic_quant.py
├── basic_quant_mix.py
├── basic_quant_opt.py
├── basic_quant_quik.py
├── basic_safetensors_generate.py
├── benchW8A8.ipynb
├── benchbitsand.py
├── benchkernels.sh
├── benchlayers.sh
├── benchmark.py
├── breakdown.sh
├── eval.py
├── mmlu.py
├── mmlu.sh
├── overhead.sh
├── smooth_quant_get_act.py
├── testgenerate.py
└── utils.py
├── figures
├── awq32.gif
├── awq512.gif
├── mixq32.gif
├── mixq512.gif
└── textmixq.jpg
├── generate.py
├── generate.sh
├── get_act.sh
├── mixquant
├── Cache.py
├── __init__.py
├── int8fusedkernel.py
├── kernel.py
├── models
│ ├── __init__.py
│ ├── aquila.py
│ ├── auto.py
│ ├── baichuan.py
│ ├── base.py
│ ├── basefuser.py
│ ├── bloom.py
│ ├── falcon.py
│ ├── gpt_bigcode.py
│ ├── gpt_neox.py
│ ├── gptj.py
│ ├── llama.py
│ ├── mistral.py
│ ├── mpt.py
│ ├── opt.py
│ └── sample.py
├── modules
│ ├── __init__.py
│ ├── act.py
│ ├── fused
│ │ ├── __init__.py
│ │ ├── attn.py
│ │ ├── block.py
│ │ ├── cache.py
│ │ ├── gptj_attn.py
│ │ ├── mistral_attn.py
│ │ ├── mlp.py
│ │ ├── model.py
│ │ └── norm.py
│ ├── linear.py
│ └── qlinear.py
├── quantize
│ ├── __init__.py
│ └── mixquant.py
└── utils
│ ├── __init__.py
│ ├── calib_data.py
│ ├── fused_utils.py
│ ├── lm_eval_adaptor.py
│ ├── module down_weight_only.py
│ ├── module.py
│ ├── module_.py
│ ├── parallel.py
│ └── utils.py
├── mmlu.sh
├── quant.sh
├── requirements.txt
├── runalltokens.sh
├── runlatency.sh
├── runppl.sh
├── runthroughput.sh
└── utils
└── utils
├── __init__.py
├── data_utils.py
├── exllama_utils.py
├── import_utils.py
├── peft_utils.py
└── perplexity_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | third-party
3 | data/
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | *.pyc
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 | cover/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | .pybuilder/
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | # For a library or package, you might want to ignore these files since the code is
91 | # intended to run in multiple environments; otherwise, check them in:
92 | # .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # poetry
102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103 | # This is especially recommended for binary packages to ensure reproducibility, and is more
104 | # commonly ignored for libraries.
105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106 | #poetry.lock
107 |
108 | # pdm
109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110 | #pdm.lock
111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112 | # in version control.
113 | # https://pdm.fming.dev/#use-with-ide
114 | .pdm.toml
115 |
116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117 | __pypackages__/
118 |
119 | # Celery stuff
120 | celerybeat-schedule
121 | celerybeat.pid
122 |
123 | # SageMath parsed files
124 | *.sage.py
125 |
126 | # Environments
127 | .env
128 | .venv
129 | env/
130 | venv/
131 | ENV/
132 | env.bak/
133 | venv.bak/
134 |
135 | # Spyder project settings
136 | .spyderproject
137 | .spyproject
138 |
139 | # Rope project settings
140 | .ropeproject
141 |
142 | # mkdocs documentation
143 | /site
144 |
145 | # mypy
146 | .mypy_cache/
147 | .dmypy.json
148 | dmypy.json
149 |
150 | # Pyre type checker
151 | .pyre/
152 |
153 | # pytype static type analyzer
154 | .pytype/
155 |
156 | # Cython debug symbolsd
157 | cython_debug/
158 |
159 | # PyCharm
160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162 | # and can be added to the global gitignore or merged into this file. For a more nuclear
163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164 | #.idea/
165 |
166 | *.pt
167 | **/*.pt
168 | **/*.pyc
169 | *.json
170 | __pycache__
171 |
172 |
173 |
174 |
175 | build/
176 | output/
177 | outputcprof/
178 | outputcproffp16/
179 | outputweight/
180 | outputncu/
181 | roofline/
182 | weight/
183 | *.onnx
184 |
185 |
--------------------------------------------------------------------------------
/.gitignore.git:
--------------------------------------------------------------------------------
1 | *.onnx
2 |
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MixQ
2 |
3 |
4 |
5 | MixQ: Taming Dynamic Outliers in Mixed-Precision Quantization by Online Prediction
6 |
7 | We use mixed-precision GEMM for enhancing throughput.
8 |
9 | Please refer to https://github.com/Qcompiler/vllm-mixed-precision for end-to-end text generation.
10 |
11 | ## Comparision with AWQ
12 |
13 | Assuming we have a task that is to compute the PPL(perplexity) of Wikitext2.
14 | The dataset wikitext contains 333088 validation data.
15 |
16 | For ```batch size = 32```, the task is devided into 10409 parts.
17 |
18 | AWQ finished the task in 10 minutes with 16.71 it/s.
19 |
20 |
21 |
22 | MixQ (W8A8O16) finished the task in 4.50 minutes with 35.02 it/s.
23 |
24 |
25 |
26 | For ```batch size = 512```, the task is devided into 655 parts.
27 |
28 | AWQ finished the task in 127 seconds with 5.2 it/s.
29 |
30 |
31 |
32 | MixQ (W8A8O16) finished the task in 30 seconds with 21.34 it/s.
33 |
34 |
35 |
36 |
37 | ## Setup
38 |
39 | Please download the mixlib kernel from https://github.com/Qcompiler/QComplier:
40 |
41 | ```
42 | git clone git@github.com:Qcompiler/QComplier.git
43 | cd EETQ
44 | python setup.py install
45 | ```
46 | ```
47 | cd quantkernel
48 | python setup.py install
49 | ```
50 |
51 | ## Benchmarking the throughput
52 |
53 |
54 |
55 | It is very easy to quantize a LLM and run by MIXQ 4bit or 8bit kernel
56 |
57 | Running the following CMD to quantize the LLM with W8A8O16 kernel:
58 |
59 | ```
60 | python examples/basic_quant_mix.py --model_path /mnt/data/checkpoint/Llama-2-7b --quant_file /home/dataset/quant/quant8/Llama-2-7b --w_bit 8
61 | ```
62 |
63 | Benchmark the throughput of MIXQ by:
64 |
65 | ```
66 | python benchflops.py --model_type mix --model_path /home/dataset/quant/quant8/Llama-2-7b --quant_file /home/dataset/quant/quant8/Llama-2-7b --batch_size 512 --bit 8
67 | ```
68 |
69 | In NVIDIA A100-PCIE-40GB, the output is
70 |
71 | ```
72 | Version: mix 8bit
73 | | Batch Size | Decode Length | Decode tokens/s | Memory (VRAM) |
74 | |-------------:|----------------:|------------------:|:-----------------|
75 | | 512 | 1024 | 10609.8 | 7.86 GB (19.97%) |
76 | ```
77 |
78 |
79 |
80 | # News !!
81 |
82 | We have integrate the MixedQLinear designed by QUIK into our framework! The QUIK now is able to support a wide range of LLMs including:
83 |
84 |
85 | - Llama-2 7B/13B/70B
86 | - Llama-3 8B
87 | - Falcon 7B/40B
88 | - ChatGLM 7B
89 | - QWen2 7B
90 |
91 |
92 | ## How to Run
93 |
94 | It is very easy to quantize a LLM and run by QUIK 4bit kernel
95 |
96 | Running the following CMD to quantize the LLM
97 |
98 | ```
99 | python examples/basic_quant_quik.py --model_path /mnt/data/checkpoint/Llama-2-7b --quant_file /home/dataset/quant/quantquik4/Llama-2-7b --w_bit 4
100 | ```
101 |
102 | Benchmark the throughput of QUIK by:
103 |
104 | ```
105 | python benchflops.py --model_type quik --model_path /home/dataset/quant/quantquik4/Llama-2-7b \
106 | --quant_file /home/dataset/quant/quantquik4/quik4/Llama-2-7b \
107 | --batch_size 512 --bit 4
108 | ```
109 |
110 | In NVIDIA A100-PCIE-40GB, the output is
111 |
112 | ```
113 | Version: quik 4bit
114 | | Batch Size | Decode Length | Decode tokens/s | Memory (VRAM) |
115 | |-------------:|----------------:|------------------:|:-----------------|
116 | | 512 | 1024 | 8981.17 | 4.88 GB (12.40%) |
117 | ```
118 |
119 |
120 |
121 | # Tensorrt-LLM implementation of QUIK and MIXQ
122 |
123 | We have supported the end-to-end text generation in TRT-LLM and VLLM!
124 |
125 | For TRT-LLM, please download the NVIDIA TensorRT docker. [TensorRT docker](https://github.com/NVIDIA/TensorRT-LLM). DO NOT USE your local environment!
126 |
127 | Please enter the e2eTRTLLM folder https://github.com/Qcompiler/MixQ_Tensorrt_LLM
128 |
129 | ```
130 | git clone https://github.com/Qcompiler/MixQ_Tensorrt_LLM.git
131 | docker pull registry.cn-hangzhou.aliyuncs.com/dongdongchen/dongdong:v1
132 | ```
133 |
134 |
135 |
136 | Please Running the docker:
137 |
138 | ```
139 | export name=myname
140 | bash -c " nvidia-smi; docker run --rm -it --ipc=host -p 6789:22 \
141 | -v /home/${name}/lianxiang/lianxiangTRT/:/code/tensorrt_llm \
142 | -v /mnt/octave/data/${name}/checkpoint:/dataset \
143 | -v /home/${name}/checkpoint:/code/checkpoint \
144 | -v /mnt/octave/data/${name}/lianxiang/checkpoint:/octave/checkpoint \
145 | --ulimit memlock=-1 --ulimit stack=67108864 \
146 | --gpus=all \
147 | --env 'CCACHE_DIR=/code/tensorrt_llm/cpp/.ccache' \
148 | --env 'CCACHE_BASEDIR=/code/tensorrt_llm' \
149 | --workdir /app/tensorrt_llm \
150 | --hostname hpc-release \
151 | --name tensorrt_llm-release-zhanghy \
152 | --tmpfs /tmp:exec \
153 | registry.cn-hangzhou.aliyuncs.com/dongdongchen/dongdong:v1 "
154 |
155 | ```
156 |
157 |
158 | After starting the docker, set the env :
159 |
160 | ```
161 | model=Llama-2-7b
162 | ngpu=1
163 | export model_dir=/code/tensorrt_llm/checkpoint/${model}
164 | export quant_dir=/code/tensorrt_llm/checkpoint/checkpoinmix/tllm_checkpoint_${ngpu}gpu_fp16${model}
165 | export out_dir=/code/tensorrt_llm/checkpoint/trt_enginesmix/tllm_checkpoint_${ngpu}gpu_fp16${model}
166 | ```
167 |
168 | Please quantize the model by:
169 |
170 | ```
171 | CUDA_VISIBLE_DEVICES=0 python quantize.py --model_dir ${model_dir} \
172 | --output_dir ${quant_dir} --dtype float16 --device cpu \
173 | --qformat int8_mix --calib_size 32
174 | ```
175 |
176 | Please build the MIXQ model by:
177 |
178 | ```
179 | CUDA_VISIBLE_DEVICES=0 trtllm-build --checkpoint_dir ${quant_dir} \
180 | --output_dir ${out_dir} \
181 | --gemm_plugin float16 --mix_precision int8
182 | ```
183 |
184 |
185 | Generating the text with MIXQ by:
186 |
187 | ```
188 | CUDA_VISIBLE_DEVICES=0 python summarize.py --test_trt_llm \
189 | --hf_model_dir ${model_dir} \
190 | --data_type fp16 \
191 | --engine_dir ${out_dir}
192 | ```
193 |
194 |
195 | ## Building the TRT-LLM MIXQ plugging with 4 stage pipline for Llama-2-70B
196 |
197 |
198 |
199 | ```
200 | model=Llama-2-70b
201 | ngpu=4
202 | export model_dir=/code/tensorrt_llm/checkpoint/${model}
203 | export quant_dir=/code/tensorrt_llm/checkpoint/checkpoinmix/tllm_checkpoint_${ngpu}gpu_fp16${model}
204 | export out_dir=/code/tensorrt_llm/checkpoint/trt_enginesmix/tllm_checkpoint_${ngpu}gpu_fp16${model}
205 | ```
206 |
207 | Please quantize the model by:
208 |
209 | ```
210 | CUDA_VISIBLE_DEVICES=0,1,2,3 python quantize.py --model_dir ${model_dir} \
211 | --output_dir ${quant_dir} --dtype float16 --device cpu \
212 | --qformat int8_mix --calib_size 32 --pp_size ${gpu}
213 | ```
214 |
215 | Please build the MIXQ model by:
216 |
217 | ```
218 | CUDA_VISIBLE_DEVICES=0,1,2,3 trtllm-build --checkpoint_dir ${quant_dir} \
219 | --output_dir ${out_dir} \
220 | --gemm_plugin float16 --mix_precision int8
221 | ```
222 |
223 |
224 | Generating the text with MIXQ by:
225 |
226 | ```
227 | CUDA_VISIBLE_DEVICES=0,1,2,3 mpirun -np 4 --allow-run-as-root python summarize.py --test_trt_llm \
228 | --hf_model_dir ${model_dir} \
229 | --data_type fp16 \
230 | --engine_dir ${out_dir}
231 | ```
232 |
233 |
234 | ## Text generation result
235 |
236 | # Llama-2-7B FP16 baseline
237 |
238 |
239 |
240 | When running the ```summarize.py``` of MIXQ (Llama-2-7B in A100, 40GB, PCIE), we get:
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 | ## Mixed-precision Inference In VLLM
249 |
250 | Please follow the https://github.com/Qcompiler/vllm-mixed-precision for mixed-precision inference.
251 |
252 | Please install the vllm by
253 | ```
254 | pip install vllm==0.6.2
255 | ```
256 |
257 |
258 | Please install the mixed-precision source code by
259 | ```
260 | git clone git@github.com:Qcompiler/vllm-mixed-precision.git
261 | ```
262 |
263 | And copy the ".so" from the vllm project
264 |
265 | ```
266 | cp -r $PYTHON_PATH/lib/python3.11/site-packages/vllm/*.so vllm-mixed-precision/vllm/
267 | ```
268 |
269 | Delete the vllm==0.6.2
270 | ```
271 | pip uninstall vllm
272 | ```
273 |
274 |
275 |
276 | ## Runing 8-bit mixed-preiciosn infernce in vllm
277 |
278 | ```
279 | export PYTHONPATH=$( pwd )
280 | python test8bit.py --quant 8
281 | ```
282 |
--------------------------------------------------------------------------------
/bench/README.md:
--------------------------------------------------------------------------------
1 | # Benchmarking for kernels
2 |
3 |
4 | # Benchmarking for MIXQ in A100
5 |
6 | For the 8-bit kernel evaluation in A100:
7 |
8 |
9 |
10 | For the 4-bit kernel evaluation in A100:
11 |
12 |
13 |
14 |
15 |
16 | # Benchmarking for FP8 and INT8 in H100
17 |
18 | In Hopper arch (H100), we bench the kernel performance of FP8, INT8, FP16. We found that the FP8 kernel is slightly slower than INT8 kernel: y-axis is the TFLOPs of kernel, x-axis is the shape of GEMM. M=N=K;
19 |
20 |
--------------------------------------------------------------------------------
/bench/bench.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/bench/bench.jpg
--------------------------------------------------------------------------------
/bench/tflops_int4_overall.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/bench/tflops_int4_overall.jpg
--------------------------------------------------------------------------------
/bench/tflops_int8_overall.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/bench/tflops_int8_overall.jpg
--------------------------------------------------------------------------------
/benchalltokens.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import sys
4 |
5 |
6 | os.environ["WORLD_SIZE"] = "1"
7 | import time
8 | import torch
9 | import argparse
10 | import numpy as np
11 | import pandas as pd
12 |
13 |
14 | from transformers import AutoTokenizer
15 | from torch.cuda import OutOfMemoryError
16 | torch.manual_seed(0)
17 |
18 | from mixquant.Cache import MixLibCache
19 | def warmup(model):
20 | warm_up = torch.randn((4096,4096)).to(next(model.parameters()).device)
21 | torch.mm(warm_up,warm_up)
22 |
23 |
24 |
25 |
26 |
27 |
28 | def prepare_data(_dataset_path = 'wikitext', _split='test', _text_column='text'):
29 | from datasets import load_dataset
30 | """
31 | Prepares the dataset by loading and formatting.
32 |
33 | Returns
34 | -------
35 | str
36 | The formatted dataset as a single string.
37 | """
38 | if _dataset_path == 'wikitext':
39 | _dataset_name = 'wikitext-2-raw-v1'
40 | data = load_dataset(_dataset_path, _dataset_name, split=_split)
41 |
42 | elif _dataset_path == 'c4':
43 | _dataset_name = 'realnewslike'
44 | data = load_dataset(_dataset_path, _dataset_name, split=_split)
45 | else:
46 | _dataset_name = 'wikitext-2-raw-v1'
47 | data = load_dataset(os.path.join(_dataset_path,'wikitext'),
48 | _dataset_name, split=_split, cache_dir="/home/chenyidong/tmp")
49 | # Format the text column of the dataset
50 | text_list = [' \n' if s == '' else s for s in data[_text_column]]
51 | return ''.join(text_list)
52 |
53 | def decode_token(model, _tokenizer, _text, n_batch, repeat = 10):
54 |
55 |
56 | tokens = _tokenizer(_text, truncation=False, return_tensors='pt').input_ids.to('cuda')
57 | start = 0
58 | end = n_batch
59 | for j in range(repeat):
60 |
61 | batch_start = start + j * n_batch
62 | batch_size = min(end - batch_start, n_batch)
63 |
64 | token_org = tokens[0][batch_start].item()
65 |
66 | if j == 0:
67 | # Replace the first token with the BOS token
68 | tokens[0][batch_start] = _tokenizer.bos_token_id
69 |
70 | # Compute the logits for the current batch of tokens
71 | _compute_batch_logits(tokens, batch_start, batch_size)
72 |
73 | tokens[0][batch_start] = token_org
74 |
75 | def _compute_batch_logits(_model,tokens, batch_start, batch_size):
76 | # Compute the logits without keeping track of gradients
77 |
78 | outputs = _model(tokens[:, batch_start:batch_start+batch_size])
79 | return outputs
80 |
81 |
82 | def generate(model, tokens, n_generate, batch_size, cache):
83 | context_time = 0
84 | generate_time = []
85 |
86 |
87 | with torch.inference_mode():
88 |
89 |
90 | # prefill context
91 | cache.is_prefill = False
92 |
93 |
94 |
95 |
96 | for i in range(10):
97 | batch_start = i * batch_size
98 | inputs = torch.as_tensor(tokens[:, batch_start:batch_start+batch_size], device=next(model.parameters()).device)
99 | inputs = inputs.reshape((batch_size,1,))
100 | out = model(inputs,use_cache=True)
101 |
102 |
103 |
104 |
105 |
106 | with torch.inference_mode():
107 | # cache.is_prefill = True
108 | # inputs = torch.as_tensor(input_ids, device=next(model.parameters()).device)
109 | # out = model(inputs,use_cache=True)
110 | # token = out[0][:, -1].max(1)[1].unsqueeze(1)
111 |
112 | for i in range(n_generate):
113 | batch_start = i * batch_size
114 | torch.cuda.synchronize()
115 |
116 |
117 |
118 | inputs = torch.as_tensor(tokens[:, batch_start:batch_start+batch_size], device=next(model.parameters()).device)
119 | inputs = inputs.reshape((batch_size,1,))
120 | start = time.time()
121 |
122 |
123 | out = model(inputs,use_cache=True)
124 | torch.cuda.synchronize()
125 |
126 |
127 | generate_time.append(time.time() - start)
128 |
129 |
130 | print("--- generate time ---")
131 | #print(generate_time)
132 | return generate_time
133 |
134 | def run_round(model_path, quant_file, n_generate, token, batch_size, safetensors, model_type='fp16',mixlibcache=None):
135 |
136 | from mixquant import AutoForCausalLM
137 | model = AutoForCausalLM.from_quantized(
138 | model_path, quant_file, fuse_layers=True,
139 | max_new_tokens=n_generate, batch_size=batch_size,
140 | safetensors=safetensors,
141 | mix = True,
142 | cache = mixlibcache
143 | )
144 |
145 |
146 |
147 |
148 |
149 |
150 | def main(args):
151 | rounds = [
152 |
153 | {"context": args.seq_length, "n_generate": args.seq_length},
154 |
155 | ]
156 |
157 | all_stats = []
158 |
159 | cache = MixLibCache(bit=args.bit)
160 |
161 | print("downloading data......")
162 | text = prepare_data(args.dataset_path)
163 | print("done......")
164 |
165 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=args.use_fast_tokenizer, trust_remote_code=True)
166 | if not tokenizer.pad_token_id:
167 | tokenizer.pad_token_id = tokenizer.eos_token_id
168 | tokenizer.model_max_length = sys.maxsize
169 | tokens = tokenizer(text, truncation=False, return_tensors='pt').input_ids
170 |
171 | print( type(tokens[0]))
172 | exit(0)
173 |
174 |
175 |
176 |
177 | for settings in rounds:
178 |
179 |
180 |
181 | stats, model_version = run_round(
182 | args.model_path,
183 | args.quant_file,
184 | settings["n_generate"],
185 | tokens,
186 | args.batch_size,
187 | args.safetensors,
188 | args.model_type,
189 | cache
190 | )
191 |
192 |
193 |
194 | if __name__ == "__main__":
195 |
196 | """
197 |
198 | """
199 | parser = argparse.ArgumentParser()
200 | parser.add_argument("--model_path", type=str, default="", help="path to the model")
201 | parser.add_argument("--quant_file", type=str, default="", help="weights filename")
202 | parser.add_argument("--batch_size", type=int, default=1, help="Batch size for cache and generation")
203 | parser.add_argument("--model_type", type=str, default="fp16")
204 | parser.add_argument("--safetensors", default=False, action="store_true", help="Use for enabling safetensors")
205 | parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer")
206 | parser.add_argument("--seq_length", type=int, default=128)
207 | parser.add_argument("--dataset_path", type=str, default='wikitext', help="Path to the dataset.")
208 | parser.add_argument("--bit", type=int, default=8)
209 | args = parser.parse_args()
210 |
211 | main(args)
--------------------------------------------------------------------------------
/benchlatency.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import sys
4 |
5 |
6 | os.environ["WORLD_SIZE"] = "1"
7 | import time
8 | import torch
9 | import argparse
10 | import numpy as np
11 | import pandas as pd
12 |
13 |
14 | from transformers import AutoTokenizer
15 | from torch.cuda import OutOfMemoryError
16 | torch.manual_seed(0)
17 |
18 | from mixquant.Cache import MixLibCache
19 | def warmup(model):
20 | warm_up = torch.randn((4096,4096)).to(next(model.parameters()).device)
21 | torch.mm(warm_up,warm_up)
22 |
23 |
24 |
25 |
26 |
27 |
28 | def prepare_data(_dataset_path = 'wikitext', _split='test', _text_column='text'):
29 | from datasets import load_dataset
30 | """
31 | Prepares the dataset by loading and formatting.
32 |
33 | Returns
34 | -------
35 | str
36 | The formatted dataset as a single string.
37 | """
38 | if _dataset_path == 'wikitext':
39 | _dataset_name = 'wikitext-2-raw-v1'
40 | data = load_dataset(_dataset_path, _dataset_name, split=_split)
41 |
42 | elif _dataset_path == 'c4':
43 | _dataset_name = 'realnewslike'
44 | data = load_dataset(_dataset_path, _dataset_name, split=_split)
45 | else:
46 | _dataset_name = 'wikitext-2-raw-v1'
47 | data = load_dataset(os.path.join(_dataset_path,'wikitext'),
48 | _dataset_name, split=_split, cache_dir="/home/chenyidong/tmp")
49 | # Format the text column of the dataset
50 | text_list = [' \n' if s == '' else s for s in data[_text_column]]
51 | return ''.join(text_list)
52 |
53 | def decode_token(model, _tokenizer, _text, n_batch, repeat = 10):
54 |
55 |
56 | tokens = _tokenizer(_text, truncation=False, return_tensors='pt').input_ids.to('cuda')
57 | start = 0
58 | end = n_batch
59 | for j in range(repeat):
60 |
61 | batch_start = start + j * n_batch
62 | batch_size = min(end - batch_start, n_batch)
63 |
64 | token_org = tokens[0][batch_start].item()
65 |
66 | if j == 0:
67 | # Replace the first token with the BOS token
68 | tokens[0][batch_start] = _tokenizer.bos_token_id
69 |
70 | # Compute the logits for the current batch of tokens
71 | _compute_batch_logits(tokens, batch_start, batch_size)
72 |
73 | tokens[0][batch_start] = token_org
74 |
75 | def _compute_batch_logits(_model,tokens, batch_start, batch_size):
76 | # Compute the logits without keeping track of gradients
77 |
78 | outputs = _model(tokens[:, batch_start:batch_start+batch_size])
79 | return outputs
80 |
81 |
82 | def generate(model, tokens, n_generate, batch_size, cache):
83 | context_time = 0
84 | generate_time = []
85 |
86 |
87 | with torch.inference_mode():
88 |
89 |
90 | # prefill context
91 | cache.is_prefill = False
92 |
93 |
94 |
95 |
96 | for i in range(10):
97 | batch_start = i * batch_size
98 | inputs = torch.as_tensor(tokens[:, batch_start:batch_start+batch_size], device=next(model.parameters()).device)
99 | inputs = inputs.reshape((batch_size,1,))
100 | out = model(inputs,use_cache=True)
101 |
102 |
103 |
104 |
105 |
106 | with torch.inference_mode():
107 | # cache.is_prefill = True
108 | # inputs = torch.as_tensor(input_ids, device=next(model.parameters()).device)
109 | # out = model(inputs,use_cache=True)
110 | # token = out[0][:, -1].max(1)[1].unsqueeze(1)
111 |
112 | for i in range(n_generate):
113 | batch_start = i * batch_size
114 | torch.cuda.synchronize()
115 |
116 |
117 | # decode tokens
118 | cache.is_prefill = False
119 | inputs = torch.as_tensor(tokens[:, batch_start:batch_start+batch_size], device=next(model.parameters()).device)
120 | inputs = inputs.reshape((batch_size,1,))
121 | start = time.time()
122 |
123 |
124 | out = model(inputs,use_cache=True)
125 | torch.cuda.synchronize()
126 |
127 |
128 | generate_time.append(time.time() - start)
129 |
130 |
131 | print("--- generate time ---")
132 | #print(generate_time)
133 | return generate_time
134 |
135 | def run_round(model_path, quant_file, n_generate, token, batch_size, safetensors, model_type='fp16',mixlibcache=None):
136 | if model_type == 'mix':
137 | from mixquant import AutoForCausalLM
138 | model = AutoForCausalLM.from_quantized(
139 | model_path, quant_file, fuse_layers=True,
140 | max_new_tokens=n_generate, batch_size=batch_size,
141 | safetensors=safetensors,
142 | mix = True,
143 | cache = mixlibcache
144 | )
145 |
146 |
147 |
148 | if model_type == 'awq':
149 |
150 | import awq
151 | from awq import AutoAWQForCausalLM
152 | print(f" -- Loading model awq...")
153 | model = AutoAWQForCausalLM.from_quantized(
154 | model_path, quant_file, fuse_layers=True,
155 | max_new_tokens=n_generate, batch_size=batch_size,
156 | safetensors=safetensors
157 | )
158 | if model_type == 'fp16':
159 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
160 | model = AutoModelForCausalLM.from_pretrained(
161 | model_path, torch_dtype=torch.float16,
162 | device_map='auto', trust_remote_code=True
163 | )
164 |
165 |
166 |
167 |
168 | if model_type == 'bitsandbytes':
169 | from transformers import AutoModelForCausalLM
170 | model = AutoModelForCausalLM.from_pretrained(
171 | model_path,
172 | torch_dtype=torch.float16,
173 | load_in_8bit=True,
174 | trust_remote_code=True,
175 | max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB')
176 |
177 |
178 |
179 | if model_type == 'quik':
180 | from mixquant import AutoForCausalLM
181 | model = AutoForCausalLM.from_quantized(
182 | model_path, quant_file, fuse_layers=True,
183 | max_new_tokens=n_generate, batch_size=batch_size,
184 | safetensors=safetensors,
185 | mix = True,
186 | cache = mixlibcache
187 | )
188 |
189 |
190 |
191 |
192 | print(model)
193 | print(f" -- Warming up...")
194 | warmup(model)
195 |
196 | print(f" -- Generating {n_generate} tokens, in context...")
197 |
198 | try:
199 | generate_time = generate(model, token, n_generate, batch_size, mixlibcache)
200 | successful_generate = True
201 | except RuntimeError as ex:
202 | if 'cuda out of memory' in str(ex).lower():
203 | successful_generate = False
204 | else:
205 | raise RuntimeError(ex)
206 |
207 | device = next(model.parameters()).device
208 | memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3)
209 | memory_pct = memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)) * 100
210 |
211 | if successful_generate:
212 | # number of tokens in context / time for processing context * batch size
213 | # 1 second / median time per token in seconds * batch size
214 | decode_tokens_per_second = 1 / np.median(generate_time) * batch_size
215 |
216 | print(f" ** Speed (Decode): {decode_tokens_per_second:.2f} tokens/second")
217 | print(f" ** Max Memory (VRAM): {memory_used:.2f} GB ({memory_pct:.2f}%)")
218 | else:
219 |
220 | decode_tokens_per_second = 'OOM'
221 |
222 | return {
223 | "Batch Size": batch_size,
224 | "Decode Length": n_generate,
225 | "Decode tokens/s": decode_tokens_per_second,
226 | "Memory (VRAM)": f"{memory_used:.2f} GB ({memory_pct:.2f}%)",
227 | "latency" : float(np.median(generate_time))
228 | }, args.model_type
229 |
230 | def main(args):
231 | rounds = [
232 |
233 | {"context": args.seq_length, "n_generate": args.seq_length},
234 |
235 | ]
236 |
237 | all_stats = []
238 |
239 | cache = MixLibCache(bit=args.bit)
240 |
241 | print("downloading data......")
242 | text = prepare_data(args.dataset_path)
243 | print("done......")
244 |
245 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=args.use_fast_tokenizer, trust_remote_code=True)
246 | if not tokenizer.pad_token_id:
247 | tokenizer.pad_token_id = tokenizer.eos_token_id
248 | tokenizer.model_max_length = sys.maxsize
249 | tokens = tokenizer(text, truncation=False, return_tensors='pt').input_ids.to('cuda')
250 |
251 |
252 |
253 |
254 |
255 |
256 | for settings in rounds:
257 |
258 |
259 |
260 | stats, model_version = run_round(
261 | args.model_path,
262 | args.quant_file,
263 | settings["n_generate"],
264 | tokens,
265 | args.batch_size,
266 | args.safetensors,
267 | args.model_type,
268 | cache
269 | )
270 |
271 | all_stats.append(stats)
272 |
273 | if stats["Decode tokens/s"] == 'OOM':
274 | break
275 |
276 | df = pd.DataFrame(all_stats)
277 | print('GPU:', torch.cuda.get_device_name())
278 | print('Model:', args.model_path)
279 | print('Version:', model_version)
280 | print(df.to_markdown(index=False))
281 | try:
282 | os.mkdir('output/throughput/'+args.model_type)
283 | except:
284 | pass
285 | df.to_csv('output/throughput/'+args.model_type + '/' + args.quant_file.split("/")[-1] \
286 | + str(args.batch_size) + '_' + str(args.bit) + ".csv")
287 |
288 | if __name__ == "__main__":
289 |
290 |
291 | parser = argparse.ArgumentParser()
292 | parser.add_argument("--model_path", type=str, default="", help="path to the model")
293 | parser.add_argument("--quant_file", type=str, default="", help="weights filename")
294 | parser.add_argument("--batch_size", type=int, default=1, help="Batch size for cache and generation")
295 | parser.add_argument("--model_type", type=str, default="fp16")
296 | parser.add_argument("--safetensors", default=False, action="store_true", help="Use for enabling safetensors")
297 | parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer")
298 | parser.add_argument("--seq_length", type=int, default=128)
299 | parser.add_argument("--dataset_path", type=str, default='wikitext', help="Path to the dataset.")
300 | parser.add_argument("--bit", type=int, default=8)
301 | args = parser.parse_args()
302 |
303 | main(args)
--------------------------------------------------------------------------------
/evalppl.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | os.environ["WORLD_SIZE"] = "1"
5 | import argparse
6 | import pandas as pd
7 | import torch
8 | from utils.utils import Perplexity
9 | from transformers import AutoTokenizer
10 |
11 |
12 |
13 |
14 |
15 |
16 | def get_fp_features_num(module: torch.nn.Linear, args):
17 | fp_features_num = args.fp_features_num
18 | if args.fp_features_frac is not None:
19 | fp_features_num = max(int(module.in_features * args.fp_features_frac), fp_features_num)
20 | return fp_features_num
21 | def llama_replace_with_kernels(model, args):
22 | import modelutils
23 | layers = model.model.layers
24 | shared_inputs = {}
25 |
26 | assert not args.w_asym, 'Benchmarking only supports symmetric weight quantization!'
27 | print("Replace with INT4 kernels.")
28 | for i in range(len(layers)):
29 | opt_block = layers[i]
30 | sequential = [
31 | ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'],
32 | ['self_attn.o_proj'],
33 | ['mlp.up_proj', 'mlp.gate_proj'],
34 | ['mlp.down_proj']
35 | ]
36 | full = modelutils.find_layers(opt_block)
37 | for j, layer_group in enumerate(sequential):
38 | subset = {n: full[n] for n in layer_group}
39 | shared_inputs[f"{i}.{j}"] = qlinear.SharedQuantizedInput(len(layer_group))
40 | for name in subset:
41 | layer = subset[name]
42 | if 'lm_head' in name or 'rotary_emb' in name:
43 | continue
44 | is_quantized = False
45 | bits = 16
46 | fp_features = 0
47 | import quant_sim
48 | import qlinear
49 | if isinstance(layer, quant_sim.ActQuantWrapper):
50 | if layer.quantizer.configured:
51 | is_quantized = True
52 | bits = layer.quantizer.bits
53 | fp_features = layer.fp_features_num
54 | layer = layer.module
55 | layer_weight = layer.weight.data
56 |
57 | layer_scale = save_dict['model.layers.{}.{}.scale'.format(i, name)]
58 | if fp_features == 0:
59 | fp_feature_idx = None
60 | else:
61 | print('---------------save act_scales----------------')
62 | layer_act_scales = act_scales['model.layers.{}.{}'.format(i, name)]
63 | fp_feature_idx = torch.sort(layer_act_scales)[1][-fp_features:]
64 |
65 | if is_quantized:
66 | int_mod = qlinear.MixedQLinear.from_float(layer, layer_weight, layer_scale,
67 | shared_inputs[f"{i}.{j}"], fp_feature_idx,
68 | bits=bits)
69 | else:
70 | int_mod = layer
71 | modelutils.replace_single_mod_opt(opt_block, name, int_mod)
72 |
73 |
74 |
75 |
76 |
77 |
78 | if __name__ == "__main__":
79 | """
80 | Example usage.
81 |
82 | Default usage with GPT2 model:
83 | python examples/benchmark/perplexity.py
84 |
85 | Specify GPTQ quantized model:
86 | http_proxy=127.0.0.1:7890 https_proxy=127.0.0.1:7890 CUDA_VISIBLE_DEVICES=0 WORLD_SIZE=1 python examples/benchmark/perplexity.py \
87 | --model_name /mnt/data/zhongrx/Llama-2-7b \
88 | --model_basename gptq_model-4bit-128g \
89 | --is_quantized
90 |
91 | Change your dataset:
92 | python examples/benchmark/perplexity.py --dataset_path tiny_shakespeare
93 |
94 | """
95 | parser = argparse.ArgumentParser(description="Calculate Perplexity for a model.")
96 | parser.add_argument("--model_path", type=str, help="Model path")
97 | parser.add_argument("--quant_file", type=str, help="quant_file Model path")
98 |
99 | parser.add_argument("--model_type", type=str, default='bitsandbytesfp16')
100 |
101 |
102 | parser.add_argument("--n_ctx", type=int, default=256, help="Context size.")
103 | parser.add_argument("--n_batch", type=int, default=256, help="Batch size.")
104 | parser.add_argument("--dataset_path", type=str, default='wikitext', help="Path to the dataset.")
105 | parser.add_argument("--dataset_name", type=str, default=None, help="Name of the dataset.")
106 | parser.add_argument("--split", type=str, default='test', help="Dataset split to use.")
107 | parser.add_argument("--text_column", type=str, default='text', help="Column in the dataset containing the text.")
108 | parser.add_argument("--per_gpu_max_memory", type=int, default=None, help="Max memory used in each GPU.")
109 | parser.add_argument("--cpu_max_memory", type=int, default=None, help="Mx memory used in CPU.")
110 |
111 | parser.add_argument("--use_safetensors", action="store_true", help="Whether to use safetensors model file")
112 | parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer")
113 | parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code")
114 | parser.add_argument("--disable_exllama", action="store_true", help="Whether to use disable exllama kernel")
115 |
116 |
117 | # Weight Quantization Params:
118 | parser.add_argument('--w_bits', type=int, default=16, choices=[4, 8, 16])
119 |
120 |
121 | parser.add_argument('--int8_down_proj', action='store_true', help='Use INT8 for Down Projection')
122 | parser.add_argument('--fp_features_frac', type=float, default=None, help='Fraction of features to keep in FP16.')
123 | parser.add_argument("--fp_features_num", type=int, default=1, help="outliers")
124 |
125 | parser.add_argument('--eval_accuracy', type=bool, default=True)
126 | parser.add_argument('--eval_throughput', type=bool, default=False)
127 |
128 |
129 | args = parser.parse_args()
130 |
131 | if args.eval_throughput is True:
132 | args.eval_accuracy = False
133 |
134 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
135 |
136 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=args.use_fast_tokenizer, trust_remote_code=True)
137 | if not tokenizer.pad_token_id:
138 | tokenizer.pad_token_id = tokenizer.eos_token_id
139 | ppl = Perplexity(None, tokenizer, args.dataset_path, args.dataset_name, args.split, args.text_column, args.eval_accuracy)
140 |
141 |
142 | model_path = args.model_path
143 | quant_file = args.quant_file
144 |
145 | if args.model_type == 'bitsandbytesfp16':
146 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
147 | print(f" -- Loading model fp16...")
148 | # model = transformers.LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16,
149 | # device_map='auto')
150 | model = AutoModelForCausalLM.from_pretrained(
151 | model_path, torch_dtype=torch.bfloat16,
152 | device_map='auto', trust_remote_code=True
153 | )
154 |
155 | model = model.to('cuda')
156 | print(model)
157 |
158 | if args.model_type == 'bitsandbytesmix4':
159 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
160 | print(f" -- Loading model mix4...")
161 |
162 | n_gpus = torch.cuda.device_count()
163 | max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
164 | max_memory = {i: max_memory for i in range(n_gpus)}
165 | quantization_config = BitsAndBytesConfig(
166 | load_in_4bit=True,
167 | llm_int4_threshold=6.0,
168 | llm_int4_has_fp16_weight=False,
169 | )
170 | model = AutoModelForCausalLM.from_pretrained(
171 | model_path,
172 | device_map='auto',
173 | max_memory=max_memory,
174 | quantization_config=quantization_config
175 | )
176 | if args.model_type == 'bitsandbytes':
177 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
178 | print(f" -- Loading model mix bit8...")
179 |
180 | n_gpus = torch.cuda.device_count()
181 | max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
182 | max_memory = {i: max_memory for i in range(n_gpus)}
183 | quantization_config = BitsAndBytesConfig(
184 | load_in_8bit=True,
185 | llm_int8_threshold=6.0,
186 | llm_int8_has_fp16_weight=False,
187 | )
188 | model = AutoModelForCausalLM.from_pretrained(
189 | model_path,
190 | device_map='auto',
191 | max_memory=max_memory,
192 | quantization_config=quantization_config
193 | )
194 |
195 |
196 | if args.model_type == 'awq':
197 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
198 | print(f" -- Loading model awq...")
199 |
200 |
201 | from awq import AutoAWQForCausalLM
202 | model = AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=True, mix = False)
203 |
204 |
205 | if 'mix' in args.model_type :
206 | from mixquant.Cache import MixLibCache
207 | from mixquant import AutoForCausalLM
208 | cache = MixLibCache(args.n_batch)
209 |
210 |
211 | model = AutoForCausalLM.from_quantized(
212 | model_path, quant_file, fuse_layers=True,
213 | mix = True, cache = cache
214 | )
215 |
216 | if args.model_type == 'fp16':
217 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
218 | model = AutoModelForCausalLM.from_pretrained(
219 | model_path, torch_dtype=torch.float16,
220 | device_map='auto', trust_remote_code=True
221 | )
222 |
223 | #model = model.to('cuda')
224 |
225 | if args.model_type == 'quik':
226 | from mixquant import AutoForCausalLM
227 | model = AutoForCausalLM.from_quantized(
228 | model_path, quant_file, fuse_layers=True,
229 | max_new_tokens=args.n_generate, batch_size=args.batch_size,
230 | safetensors=args.safetensors,
231 | mix = True,
232 | cache = cache
233 | )
234 | print(model)
235 | ppl = Perplexity(model, tokenizer, args.dataset_path, args.dataset_name,
236 | args.split, args.text_column, args.eval_accuracy)
237 | allppl = ppl.calculate_perplexity(args.n_ctx, args.n_batch)
238 |
239 | data = pd.DataFrame(allppl)
240 | try:
241 | os.mkdir("output")
242 | except:
243 | pass
244 | data.to_csv("output/ppl_batchsize"+str(args.n_ctx)+"_"+args.model_type+"_"+model_path.split('/')[-1]+".csv" + str(args.fp_features_num))
245 |
246 |
247 |
248 |
249 |
--------------------------------------------------------------------------------
/examples/.gitignore:
--------------------------------------------------------------------------------
1 | data
2 | data/
3 |
--------------------------------------------------------------------------------
/examples/basic_generate.py:
--------------------------------------------------------------------------------
1 | from awq import AutoAWQForCausalLM
2 | from transformers import AutoTokenizer, TextStreamer
3 |
4 | quant_path = "TheBloke/Mistral-7B-OpenOrca-AWQ"
5 |
6 | # Load model
7 | model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True, safetensors=True)
8 | tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
9 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
10 |
11 | # Convert prompt to tokens
12 | prompt_template = """\
13 | <|im_start|>system
14 | You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!<|im_end|>
15 | <|im_start|>user
16 | {prompt}<|im_end|>
17 | <|im_start|>assistant"""
18 |
19 | prompt = "You're standing on the surface of the Earth. "\
20 | "You walk one mile south, one mile west and one mile north. "\
21 | "You end up exactly where you started. Where are you?"
22 |
23 | tokens = tokenizer(
24 | prompt_template.format(prompt=prompt),
25 | return_tensors='pt'
26 | ).input_ids.cuda()
27 |
28 | # Generate output
29 | generation_output = model.generate(
30 | tokens,
31 | streamer=streamer,
32 | max_new_tokens=512
33 | )
--------------------------------------------------------------------------------
/examples/basic_quant.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | os.environ["CUDA_VISIBLE_DEVICES"] = "7"
4 | os.environ["WORLD_SIZE"] = "1"
5 | from awq import AutoAWQForCausalLM
6 | from transformers import AutoTokenizer
7 |
8 | # model_path = '/mnt/data/zhongrx/Llama-2-13b-hf'
9 | # quant_path = '/mnt/data/chenyd/Llama-2-13b-awq'
10 |
11 |
12 | model_path = '/mnt/data/huangkz/huggingface/hub/models--facebook--opt-30b/snapshots/ceea0a90ac0f6fae7c2c34bcb40477438c152546'
13 | quant_path = '/mnt/data/chenyd/models--facebook--opt-30b-awq'
14 |
15 | quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
16 | print(quant_path)
17 | # Load model
18 | # NOTE: pass safetensors=True to load safetensors
19 | model = AutoAWQForCausalLM.from_pretrained(model_path, **{"low_cpu_mem_usage": True})
20 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
21 |
22 | # Quantize
23 | model.quantize(tokenizer, quant_config=quant_config)
24 |
25 | # Save quantized model
26 | # NOTE: pass safetensors=True to save quantized model weights as safetensors
27 | model.save_quantized(quant_path)
28 | tokenizer.save_pretrained(quant_path)
29 |
30 | print(f'Model is quantized and saved at "{quant_path}"')
--------------------------------------------------------------------------------
/examples/basic_quant_mix.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | # 本测试使用 mixquant 的量化工具集合实现混和精度量化
4 | import os
5 | os.environ["WORLD_SIZE"] = "1"
6 |
7 |
8 | import sys
9 | from mixquant import AutoForCausalLM
10 | from transformers import AutoTokenizer
11 |
12 | import argparse
13 | parser = argparse.ArgumentParser(description="Calculate Perplexity for a model.")
14 | parser.add_argument("--model_path", type=str, help="Model path")
15 | parser.add_argument("--quant_file", type=str, help="quant_file Model path")
16 | parser.add_argument("--w_bit", type=int, default=8, help="weight bit")
17 | args = parser.parse_args()
18 |
19 | model_path = args.model_path
20 | quant_path = args.quant_file
21 | quant_config = { "w_bit": args.w_bit, "version": "MIX" }
22 | print(quant_path)
23 | # Load model
24 | # NOTE: pass safetensors=True to load safetensors
25 | model = AutoForCausalLM.from_pretrained(model_path, mix = True, **{"low_cpu_mem_usage": True},device_map='cpu')
26 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
27 |
28 | print(model)
29 | # Quantize
30 | model.quantize_mix(tokenizer, quant_config=quant_config)
31 |
32 | # Save quantized model
33 | # NOTE: pass safetensors=True to save quantized model weights as safetensors
34 | model.save_quantized(quant_path)
35 | tokenizer.save_pretrained(quant_path)
36 |
37 | print(f'Mix Model is quantized and saved at "{quant_path}"')
--------------------------------------------------------------------------------
/examples/basic_quant_opt.py:
--------------------------------------------------------------------------------
1 | from awq import OptAWQForCausalLM
2 | from transformers import AutoTokenizer
3 | import os
4 | os.environ["CUDA_VISIBLE_DEVICES"] = "7"
5 | os.environ["WORLD_SIZE"] = "1"
6 | model_path = '/mnt/data/huangkz/huggingface/hub/models--facebook--opt-30b'
7 | quant_path = '/mnt/data/chenyd/models--facebook--opt-30b-awq'
8 | quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
9 | print(quant_path)
10 | # Load model
11 | # NOTE: pass safetensors=True to load safetensors
12 | model = OptAWQForCausalLM.from_pretrained(model_path, **{"low_cpu_mem_usage": True})
13 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
14 |
15 | # Quantize
16 | model.quantize(tokenizer, quant_config=quant_config)
17 |
18 | # Save quantized model
19 | # NOTE: pass safetensors=True to save quantized model weights as safetensors
20 | model.save_quantized(quant_path)
21 | tokenizer.save_pretrained(quant_path)
22 |
23 | print(f'Model is quantized and saved at "{quant_path}"')
--------------------------------------------------------------------------------
/examples/basic_quant_quik.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 |
4 | os.environ["WORLD_SIZE"] = "1"
5 | import sys
6 | from mixquant import AutoForCausalLM
7 | from transformers import AutoTokenizer
8 |
9 | import argparse
10 | parser = argparse.ArgumentParser(description="Calculate Perplexity for a model.")
11 | parser.add_argument("--model_path", type=str, help="Model path")
12 | parser.add_argument("--quant_file", type=str, help="quant_file Model path")
13 | parser.add_argument("--w_bit", type=int, default=4, help="weight bit")
14 | args = parser.parse_args()
15 |
16 | model_path = args.model_path
17 | quant_path = args.quant_file
18 | quant_config = { "w_bit": args.w_bit, "version": "QUIK" }
19 | print(quant_path)
20 | # Load model
21 | # NOTE: pass safetensors=True to load safetensors
22 | model = AutoForCausalLM.from_pretrained(model_path, mix = True, **{"low_cpu_mem_usage": True},device_map='cpu')
23 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
24 |
25 | print(model)
26 | # Quantize
27 | model.quantize_quik(tokenizer, quant_config=quant_config)
28 |
29 | # Save quantized model
30 | # NOTE: pass safetensors=True to save quantized model weights as safetensors
31 | model.save_quantized(quant_path)
32 | tokenizer.save_pretrained(quant_path)
33 |
34 | print(f'QUIK Model is quantized and saved at "{quant_path}"')
--------------------------------------------------------------------------------
/examples/basic_safetensors_generate.py:
--------------------------------------------------------------------------------
1 | from awq import AutoAWQForCausalLM
2 | from transformers import AutoTokenizer, TextStreamer
3 |
4 | quant_path = "casperhansen/opt-125m-awq"
5 |
6 | # Load model
7 | model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True, safetensors=True)
8 | tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
9 | streamer = TextStreamer(tokenizer, skip_special_tokens=True)
10 |
11 | # Convert prompt to tokens
12 | prompt_template = """\
13 | A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
14 |
15 | USER: {prompt}
16 | ASSISTANT:"""
17 |
18 | tokens = tokenizer(
19 | prompt_template.format(prompt="How are you today?"),
20 | return_tensors='pt'
21 | ).input_ids.cuda()
22 |
23 | # Generate output
24 | generation_output = model.generate(
25 | tokens,
26 | streamer=streamer,
27 | max_new_tokens=512
28 | )
29 |
--------------------------------------------------------------------------------
/examples/benchkernels.sh:
--------------------------------------------------------------------------------
1 | for i in {0..39}; do
2 | # echo "evaluate ${file}" >> eval_out
3 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py $i result/llama13b-1.csv LLama-2-13b q >>temp13
4 | done
5 |
6 | for i in {0..31}; do
7 | # echo "evaluate ${file}" >> eval_out
8 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py $i result/llama7b-1.csv LLama-2-7b q >>temp7
9 | done
10 |
11 | for i in {0..31}; do
12 | # echo "evaluate ${file}" >> eval_out
13 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py $i result/llama70b-1.csv LLama-2-70b q >>temp70
14 | done
15 |
16 | for i in {0..31}; do
17 | # echo "evaluate ${file}" >> eval_out
18 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py $i result/llama70b-down.csv LLama-2-70b down
19 | done
20 |
21 | for i in {0..31}; do
22 | # echo "evaluate ${file}" >> eval_out
23 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py $i result/llama70b-up.csv LLama-2-70b up
24 | done
25 |
--------------------------------------------------------------------------------
/examples/benchlayers.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 0 result/llama13b-bd-unschedule.csv LLama-2-13b q
4 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 20 result/llama13b-bd-unschedule.csv LLama-2-13b q
5 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 39 result/llama13b-bd-unschedule.csv LLama-2-13b q
6 |
7 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 0 result/llama7b-bd-unschedule.csv LLama-2-7b q
8 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 15 result/llama7b-bd-unschedule.csv LLama-2-7b q
9 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 31 result/llama7b-bd-unschedule.csv LLama-2-7b q
10 |
11 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 0 result/llama70b-bd-unschedule.csv LLama-2-70b q
12 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 15 result/llama70b-bd-unschedule.csv LLama-2-70b q
13 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 31 result/llama70b-bd-unschedule.csv LLama-2-70b q
14 |
15 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 0 result/llama7b-up-1.csv LLama-2-7b up
16 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 15 result/llama7b-up-1.csv LLama-2-7b up
17 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 31 result/llama7b-up-1.csv LLama-2-7b up
18 |
19 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 0 result/llama7b-down-1.csv LLama-2-7b down
20 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 15 result/llama7b-down-1.csv LLama-2-7b down
21 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 31 result/llama7b-down-1.csv LLama-2-7b down
22 |
23 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 0 result/llama70b-down-1.csv LLama-2-70b down
24 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 15 result/llama70b-down-1.csv LLama-2-70b down
25 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 31 result/llama70b-down-1.csv LLama-2-70b down
26 |
27 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 0 result/llama13b-down-1.csv LLama-2-13b down
28 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 20 result/llama13b-down-1.csv LLama-2-13b down
29 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 39 result/llama13b-down-1.csv LLama-2-13b down
30 |
31 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 0 result/llama70b-up-1.csv LLama-2-70b up
32 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 15 result/llama70b-up-1.csv LLama-2-70b up
33 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 31 result/llama70b-up-1.csv LLama-2-70b up
34 |
35 |
--------------------------------------------------------------------------------
/examples/benchmark.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | os.environ["CUDA_VISIBLE_DEVICES"] = "6"
4 | os.environ["WORLD_SIZE"] = "1"
5 | import time
6 | import torch
7 | import argparse
8 | import numpy as np
9 | import pandas as pd
10 | from awq import AutoAWQForCausalLM
11 | from transformers import AutoTokenizer
12 | from torch.cuda import OutOfMemoryError
13 |
14 | def warmup(model):
15 | warm_up = torch.randn((4096,4096)).to(next(model.parameters()).device)
16 | torch.mm(warm_up,warm_up)
17 |
18 |
19 | def load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True):
20 | from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils
21 | config = LlamaConfig.from_pretrained(model)
22 |
23 | def noop(*args, **kwargs):
24 | pass
25 |
26 | torch.nn.init.kaiming_uniform_ = noop
27 | torch.nn.init.uniform_ = noop
28 | torch.nn.init.normal_ = noop
29 |
30 | torch.set_default_dtype(torch.half)
31 | modeling_utils._init_weights = False
32 | torch.set_default_dtype(torch.half)
33 | model = LlamaForCausalLM(config)
34 | torch.set_default_dtype(torch.float)
35 | if eval:
36 | model = model.eval()
37 | layers = find_layers(model)
38 | for name in ['lm_head']:
39 | if name in layers:
40 | del layers[name]
41 | quant.make_quant_linear(model, layers, wbits, groupsize)
42 |
43 | del layers
44 |
45 | print('Loading model ...')
46 | if checkpoint.endswith('.safetensors'):
47 | from safetensors.torch import load_file as safe_load
48 | model.load_state_dict(safe_load(checkpoint))
49 | else:
50 | model.load_state_dict(torch.load(checkpoint))
51 |
52 | if eval:
53 | quant.make_quant_attn(model)
54 | quant.make_quant_norm(model)
55 | if fused_mlp:
56 | quant.make_fused_mlp(model)
57 |
58 | if warmup_autotune:
59 | quant.autotune_warmup_linear(model, transpose=not (eval))
60 | if eval and fused_mlp:
61 | quant.autotune_warmup_fused(model)
62 | model.seqlen = 2048
63 | print('Done.')
64 |
65 | return model
66 |
67 | def generate(model, input_ids, n_generate):
68 | context_time = 0
69 | generate_time = []
70 |
71 | with torch.inference_mode():
72 | for i in range(n_generate):
73 | torch.cuda.synchronize()
74 | start = time.time()
75 |
76 | if i == 0:
77 | # prefill context
78 | inputs = torch.as_tensor(input_ids, device=next(model.parameters()).device)
79 | else:
80 | # decode tokens
81 | inputs = torch.as_tensor(token, device=next(model.parameters()).device)
82 |
83 | out = model(inputs, use_cache=True)
84 |
85 | torch.cuda.synchronize()
86 | token = out[0][:, -1].max(1)[1].unsqueeze(1)
87 |
88 | if i == 0:
89 | context_time += time.time() - start
90 | else:
91 | generate_time.append(time.time() - start)
92 |
93 | return context_time, generate_time
94 |
95 | def run_round(model_path, quant_file, n_generate, input_ids, batch_size, safetensors):
96 | print(f" -- Loading model...")
97 | model = AutoAWQForCausalLM.from_quantized(
98 | model_path, quant_file, fuse_layers=True,
99 | max_new_tokens=n_generate, batch_size=batch_size,
100 | safetensors=safetensors
101 | )
102 |
103 | print(f" -- Warming up...")
104 | warmup(model)
105 |
106 | print(f" -- Generating {n_generate} tokens, {input_ids.shape[1]} in context...")
107 |
108 | try:
109 | context_time, generate_time = generate(model, input_ids, n_generate)
110 | successful_generate = True
111 | except RuntimeError as ex:
112 | if 'cuda out of memory' in str(ex).lower():
113 | successful_generate = False
114 | else:
115 | raise RuntimeError(ex)
116 |
117 | device = next(model.parameters()).device
118 | memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3)
119 | memory_pct = memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)) * 100
120 |
121 | if successful_generate:
122 | # number of tokens in context / time for processing context * batch size
123 | prefill_tokens_per_second = input_ids.shape[1] / context_time * batch_size
124 | # 1 second / median time per token in seconds * batch size
125 | decode_tokens_per_second = 1 / np.median(generate_time) * batch_size
126 |
127 | print(f" ** Speed (Prefill): {prefill_tokens_per_second:.2f} tokens/second")
128 | print(f" ** Speed (Decode): {decode_tokens_per_second:.2f} tokens/second")
129 | print(f" ** Max Memory (VRAM): {memory_used:.2f} GB ({memory_pct:.2f}%)")
130 | else:
131 | prefill_tokens_per_second = 'OOM'
132 | decode_tokens_per_second = 'OOM'
133 |
134 | return {
135 | "Batch Size": batch_size,
136 | "Prefill Length": input_ids.shape[1],
137 | "Decode Length": n_generate,
138 | "Prefill tokens/s": prefill_tokens_per_second,
139 | "Decode tokens/s": decode_tokens_per_second,
140 | "Memory (VRAM)": f"{memory_used:.2f} GB ({memory_pct:.2f}%)"
141 | }, model.quant_config["version"]
142 |
143 | def main(args):
144 | rounds = [
145 | {"context": 32, "n_generate": 32},
146 | {"context": 64, "n_generate": 64},
147 | {"context": 128, "n_generate": 128},
148 | {"context": 256, "n_generate": 256},
149 | {"context": 512, "n_generate": 512},
150 | {"context": 1024, "n_generate": 1024},
151 | {"context": 2048, "n_generate": 2048},
152 | ]
153 |
154 | all_stats = []
155 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
156 |
157 | for settings in rounds:
158 | input_ids = torch.randint(0, tokenizer.vocab_size, (args.batch_size, settings["context"])).cuda()
159 |
160 | stats, model_version = run_round(
161 | args.model_path,
162 | args.quant_file,
163 | settings["n_generate"],
164 | input_ids,
165 | args.batch_size,
166 | args.safetensors
167 | )
168 |
169 | all_stats.append(stats)
170 |
171 | if stats["Prefill tokens/s"] == 'OOM':
172 | break
173 |
174 | df = pd.DataFrame(all_stats)
175 | print('GPU:', torch.cuda.get_device_name())
176 | print('Model:', args.model_path)
177 | print('Version:', model_version)
178 | print(df.to_markdown(index=False))
179 | df.to_csv(args.quant_file.split("/")[-1]+".csv"+str(args.batch_size))
180 |
181 | if __name__ == "__main__":
182 |
183 | """
184 | python examples/benchmark.py --model_path /mnt/data/zhongrx/Llama-2-7b-hf --quant_file /mnt/data/chenyd/Llama-2-7b-awq
185 |
186 | """
187 | parser = argparse.ArgumentParser()
188 | parser.add_argument("--model_path", type=str, default="casperhansen/vicuna-7b-v1.5-awq", help="path to the model")
189 | parser.add_argument("--quant_file", type=str, default="awq_model_w4_g128.pt", help="weights filename")
190 | parser.add_argument("--batch_size", type=int, default=1, help="Batch size for cache and generation")
191 | parser.add_argument("--safetensors", default=False, action="store_true", help="Use for enabling safetensors")
192 | args = parser.parse_args()
193 |
194 | main(args)
--------------------------------------------------------------------------------
/examples/breakdown.sh:
--------------------------------------------------------------------------------
1 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 0 result/llama70b-up-1.csv LLama-2-70b up
2 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 15 result/llama70b-up-1.csv LLama-2-70b up
3 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchbitsand.py 31 result/llama70b-up-1.csv LLama-2-70b up
4 |
--------------------------------------------------------------------------------
/examples/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from lm_eval import evaluator
3 | from awq import AutoAWQForCausalLM
4 | from transformers import AutoTokenizer
5 | from awq.utils.lm_eval_adaptor import LMEvalAdaptor
6 |
7 | def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot, task_use_pretrained):
8 | """
9 | Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
10 | """
11 | # Load model
12 | if task_use_pretrained:
13 | model = AutoAWQForCausalLM.from_pretrained(model_path)
14 | else:
15 |
16 | model = AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=False)
17 |
18 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
19 |
20 | # Load adapter
21 | lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=task_batch_size)
22 |
23 | # Evaluate perplexity of quantized model
24 | results = evaluator.simple_evaluate(
25 | model=lm_eval_model,
26 | tasks=tasks.split(','),
27 | batch_size=task_batch_size,
28 | no_cache=True,
29 | num_fewshot=task_n_shot,
30 | )
31 |
32 | print(evaluator.make_table(results))
33 |
34 | if __name__ == '__main__':
35 | """
36 | - Run perplexity of quantized model:
37 | python examples/eval.py --model_path /mnt/data/zhongrx/Llama-2-7b-hf --quant_file /mnt/data/chenyd/Llama-2-7b-awq
38 |
39 | - Run perplexity unquantized FP16 model:
40 | python examples/eval.py --use_pretrained --model_path lmsys/vicuna-7b-v1.5
41 | """
42 |
43 | parser = argparse.ArgumentParser()
44 | parser.add_argument('--model_path', type=str, help='Path to hf model')
45 | parser.add_argument('--quant_file', default='', type=str, help='Path to quantized AWQ model file')
46 | parser.add_argument('--device', type=str, default='cuda:0', help='Device to load model to')
47 | parser.add_argument("--use_pretrained", default=False, action='store_true',
48 | help="Pass '--use_pretrained' to use a pretrained model running FP16")
49 | parser.add_argument('--tasks', type=str, default='wikitext', help='Tasks to evaluate. '
50 | 'Separate tasks by comma for multiple tasks.'
51 | 'https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md')
52 | parser.add_argument('--batch_size', type=int, default=1)
53 | parser.add_argument('--n_shot', type=int, default=0)
54 | args = parser.parse_args()
55 |
56 | run_eval(args.model_path, args.quant_file, args.device,
57 | args.tasks, args.batch_size, args.n_shot, args.use_pretrained)
--------------------------------------------------------------------------------
/examples/mmlu.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | if [ $2 == a100 ]
4 | then
5 | CMD=" srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python "
6 | else
7 | CMD="srun -p twills -A h100 --gres=gpu:h100:1 --export=ALL python"
8 | fi
9 | set -ex
10 | quantpath=/home/dataset/quant/quant
11 | modelpath=/home/dataset
12 |
13 |
14 | models=( "Llama-2-7b" )
15 | ngpu=1
16 |
17 |
18 | data_type=$3
19 | if [ ${data_type} == mix8 ]
20 | then
21 | bit=${data_type:3:3}
22 | for model in "${models[@]}"
23 | do
24 | echo ${model}
25 |
26 | CUDA_VISIBLE_DEVICES=$1 ${CMD} mmlu.py \
27 | --model_type ${data_type} --hf_model_dir ${quantpath}${bit}/${model}
28 |
29 | done
30 | fi
31 |
32 | if [ ${data_type} == fp16 ]
33 | then
34 |
35 | for model in "${models[@]}"
36 | do
37 | echo ${model}
38 |
39 | CUDA_VISIBLE_DEVICES=$1 ${CMD} mmlu.py \
40 | --model_type ${data_type} --hf_model_dir ${modelpath}/${model}
41 |
42 | done
43 | fi
44 |
45 |
--------------------------------------------------------------------------------
/examples/overhead.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python testoverhead.py
--------------------------------------------------------------------------------
/examples/smooth_quant_get_act.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from datasets import load_dataset
5 | import functools
6 | from collections import defaultdict
7 |
8 | from functools import partial
9 | import numpy as np
10 | from tqdm import tqdm
11 |
12 |
13 | def get_act_scales(model, tokenizer, dataset_path, num_samples=512, seq_len=512):
14 | model.eval()
15 | device = next(model.parameters()).device
16 | act_scales = {}
17 |
18 | def stat_tensor(name, tensor):
19 | hidden_dim = tensor.shape[-1]
20 | tensor = tensor.view(-1, hidden_dim).abs().detach()
21 | comming_max = torch.max(tensor, dim=0)[0].float().cpu()
22 | if name in act_scales:
23 | act_scales[name] = torch.max(act_scales[name], comming_max)
24 | else:
25 | act_scales[name] = comming_max
26 |
27 | def stat_input_hook(m, x, y, name):
28 | if isinstance(x, tuple):
29 | x = x[0]
30 | stat_tensor(name, x)
31 |
32 | hooks = []
33 | for name, m in model.named_modules():
34 | if isinstance(m, nn.Linear):
35 | hooks.append(
36 | m.register_forward_hook(functools.partial(stat_input_hook, name=name))
37 | )
38 |
39 | dataset = load_dataset("json", data_files=dataset_path, split="train")
40 | dataset = dataset.shuffle(seed=42)
41 |
42 | for i in tqdm(range(num_samples)):
43 | input_ids = tokenizer(
44 | dataset[i]["text"], return_tensors="pt", max_length=seq_len, truncation=True
45 | ).input_ids.to(device)
46 | model(input_ids)
47 |
48 | for h in hooks:
49 | h.remove()
50 |
51 | return act_scales
52 |
53 |
54 | @torch.no_grad()
55 | def get_static_decoder_layer_scales(
56 | model,
57 | tokenizer,
58 | dataset_path,
59 | num_samples=512,
60 | seq_len=512,
61 | ):
62 | model.eval()
63 | device = next(model.parameters()).device
64 |
65 | act_dict = defaultdict(dict)
66 |
67 | def stat_io_hook(m, x, y, name):
68 | if isinstance(x, tuple):
69 | x = x[0]
70 | if name not in act_dict or "input" not in act_dict[name]:
71 | act_dict[name]["input"] = x.detach().abs().max().item()
72 | else:
73 | act_dict[name]["input"] = max(
74 | act_dict[name]["input"], x.detach().abs().max().item()
75 | )
76 | if isinstance(y, tuple):
77 | y = y[0]
78 | if name not in act_dict or "output" not in act_dict[name]:
79 | act_dict[name]["output"] = y.detach().abs().max().item()
80 | else:
81 | act_dict[name]["output"] = max(
82 | act_dict[name]["output"], y.detach().abs().max().item()
83 | )
84 |
85 | hooks = []
86 | for name, m in model.named_modules():
87 | if isinstance(m, torch.nn.Linear):
88 | hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name)))
89 |
90 | print("Collecting activation scales...")
91 | pbar = tqdm(range(num_samples))
92 | dataset = load_dataset("json", data_files=dataset_path, split="train")
93 | dataset = dataset.shuffle(seed=42)
94 | for i in pbar:
95 | input_ids = tokenizer(
96 | dataset[i]["text"], return_tensors="pt", max_length=seq_len, truncation=True
97 | ).input_ids.to(device)
98 | model(input_ids)
99 | mean_scale = np.mean([v["input"] for v in act_dict.values()])
100 | pbar.set_description(f"Mean input scale: {mean_scale:.2f}")
101 | for hook in hooks:
102 | hook.remove()
103 |
104 | decoder_layer_scales = []
105 | for idx in range(model.config.num_hidden_layers):
106 | scale_dict = {}
107 | scale_dict["attn_input_scale"] = (
108 | act_dict[f"model.decoder.layers.{idx}.self_attn.q_proj"]["input"] / 127
109 | )
110 | scale_dict["q_output_scale"] = (
111 | act_dict[f"model.decoder.layers.{idx}.self_attn.q_proj"]["output"] / 127
112 | )
113 | scale_dict["k_output_scale"] = (
114 | act_dict[f"model.decoder.layers.{idx}.self_attn.k_proj"]["output"] / 127
115 | )
116 | scale_dict["v_output_scale"] = (
117 | act_dict[f"model.decoder.layers.{idx}.self_attn.v_proj"]["output"] / 127
118 | )
119 | scale_dict["out_input_scale"] = (
120 | act_dict[f"model.decoder.layers.{idx}.self_attn.out_proj"]["input"] / 127
121 | )
122 | scale_dict["fc1_input_scale"] = (
123 | act_dict[f"model.decoder.layers.{idx}.fc1"]["input"] / 127
124 | )
125 | scale_dict["fc2_input_scale"] = (
126 | act_dict[f"model.decoder.layers.{idx}.fc2"]["input"] / 127
127 | )
128 | decoder_layer_scales.append(scale_dict)
129 |
130 | return decoder_layer_scales, act_dict
131 |
132 |
133 | import torch
134 | import os
135 |
136 | from transformers import (
137 | AutoModelForCausalLM,
138 | AutoTokenizer,
139 | )
140 | import argparse
141 |
142 |
143 | def build_model_and_tokenizer(model_name):
144 |
145 | print(model_name)
146 | tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512, trust_remote_code=True)
147 | kwargs = {"torch_dtype": torch.float16, "device_map": "sequential"}
148 | model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
149 | return model, tokenizer
150 |
151 |
152 | def parse_args():
153 | parser = argparse.ArgumentParser()
154 | parser.add_argument('--model-name', type=str,
155 | default='facebook/opt-1.3b', help='model name')
156 | parser.add_argument('--output-path', type=str, default='act_scales/opt-1.3b.pt',
157 | help='where to save the act scales')
158 | parser.add_argument('--dataset-path', type=str, default='dataset/val.jsonl.zst',
159 | help='location of the calibration dataset, we use the validation set of the Pile dataset')
160 | parser.add_argument('--num-samples', type=int, default=512)
161 | parser.add_argument('--seq-len', type=int, default=512)
162 | args = parser.parse_args()
163 | return args
164 |
165 | if __name__ == '__main__':
166 |
167 |
168 |
169 |
170 | args = parse_args()
171 | model, tokenizer = build_model_and_tokenizer(args.model_name)
172 |
173 | act_scales = get_act_scales(model, tokenizer, args.dataset_path,
174 | args.num_samples, args.seq_len)
175 |
176 | os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
177 | torch.save(act_scales, args.output_path)
178 |
179 |
180 |
181 |
--------------------------------------------------------------------------------
/examples/testgenerate.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | import os
5 | import time
6 | import psutil
7 | import random
8 | import torch
9 | import numpy as np
10 | from torch.nn.parameter import Parameter
11 |
12 | from transformers import AutoTokenizer, LlamaModel, LlamaForCausalLM, LlamaTokenizer, AutoConfig, AutoModelForCausalLM
13 |
14 |
15 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
16 |
17 | def set_random_seed(seed):
18 | random.seed(seed)
19 | torch.manual_seed(seed)
20 | torch.backends.cudnn.deterministic = True
21 | torch.cuda.manual_seed_all(seed)
22 | np.random.seed(seed)
23 |
24 |
25 | def test_from_fp16():
26 | torch.set_printoptions(precision=6, sci_mode=False)
27 | torch.set_grad_enabled(False)
28 | set_random_seed(1)
29 |
30 | # model_name = '/root/data/models/2023/llama-13B-v1/'
31 | model_name = '/mnt/octave/data/chenyidong/checkpoint/Llama-2-7b'
32 | MAX_NEW_TOKENS = 32
33 |
34 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
35 | config = AutoConfig.from_pretrained(model_name)
36 | # config.num_hidden_layers = 1
37 |
38 | free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
39 | max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-1}GB'
40 |
41 | n_gpus = torch.cuda.device_count()
42 | max_memory = {i: max_memory for i in range(n_gpus)}
43 |
44 | model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.float16)
45 | model.eval()
46 |
47 | # from eetq.utils import eet_accelerator
48 | # eet_accelerator(model, quantize=True, fused_attn=True, dev="cuda:0")
49 | model.to("cuda:0")
50 | # for k, v in model.state_dict().items():
51 | # print(k, v.shape, v.dtype, v.device)
52 | # torch.save(model, "eetq_llama13B_model_fused_attn_v2.pt")
53 |
54 | prompt_template = "[INST] {prompt} [/INST]"
55 |
56 | prompt = "You're standing on the surface of the Earth. "\
57 | "You walk one mile south, one mile west and one mile north. "\
58 | "You end up exactly where you started. Where are you?"
59 |
60 |
61 | messages = []
62 | messages.append({"role": "user", "content": prompt})
63 | response = model.chat(tokenizer, messages)
64 | print(response)
65 | test_from_fp16()
--------------------------------------------------------------------------------
/figures/awq32.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/figures/awq32.gif
--------------------------------------------------------------------------------
/figures/awq512.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/figures/awq512.gif
--------------------------------------------------------------------------------
/figures/mixq32.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/figures/mixq32.gif
--------------------------------------------------------------------------------
/figures/mixq512.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/figures/mixq512.gif
--------------------------------------------------------------------------------
/figures/textmixq.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/figures/textmixq.jpg
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | os.environ["WORLD_SIZE"] = "1"
5 | import argparse
6 | import pandas as pd
7 | import torch
8 | from utils.utils import Perplexity
9 | from transformers import AutoTokenizer
10 |
11 |
12 |
13 |
14 |
15 |
16 | def get_fp_features_num(module: torch.nn.Linear, args):
17 | fp_features_num = args.fp_features_num
18 | if args.fp_features_frac is not None:
19 | fp_features_num = max(int(module.in_features * args.fp_features_frac), fp_features_num)
20 | return fp_features_num
21 | def llama_replace_with_kernels(model, args):
22 | import modelutils
23 | layers = model.model.layers
24 | shared_inputs = {}
25 |
26 | assert not args.w_asym, 'Benchmarking only supports symmetric weight quantization!'
27 | print("Replace with INT4 kernels.")
28 | for i in range(len(layers)):
29 | opt_block = layers[i]
30 | sequential = [
31 | ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'],
32 | ['self_attn.o_proj'],
33 | ['mlp.up_proj', 'mlp.gate_proj'],
34 | ['mlp.down_proj']
35 | ]
36 | full = modelutils.find_layers(opt_block)
37 | for j, layer_group in enumerate(sequential):
38 | subset = {n: full[n] for n in layer_group}
39 | shared_inputs[f"{i}.{j}"] = qlinear.SharedQuantizedInput(len(layer_group))
40 | for name in subset:
41 | layer = subset[name]
42 | if 'lm_head' in name or 'rotary_emb' in name:
43 | continue
44 | is_quantized = False
45 | bits = 16
46 | fp_features = 0
47 | import quant_sim
48 | import qlinear
49 | if isinstance(layer, quant_sim.ActQuantWrapper):
50 | if layer.quantizer.configured:
51 | is_quantized = True
52 | bits = layer.quantizer.bits
53 | fp_features = layer.fp_features_num
54 | layer = layer.module
55 | layer_weight = layer.weight.data
56 |
57 | layer_scale = save_dict['model.layers.{}.{}.scale'.format(i, name)]
58 | if fp_features == 0:
59 | fp_feature_idx = None
60 | else:
61 | print('---------------save act_scales----------------')
62 | layer_act_scales = act_scales['model.layers.{}.{}'.format(i, name)]
63 | fp_feature_idx = torch.sort(layer_act_scales)[1][-fp_features:]
64 |
65 | if is_quantized:
66 | int_mod = qlinear.MixedQLinear.from_float(layer, layer_weight, layer_scale,
67 | shared_inputs[f"{i}.{j}"], fp_feature_idx,
68 | bits=bits)
69 | else:
70 | int_mod = layer
71 | modelutils.replace_single_mod_opt(opt_block, name, int_mod)
72 |
73 |
74 |
75 |
76 |
77 |
78 | if __name__ == "__main__":
79 | """
80 | Example usage.
81 |
82 | Default usage with GPT2 model:
83 | python examples/benchmark/perplexity.py
84 |
85 | Specify GPTQ quantized model:
86 | http_proxy=127.0.0.1:7890 https_proxy=127.0.0.1:7890 CUDA_VISIBLE_DEVICES=0 WORLD_SIZE=1 python examples/benchmark/perplexity.py \
87 | --model_name /mnt/data/zhongrx/Llama-2-7b \
88 | --model_basename gptq_model-4bit-128g \
89 | --is_quantized
90 |
91 | Change your dataset:
92 | python examples/benchmark/perplexity.py --dataset_path tiny_shakespeare
93 |
94 | """
95 | parser = argparse.ArgumentParser(description="Calculate Perplexity for a model.")
96 | parser.add_argument("--model_path", type=str, help="Model path")
97 | parser.add_argument("--quant_file", type=str, help="quant_file Model path")
98 |
99 | parser.add_argument("--model_type", type=str, default='bitsandbytesfp16')
100 |
101 |
102 | parser.add_argument("--n_ctx", type=int, default=256, help="Context size.")
103 | parser.add_argument("--n_batch", type=int, default=256, help="Batch size.")
104 | parser.add_argument("--dataset_path", type=str, default='wikitext', help="Path to the dataset.")
105 | parser.add_argument("--dataset_name", type=str, default=None, help="Name of the dataset.")
106 | parser.add_argument("--split", type=str, default='test', help="Dataset split to use.")
107 | parser.add_argument("--text_column", type=str, default='text', help="Column in the dataset containing the text.")
108 | parser.add_argument("--per_gpu_max_memory", type=int, default=None, help="Max memory used in each GPU.")
109 | parser.add_argument("--cpu_max_memory", type=int, default=None, help="Mx memory used in CPU.")
110 |
111 | parser.add_argument("--use_safetensors", action="store_true", help="Whether to use safetensors model file")
112 | parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer")
113 | parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code")
114 | parser.add_argument("--disable_exllama", action="store_true", help="Whether to use disable exllama kernel")
115 |
116 |
117 | # Weight Quantization Params:
118 | parser.add_argument('--w_bits', type=int, default=16, choices=[4, 8, 16])
119 |
120 |
121 | parser.add_argument('--int8_down_proj', action='store_true', help='Use INT8 for Down Projection')
122 | parser.add_argument('--fp_features_frac', type=float, default=None, help='Fraction of features to keep in FP16.')
123 | parser.add_argument("--fp_features_num", type=int, default=1, help="outliers")
124 |
125 | parser.add_argument('--eval_accuracy', type=bool, default=True)
126 | parser.add_argument('--eval_throughput', type=bool, default=False)
127 |
128 |
129 | args = parser.parse_args()
130 |
131 | if args.eval_throughput is True:
132 | args.eval_accuracy = False
133 |
134 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
135 |
136 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=args.use_fast_tokenizer, trust_remote_code=True)
137 | if not tokenizer.pad_token_id:
138 | tokenizer.pad_token_id = tokenizer.eos_token_id
139 | ppl = Perplexity(None, tokenizer, args.dataset_path, args.dataset_name, args.split, args.text_column, args.eval_accuracy)
140 |
141 |
142 | model_path = args.model_path
143 | quant_file = args.quant_file
144 |
145 | if args.model_type == 'bitsandbytesfp16':
146 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
147 | print(f" -- Loading model fp16...")
148 | # model = transformers.LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16,
149 | # device_map='auto')
150 | model = AutoModelForCausalLM.from_pretrained(
151 | model_path, torch_dtype=torch.bfloat16,
152 | device_map='auto', trust_remote_code=True
153 | )
154 |
155 | model = model.to('cuda')
156 | print(model)
157 |
158 | if args.model_type == 'bitsandbytesmix4':
159 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
160 | print(f" -- Loading model mix4...")
161 |
162 | n_gpus = torch.cuda.device_count()
163 | max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
164 | max_memory = {i: max_memory for i in range(n_gpus)}
165 | quantization_config = BitsAndBytesConfig(
166 | load_in_4bit=True,
167 | llm_int4_threshold=6.0,
168 | llm_int4_has_fp16_weight=False,
169 | )
170 | model = AutoModelForCausalLM.from_pretrained(
171 | model_path,
172 | device_map='auto',
173 | max_memory=max_memory,
174 | quantization_config=quantization_config
175 | )
176 | if args.model_type == 'bitsandbytes':
177 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
178 | print(f" -- Loading model mix bit8...")
179 |
180 | n_gpus = torch.cuda.device_count()
181 | max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
182 | max_memory = {i: max_memory for i in range(n_gpus)}
183 | quantization_config = BitsAndBytesConfig(
184 | load_in_8bit=True,
185 | llm_int8_threshold=6.0,
186 | llm_int8_has_fp16_weight=False,
187 | )
188 | model = AutoModelForCausalLM.from_pretrained(
189 | model_path,
190 | device_map='auto',
191 | max_memory=max_memory,
192 | quantization_config=quantization_config
193 | )
194 |
195 |
196 | if args.model_type == 'awq':
197 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
198 | print(f" -- Loading model awq...")
199 |
200 |
201 | from awq import AutoAWQForCausalLM
202 | model = AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=True, mix = False)
203 |
204 |
205 | if args.model_type == 'mix':
206 | from mixquant.Cache import MixLibCache
207 | from mixquant import AutoForCausalLM
208 | cache = MixLibCache(args.n_batch)
209 |
210 |
211 | model = AutoForCausalLM.from_quantized(
212 | model_path, quant_file, fuse_layers=True,
213 | mix = True, cache = cache
214 | )
215 |
216 | if args.model_type == 'fp16':
217 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
218 | model = AutoModelForCausalLM.from_pretrained(
219 | model_path, torch_dtype=torch.float16,
220 | device_map='auto', trust_remote_code=True
221 | )
222 |
223 | #model = model.to('cuda')
224 |
225 | if args.model_type == 'quik':
226 | from mixquant import AutoForCausalLM
227 | model = AutoForCausalLM.from_quantized(
228 | model_path, quant_file, fuse_layers=True,
229 | max_new_tokens=args.n_generate, batch_size=args.batch_size,
230 | safetensors=args.safetensors,
231 | mix = True,
232 | cache = cache
233 | )
234 | print(model)
235 | ppl = Perplexity(model, tokenizer, args.dataset_path, args.dataset_name,
236 | args.split, args.text_column, args.eval_accuracy)
237 | allppl = ppl.calculate_perplexity(args.n_ctx, args.n_batch)
238 |
239 |
240 |
241 |
242 |
--------------------------------------------------------------------------------
/generate.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | if [ $2 == a100 ]
4 | then
5 | CMD=" srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python "
6 | fi
7 |
8 | if [ $2 == h100 ]
9 | then
10 | CMD="srun -p twills -A h100 --gres=gpu:h100:1 --export=ALL python"
11 | fi
12 |
13 | export http_proxy=127.0.0.1:7890
14 | export https_proxy=127.0.0.1:7890
15 | set -x
16 |
17 | quantpath=/home/dataset/quant/quant
18 | modelpath=/home/dataset
19 |
20 | for batch in 32
21 | #for batch in 1
22 |
23 | do
24 | for seq in 1024
25 | do
26 | ##model_type=Aquila2
27 | #model_type=opt
28 | #model_type=Mistral
29 | #model_type=gpt-j
30 | #model_type=falcon
31 | model_type=$3
32 |
33 |
34 |
35 | data_types=( "mix" )
36 | for bit in 8
37 | do
38 | for data_type in "${data_types[@]}"
39 | do
40 | model=${model_type}
41 | echo ${model}
42 | rm -r ${quantpath}${bit}/${model}/model.safetensors
43 | CUDA_VISIBLE_DEVICES=$1 ${CMD} generate.py --model_type ${data_type} --model_path \
44 | ${quantpath}${bit}/${model} \
45 | --quant_file ${quantpath}${bit}/${model} \
46 | --n_batch ${batch} --n_ctx ${batch} --dataset_path /home/chenyidong/checkpoint/dataset
47 | done
48 | done
49 |
50 |
51 | # data_types=( "quik" )
52 | # for bit in 4
53 | # do
54 | # for data_type in "${data_types[@]}"
55 | # do
56 | # model=${model_type}
57 | # echo ${model}
58 | # CUDA_VISIBLE_DEVICES=$1 ${CMD} generate.py --model_type ${data_type} --model_path \
59 | # ${quantpath}quik${bit}/${model} \
60 | # --quant_file ${quantpath}quik${bit}/${model} \
61 | # --batch_size ${batch} --bit ${bit} --dataset_path /home/chenyidong/checkpoint/dataset
62 | # done
63 | # done
64 | # data_types=( "bitsandbytes" )
65 | # for data_type in "${data_types[@]}"
66 | # do
67 | # model=${model_type}
68 |
69 | # echo ${model}
70 | # CUDA_VISIBLE_DEVICES=$1 ${CMD} benchflops.py --model_type ${data_type} --model_path \
71 | # ${modelpath}/${model} \
72 | # --quant_file ${modelpath}/${model} --batch_size ${batch} --dataset_path /home/chenyidong/checkpoint/dataset
73 |
74 | # done
75 | # data_types=( "awq" )
76 | # for data_type in "${data_types[@]}"
77 | # do
78 | # for model in "${models[@]}"
79 | # do
80 | # echo ${model}
81 | # CUDA_VISIBLE_DEVICES=$1 ${CMD} benchflops.py --model_type ${data_type} --model_path \
82 | # ${quantpath}/awqquant/${model} \
83 | # --quant_file ${quantpath}t/awqquant/${model} --batch_size ${batch}
84 | # done
85 | # done
86 |
87 |
88 |
89 | done
90 | done
91 |
--------------------------------------------------------------------------------
/get_act.sh:
--------------------------------------------------------------------------------
1 |
2 | CMD=" srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python "
3 | set -x
4 |
5 | model=( Llama-2-7b )
6 | model=( Aquila2-7b )
7 | model=( Baichuan2-7b )
8 | $CMD examples/smooth_quant_get_act.py --model-name /mnt/octave/data/chenyidong/checkpoint/${model} \
9 | --output-path /home/chenyidong/SC3/MixQ/src/act_scales/${model}.pt --dataset-path /home/chenyidong/val.jsonl.zst
10 |
11 |
--------------------------------------------------------------------------------
/mixquant/Cache.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 |
5 | class MixLibCache:
6 | def __init__(self, inputdim = 1024, sigma = 6, bit = 8, eval_ppl = False, locality = False):
7 | self.device = 'cuda'
8 | self.x_scale = torch.zeros((inputdim,1),dtype=torch.float16).to('cuda')
9 |
10 | self.sigma = torch.zeros((1,1),dtype=torch.float16).to('cuda')
11 | self.zeros = torch.zeros((inputdim,12288*3),dtype=torch.float16).to('cuda')
12 | self.sigma[0] = sigma
13 |
14 |
15 | self.ind = None
16 | self.shape = None
17 | self.activation_outliers = None
18 | self.is_prefill = False
19 | self.bit = bit
20 |
21 | self.max_outliers = 256
22 | self.stop = 2
23 |
24 | self.eval_ppl = eval_ppl
25 | self.locality = locality
26 | def do_bench_cudagraph(self, fn):
27 | if torch.cuda.current_stream() == torch.cuda.default_stream():
28 | raise RuntimeError("Cannot capture graph in default stream. Please use side stream in benchmark code.")
29 | # warmup
30 | for i in range(10):
31 | fn()
32 | g = torch.cuda.CUDAGraph()
33 | with torch.cuda.graph(g):
34 | fn()
35 | torch.cuda.synchronize()
36 |
37 |
38 | return g
39 |
40 |
41 |
42 | class MLPCache:
43 | def __init__(self, max_batch_size = 4096):
44 | self.device = 'cuda'
45 | self.x_scale = torch.zeros((max_batch_size,1),dtype=torch.float16).to('cuda')
46 | self.ind = None
47 | self.shape = None
48 | self.activation_outliers = None
49 |
50 |
--------------------------------------------------------------------------------
/mixquant/__init__.py:
--------------------------------------------------------------------------------
1 | from mixquant.models.auto import AutoForCausalLM
--------------------------------------------------------------------------------
/mixquant/int8fusedkernel.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import triton
4 | import triton.language as tl
5 |
6 | from triton import Config, autotune, cdiv, heuristics, jit
7 | from triton import language as tl
8 | from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
9 | # `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
10 | def upcast_if_fp8(a):
11 | if "fp8" in str(a):
12 | return torch.float16
13 | return a
14 |
15 | _ordered_datatypes = [torch.int8, torch.float16, torch.bfloat16, torch.float32]
16 |
17 | def get_higher_dtype(a, b):
18 | a = upcast_if_fp8(a)
19 | b = upcast_if_fp8(b)
20 | if a is b:
21 | return a
22 |
23 | assert a in _ordered_datatypes
24 | assert b in _ordered_datatypes
25 |
26 | for d in _ordered_datatypes:
27 | if a is d:
28 | return b
29 | if b is d:
30 | return a
31 |
32 |
33 | def init_to_zero(name):
34 | return lambda nargs: nargs[name].zero_()
35 |
36 |
37 | def get_configs_io_bound():
38 | configs = []
39 | for num_stages in [2, 3, 4, 5, 6]:
40 | for block_m in [16, 32]:
41 | for block_k in [32, 64]:
42 | for block_kfp in [32, 64]:
43 | for block_n in [32, 64, 128, 256]:
44 | num_warps = 2 if block_n <= 64 else 4
45 | configs.append(
46 | Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,'BLOCK_Kfp': block_kfp, 'SPLIT_K': 1},
47 | num_stages=num_stages, num_warps=num_warps))
48 | # split_k
49 | for split_k in [2, 4, 8, 16]:
50 | configs.append(
51 | Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'BLOCK_Kfp': block_kfp, 'SPLIT_K': split_k},
52 | num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
53 | return configs
54 |
55 |
56 |
57 |
58 |
59 | @autotune(
60 | configs=[
61 | # basic configs for compute-bound matmuls
62 | Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8),
63 | Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8),
64 | Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),
65 | Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),
66 | Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),
67 | Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),
68 | Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),
69 | Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),
70 | Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=5, num_warps=2),
71 | # good for int8
72 | Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8),
73 | Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8),
74 | Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),
75 | Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),
76 | Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),
77 | Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),
78 | Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),
79 | Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),
80 | Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=5, num_warps=2),
81 | ] + get_configs_io_bound(),
82 | key=['M', 'N'],
83 | prune_configs_by={
84 | 'early_config_prune': early_config_prune,
85 | 'perf_model': estimate_matmul_time,
86 | 'top_k': 10,
87 | },
88 | )
89 | @heuristics({
90 | 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
91 | })
92 |
93 | @triton.jit
94 | def matmul_kernelint8(x,w, A, B, C, M, N, K, #
95 | stride_am, stride_ak, #
96 | stride_bk, stride_bn, #
97 | stride_cm, stride_cn, #
98 |
99 | Afp, Bfp, Cfp,
100 | Kfp,
101 | stride_amfp, stride_akfp, #
102 | stride_bkfp, stride_bnfp, #
103 | stride_cmfp, stride_cnfp,
104 |
105 | acc_dtype: tl.constexpr, #
106 | allow_tf32: tl.constexpr, #
107 | fp8_fast_accum: tl.constexpr, #
108 | BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
109 | BLOCK_Kfp: tl.constexpr,
110 | GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr #
111 | ):
112 | # matrix multiplication
113 | pid = tl.program_id(0)
114 | pid_z = tl.program_id(1)
115 | grid_m = tl.cdiv(M, BLOCK_M)
116 | grid_n = tl.cdiv(N, BLOCK_N)
117 | # re-order program ID for better L2 performance
118 | width = GROUP_M * grid_n
119 | group_id = pid // width
120 | group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
121 | pid_m = group_id * GROUP_M + (pid % group_size)
122 | pid_n = (pid % width) // (group_size)
123 | # do matrix multiplication
124 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
125 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
126 | ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
127 | rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
128 | rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
129 | # pointers
130 | A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
131 | B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
132 |
133 |
134 | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)
135 |
136 |
137 | for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
138 | if EVEN_K:
139 | a = tl.load(A)
140 | b = tl.load(B)
141 | else:
142 | k_remaining = K - k * (BLOCK_K * SPLIT_K)
143 | _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
144 | a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
145 | b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
146 | acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=False)
147 |
148 | A += BLOCK_K * SPLIT_K * stride_ak
149 | B += BLOCK_K * SPLIT_K * stride_bk
150 |
151 |
152 |
153 |
154 | # rematerialize rm and rn to save registers
155 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
156 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
157 |
158 |
159 | #Cfp = Cfp + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
160 | C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
161 | mask = (rm < M)[:, None] & (rn < N)[None, :]
162 |
163 | # handles write-back with reduction-splitting
164 |
165 | # x_ = tl.load(x + 0)
166 |
167 | # accfp = acc.to(tl.float32)
168 | # accfp *= x_
169 | # accfp = accfp.to(tl.float16)
170 | if SPLIT_K == 1:
171 | tl.store(C, acc, mask=mask)
172 |
173 | else:
174 | tl.atomic_add(C, acc, mask=mask)
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 | def matmulint8_fused_dequant(x,w, a, b, afp, bfp, c, cfp16, M, N, K, Kfp):
184 |
185 | grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
186 |
187 |
188 | allow_tf32=True
189 | fp8_fast_accum=True
190 | matmul_kernelint8[grid](
191 | x, w,
192 | a, b, c, #
193 | M, N, K, #
194 | K, 1, #
195 | 1, K, #
196 | N, 1, #
197 | afp, bfp, cfp16, Kfp[0],
198 | Kfp, 1, #
199 | 1, Kfp, #
200 | N, 1, #
201 | allow_tf32=allow_tf32, #
202 | fp8_fast_accum=fp8_fast_accum, #
203 | GROUP_M=8, acc_dtype=tl.int32,AB_DTYPE=None
204 | )
205 | return c, cfp16
--------------------------------------------------------------------------------
/mixquant/kernel.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 |
4 | import triton
5 | import triton.language as tl
6 |
7 | from triton import Config, autotune, cdiv, heuristics, jit
8 | from triton import language as tl
9 | from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
10 |
11 |
12 |
13 | def get_configs_fp_io_bound():
14 | configs = []
15 | for num_stages in [2, 3, 4, 5, 6]:
16 | for block_m in [16, 32]:
17 | for block_kfp in [32, 64]:
18 | for block_n in [32, 64, 128, 256]:
19 | num_warps = 2 if block_n <= 64 else 4
20 | configs.append(
21 | Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_kfp, 'SPLIT_K': 1},
22 | num_stages=num_stages, num_warps=num_warps))
23 |
24 | return configs
25 |
26 | @autotune(
27 | configs=[
28 | # basic configs for compute-bound matmuls
29 | Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
30 | Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
31 | Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
32 | Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
33 | Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
34 | Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
35 | Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
36 | Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
37 | Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
38 | # good for int8
39 | Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
40 | Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
41 | Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
42 | Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
43 | Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
44 | Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
45 | Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
46 | Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
47 | Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
48 | ] + get_configs_fp_io_bound(),
49 | key=['M', 'N'],
50 | prune_configs_by={
51 | 'early_config_prune': early_config_prune,
52 | 'perf_model': estimate_matmul_time,
53 | 'top_k': 10,
54 | },
55 | )
56 |
57 | @triton.jit
58 | def matmul_kernelfp16(A, B, C, M, N, K,
59 | stride_amfp, stride_akfp, #
60 | stride_bkfp, stride_bnfp, #
61 | stride_cmfp, stride_cnfp,
62 | BLOCK_M: tl.constexpr,
63 | BLOCK_N: tl.constexpr,
64 | BLOCK_K: tl.constexpr,
65 | SPLIT_K: tl.constexpr,
66 | GROUP_M: tl.constexpr
67 | ):
68 | # matrix multiplication
69 | pid = tl.program_id(0)
70 | pid_z = tl.program_id(1)
71 | grid_m = tl.cdiv(M, BLOCK_M)
72 | grid_n = tl.cdiv(N, BLOCK_N)
73 | # re-order program ID for better L2 performance
74 | width = GROUP_M * grid_n
75 | group_id = pid // width
76 | group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
77 | pid_m = group_id * GROUP_M + (pid % group_size)
78 | pid_n = (pid % width) // (group_size)
79 | # do matrix multiplication
80 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
81 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
82 | ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
83 | rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
84 |
85 | # pointers
86 |
87 | rkfp = tl.arange(0, BLOCK_K)
88 | A = A + (ram[:, None] * stride_amfp + rkfp[None, :] * stride_akfp)
89 | B = B + (rkfp[:, None] * stride_bkfp + rbn[None, :] * stride_bnfp)
90 |
91 | accfp = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
92 |
93 |
94 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
95 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
96 |
97 | rmfp = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
98 | rnfp = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
99 |
100 |
101 | afp = tl.zeros((BLOCK_M, BLOCK_K), dtype=C.dtype.element_ty)
102 | bfp = tl.zeros((BLOCK_K, BLOCK_N), dtype=C.dtype.element_ty)
103 | C = C + (rmfp[:, None] * stride_cmfp + rnfp[None, :] * stride_cnfp)
104 | mask = (rm < M)[:, None] & (rn < N)[None, :]
105 |
106 |
107 |
108 | K_ = tl.load(K + 0)
109 | if K_ == 0:
110 |
111 | return
112 |
113 | maxK = tl.cdiv(K_, BLOCK_K )
114 | for k in range(0, maxK - 1):
115 |
116 | afp = tl.load(A)
117 | bfp = tl.load(B)
118 |
119 | A += BLOCK_K * stride_akfp
120 | B += BLOCK_K * stride_bkfp
121 |
122 | accfp = tl.dot(afp, bfp, accfp, out_dtype=tl.float32, allow_tf32=False)
123 |
124 | k = maxK - 1
125 | if K_ % ( BLOCK_K ) == 0:
126 | afp = tl.load(A)
127 | bfp = tl.load(B)
128 | else:
129 | k_remainingfp = K_ - k * (BLOCK_K )
130 | afp = tl.load(A, mask=rkfp[None, :] < k_remainingfp, other=0.0)
131 | bfp = tl.load(B, mask=rkfp[:, None] < k_remainingfp, other=0.0)
132 |
133 | accfp = tl.dot(afp, bfp, accfp, out_dtype=tl.float32, allow_tf32=False)
134 |
135 | accfp = accfp.to(tl.float16)
136 |
137 | # rematerialize rm and rn to save registers
138 |
139 |
140 | tl.store(C, accfp, mask=mask)
141 |
142 |
143 |
144 |
145 |
146 | def matmulfp16( afp, bfp, cfp16, M, N, K):
147 |
148 | grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
149 |
150 |
151 |
152 | matmul_kernelfp16[grid](
153 | afp, bfp, cfp16,M, N, K,
154 | 1, M, #
155 | N, 1, #
156 | N, 1, #
157 | GROUP_M=8
158 | )
159 |
160 | return
--------------------------------------------------------------------------------
/mixquant/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .llama import LlamaMixQForCausalLM
2 | from .opt import OptMixForCausalLM
3 | #from .mistral import MistralMixForCausalLM
4 | from .gptj import GPTJMixForCausalLM
5 | from .falcon import FalconMixForCausalLM
6 | from .baichuan import BaichuanMixQForCausalLM
--------------------------------------------------------------------------------
/mixquant/models/aquila.py:
--------------------------------------------------------------------------------
1 | ## Reference from llama.py
2 | from .base import BaseAWQForCausalLM
3 | from typing import Dict
4 | from transformers.models.llama.modeling_llama import (
5 | LlamaDecoderLayer as AquilaDecoderLayer,
6 | LlamaForCausalLM as AquilaForCausalLM,
7 | LlamaAttention as AquilaAttention,
8 | LlamaRMSNorm as AquilaRMSNorm,
9 | LlamaMLP as AquilaMLP
10 | )
11 |
12 | class AquilaAWQForCausalLM(BaseAWQForCausalLM):
13 | layer_type = "AquilaDecoderLayer"
14 | max_new_tokens_key = "max_position_embeddings"
15 |
16 | @staticmethod
17 | def fuse_layers(model: AquilaForCausalLM, quant_config: Dict):
18 | fuser = AquilaFuser(model, quant_config)
19 | fuser.fuse_attention()
20 | fuser.fuse_rmsnorm()
21 | fuser.fuse_mlp()
22 |
23 | @staticmethod
24 | def get_model_layers(model: AquilaForCausalLM):
25 | return model.model.layers
26 |
27 | @staticmethod
28 | def get_act_for_scaling(module: AquilaDecoderLayer):
29 | return dict(
30 | is_scalable=False
31 | )
32 |
33 | @staticmethod
34 | def move_embed(model: AquilaForCausalLM, device: str):
35 | model.model.embed_tokens = model.model.embed_tokens.to(device)
36 |
37 | @staticmethod
38 | def get_layers_for_scaling(module: AquilaDecoderLayer, input_feat, module_kwargs):
39 | layers = []
40 |
41 | # attention input
42 | layers.append(dict(
43 | prev_op=module.input_layernorm,
44 | layers=[module.self_attn.q_proj,
45 | module.self_attn.k_proj, module.self_attn.v_proj],
46 | inp=input_feat['self_attn.q_proj'],
47 | module2inspect=module.self_attn, kwargs=module_kwargs,
48 | ))
49 |
50 | if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
51 | layers.append(dict(
52 | prev_op=module.self_attn.v_proj,
53 | layers=[module.self_attn.o_proj],
54 | inp=input_feat['self_attn.o_proj'],
55 | ))
56 |
57 | # linear 1
58 | layers.append(dict(
59 | prev_op=module.post_attention_layernorm,
60 | layers=[module.mlp.gate_proj, module.mlp.up_proj],
61 | inp=input_feat['mlp.gate_proj'],
62 | module2inspect=module.mlp,
63 | ))
64 |
65 | # linear 2
66 | layers.append(dict(
67 | prev_op=module.mlp.up_proj,
68 | layers=[module.mlp.down_proj],
69 | inp=input_feat['mlp.down_proj'],
70 | ))
71 |
72 | return layers
73 |
74 |
--------------------------------------------------------------------------------
/mixquant/models/auto.py:
--------------------------------------------------------------------------------
1 | import os
2 | from transformers import AutoConfig
3 | from mixquant.models import *
4 | from mixquant.models.base import BaseForCausalLM
5 |
6 | CAUSAL_LM_MODEL_MAP = {
7 |
8 | "llama": LlamaMixQForCausalLM,
9 | "baichuan": BaichuanMixQForCausalLM,
10 | "aquila": LlamaMixQForCausalLM,
11 | #"mistral": MistralMixForCausalLM,
12 | "gptj" : GPTJMixForCausalLM,
13 | "falcon": FalconMixForCausalLM,
14 | "opt": OptMixForCausalLM
15 | }
16 |
17 | def check_and_get_model_type(model_dir, trust_remote_code=True):
18 | config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code)
19 | if config.model_type not in CAUSAL_LM_MODEL_MAP.keys():
20 | raise TypeError(f"{config.model_type} isn't supported yet.")
21 | model_type = config.model_type
22 | if config.architectures[0]=="BaichuanForCausalLM":
23 | model_type="baichuan"
24 | return model_type
25 |
26 | class AutoForCausalLM:
27 | def __init__(self):
28 | raise EnvironmentError('You must instantiate AutoAWQForCausalLM with\n'
29 | 'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained')
30 |
31 | @classmethod
32 | def from_pretrained(self, model_path, trust_remote_code=True, safetensors=False,
33 | device_map=None, mix = False, **model_init_kwargs) -> BaseForCausalLM:
34 | model_type = check_and_get_model_type(model_path, trust_remote_code)
35 |
36 | return CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
37 | model_path, model_type, trust_remote_code=trust_remote_code, safetensors=safetensors,
38 | device_map=device_map, mix = mix, **model_init_kwargs
39 | )
40 |
41 | @classmethod
42 | def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None,
43 | trust_remote_code=True, fuse_layers=True,
44 | batch_size=1, safetensors=False,
45 | max_memory=None, offload_folder=None, mix = False, cache = None) -> BaseForCausalLM:
46 |
47 | model_type = check_and_get_model_type(quant_path, trust_remote_code)
48 | os.environ["BATCH_SIZE"] = str(batch_size)
49 | return CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
50 | quant_path, model_type, quant_filename, max_new_tokens, trust_remote_code=trust_remote_code,
51 | fuse_layers=fuse_layers, safetensors=safetensors,
52 | max_memory=max_memory, offload_folder=offload_folder, mix = mix, cache = cache
53 | )
54 |
--------------------------------------------------------------------------------
/mixquant/models/baichuan.py:
--------------------------------------------------------------------------------
1 | from .base import BaseForCausalLM
2 | from typing import Dict
3 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
4 |
5 | class BaichuanMixQForCausalLM(BaseForCausalLM):
6 | layer_type = "LlamaDecoderLayer"
7 | max_new_tokens_key = "max_position_embeddings"
8 |
9 | @staticmethod
10 | def fuse_layers(model: LlamaForCausalLM, quant_config: Dict, mix = False, cache = None):
11 |
12 | fuser = LlamaFuser(model, quant_config)
13 |
14 | fuser.fuse_attention(MixGemmCache = cache)
15 |
16 | fuser.fuse_mlp(mix, MixGemmCache = cache)
17 | fuser.fuse_rmsnorm(MixGemmCache = cache)
18 |
19 |
20 | for layer in model.model.layers:
21 | layer.input_layernorm.next_layer = layer.self_attn.W_pack
22 | layer.post_attention_layernorm.next_layer = layer.mlp.up_proj_
23 |
24 | @staticmethod
25 | def get_model_layers(model: LlamaForCausalLM):
26 | return model.model.layers
27 |
28 |
29 |
30 | @staticmethod
31 | def move_embed(model: LlamaForCausalLM, device: str):
32 | model.model.embed_tokens = model.model.embed_tokens.to(device)
33 |
34 |
35 | import torch
36 | from typing import List, Tuple, Union
37 | from mixquant.utils.utils import set_module_name
38 | from mixquant.modules.fused.mlp import MixLlamaMLP
39 | from mixquant.modules.fused.attn import QuantAttentionFused, QuantAttentionFusedBaichuan13B
40 | from mixquant.modules.fused.norm import FasterTransformerRMSNorm
41 | from mixquant.modules.linear import MixLinear_GEMM
42 |
43 | from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
44 | import sys
45 |
46 |
47 | #from modeling_baichuan import Attention
48 | class LlamaFuser:
49 | def __init__(self, model, quant_config):
50 | self.model = model
51 | self.quant_config = quant_config
52 |
53 | #print(model.model.layers[0].self_attn.o_proj) # 确认一下模型的权重的格式
54 |
55 | #需要加入百川的 Attention
56 | self.attention_modules: List[Tuple[str, LlamaAttention]] = [
57 | (name, module) for name, module in self.model.named_modules()
58 | if isinstance(module, LlamaAttention) or "Attention" in str(module.__class__)
59 | ]
60 | #print(self.attention_modules)
61 |
62 | self.rmsnorm_modules: List[Tuple[str, LlamaRMSNorm]] = [
63 | (name, module) for name, module in self.model.named_modules()
64 | if isinstance(module, LlamaRMSNorm) or "RMSNorm" in str(module.__class__)
65 | ]
66 |
67 | self.mlp_modules: List[Tuple[str, LlamaMLP]] = [
68 | (name, module) for name, module in self.model.named_modules()
69 | if isinstance(module, LlamaMLP) or "MLP" in str(module.__class__)
70 | ]
71 |
72 | def fuse_attention(self, MixGemmCache):
73 | for name, module in self.attention_modules:
74 | qkv_layer = self._fuse_qkv(module, MixGemmCache)
75 | try:
76 | num_key_value_heads = module.num_key_value_heads
77 | except:
78 | # 为了处理百川的模型
79 | num_key_value_heads = 32
80 |
81 | if self.model.config.num_hidden_layers == 40:
82 | attn = QuantAttentionFusedBaichuan13B(
83 | module.hidden_size,
84 | module.num_heads,
85 | num_key_value_heads,
86 | qkv_layer,
87 | module.o_proj,
88 | next(iter(qkv_layer.state_dict().values())).device,
89 | self.model.config.max_new_tokens,
90 | MixGemmCache = MixGemmCache
91 | )
92 |
93 | else:
94 | attn = QuantAttentionFused(
95 | module.hidden_size,
96 | module.num_heads,
97 | num_key_value_heads,
98 | qkv_layer,
99 | module.o_proj,
100 | next(iter(qkv_layer.state_dict().values())).device,
101 | self.model.config.max_new_tokens,
102 | MixGemmCache = MixGemmCache
103 | )
104 | set_module_name(self.model, name, attn)
105 |
106 | def fuse_attention(self, MixGemmCache):
107 |
108 | for name, module in self.attention_modules:
109 |
110 | layer_idx = int(name.split('.')[2])
111 | qkv_layer = self._fuse_qkv(module, MixGemmCache)
112 | try:
113 | num_key_value_heads = module.num_key_value_heads
114 | except:
115 | # 为了处理百川的模型
116 | print("do not find the attr module.num_key_value_heads")
117 | num_key_value_heads = 32
118 | attn = QuantAttentionFused(
119 | module.hidden_size,
120 | module.num_heads,
121 | num_key_value_heads,
122 | qkv_layer,
123 | module.o_proj,
124 | next(iter(qkv_layer.state_dict().values())).device,
125 | self.model.config.max_new_tokens,
126 | MixGemmCache = MixGemmCache,
127 | layer_idx = layer_idx
128 | )
129 | set_module_name(self.model, name, attn)
130 |
131 | def _fuse_qkv(self, module: LlamaAttention,cache):
132 | try:
133 | q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
134 | except:
135 | qkv_layer = module.W_pack
136 | return qkv_layer
137 |
138 |
139 |
140 | if not isinstance(q_proj, MixLinear_GEMM) :
141 | raise "no implement error"
142 |
143 | if isinstance(q_proj, MixLinear_GEMM):
144 | qkv_layer = MixLinear_GEMM(q_proj.in_features,q_proj.out_features + k_proj.out_features + v_proj.out_features,
145 | q_proj.bias is not None,
146 | next(iter(module.state_dict().values())).device,
147 | bit = self.quant_config['w_bit'],
148 | weight_only=False,
149 | cache=cache)
150 |
151 |
152 |
153 | if isinstance(qkv_layer, MixLinear_GEMM):
154 | shapew = qkv_layer.q_weight.shape
155 |
156 | if qkv_layer.weight_only:
157 | qkv_layer.q_weight = torch.cat([q_proj.q_weight, k_proj.q_weight, v_proj.q_weight], dim=1)
158 | qkv_layer.scale_col = torch.cat([q_proj.scale_col, k_proj.scale_col, v_proj.scale_col], dim=0)
159 |
160 | else:
161 | qkv_layer.q_weight = torch.cat([q_proj.q_weight, k_proj.q_weight, v_proj.q_weight], dim=0)
162 | qkv_layer.scale_col = torch.cat([q_proj.scale_col, k_proj.scale_col, v_proj.scale_col], dim=1)
163 | assert shapew[0] == qkv_layer.q_weight.shape[0]
164 | assert shapew[1] == qkv_layer.q_weight.shape[1]
165 | assert shapew[0] == qkv_layer.scale_col.shape[1]
166 | assert 1 == qkv_layer.scale_col.shape[0]
167 | if self.quant_config['w_bit'] == 4:
168 |
169 |
170 |
171 | qkv_layer.weight_cache.copy_(torch.cat([q_proj.weight_cache,
172 | k_proj.weight_cache,
173 | v_proj.weight_cache], dim=0))
174 |
175 |
176 |
177 | qkv_layer.ind.copy_(q_proj.ind)
178 |
179 |
180 |
181 |
182 |
183 |
184 | if q_proj.bias is not None:
185 | raise NotImplementedError
186 | else:
187 | qkv_layer.bias = None
188 |
189 | else:
190 | raise "no implement"
191 |
192 | q_proj.q_weight = q_proj.q_weight.to('cpu')
193 | k_proj.q_weight = k_proj.q_weight.to('cpu')
194 | v_proj.q_weight = v_proj.q_weight.to('cpu')
195 | q_proj.scale_col = q_proj.scale_col.to('cpu')
196 | k_proj.scale_col = k_proj.scale_col.to('cpu')
197 | v_proj.scale_col = v_proj.scale_col.to('cpu')
198 | torch.cuda.empty_cache()
199 | return qkv_layer
200 |
201 | def fuse_rmsnorm(self, MixGemmCache):
202 | for name, module in self.rmsnorm_modules:
203 | norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon, MixGemmCache)
204 | set_module_name(self.model, name, norm)
205 |
206 | def fuse_mlp(self,mix, MixGemmCache = None):
207 | for name, module in self.mlp_modules:
208 | if mix:
209 | assert MixGemmCache is not None
210 | mlp = MixLlamaMLP(module.gate_proj, module.down_proj, module.up_proj , MixGemmCache)
211 | set_module_name(self.model, name, mlp)
--------------------------------------------------------------------------------
/mixquant/models/basefuser.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/mixquant/models/basefuser.py
--------------------------------------------------------------------------------
/mixquant/models/bloom.py:
--------------------------------------------------------------------------------
1 | from .base import BaseAWQForCausalLM
2 | from transformers.models.bloom.modeling_bloom import BloomForCausalLM, BloomBlock
3 |
4 | class BloomMixForCausalLM(BaseAWQForCausalLM):
5 | layer_type = "BloomBlock"
6 |
7 | @staticmethod
8 | def get_model_layers(model: BloomForCausalLM):
9 | return model.transformer.h
10 |
11 |
12 | @staticmethod
13 | def move_embed(model: BloomForCausalLM, device: str):
14 | model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
15 | model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.to(device)
16 |
17 |
--------------------------------------------------------------------------------
/mixquant/models/falcon.py:
--------------------------------------------------------------------------------
1 | from .base import BaseForCausalLM
2 | from typing import Dict
3 | from transformers.models.falcon.modeling_falcon import FalconDecoderLayer as OldFalconDecoderLayer, FalconForCausalLM, FalconAttention
4 | from transformers.models.falcon.configuration_falcon import FalconConfig
5 | from transformers.models.falcon.modeling_falcon import FalconMLP
6 |
7 |
8 | from mixquant.modules.fused.mlp import MixFalconMLP
9 | from mixquant.utils.utils import set_module_name
10 | import torch
11 | from typing import Optional, Tuple, Union, List
12 | from torch import nn
13 | from torch.nn import functional as F
14 |
15 | class FalconMixForCausalLM(BaseForCausalLM):
16 | layer_type = "FalconDecoderLayer"
17 |
18 | @staticmethod
19 | def fuse_layers(model: FalconForCausalLM, quant_config: Dict, mix, cache):
20 | fuser = FalconFuser(model)
21 |
22 |
23 | fuser.fuse_mlp(mix, cache)
24 |
25 | @staticmethod
26 | def get_model_layers(model: FalconForCausalLM):
27 | return model.transformer.h
28 |
29 | @staticmethod
30 | def move_embed(model: FalconForCausalLM, device):
31 | model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
32 |
33 |
34 |
35 |
36 | class FalconFuser:
37 | def __init__(self, model: FalconForCausalLM):
38 | self.model = model
39 |
40 | self.attention_modules = [
41 | (name, module) for name, module in self.model.named_modules()
42 | if "Attention" in str(module.__class__)
43 | ]
44 | self.mlp_modules: List[Tuple[str, FalconMLP]] = [
45 | (name, module) for name, module in self.model.named_modules()
46 | if isinstance(module, FalconMLP) or "MLP" in str(module.__class__)
47 | ]
48 |
49 | def fuse_mlp(self,mix, MixGemmCache = None):
50 | for name, module in self.mlp_modules:
51 | if mix:
52 | assert MixGemmCache is not None
53 | mlp = MixFalconMLP(module.dense_h_to_4h, module.dense_4h_to_h, MixGemmCache)
54 | set_module_name(self.model, name, mlp)
--------------------------------------------------------------------------------
/mixquant/models/gpt_bigcode.py:
--------------------------------------------------------------------------------
1 | from .base import BaseAWQForCausalLM
2 | from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM, GPTBigCodeBlock as OldGptBigCodeBlock
3 |
4 | class GptBigCodeAWQForCausalLM(BaseAWQForCausalLM):
5 | layer_type = "GPTBigCodeBlock"
6 | max_new_tokens_key = "n_positions"
7 |
8 | @staticmethod
9 | def get_model_layers(model: GPTBigCodeForCausalLM):
10 | return model.transformer.h
11 |
12 | @staticmethod
13 | def get_act_for_scaling(module: OldGptBigCodeBlock):
14 | return dict(
15 | is_scalable=True,
16 | scale_name="mlp.act",
17 | scale_layer=module.mlp.act,
18 | scale_shape=module.mlp.c_fc.out_features
19 | )
20 |
21 | @staticmethod
22 | def move_embed(model: GPTBigCodeForCausalLM, device):
23 | model.transformer.wte = model.transformer.wte.to(device)
24 | model.transformer.wpe = model.transformer.wpe.to(device)
25 | model.transformer.drop = model.transformer.drop.to(device)
26 |
27 | @staticmethod
28 | def get_layers_for_scaling(module:OldGptBigCodeBlock, input_feat, module_kwargs):
29 | layers = []
30 |
31 | # attention input
32 | layers.append(dict(
33 | prev_op=module.ln_1,
34 | layers=[module.attn.c_attn],
35 | inp=input_feat['attn.c_attn'],
36 | module2inspect=module.attn,
37 | kwargs=module_kwargs
38 | ))
39 |
40 | # linear 1
41 | layers.append(dict(
42 | prev_op=module.ln_2,
43 | layers=[module.mlp.c_fc],
44 | inp=input_feat['mlp.c_fc'],
45 | module2inspect=module.mlp
46 | ))
47 |
48 | # linear 2
49 | layers.append(dict(
50 | prev_op=module.mlp.act,
51 | layers=[module.mlp.c_proj],
52 | inp=input_feat['mlp.c_proj']
53 | ))
54 |
55 | return layers
56 |
--------------------------------------------------------------------------------
/mixquant/models/gpt_neox.py:
--------------------------------------------------------------------------------
1 | from .base import BaseAWQForCausalLM
2 | from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXLayer, GPTNeoXForCausalLM
3 |
4 | class GPTNeoXAWQForCausalLM(BaseAWQForCausalLM):
5 | layer_type = "GPTNeoXDecoderLayer"
6 | max_new_tokens_key = "max_position_embeddings"
7 |
8 | @staticmethod
9 | def get_model_layers(model: GPTNeoXForCausalLM):
10 | return model.gpt_neox.layers
11 |
12 | @staticmethod
13 | def get_act_for_scaling(module: GPTNeoXLayer):
14 | return dict(
15 | is_scalable=True,
16 | scale_name="mlp.act",
17 | scale_layer=module.mlp.act,
18 | scale_shape=module.mlp.dense_h_to_4h.out_features,
19 | )
20 |
21 | @staticmethod
22 | def move_embed(model: GPTNeoXForCausalLM, device: str):
23 | model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(device)
24 |
25 | @staticmethod
26 | def get_layers_for_scaling(module: GPTNeoXLayer, input_feat, module_kwargs):
27 | layers = []
28 |
29 | # attention input
30 | layers.append(dict(
31 | prev_op=module.input_layernorm,
32 | layers=[module.attention.query_key_value],
33 | inp=input_feat['attention.query_key_value'],
34 | ))
35 |
36 | # attention out
37 | # Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469
38 | """
39 | layers.append(dict(
40 | prev_op=module.attention.query_key_value,
41 | layers=[module.attention.dense],
42 | inp=input_feat['attention.dense'],
43 | ))
44 | """
45 |
46 | # linear 1
47 | layers.append(dict(
48 | prev_op=module.post_attention_layernorm,
49 | layers=[module.mlp.dense_h_to_4h],
50 | inp=input_feat['mlp.dense_h_to_4h'],
51 | ))
52 |
53 | # linear 2
54 | layers.append(dict(
55 | prev_op=module.mlp.act,
56 | layers=[module.mlp.dense_4h_to_h],
57 | inp=input_feat['mlp.dense_4h_to_h'],
58 | ))
59 |
60 | return layers
61 |
--------------------------------------------------------------------------------
/mixquant/models/gptj.py:
--------------------------------------------------------------------------------
1 | from .base import BaseForCausalLM
2 | from transformers.models.gptj.modeling_gptj import GPTJForCausalLM, GPTJBlock, GPTJAttention, GPTJMLP
3 | from typing import List, Tuple, Union
4 | from mixquant.modules.linear import MixLinear_GEMM
5 | import torch
6 | from mixquant.modules.fused.gptj_attn import QuantGPTJAttentionFused
7 | from mixquant.modules.fused.mlp import MixGPTJMLP
8 | from mixquant.utils.utils import set_module_name
9 |
10 |
11 | class GPTJMixForCausalLM(BaseForCausalLM):
12 | layer_type = "GPTJBlock"
13 | max_new_tokens_key = "n_positions"
14 |
15 | @staticmethod
16 | def get_model_layers(model: GPTJForCausalLM):
17 | return model.transformer.h
18 |
19 |
20 | @staticmethod
21 | def move_embed(model: GPTJForCausalLM, device: str):
22 | model.transformer.wte = model.transformer.wte.to(device)
23 |
24 |
25 |
26 |
27 |
28 | @staticmethod
29 | def fuse_layers(model: GPTJForCausalLM, quant_config, mix, cache):
30 | fuser = GPTJFuser(model)
31 |
32 |
33 | fuser.fuse_mlp(mix, cache)
34 | fuser.fuse_attention(MixGemmCache = cache)
35 |
36 |
37 |
38 |
39 |
40 | class GPTJFuser:
41 | def __init__(self, model: GPTJForCausalLM):
42 | self.model = model
43 |
44 | self.attention_modules: List[Tuple[str, GPTJAttention]] = [
45 | (name, module) for name, module in self.model.named_modules()
46 | if isinstance(module, GPTJAttention) or "Attention" in str(module.__class__)
47 | ]
48 | self.mlp_modules: List[Tuple[str, GPTJMLP]] = [
49 | (name, module) for name, module in self.model.named_modules()
50 | if isinstance(module, GPTJMLP) or "MLP" in str(module.__class__)
51 | ]
52 | def fuse_mlp(self,mix, MixGemmCache = None):
53 | for name, module in self.mlp_modules:
54 | if mix:
55 | assert MixGemmCache is not None
56 | mlp = MixGPTJMLP(module, self.model.config, MixGemmCache)
57 | set_module_name(self.model, name, mlp)
58 |
59 |
60 | def _fuse_qkv(self, module,cache):
61 | try:
62 | q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
63 | except:
64 | qkv_layer = module.W_pack
65 | return qkv_layer
66 |
67 |
68 |
69 | if not isinstance(q_proj, MixLinear_GEMM) :
70 | raise "no implement error"
71 |
72 | if isinstance(q_proj, MixLinear_GEMM):
73 | qkv_layer = MixLinear_GEMM(q_proj.in_features,q_proj.out_features + k_proj.out_features + v_proj.out_features,
74 | q_proj.bias is not None,
75 | next(iter(module.state_dict().values())).device,
76 | False,
77 | cache)
78 |
79 |
80 |
81 | if isinstance(qkv_layer, MixLinear_GEMM):
82 | shapew = qkv_layer.q_weight.shape
83 | qkv_layer.q_weight = torch.cat([q_proj.q_weight, k_proj.q_weight, v_proj.q_weight], dim=0)
84 | qkv_layer.scale_col = torch.cat([q_proj.scale_col, k_proj.scale_col, v_proj.scale_col], dim=1)
85 |
86 | assert shapew[0] == qkv_layer.q_weight.shape[0]
87 | assert shapew[1] == qkv_layer.q_weight.shape[1]
88 | assert shapew[0] == qkv_layer.scale_col.shape[1]
89 | assert 1 == qkv_layer.scale_col.shape[0]
90 |
91 | if q_proj.bias is not None:
92 | raise NotImplementedError
93 | else:
94 | qkv_layer.bias = None
95 |
96 | else:
97 | raise "no implement"
98 |
99 | q_proj.q_weight = q_proj.q_weight.to('cpu')
100 | k_proj.q_weight = k_proj.q_weight.to('cpu')
101 | v_proj.q_weight = v_proj.q_weight.to('cpu')
102 | q_proj.scale_col = q_proj.scale_col.to('cpu')
103 | k_proj.scale_col = k_proj.scale_col.to('cpu')
104 | v_proj.scale_col = v_proj.scale_col.to('cpu')
105 | torch.cuda.empty_cache()
106 | return qkv_layer
107 |
108 |
109 | def fuse_attention(self, MixGemmCache):
110 | for name, module in self.attention_modules:
111 | qkv_layer = self._fuse_qkv(module)
112 |
113 | attn = QuantGPTJAttentionFused(
114 | self.model.config,
115 | module,
116 | qkv_layer,
117 | module.out_proj,
118 | MixGemmCache = MixGemmCache
119 | )
120 | set_module_name(self.model, name, attn)
--------------------------------------------------------------------------------
/mixquant/models/llama.py:
--------------------------------------------------------------------------------
1 | from .base import BaseForCausalLM
2 | from typing import Dict
3 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
4 |
5 | class LlamaMixQForCausalLM(BaseForCausalLM):
6 | layer_type = "LlamaDecoderLayer"
7 | max_new_tokens_key = "max_position_embeddings"
8 |
9 | @staticmethod
10 | def fuse_layers(model: LlamaForCausalLM, quant_config: Dict, mix = False, cache = None):
11 |
12 | fuser = LlamaFuser(model, quant_config)
13 |
14 | fuser.fuse_attention(MixGemmCache = cache)
15 |
16 | fuser.fuse_mlp(mix, MixGemmCache = cache)
17 | fuser.fuse_rmsnorm(MixGemmCache = cache)
18 |
19 |
20 | for layer in model.model.layers:
21 | layer.input_layernorm.next_layer = layer.self_attn.W_pack
22 | layer.post_attention_layernorm.next_layer = layer.mlp.up_proj_
23 |
24 |
25 | @staticmethod
26 | def get_model_layers(model: LlamaForCausalLM):
27 | return model.model.layers
28 |
29 |
30 |
31 | @staticmethod
32 | def move_embed(model: LlamaForCausalLM, device: str):
33 | model.model.embed_tokens = model.model.embed_tokens.to(device)
34 |
35 |
36 | import torch
37 | from typing import List, Tuple, Union
38 | from mixquant.utils.utils import set_module_name
39 | from mixquant.modules.fused.mlp import MixLlamaMLP
40 | from mixquant.modules.fused.attn import QuantAttentionFused
41 | from mixquant.modules.fused.norm import FasterTransformerRMSNorm
42 | from mixquant.modules.linear import MixLinear_GEMM
43 |
44 | from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
45 | import sys
46 |
47 |
48 | #from modeling_baichuan import Attention
49 | class LlamaFuser:
50 | def __init__(self, model, quant_config):
51 | self.model = model
52 | self.quant_config = quant_config
53 |
54 | #print(model.model.layers[0].self_attn.o_proj) # 确认一下模型的权重的格式
55 |
56 | #需要加入百川的 Attention
57 | self.attention_modules: List[Tuple[str, LlamaAttention]] = [
58 | (name, module) for name, module in self.model.named_modules()
59 | if isinstance(module, LlamaAttention) or "Attention" in str(module.__class__)
60 | ]
61 | #print(self.attention_modules)
62 |
63 | self.rmsnorm_modules: List[Tuple[str, LlamaRMSNorm]] = [
64 | (name, module) for name, module in self.model.named_modules()
65 | if isinstance(module, LlamaRMSNorm) or "RMSNorm" in str(module.__class__)
66 | ]
67 |
68 | self.mlp_modules: List[Tuple[str, LlamaMLP]] = [
69 | (name, module) for name, module in self.model.named_modules()
70 | if isinstance(module, LlamaMLP) or "MLP" in str(module.__class__)
71 | ]
72 |
73 | def fuse_attention(self, MixGemmCache):
74 |
75 | for name, module in self.attention_modules:
76 |
77 | layer_idx = int(name.split('.')[2])
78 | qkv_layer = self._fuse_qkv(module, MixGemmCache)
79 | try:
80 | num_key_value_heads = module.num_key_value_heads
81 | except:
82 | # 为了处理百川的模型
83 | print("do not find the attr module.num_key_value_heads")
84 | num_key_value_heads = 32
85 | attn = QuantAttentionFused(
86 | module.hidden_size,
87 | module.num_heads,
88 | num_key_value_heads,
89 | qkv_layer,
90 | module.o_proj,
91 | next(iter(qkv_layer.state_dict().values())).device,
92 | self.model.config.max_new_tokens,
93 | MixGemmCache = MixGemmCache,
94 | layer_idx = layer_idx
95 | )
96 | set_module_name(self.model, name, attn)
97 |
98 | def _fuse_qkv(self, module: LlamaAttention,cache):
99 | try:
100 | q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
101 | except:
102 | qkv_layer = module.W_pack
103 | return qkv_layer
104 |
105 |
106 |
107 | if not isinstance(q_proj, MixLinear_GEMM) :
108 | raise "no implement error"
109 |
110 | if isinstance(q_proj, MixLinear_GEMM):
111 | qkv_layer = MixLinear_GEMM(q_proj.in_features,q_proj.out_features + k_proj.out_features + v_proj.out_features,
112 | q_proj.bias is not None,
113 | next(iter(module.state_dict().values())).device,
114 | bit = self.quant_config['w_bit'],
115 | weight_only=False,
116 | cache=cache)
117 |
118 |
119 |
120 | if isinstance(qkv_layer, MixLinear_GEMM):
121 | shapew = qkv_layer.q_weight.shape
122 |
123 | if qkv_layer.weight_only:
124 | qkv_layer.q_weight = torch.cat([q_proj.q_weight, k_proj.q_weight, v_proj.q_weight], dim=1)
125 | qkv_layer.scale_col = torch.cat([q_proj.scale_col, k_proj.scale_col, v_proj.scale_col], dim=0)
126 |
127 | else:
128 | qkv_layer.q_weight = torch.cat([q_proj.q_weight, k_proj.q_weight, v_proj.q_weight], dim=0)
129 | qkv_layer.scale_col = torch.cat([q_proj.scale_col, k_proj.scale_col, v_proj.scale_col], dim=1)
130 | assert shapew[0] == qkv_layer.q_weight.shape[0]
131 | assert shapew[1] == qkv_layer.q_weight.shape[1]
132 | assert shapew[0] == qkv_layer.scale_col.shape[1]
133 | assert 1 == qkv_layer.scale_col.shape[0]
134 | if self.quant_config['w_bit'] == 4:
135 |
136 |
137 |
138 | qkv_layer.weight_cache.copy_(torch.cat([q_proj.weight_cache,
139 | k_proj.weight_cache,
140 | v_proj.weight_cache], dim=0))
141 |
142 |
143 |
144 | qkv_layer.ind.copy_(q_proj.ind)
145 |
146 |
147 |
148 |
149 |
150 |
151 | if q_proj.bias is not None:
152 | raise NotImplementedError
153 | else:
154 | qkv_layer.bias = None
155 |
156 | else:
157 | raise "no implement"
158 |
159 | q_proj.q_weight = q_proj.q_weight.to('cpu')
160 | k_proj.q_weight = k_proj.q_weight.to('cpu')
161 | v_proj.q_weight = v_proj.q_weight.to('cpu')
162 | q_proj.scale_col = q_proj.scale_col.to('cpu')
163 | k_proj.scale_col = k_proj.scale_col.to('cpu')
164 | v_proj.scale_col = v_proj.scale_col.to('cpu')
165 | torch.cuda.empty_cache()
166 | return qkv_layer
167 |
168 | def fuse_rmsnorm(self, MixGemmCache):
169 | for name, module in self.rmsnorm_modules:
170 | norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon, MixGemmCache)
171 | set_module_name(self.model, name, norm)
172 |
173 | def fuse_mlp(self,mix, MixGemmCache = None):
174 | for name, module in self.mlp_modules:
175 | if mix:
176 | assert MixGemmCache is not None
177 | mlp = MixLlamaMLP(module.gate_proj, module.down_proj, module.up_proj , MixGemmCache)
178 | set_module_name(self.model, name, mlp)
--------------------------------------------------------------------------------
/mixquant/models/mistral.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | from .base import BaseForCausalLM
3 | from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralForCausalLM
4 |
5 | class MistralMixForCausalLM(BaseForCausalLM):
6 | layer_type = "MistralDecoderLayer"
7 | max_new_tokens_key = "max_position_embeddings"
8 |
9 | @staticmethod
10 | def fuse_layers(model: MistralForCausalLM, quant_config: Dict, mix, cache):
11 | fuser = MistralFuser(model, quant_config)
12 | fuser.fuse_attention(cache)
13 | fuser.fuse_rmsnorm()
14 | fuser.fuse_mlp()
15 |
16 | @staticmethod
17 | def get_model_layers(model: MistralForCausalLM):
18 | return model.model.layers
19 |
20 | @staticmethod
21 | def get_act_for_scaling(module: MistralDecoderLayer):
22 | return dict(
23 | is_scalable=False
24 | )
25 |
26 | @staticmethod
27 | def move_embed(model: MistralForCausalLM, device: str):
28 | model.model.embed_tokens = model.model.embed_tokens.to(device)
29 |
30 |
31 |
32 | import torch
33 | from typing import List, Tuple, Union
34 | from mixquant.utils.utils import set_module_name
35 | from mixquant.modules.fused.mlp import MixLlamaMLP
36 | from mixquant.modules.fused.mistral_attn import MistralQuantAttentionFused
37 | from mixquant.modules.fused.norm import FasterTransformerRMSNorm
38 | from mixquant.modules.linear import MixLinear_GEMM
39 |
40 | from transformers.models.mistral.modeling_mistral import MistralAttention, MistralRMSNorm, MistralMLP
41 |
42 | class MistralFuser:
43 | def __init__(self, model, quant_config):
44 | self.model = model
45 | self.quant_config = quant_config
46 |
47 | self.attention_modules: List[Tuple[str, MistralAttention]] = [
48 | (name, module) for name, module in self.model.named_modules()
49 | if isinstance(module, MistralAttention)
50 | ]
51 |
52 | self.rmsnorm_modules: List[Tuple[str, MistralRMSNorm]] = [
53 | (name, module) for name, module in self.model.named_modules()
54 | if isinstance(module, MistralRMSNorm)
55 | ]
56 |
57 | self.mlp_modules: List[Tuple[str, MistralMLP]] = [
58 | (name, module) for name, module in self.model.named_modules()
59 | if isinstance(module, MistralMLP)
60 | ]
61 |
62 | def fuse_attention(self,cache):
63 | for name, module in self.attention_modules:
64 | qkv_layer = self._fuse_qkv(module, cache)
65 | attn = MistralQuantAttentionFused(
66 | module.hidden_size,
67 | module.num_heads,
68 | module.num_key_value_heads,
69 | qkv_layer,
70 | module.o_proj,
71 | next(iter(qkv_layer.state_dict().values())).device,
72 | self.model.config.max_new_tokens,
73 | MixGemmCache = cache,
74 | config = self.model.config
75 | )
76 | set_module_name(self.model, name, attn)
77 |
78 | def _fuse_qkv(self, module: MistralAttention, cache):
79 | try:
80 | q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
81 | except:
82 | qkv_layer = module.W_pack
83 | return qkv_layer
84 |
85 |
86 |
87 | if not isinstance(q_proj, MixLinear_GEMM) :
88 | raise NotImplementedError
89 |
90 | if isinstance(q_proj, MixLinear_GEMM):
91 | qkv_layer = MixLinear_GEMM(q_proj.in_features,q_proj.out_features + k_proj.out_features + v_proj.out_features,
92 | q_proj.bias is not None,
93 | next(iter(module.state_dict().values())).device,
94 | False,
95 | cache)
96 |
97 |
98 |
99 | if isinstance(qkv_layer, MixLinear_GEMM):
100 | shapew = qkv_layer.q_weight.shape
101 | qkv_layer.q_weight = torch.cat([q_proj.q_weight, k_proj.q_weight, v_proj.q_weight], dim=0)
102 | qkv_layer.scale_col = torch.cat([q_proj.scale_col, k_proj.scale_col, v_proj.scale_col], dim=1)
103 |
104 | assert shapew[0] == qkv_layer.q_weight.shape[0]
105 | assert shapew[1] == qkv_layer.q_weight.shape[1]
106 | assert shapew[0] == qkv_layer.scale_col.shape[1]
107 | assert 1 == qkv_layer.scale_col.shape[0]
108 |
109 | if q_proj.bias is not None:
110 | raise NotImplementedError
111 | else:
112 | qkv_layer.bias = None
113 |
114 | else:
115 | raise NotImplementedError
116 |
117 | q_proj.q_weight = q_proj.q_weight.to('cpu')
118 | k_proj.q_weight = k_proj.q_weight.to('cpu')
119 | v_proj.q_weight = v_proj.q_weight.to('cpu')
120 | q_proj.scale_col = q_proj.scale_col.to('cpu')
121 | k_proj.scale_col = k_proj.scale_col.to('cpu')
122 | v_proj.scale_col = v_proj.scale_col.to('cpu')
123 | torch.cuda.empty_cache()
124 | return qkv_layer
125 |
126 | def fuse_rmsnorm(self):
127 | for name, module in self.rmsnorm_modules:
128 | norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon)
129 | set_module_name(self.model, name, norm)
130 |
131 | def fuse_mlp(self):
132 | for name, module in self.mlp_modules:
133 | mlp = MixLlamaMLP(module.gate_proj, module.down_proj, module.up_proj)
134 | set_module_name(self.model, name, mlp)
135 |
--------------------------------------------------------------------------------
/mixquant/models/mpt.py:
--------------------------------------------------------------------------------
1 | from .base import BaseAWQForCausalLM
2 | from typing import Dict
3 | from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM
4 |
5 | class MptAWQForCausalLM(BaseAWQForCausalLM):
6 | layer_type = "MPTBlock"
7 | max_new_tokens_key = "max_seq_len"
8 |
9 | @staticmethod
10 | def fuse_layers(model: MptForCausalLM, quant_config: Dict):
11 | fuser = MptFuser(model)
12 | fuser.fuse_transformer()
13 |
14 | @staticmethod
15 | def get_model_layers(model: MptForCausalLM):
16 | return model.transformer.blocks
17 |
18 | @staticmethod
19 | def get_act_for_scaling(module: OldMptBlock):
20 | return dict(
21 | is_scalable=True,
22 | scale_name="ffn.act",
23 | scale_layer=module.ffn.act,
24 | scale_shape=module.ffn.up_proj.out_features
25 | )
26 |
27 | @staticmethod
28 | def move_embed(model: MptForCausalLM, device: str):
29 | model.transformer.wte = model.transformer.wte.to(device)
30 | model.transformer.emb_drop = model.transformer.emb_drop.to(device)
31 |
32 | @staticmethod
33 | def get_layers_for_scaling(module: OldMptBlock, input_feat, module_kwargs):
34 | layers = []
35 |
36 | # attention input
37 | layers.append(dict(
38 | prev_op=module.norm_1,
39 | layers=[module.attn.Wqkv],
40 | inp=input_feat['attn.Wqkv'],
41 | module2inspect=module.attn,
42 | kwargs=module_kwargs
43 | ))
44 |
45 | # attention output
46 | layers.append(dict(
47 | prev_op=module.attn.Wqkv,
48 | layers=[module.attn.out_proj],
49 | inp=input_feat['attn.out_proj']
50 | ))
51 |
52 | # linear 1
53 | layers.append(dict(
54 | prev_op=module.norm_2,
55 | layers=[module.ffn.up_proj],
56 | inp=input_feat['ffn.up_proj'],
57 | module2inspect=module.ffn
58 | ))
59 |
60 | # linear 2
61 | layers.append(dict(
62 | prev_op=module.ffn.act,
63 | layers=[module.ffn.down_proj],
64 | inp=input_feat['ffn.down_proj']
65 | ))
66 |
67 | return layers
68 |
69 | from typing import List, Tuple
70 | from awq.utils.utils import set_module_name
71 | from awq.modules.fused.block import MPTBlock
72 | from awq.modules.fused.model import MPTModel
73 |
74 | class MptFuser:
75 | def __init__(self, model: MptForCausalLM):
76 | self.model = model
77 |
78 | self.mpt_blocks: List[Tuple[str, OldMptBlock]] = [
79 | (name, module) for name, module in self.model.named_modules()
80 | if 'mptblock' in module.__class__.__name__.lower()
81 | ]
82 |
83 | def fuse_transformer(self):
84 | blocks = []
85 |
86 | module: OldMptBlock
87 | for module in self.model.transformer.blocks:
88 | blocks.append(MPTBlock(
89 | self.model.config.d_model,
90 | self.model.config.n_heads,
91 | module.attn.Wqkv,
92 | module.attn.out_proj,
93 | module.ffn,
94 | module.norm_1,
95 | module.norm_2,
96 | next(iter(module.state_dict().values())).device,
97 | self.model.config.max_new_tokens
98 | ))
99 |
100 | self.model.transformer = MPTModel(
101 | self.model.config.vocab_size,
102 | blocks,
103 | self.model.transformer.wte,
104 | self.model.transformer.norm_f,
105 | )
--------------------------------------------------------------------------------
/mixquant/models/opt.py:
--------------------------------------------------------------------------------
1 | from .base import BaseForCausalLM
2 | from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTDecoderLayer
3 |
4 | class OptMixForCausalLM(BaseForCausalLM):
5 | layer_type = "OPTDecoderLayer"
6 | max_new_tokens_key = "max_position_embeddings"
7 |
8 | @staticmethod
9 | def get_model_layers(model: OPTForCausalLM):
10 | return model.model.decoder.layers
11 |
12 |
13 | @staticmethod
14 | def move_embed(model: OPTForCausalLM, device: str):
15 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device)
16 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(device)
17 |
18 |
--------------------------------------------------------------------------------
/mixquant/models/sample.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 |
4 | @codegen(dtype = [torch.float16, torch.int8 , torch.int32], codegen = "cutlass")
5 | def mixgemm(a, bint8, ind):
6 | afp = a[:,ind]
7 | bfp = bint8[:,ind].to(torch.float16).scale()
8 | a[:,ind] = 0
9 | aint8 = a.to(torch.int8)
10 | c = (aint8 * bint8).to(torch.float16).scale() + afp * bfp
11 | out = torch.relu(c)
12 | return out
13 |
14 |
15 |
16 | class Attention:
17 | def __init__(self,weight,oproj):
18 | self.qweight = weight
19 | self.oproj = oproj
20 |
21 |
22 | @quant([function=self.qweight, type="Mix",activation="A8",weight="W8",codegen = "cutlass"],
23 | [function=self.oproj, type="Weight",activation="A16",weight="W8",codegen = "cutlass"])
24 | def attention(self,hidden_state):
25 |
26 | qx = self.qweight(hidden_state)
27 | o = self.oproj(qx)
28 |
29 | return o
--------------------------------------------------------------------------------
/mixquant/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/mixquant/modules/__init__.py
--------------------------------------------------------------------------------
/mixquant/modules/act.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class ScaledActivation(nn.Module):
4 | def __init__(self, module, scales):
5 | super().__init__()
6 | self.act = module
7 | self.scales = nn.Parameter(scales.data)
8 |
9 | def forward(self, x):
10 | return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
11 |
--------------------------------------------------------------------------------
/mixquant/modules/fused/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/mixquant/modules/fused/__init__.py
--------------------------------------------------------------------------------
/mixquant/modules/fused/block.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.nn as nn
3 | from awq.modules.fused.attn import QuantAttentionFused
4 |
5 | class MPTBlock(nn.Module):
6 | def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len):
7 | super().__init__()
8 | self.n_heads = n_heads
9 | self.n_kv_heads = 0
10 | self.hidden_size = hidden_size
11 | self.norm_1 = norm_1
12 | self.attn = QuantAttentionFused(
13 | hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
14 | dev=dev, max_seq_len=max_seq_len, use_alibi=True
15 | ).to(dev)
16 | self.norm_2 = norm_2
17 | self.ffn = mpt_mlp.to(dev)
18 |
19 | def forward(
20 | self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
21 | ):
22 | norm_out = self.norm_1(hidden_states)
23 | attn_output, _, past_key_value = self.attn.forward(
24 | hidden_states=norm_out,
25 | past_key_value=past_key_value,
26 | attention_mask=attention_mask,
27 | position_ids=None,
28 | output_attentions=False,
29 | use_cache=True
30 | )
31 |
32 | h = hidden_states + attn_output
33 | out = h + self.ffn.forward(self.norm_2(h))
34 | return out, None, past_key_value
35 |
36 | class FalconDecoderLayer(nn.Module):
37 | def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_len,
38 | input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True):
39 | super().__init__()
40 | self.n_heads = n_heads
41 | self.n_kv_heads = 8 if new_decoder_arch else 0
42 | self.hidden_size = hidden_size
43 | self.new_decoder_arch = new_decoder_arch
44 |
45 | if new_decoder_arch:
46 | attention_shapes = None
47 | else:
48 | attention_shapes = self._get_attention_shapes(n_heads, max_seq_len, self.hidden_size // n_heads)
49 |
50 | # TODO: Falcon has ALiBi implemented but which model uses it?
51 | self.attn = QuantAttentionFused(
52 | hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
53 | dev=dev, max_seq_len=max_seq_len, use_alibi=False,
54 | attention_shapes=attention_shapes
55 | ).to(dev)
56 |
57 | if new_decoder_arch:
58 | self.ln_attn = ln_attn # before attention
59 | self.ln_mlp = ln_mlp # before mlp
60 | else:
61 | self.input_layernorm = input_layernorm # before attention
62 |
63 | self.mlp = mlp
64 |
65 | def _get_attention_shapes(self, n_heads, max_seq_len, head_dim):
66 | batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
67 |
68 | self.attention_shapes = {
69 | # following fastertransformer definition
70 | "cache_v": (batch_size, 1, max_seq_len, head_dim,),
71 | # 8: pack 8 fp16 in FT, if fp32 then use 4
72 | "cache_k": (batch_size, 1, head_dim // 8, max_seq_len, 8,),
73 | "xqkv_view": (n_heads+2, head_dim),
74 | "xq_slice": lambda xqkv: xqkv[:, :, :-2],
75 | "xk_slice": lambda xqkv: xqkv[:, :, [-2]],
76 | "xv_slice": lambda xqkv: xqkv[:, :, [-1]],
77 | "xq_view": (n_heads, head_dim),
78 | "xk_view": (1, head_dim),
79 | "xv_view": (1, head_dim),
80 | "xk_reshape": (1, head_dim // 8, 8),
81 | "single_xq_view": (n_heads, head_dim),
82 | "single_xk_view": (1, head_dim),
83 | "single_xv_view": (1, head_dim)
84 | }
85 |
86 | return self.attention_shapes
87 |
88 | def forward(
89 | self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
90 | ):
91 | if self.new_decoder_arch:
92 | layernorm_out = self.ln_attn(hidden_states)
93 | mlp_layernorm_out = self.ln_mlp(hidden_states)
94 | else:
95 | layernorm_out = self.input_layernorm(hidden_states)
96 |
97 | attn_output, _, past_key_value = self.attn.forward(
98 | hidden_states=layernorm_out,
99 | past_key_value=past_key_value,
100 | attention_mask=attention_mask,
101 | position_ids=None,
102 | output_attentions=False,
103 | use_cache=True
104 | )
105 |
106 | h_attn = hidden_states + attn_output
107 |
108 | if self.new_decoder_arch:
109 | h_mlp = self.mlp.forward(mlp_layernorm_out)
110 | else:
111 | h_mlp = self.mlp.forward(layernorm_out)
112 |
113 | out = h_attn + h_mlp
114 |
115 | return out, None, past_key_value
116 |
--------------------------------------------------------------------------------
/mixquant/modules/fused/cache.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class WindowedCache:
4 | def __init__(self, cache_v_shape, cache_k_shape, device):
5 | """
6 | The window size is the same as the max_new_tokens. The window will
7 | automatically roll once max_new_tokens is exceeded.
8 | """
9 | # [batch_size, n_kv_heads, max_seq_len, head_dim]
10 | self.v = torch.zeros(cache_v_shape).to(device).half()
11 | # [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor]
12 | self.k = torch.zeros(cache_k_shape).to(device).half()
13 |
14 | def get_kv(self, batch_size, start_pos, seqlen, head_dim):
15 | xv = self.v[:batch_size, :, : start_pos + seqlen, :].transpose(1, 2).contiguous()
16 | xk = self.k[:batch_size, :, :, : start_pos + seqlen, :].transpose(2, 3).contiguous()
17 | xk = xk.reshape(xk.shape[:-2] + (head_dim,)).transpose(1, 2).contiguous()
18 |
19 | return xv, xk
20 |
21 | def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen):
22 | self.v[:batch_size, :, start_pos : start_pos + seqlen, :] = values_store
23 | self.k[:batch_size, :, :, start_pos : start_pos + seqlen, :] = keys_store
24 |
25 | def roll_kv(self, roll_len, start_pos):
26 | # Roll only the necessary part of the cache to the left
27 | self.v[:, :, :-roll_len, :] = self.v[:, :, roll_len:, :]
28 | self.k[:, :, :, :-roll_len, :] = self.k[:, :, :, roll_len:, :]
29 |
30 | # Zero out the new part
31 | self.v[:, :, -roll_len:, :] = 0
32 | self.k[:, :, :, -roll_len:, :] = 0
33 |
34 | return start_pos - roll_len
35 |
36 | def to(self, device):
37 | self.k = self.k.to(device)
38 | self.v = self.v.to(device)
39 |
--------------------------------------------------------------------------------
/mixquant/modules/fused/gptj_attn.py:
--------------------------------------------------------------------------------
1 | from transformers.utils.import_utils import (
2 | is_torch_fx_proxy,
3 | )
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.nn import functional as F
8 | from typing import Optional, Tuple, Union
9 |
10 |
11 |
12 |
13 | def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
14 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
15 | sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float()
16 | return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
17 | @torch.fx.wrap
18 | def get_embed_positions(embed_positions, position_ids):
19 | return embed_positions.to(position_ids.device).repeat(position_ids.shape[0], 1, 1)
20 |
21 |
22 | def GPTJrotate_every_two(x: torch.Tensor) -> torch.Tensor:
23 | x1 = x[:, :, :, ::2]
24 | x2 = x[:, :, :, 1::2]
25 | x = torch.stack((-x2, x1), dim=-1)
26 | return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
27 |
28 |
29 | def GPTJapply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
30 | sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
31 | cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
32 | return (tensor * cos) + (GPTJrotate_every_two(tensor) * sin)
33 |
34 | class QuantGPTJAttentionFused(nn.Module):
35 |
36 |
37 |
38 |
39 | def __init__(self, config,module,qkv_layer,out_proj,MixGemmCache):
40 | super().__init__()
41 |
42 | max_positions = config.max_position_embeddings
43 | self.register_buffer(
44 | "bias",
45 | torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
46 | 1, 1, max_positions, max_positions
47 | ),
48 | persistent=False,
49 | )
50 | self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
51 |
52 | self.attn_dropout = nn.Dropout(config.attn_pdrop)
53 | self.resid_dropout = nn.Dropout(config.resid_pdrop)
54 |
55 | self.embed_dim = config.hidden_size
56 | self.hidden_size = config.hidden_size
57 | self.num_attention_heads = config.num_attention_heads
58 | self.head_dim = self.embed_dim // self.num_attention_heads
59 | if self.head_dim * self.num_attention_heads != self.embed_dim:
60 | raise ValueError(
61 | f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
62 | f" `num_attention_heads`: {self.num_attention_heads})."
63 | )
64 | self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
65 |
66 | self.qkv_proj = qkv_layer
67 | self.out_proj = out_proj
68 | self.MixGemmCache = MixGemmCache
69 | self.rotary_dim = config.rotary_dim
70 | pos_embd_dim = self.rotary_dim or self.embed_dim
71 | self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)
72 |
73 | def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary):
74 | """
75 | Splits hidden dim into attn_head_size and num_attention_heads
76 | """
77 | new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
78 | tensor = tensor.view(new_shape)
79 | if rotary:
80 | return tensor
81 | if len(tensor.shape) == 5:
82 | return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features)
83 | elif len(tensor.shape) == 4:
84 | return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
85 | else:
86 | raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
87 |
88 | def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
89 | """
90 | Merges attn_head_size dim and num_attn_heads dim into hidden dim
91 | """
92 | if len(tensor.shape) == 5:
93 | tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
94 | elif len(tensor.shape) == 4:
95 | tensor = tensor.permute(0, 2, 1, 3).contiguous()
96 | else:
97 | raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
98 | new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
99 | return tensor.view(new_shape)
100 |
101 | def _attn(
102 | self,
103 | query,
104 | key,
105 | value,
106 | attention_mask=None,
107 | head_mask=None,
108 | ):
109 | # compute causal mask from causal mask buffer
110 | self.bias = self.bias.to(query.device)
111 | query_length, key_length = query.size(-2), key.size(-2)
112 | causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
113 |
114 | # Keep the attention weights computation in fp32 to avoid overflow issues
115 | query = query.to(torch.float32)
116 | key = key.to(torch.float32)
117 |
118 | attn_weights = torch.matmul(query, key.transpose(-1, -2))
119 |
120 | mask_value = torch.finfo(attn_weights.dtype).min
121 | # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
122 | # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
123 | mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
124 | attn_weights = torch.where(causal_mask, attn_weights, mask_value)
125 |
126 | attn_weights = attn_weights / self.scale_attn
127 |
128 | if attention_mask is not None:
129 | # Apply the attention mask
130 | attn_weights = attn_weights + attention_mask
131 |
132 | attn_weights = nn.functional.softmax(attn_weights, dim=-1)
133 | attn_weights = attn_weights.to(value.dtype)
134 |
135 |
136 | attn_weights = self.attn_dropout(attn_weights)
137 |
138 | # Mask heads if we want to
139 | if head_mask is not None:
140 | attn_weights = attn_weights * head_mask
141 |
142 | attn_output = torch.matmul(attn_weights, value)
143 |
144 | return attn_output, attn_weights
145 |
146 | def _get_embed_positions(self, position_ids):
147 | embed_positions = self.embed_positions
148 | if embed_positions.device != position_ids.device:
149 | embed_positions = embed_positions.to(position_ids.device)
150 | self.embed_positions = embed_positions
151 | return embed_positions.repeat(position_ids.shape[0], 1, 1)
152 |
153 | def forward(
154 | self,
155 | hidden_states: torch.FloatTensor,
156 | layer_past: Optional[Tuple[torch.Tensor]] = None,
157 | attention_mask: Optional[torch.FloatTensor] = None,
158 | position_ids: Optional[torch.LongTensor] = None,
159 | head_mask: Optional[torch.FloatTensor] = None,
160 | use_cache: Optional[bool] = False,
161 | output_attentions: Optional[bool] = False,
162 | ) -> Union[
163 | Tuple[torch.Tensor, Tuple[torch.Tensor]],
164 | Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
165 | ]:
166 |
167 |
168 |
169 | proj = self.qkv_proj(hidden_states, self.MixGemmCache)
170 |
171 |
172 | proj = (
173 | proj.unflatten(-1, (3, self.hidden_size))
174 | .unsqueeze(0)
175 | .transpose(0, -2)
176 | .squeeze(-2)
177 | )
178 | query = proj[0]
179 | key = proj[1]
180 | value = proj[2]
181 |
182 |
183 |
184 | query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
185 | key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
186 | value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)
187 |
188 | if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing():
189 | # The logic to conditionally copy to GPU could not be traced, so we do this
190 | # every time in the torch.fx case
191 | embed_positions = get_embed_positions(self.embed_positions, position_ids)
192 | else:
193 | embed_positions = self._get_embed_positions(position_ids)
194 |
195 | repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
196 | sincos = torch.gather(embed_positions, 1, repeated_position_ids)
197 | sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
198 |
199 | if self.rotary_dim is not None:
200 | k_rot = key[:, :, :, : self.rotary_dim]
201 | k_pass = key[:, :, :, self.rotary_dim :]
202 |
203 | q_rot = query[:, :, :, : self.rotary_dim]
204 | q_pass = query[:, :, :, self.rotary_dim :]
205 |
206 | k_rot = GPTJapply_rotary_pos_emb(k_rot, sin, cos)
207 | q_rot = GPTJapply_rotary_pos_emb(q_rot, sin, cos)
208 |
209 | key = torch.cat([k_rot, k_pass], dim=-1)
210 | query = torch.cat([q_rot, q_pass], dim=-1)
211 | else:
212 | key = GPTJapply_rotary_pos_emb(key, sin, cos)
213 | query = GPTJapply_rotary_pos_emb(query, sin, cos)
214 |
215 | key = key.permute(0, 2, 1, 3)
216 | query = query.permute(0, 2, 1, 3)
217 |
218 | if layer_past is not None:
219 | past_key = layer_past[0]
220 | past_value = layer_past[1]
221 | key = torch.cat((past_key, key), dim=-2)
222 | value = torch.cat((past_value, value), dim=-2)
223 |
224 | if use_cache is True:
225 | present = (key, value)
226 | else:
227 | present = None
228 |
229 | # compute self-attention: V x Softmax(QK^T)
230 | attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
231 |
232 | attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
233 | attn_output = self.out_proj(attn_output)
234 | attn_output = self.resid_dropout(attn_output)
235 |
236 | outputs = (attn_output, present)
237 | if output_attentions:
238 | outputs += (attn_weights,)
239 |
240 | return outputs # a, present, (attentions)
--------------------------------------------------------------------------------
/mixquant/modules/fused/mlp.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | import torch
4 | from mixquant.Cache import MixLibCache, MLPCache
5 | import mixlib
6 |
7 |
8 | class MixFalconMLP(nn.Module):
9 |
10 | def __init__(
11 | self,
12 | dense_h_to_4h,
13 | dense_4h_to_h,
14 | MixGemmCache = None
15 | ):
16 | super().__init__()
17 |
18 |
19 |
20 | self.dense_h_to_4h = dense_h_to_4h
21 | self.dense_4h_to_h = dense_4h_to_h
22 | self.act = nn.GELU()
23 | self.MixGemmCache = MixGemmCache
24 |
25 |
26 | def forward(self, x):
27 |
28 | x = self.act(self.dense_h_to_4h(x, self.MixGemmCache))
29 |
30 |
31 | x = self.dense_4h_to_h(x, self.MixGemmCache,True)
32 |
33 | return x
34 |
35 |
36 | import time
37 | class MixLlamaMLP(nn.Module):
38 |
39 | def __init__(
40 | self,
41 | gate_proj,
42 | down_proj,
43 | up_proj,
44 | MixGemmCache = None
45 | ):
46 | super().__init__()
47 |
48 |
49 |
50 | self.down_proj_ = down_proj
51 | self.gate_proj_ = gate_proj
52 | self.up_proj_ = up_proj
53 | self.out_features = down_proj.out_features
54 | self.MLPCache = MixGemmCache
55 |
56 |
57 | def forward(self, x):
58 |
59 |
60 |
61 | up_output = self.up_proj_(x, self.MLPCache)
62 | gate_output = self.gate_proj_.forward_without_preconditionFusedSilu(x, self.MLPCache)
63 |
64 | gate_output *= up_output
65 |
66 |
67 | y = self.down_proj_(gate_output,None,True)
68 |
69 |
70 | return y
71 |
72 |
73 | from transformers.activations import ACT2FN
74 | class MixGPTJMLP(nn.Module):
75 | def __init__(self, module, config, MixGemmCache = None): # in MLP: intermediate_size= 4 * embed_dim
76 | super().__init__()
77 |
78 |
79 | self.fc_in = module.fc_in
80 | self.fc_out = module.fc_out
81 |
82 | self.act = ACT2FN[config.activation_function]
83 | self.dropout = nn.Dropout(config.resid_pdrop)
84 |
85 |
86 | self.MLPCache = MLPCache()
87 |
88 | def forward(self, hidden_states) -> torch.FloatTensor:
89 |
90 | hidden_states = self.fc_in(hidden_states, self.MLPCache)
91 | hidden_states = self.act(hidden_states)
92 | hidden_states = self.fc_out(hidden_states, self.MLPCache)
93 | hidden_states = self.dropout(hidden_states)
94 | return hidden_states
--------------------------------------------------------------------------------
/mixquant/modules/fused/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from typing import List
4 | from mixquant.modules.fused.block import MPTBlock, FalconDecoderLayer
5 | from transformers.modeling_outputs import BaseModelOutputWithPast
6 |
7 | class MPTModel(nn.Module):
8 | def __init__(self, vocab_size, blocks, wte, norm_f):
9 | super().__init__()
10 | self.vocab_size = vocab_size
11 | self.wte = wte
12 | self.blocks: List[MPTBlock] = nn.ModuleList(blocks)
13 | self.norm_f = norm_f
14 | self.attn_uses_sequence_id = False
15 | self.prefix_lm = False
16 |
17 | @torch.inference_mode()
18 | def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
19 | _bsz, seqlen = input_ids.shape
20 | h = self.wte(input_ids)
21 |
22 | mask = None
23 | if seqlen > 1:
24 | mask = torch.full(
25 | (1, 1, seqlen, seqlen), float("-inf"), device=input_ids.device
26 | )
27 | mask = torch.triu(mask, diagonal=self.blocks[0].attn.start_pos + 1).type_as(h)
28 |
29 | for layer in self.blocks:
30 | h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
31 | h = self.norm_f(h)
32 |
33 | return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
34 |
35 | class FalconModel(nn.Module):
36 | def __init__(self, vocab_size, blocks, word_embeddings, ln_f):
37 | super().__init__()
38 | self.vocab_size = vocab_size
39 | self.word_embeddings = word_embeddings
40 | self.blocks: List[FalconDecoderLayer] = nn.ModuleList(blocks)
41 | self.ln_f = ln_f
42 | self.attn_uses_sequence_id = False
43 | self.prefix_lm = False
44 |
45 | @torch.inference_mode()
46 | def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
47 | # NOTE: falcon input ids contain full context
48 | # after context is processed, slice to latest token
49 | if self.blocks[0].attn.start_pos != 0 and input_ids.shape[-1] != 1:
50 | input_ids = input_ids[:, self.blocks[0].attn.start_pos:]
51 |
52 | _bsz, seqlen = input_ids.shape
53 | h = self.word_embeddings(input_ids)
54 |
55 | mask = None
56 | if seqlen > 1:
57 | mask = torch.full(
58 | (1, 1, seqlen, seqlen), float("-inf"), device=input_ids.device
59 | )
60 | mask = torch.triu(mask, diagonal=self.blocks[0].attn.start_pos + 1).type_as(h)
61 |
62 | for layer in self.blocks:
63 | h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
64 | h = self.ln_f(h)
65 |
66 | return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
67 |
--------------------------------------------------------------------------------
/mixquant/modules/fused/norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import mixlib
4 |
5 |
6 | class FasterTransformerRMSNorm(nn.Module):
7 | def __init__(self, weight, eps=1e-6, cache = None):
8 | super().__init__()
9 | self.weight = weight.cuda().to(torch.float16)
10 | self.variance_epsilon = eps
11 | self.cache = cache
12 | self.next_layer = None
13 |
14 | @torch.no_grad()
15 | def forward(self, x):
16 |
17 |
18 | output = torch.empty_like(x)
19 |
20 | if self.next_layer is None:
21 | mixlib.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon)
22 |
23 | else:
24 | if self.next_layer.bit == 8:
25 | self.cache.activation_outliers, self.cache.q_xcache = mixlib.layernorm_forward_cuda_extract_outliers(x,
26 | self.weight,
27 | output, self.variance_epsilon,
28 | self.next_layer.ind, self.cache.x_scale)
29 | elif self.next_layer.bit == 4:
30 | self.cache.activation_outliers, self.cache.q_xcache = mixlib.layernorm_forward_cuda_extract_outliers_int4(x,
31 | self.weight,
32 | output, self.variance_epsilon,
33 | self.next_layer.ind, self.cache.x_scale)
34 |
35 |
36 | else:
37 | raise NotImplementedError
38 |
39 | return output
40 |
--------------------------------------------------------------------------------
/mixquant/quantize/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/mixquant/quantize/__init__.py
--------------------------------------------------------------------------------
/mixquant/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Qcompiler/MIXQ/03e7c5cc663b5a698c21cd9a3aef06df498d55b7/mixquant/utils/__init__.py
--------------------------------------------------------------------------------
/mixquant/utils/calib_data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | from typing import List, Union
4 | from datasets import load_dataset
5 |
6 | def get_calib_dataset(data: Union[str, List[str]] = "pileval",
7 | tokenizer=None, n_samples=512, block_size=512,
8 | split="train", text_column="text"):
9 | if isinstance(data, str):
10 | if data == "pileval":
11 | dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
12 | else:
13 | dataset = load_dataset(data, split=split)
14 | text_column = 'question'
15 |
16 | dataset = dataset.shuffle(seed=42)
17 |
18 | elif isinstance(data, list):
19 | dataset = [{text_column: text} for text in data]
20 | else:
21 | raise NotImplementedError(
22 | "Either pass a string to a huggingface dataset or a list"
23 | "that is preprocessed with one sample of text per element.")
24 |
25 | samples = []
26 | n_run = 0
27 |
28 | for data in dataset:
29 | line = data[text_column]
30 | line = line.strip()
31 | line_encoded = tokenizer.encode(line)
32 | if len(line_encoded) > 512:
33 | continue
34 | sample = torch.tensor([line_encoded])
35 | if sample.numel() == 0:
36 | continue
37 | samples.append(sample)
38 | n_run += 1
39 | if n_run == n_samples:
40 | break
41 | # now concatenate all samples and split according to block size
42 | cat_samples = torch.cat(samples, dim=1)
43 | n_split = cat_samples.shape[1] // block_size
44 | logging.debug(f" * Split into {n_split} blocks")
45 | return [cat_samples[:, i*block_size:(i+1)*block_size] for i in range(n_split)]
46 |
--------------------------------------------------------------------------------
/mixquant/utils/fused_utils.py:
--------------------------------------------------------------------------------
1 |
2 | def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim):
3 | if attention_shapes is not None:
4 | attention_shapes = attention_shapes
5 |
6 | elif n_kv_heads == 0:
7 | attention_shapes = {
8 | # following fastertransformer definition
9 | "cache_v": (cache_batch_size, n_heads, max_seq_len, head_dim,),
10 | # 8: pack 8 fp16 in FT, if fp32 then use 4
11 | "cache_k": (cache_batch_size, n_heads, head_dim // 8, max_seq_len, 8,),
12 | "xqkv_view": (-1, n_heads, head_dim),
13 | "xq_slice": lambda xqkv: xqkv[:, :, 0],
14 | "xk_slice": lambda xqkv: xqkv[:, :, 1],
15 | "xv_slice": lambda xqkv: xqkv[:, :, 2],
16 | "xq_view": (n_heads, head_dim),
17 | "xk_view": (n_heads, head_dim),
18 | "xv_view": (n_heads, head_dim),
19 | "xk_reshape": (n_heads, head_dim // 8, 8),
20 | "single_xq_view": (n_heads, head_dim),
21 | "single_xk_view": (n_heads, head_dim),
22 | "single_xv_view": (n_heads, head_dim)
23 | }
24 |
25 | else:
26 | attention_shapes = {
27 | # following fastertransformer definition
28 | "cache_v": (cache_batch_size, n_kv_heads, max_seq_len, head_dim,),
29 | # 8: pack 8 fp16 in FT, if fp32 then use 4
30 | "cache_k": (cache_batch_size, n_kv_heads, head_dim // 8, max_seq_len, 8,),
31 | "xqkv_view": (n_heads + n_kv_heads * 2, head_dim),
32 | "xq_slice": lambda xqkv: xqkv[:, :, 0 : n_heads],
33 | "xk_slice": lambda xqkv: xqkv[:, :, n_heads : (n_heads + n_kv_heads)],
34 | "xv_slice": lambda xqkv: xqkv[:, :, -n_kv_heads :],
35 | "xq_view": (n_heads, head_dim),
36 | "xk_view": (n_kv_heads, head_dim),
37 | "xv_view": (n_kv_heads, head_dim),
38 | "xk_reshape": (n_kv_heads, head_dim // 8, 8),
39 | "single_xq_view": (n_heads, head_dim),
40 | "single_xk_view": (n_kv_heads, head_dim),
41 | "single_xv_view": (n_kv_heads, head_dim)
42 | }
43 |
44 | return attention_shapes
--------------------------------------------------------------------------------
/mixquant/utils/lm_eval_adaptor.py:
--------------------------------------------------------------------------------
1 | import transformers
2 | import torch
3 | from lm_eval.base import BaseLM
4 | import fnmatch
5 | import logging
6 |
7 | class LMEvalAdaptor(BaseLM):
8 |
9 | def __init__(self, model_name, model, tokenizer, device, batch_size=1, max_length=-1):
10 | super().__init__()
11 |
12 | assert isinstance(batch_size, int)
13 |
14 | self.model_name = model_name
15 | self.model = model.to(device)
16 | self.model.eval()
17 |
18 | self.tokenizer = tokenizer
19 |
20 | # assert isinstance(self.tokenizer, (
21 | # transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
22 | # transformers.T5Tokenizer, transformers.T5TokenizerFast,
23 | # )), "this tokenizer has not been checked for compatibility yet!"
24 |
25 | self.vocab_size = self.tokenizer.vocab_size
26 |
27 | self._batch_size = batch_size
28 |
29 | self._max_length = max_length
30 |
31 | @property
32 | def eot_token_id(self):
33 | # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
34 | return self.tokenizer.eos_token_id
35 |
36 | @property
37 | def max_length(self):
38 | if self._max_length != -1:
39 | return self._max_length
40 | if hasattr(self.model.config, 'n_ctx'):
41 | return self.model.config.n_ctx
42 | elif hasattr(self.model.config, 'max_position_embeddings'):
43 | return self.model.config.max_position_embeddings
44 | elif hasattr(self.model.config, 'n_positions'):
45 | return self.model.config.n_positions
46 | elif 'bloom' in self.model_name:
47 | return 2048
48 | elif 'llama' in self.model_name:
49 | return 2048 # TODO: did not check this
50 | elif 'mpt' in self.model_name:
51 | return 2048
52 | elif 'falcon' in self.model_name:
53 | return 2048
54 | else:
55 | logging.debug(self.model.config)
56 | raise NotImplementedError
57 |
58 | @property
59 | def max_gen_toks(self):
60 | return 256
61 |
62 | @property
63 | def batch_size(self):
64 | return self._batch_size
65 |
66 | @property
67 | def device(self):
68 | return "cuda"
69 |
70 | def tok_encode(self, string: str):
71 | return self.tokenizer.encode(string, add_special_tokens=False)
72 |
73 | def tok_decode(self, tokens):
74 | return self.tokenizer.decode(tokens)
75 |
76 | def _model_call(self, inps):
77 | """
78 | inps: a torch tensor of shape [batch, sequence]
79 | the size of sequence may vary from call to call
80 |
81 | returns: a torch tensor of shape [batch, sequence, vocab] with the
82 | logits returned from the model
83 | """
84 | with torch.no_grad():
85 | if isinstance(self.model, transformers.models.t5.modeling_t5.T5ForConditionalGeneration):
86 | dec_inps = torch.cat(
87 | [
88 | torch.tensor(
89 | self.model.generation_config.decoder_start_token_id,
90 | )
91 | .tile(len(inps), 1)
92 | .to(inps),
93 | inps,
94 | ],
95 | dim=1,
96 | )
97 |
98 | kwargs = {"decoder_input_ids": dec_inps,}
99 | else:
100 | kwargs = {}
101 | out = self.model(inps, **kwargs)[0]
102 | if "opt" in self.model_name: # there are a few extra tokens in opt, which we should omit
103 | return out[:, :, :50257]
104 | else:
105 | return out # [:, :, :self.tokenizer.vocab_size]
106 |
107 | def _model_generate(self, context, max_length, eos_token_id):
108 | return self.model.generate(
109 | context,
110 | max_length=max_length,
111 | eos_token_id=eos_token_id,
112 | do_sample=False
113 | )
114 |
115 |
--------------------------------------------------------------------------------
/mixquant/utils/module down_weight_only.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | eightbit_only_name = ["down_proj", "fc_out"]
3 |
4 | weight_only_map = {
5 | "GPTJForCausalLM": ["fc_out"],
6 | "LlamaForCausalLM": ["down_proj"],
7 | "AquilaForCausalLM": ["down_proj"],
8 | "BaichuanForCausalLM": ["down_proj"],
9 | "MistralForCausalLM": ["down_proj"],
10 | "FalconForCausalLM" : [],
11 |
12 | }
13 |
14 | def get_named_linears(module):
15 | return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}
16 |
17 | def get_op_by_name(module, op_name):
18 | # get the op by its name relative to the module
19 | for name, m in module.named_modules():
20 | if name == op_name:
21 | return m
22 | raise ValueError(f"Cannot find op {op_name} in module {module}")
23 |
24 |
25 | def set_op_by_name(layer, name, new_module):
26 | levels = name.split('.')
27 | if len(levels) > 1:
28 | mod_ = layer
29 | for l_idx in range(len(levels)-1):
30 | if levels[l_idx].isdigit():
31 | mod_ = mod_[int(levels[l_idx])]
32 | else:
33 | mod_ = getattr(mod_, levels[l_idx])
34 | setattr(mod_, levels[-1], new_module)
35 | else:
36 | setattr(layer, name, new_module)
37 |
38 |
39 | def get_op_name(module, op):
40 | # get the name of the op relative to the module
41 | for name, m in module.named_modules():
42 | if m is op:
43 | return name
44 | raise ValueError(f"Cannot find op {op} in module {module}")
45 |
46 |
47 | def append_str_prefix(x, prefix):
48 | if isinstance(x, str):
49 | return prefix + x
50 | elif isinstance(x, tuple):
51 | return tuple([append_str_prefix(y, prefix) for y in x])
52 | elif isinstance(x, list):
53 | return [append_str_prefix(y, prefix) for y in x]
54 | else:
55 | return x
--------------------------------------------------------------------------------
/mixquant/utils/module.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | eightbit_only_name = ["down_proj", "o_proj", "fc_out"]
3 |
4 | weight_only_map = {
5 | "GPTJForCausalLM": ["fc_out"],
6 | "LlamaForCausalLM": [],
7 | "AquilaForCausalLM": [],
8 | "BaichuanForCausalLM": [],
9 | "MistralForCausalLM": [],
10 | "FalconForCausalLM" : [],
11 |
12 | }
13 |
14 | def get_named_linears(module):
15 | return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}
16 |
17 | def get_op_by_name(module, op_name):
18 | # get the op by its name relative to the module
19 | for name, m in module.named_modules():
20 | if name == op_name:
21 | return m
22 | raise ValueError(f"Cannot find op {op_name} in module {module}")
23 |
24 |
25 | def set_op_by_name(layer, name, new_module):
26 | levels = name.split('.')
27 | if len(levels) > 1:
28 | mod_ = layer
29 | for l_idx in range(len(levels)-1):
30 | if levels[l_idx].isdigit():
31 | mod_ = mod_[int(levels[l_idx])]
32 | else:
33 | mod_ = getattr(mod_, levels[l_idx])
34 | setattr(mod_, levels[-1], new_module)
35 | else:
36 | setattr(layer, name, new_module)
37 |
38 |
39 | def get_op_name(module, op):
40 | # get the name of the op relative to the module
41 | for name, m in module.named_modules():
42 | if m is op:
43 | return name
44 | raise ValueError(f"Cannot find op {op} in module {module}")
45 |
46 |
47 | def append_str_prefix(x, prefix):
48 | if isinstance(x, str):
49 | return prefix + x
50 | elif isinstance(x, tuple):
51 | return tuple([append_str_prefix(y, prefix) for y in x])
52 | elif isinstance(x, list):
53 | return [append_str_prefix(y, prefix) for y in x]
54 | else:
55 | return x
--------------------------------------------------------------------------------
/mixquant/utils/module_.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | eightbit_only_name = ["down_proj", "fc_out"]
3 |
4 | weight_only_map = {
5 | "GPTJForCausalLM": ["fc_out"],
6 | "LlamaForCausalLM": [],
7 | "AquilaForCausalLM": ["down_proj"],
8 | "BaichuanForCausalLM": ["down_proj"],
9 | "MistralForCausalLM": ["down_proj"],
10 | "FalconForCausalLM" : [],
11 |
12 | }
13 |
14 | def get_named_linears(module):
15 | return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}
16 |
17 | def get_op_by_name(module, op_name):
18 | # get the op by its name relative to the module
19 | for name, m in module.named_modules():
20 | if name == op_name:
21 | return m
22 | raise ValueError(f"Cannot find op {op_name} in module {module}")
23 |
24 |
25 | def set_op_by_name(layer, name, new_module):
26 | levels = name.split('.')
27 | if len(levels) > 1:
28 | mod_ = layer
29 | for l_idx in range(len(levels)-1):
30 | if levels[l_idx].isdigit():
31 | mod_ = mod_[int(levels[l_idx])]
32 | else:
33 | mod_ = getattr(mod_, levels[l_idx])
34 | setattr(mod_, levels[-1], new_module)
35 | else:
36 | setattr(layer, name, new_module)
37 |
38 |
39 | def get_op_name(module, op):
40 | # get the name of the op relative to the module
41 | for name, m in module.named_modules():
42 | if m is op:
43 | return name
44 | raise ValueError(f"Cannot find op {op} in module {module}")
45 |
46 |
47 | def append_str_prefix(x, prefix):
48 | if isinstance(x, str):
49 | return prefix + x
50 | elif isinstance(x, tuple):
51 | return tuple([append_str_prefix(y, prefix) for y in x])
52 | elif isinstance(x, list):
53 | return [append_str_prefix(y, prefix) for y in x]
54 | else:
55 | return x
--------------------------------------------------------------------------------
/mixquant/utils/parallel.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import gc
4 | import logging
5 |
6 |
7 | def auto_parallel(args):
8 | model_size = args.model_path.split("-")[-1]
9 | if model_size.endswith("m"):
10 | model_gb = 1
11 | else:
12 | model_gb = float(model_size[:-1])
13 | if model_gb < 20:
14 | n_gpu = 1
15 | elif model_gb < 50:
16 | n_gpu = 4
17 | else:
18 | n_gpu = 8
19 | args.parallel = n_gpu > 1
20 | cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
21 | if isinstance(cuda_visible_devices, str):
22 | cuda_visible_devices = cuda_visible_devices.split(",")
23 | else:
24 | cuda_visible_devices = list(range(8))
25 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
26 | [str(dev) for dev in cuda_visible_devices[:n_gpu]])
27 | logging.debug("CUDA_VISIBLE_DEVICES: ", os.environ["CUDA_VISIBLE_DEVICES"])
28 | return cuda_visible_devices
29 |
--------------------------------------------------------------------------------
/mixquant/utils/utils.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import torch
3 | import accelerate
4 |
5 |
6 | def get_module_by_name_suffix(model, module_name: str):
7 | for name, module in model.named_modules():
8 | if name.endswith(module_name):
9 | return module
10 |
11 | def simple_dispatch_model(model, device_map):
12 | from accelerate.hooks import add_hook_to_module, AlignDevicesHook
13 |
14 | if "" in device_map:
15 | d = device_map[""]
16 | model = model.to(torch.device(d))
17 | model.hf_device_map = device_map
18 | return model
19 |
20 | tied_params = accelerate.utils.modeling.find_tied_parameters(model)
21 | if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
22 | main_device = "cpu"
23 | else:
24 | main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
25 |
26 | cpu_offload_group = [(n, d) for n, d in device_map.items() if d == "cpu"]
27 | prev_hook = None
28 | for idx, (n, d) in enumerate(cpu_offload_group):
29 | m = get_module_by_name_suffix(model, n)
30 | _, prev_hook = accelerate.cpu_offload_with_hook(m, execution_device=main_device, prev_module_hook=prev_hook)
31 | # set first cpu offload module's prev_module_hook to the last cpu offload module's hook
32 | if len(cpu_offload_group) > 1:
33 | get_module_by_name_suffix(model, cpu_offload_group[0][0])._hf_hook.prev_module_hook = prev_hook
34 |
35 | for n, d in device_map.items():
36 | m = get_module_by_name_suffix(model, n)
37 | if d != "cpu":
38 | d = torch.device(d)
39 | hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True)
40 | add_hook_to_module(m, hook)
41 | accelerate.utils.modeling.retie_parameters(model, tied_params)
42 | model.hf_device_map = device_map
43 |
44 | return model
45 |
46 | def set_module_name(model, name, value):
47 | if '.' in name:
48 | parent_name = name.rsplit('.', 1)[0]
49 | child_name = name[len(parent_name) + 1:]
50 | parent = model.get_submodule(parent_name)
51 | else:
52 | parent_name = ''
53 | parent = model
54 | child_name = name
55 |
56 | setattr(parent, child_name, value)
57 |
58 | def clear_memory(weight=None):
59 | if weight is not None:
60 | del weight
61 | gc.collect()
62 | torch.cuda.empty_cache()
63 |
64 | def compute_memory_used_pct(device):
65 | memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3)
66 | memory_pct = memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)) * 100
67 | return memory_pct
--------------------------------------------------------------------------------
/mmlu.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | CMD="srun -N 1 --gres=gpu:4090:1 --pty python "
4 | set -ex
5 | basepath=/home/chenyidong/data/mixqdata
6 | _dataset_path=/code/checkpoint/dataset
7 |
8 |
9 | data_type=$1
10 |
11 |
12 |
13 | models=( "falcon-7b" "vicuna-7b" "chatglm2-6b" )
14 | ngpu=1
15 | if [ ${data_type} == mix8 ]
16 | then
17 | for model in "${models[@]}"
18 | do
19 |
20 | echo ${model}
21 | bit=${data_type:3:3}
22 | CUDA_VISIBLE_DEVICES=0 ${CMD} examples/mmlu.py --model_type ${data_type} \
23 | --hf_model_dir ${basepath}/quant8/${model} \
24 | --data_dir ${basepath}/data/data
25 |
26 | done
27 | fi
28 |
29 |
30 | if [ ${data_type} == fp16 ] || [ ${data_type} == bitsandbytes ]
31 | then
32 | for model in "${models[@]}"
33 | do
34 | echo ${model}
35 | export TRANSFORMERS_VERBOSITY=error
36 | CUDA_VISIBLE_DEVICES=0 ${CMD} examples/mmlu.py --model_type ${data_type} --hf_model_dir ${basepath}/${model} \
37 | --hf_model_dir ${basepath}/${model} \
38 | --data_dir ${basepath}/data/data
39 | done
40 | fi
41 |
42 |
43 | if [ ${data_type} == awq ]
44 | then
45 |
46 | for model in "${models[@]}"
47 | do
48 | echo ${model}
49 |
50 | CUDA_VISIBLE_DEVICES=0 ${CMD} examples/mmlu.py --model_type ${data_type} \
51 | --hf_model_dir ${basepath}/${model}-AWQ --data_dir ${basepath}/data/data
52 |
53 | done
54 |
55 | fi
56 |
57 |
58 |
--------------------------------------------------------------------------------
/quant.sh:
--------------------------------------------------------------------------------
1 | export PYTHONPATH=$PYTHONPATH:/home/chenyidong/MIXQ
2 |
3 | if [ $2 == a100 ]
4 | then
5 | CMD=" srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python "
6 | #CMD=" python "
7 | fi
8 | if [ $2 == direct ]
9 | then
10 | CMD=" python "
11 | #CMD=" python "
12 | fi
13 |
14 | if [ $2 == h100 ]
15 | then
16 | CMD="srun -p twills -A h100 --gres=gpu:h100:1 --export=ALL python"
17 | fi
18 | if [ $2 == 4090 ]
19 | then
20 | CMD=" srun -N 1 --gres=gpu:4090:1 --pty python"
21 | fi
22 | #CMD=" python"
23 |
24 | set -x
25 |
26 | # model=65
27 | # CUDA_VISIBLE_DEVICES=$1 http_proxy=127.0.0.1:7890 https_proxy=127.0.0.1:7890 ${CMD} \
28 | # python examples/basic_quant_mix.py \
29 | # --model_path /home/dataset/llama-2/checkpoint/Llama-${model}b \
30 | # --quant_file /home/dataset/llama-2/checkpoint/quant/Llama-${model}b
31 |
32 |
33 | models=( "Baichuan2-7b" "Baichuan2-13b" "Aquila2-7b" "Llama-2-7b" "Mistral-7b" )
34 | models=( "llama-2-hf" )
35 | models=( "falcon-40b" )
36 | models=( "Llama-2-7b" "falcon-7b" "vicuna-7b" "chatglm2-6b" )
37 | quantpath=/home/chenyidong/data/mixqdata/quant
38 | modelpath=/home/chenyidong/data/mixqdata
39 |
40 | for bit in 8
41 | do
42 | for model in "${models[@]}"
43 | do
44 | echo ${model}
45 | ${CMD} examples/basic_quant_mix.py \
46 | --model_path ${modelpath}/${model} \
47 | --quant_file ${quantpath}${bit}/${model} --w_bit ${bit}
48 | done
49 | done
50 |
51 |
52 | # for bit in 4
53 | # do
54 | # for model in "${models[@]}"
55 | # do
56 | # echo ${model}
57 | # ${CMD} \
58 | # examples/basic_quant_quik.py \
59 | # --model_path ${modelpath}/${model} \
60 | # --quant_file ${quantpath}quik${bit}/${model} --w_bit ${bit}
61 | # done
62 | # done
63 |
64 |
65 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | datasets==2.14.7
2 | torch==2.4.0
3 | flash_attn==2.5.8
4 |
--------------------------------------------------------------------------------
/runalltokens.sh:
--------------------------------------------------------------------------------
1 | set -x
2 | srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python benchalltokens.py --model_type mix --model_path /home/dataset/quant/quant8/Llama-2-7b --quant_file /home/dataset/quant/quant8/Llama-2-7b --batch_size 512 --bit 8 --dataset_path /home/chenyidong/checkpoint/dataset
3 |
--------------------------------------------------------------------------------
/runlatency.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | CMD=" srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python "
4 | export http_proxy=127.0.0.1:7890
5 | export https_proxy=127.0.0.1:7890
6 | set -x
7 |
8 |
9 | bit=4
10 | for batch in 512
11 | #for batch in 1
12 |
13 | do
14 | for seq in 64
15 | do
16 | ##model_type=Aquila2
17 | #model_type=opt
18 | #model_type=Mistral
19 | #model_type=gpt-j
20 | #model_type=falcon
21 | model_type=Llama-2
22 |
23 |
24 | models=( "Llama-2-7b" )
25 | data_types=( "mix" )
26 |
27 | for data_type in "${data_types[@]}"
28 | do
29 | for model in "${models[@]}"
30 | do
31 | echo ${model}
32 | CUDA_VISIBLE_DEVICES=$1 http_proxy=127.0.0.1:7890 https_proxy=127.0.0.1:7890 \
33 | ${CMD} benchlatency.py --model_type ${data_type} --model_path \
34 | /home/dataset/quant${bit}/${model} \
35 | --quant_file /home/dataset/quant${bit}/${model} \
36 | --batch_size ${batch} --bit ${bit}
37 |
38 | done
39 | done
40 |
41 | data_types=( "fp16" , "bitsandbytes" )
42 | for data_type in "${data_types[@]}"
43 | do
44 | for model in "${models[@]}"
45 | do
46 | echo ${model}
47 | CUDA_VISIBLE_DEVICES=$1 http_proxy=127.0.0.1:7890 https_proxy=127.0.0.1:7890 \
48 | ${CMD} benchflops.py --model_type ${data_type} --model_path \
49 | /mnt/octave/data/chenyidong/checkpoint/${model} \
50 | --quant_file /mnt/octave/data/chenyidong/checkpoint/${model} --batch_size ${batch}
51 |
52 |
53 | done
54 | done
55 | data_types=( "awq" )
56 | for data_type in "${data_types[@]}"
57 | do
58 | for model in "${models[@]}"
59 | do
60 | echo ${model}
61 | CUDA_VISIBLE_DEVICES=$1 http_proxy=127.0.0.1:7890 https_proxy=127.0.0.1:7890 \
62 | ${CMD} benchflops.py --model_type ${data_type} --model_path \
63 | /mnt/octave/data/chenyidong/checkpoint/awqquant/${model} \
64 | --quant_file /mnt/octave/data/chenyidong/checkpoint/awqquant/${model} --batch_size ${batch}
65 | done
66 | done
67 |
68 |
69 |
70 | done
71 | done
72 |
--------------------------------------------------------------------------------
/runppl.sh:
--------------------------------------------------------------------------------
1 | if [ $2 == a100 ]
2 | then
3 | CMD=" srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python "
4 | #CMD=" python "
5 | fi
6 | if [ $2 == direct ]
7 | then
8 | CMD=" python "
9 | #CMD=" python "
10 | fi
11 |
12 | if [ $2 == h100 ]
13 | then
14 | CMD="srun -p twills -A h100 --gres=gpu:h100:1 --export=ALL python"
15 | fi
16 | if [ $2 == 4090 ]
17 | then
18 | CMD=" srun -N 1 --gres=gpu:4090:1 --pty python"
19 | fi
20 | set -x
21 |
22 | quantpath=/home/dataset/quant
23 | modelpath=/home/dataset
24 | dataset_path=/home/dataset/quant/checkpoint/dataset
25 |
26 | model=$3
27 | data_type=$4
28 | down_weight_only=0
29 |
30 | for batch in 512
31 | do
32 | for seq in 64
33 | do
34 |
35 |
36 | if [ ${data_type} == fp16 ]
37 | then
38 |
39 |
40 | echo ${model}
41 | CUDA_VISIBLE_DEVICES=$1 ${CMD} evalppl.py --fp_features_num 128 --model_type ${data_type} --model_path \
42 | ${modelpath}/${model} \
43 | --quant_file ${modelpath}/${model} \
44 | --n_ctx $batch --n_batch $batch --eval_accuracy True --dataset_path ${dataset_path}
45 |
46 |
47 |
48 | fi
49 |
50 | if [ ${data_type} == bitsandbytes ]
51 | then
52 |
53 |
54 | echo ${model}
55 | CUDA_VISIBLE_DEVICES=$1 ${CMD} evalppl.py --fp_features_num 128 --model_type ${data_type} --model_path \
56 | ${modelpath}/${model} \
57 | --quant_file ${modelpath}/${model} \
58 | --n_ctx $batch --n_batch $batch --eval_accuracy True --dataset_path ${dataset_path}
59 |
60 |
61 |
62 | fi
63 |
64 | if [ ${data_type} == awq ]
65 | then
66 | pip install transformers==4.35
67 |
68 | echo ${model}
69 | CUDA_VISIBLE_DEVICES=$1 \
70 | ${CMD} evalppl.py --model_type ${data_type} --model_path \
71 | ${quantpath}/awqquant/${model} \
72 | --quant_file ${quantpath}/awqquant/${model} \
73 | --n_ctx $batch --n_batch $batch --dataset_path ${dataset_path} --eval_accuracy True
74 |
75 | pip install transformers==4.41.2
76 |
77 | fi
78 |
79 | if [ ${data_type} == mix8 ]
80 | then
81 | bit=8
82 | echo "---------run mix 8--------"
83 |
84 | echo ${model}
85 | if [ ${down_weight_only} == 1 ]
86 | then
87 | rm -r ${quantpath}/quant${bit}/down_weight_only/${model}/model.safetensors
88 | CUDA_VISIBLE_DEVICES=$1 \
89 | ${CMD} evalppl.py --model_type ${data_type} --model_path \
90 | ${quantpath}/quant${bit}/down_weight_only/${model} \
91 | --quant_file ${quantpath}/quant${bit}/down_weight_only/${model} \
92 | --n_ctx ${batch} --n_batch $batch --dataset_path ${dataset_path} --eval_accuracy True
93 | fi
94 | if [ ${down_weight_only} == 0 ]
95 | then
96 | rm -r ${quantpath}/quant${bit}/${model}/model.safetensors
97 | CUDA_VISIBLE_DEVICES=$1 \
98 | ${CMD} evalppl.py --model_type ${data_type} --model_path \
99 | ${quantpath}/quant${bit}/${model} \
100 | --quant_file ${quantpath}/quant${bit}/${model} \
101 | --n_ctx ${batch} --n_batch $batch --dataset_path ${dataset_path} --eval_accuracy True
102 | fi
103 |
104 | fi
105 | if [ ${data_type} == mix4 ]
106 | then
107 | bit=4
108 | echo "---------run mix 4--------"
109 |
110 | rm -r ${quantpath}/quant${bit}/down_weight_only/${model}/model.safetensors
111 | echo ${model}
112 | CUDA_VISIBLE_DEVICES=$1 \
113 | ${CMD} evalppl.py --model_type ${data_type} --model_path \
114 | ${quantpath}/quant${bit}/down_weight_only/${model} \
115 | --quant_file ${quantpath}/quant${bit}/down_weight_only/${model} \
116 | --n_ctx ${batch} --n_batch $batch --dataset_path ${dataset_path} --eval_accuracy True
117 |
118 |
119 |
120 | fi
121 |
122 | if [ ${data_type} == quik ]
123 | then
124 | bit=4
125 | echo "---------run quik 4--------"
126 |
127 |
128 | echo ${model}
129 | CUDA_VISIBLE_DEVICES=$1 \
130 | ${CMD} evalppl.py --model_type ${data_type} --model_path \
131 | ${quantpath}/quantquik${bit}/${model} \
132 | --quant_file ${quantpath}/quantquik${bit}/${model} \
133 | --n_ctx ${batch} --n_batch $batch --dataset_path ${dataset_path} --eval_accuracy True
134 |
135 |
136 |
137 | fi
138 | done
139 | done
--------------------------------------------------------------------------------
/runthroughput.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | if [ $2 == a100 ]
4 | then
5 | CMD=" srun -N 1 --pty --gres=gpu:a100:1 -p octave -A public python "
6 | fi
7 |
8 | if [ $2 == h100 ]
9 | then
10 | CMD="srun -p twills -A h100 --gres=gpu:h100:1 --export=ALL python"
11 | fi
12 |
13 | export http_proxy=127.0.0.1:7890
14 | export https_proxy=127.0.0.1:7890
15 | set -x
16 |
17 | quantpath=/home/dataset/quant/quant
18 | modelpath=/home/dataset
19 |
20 | for batch in 32
21 | #for batch in 1
22 |
23 | do
24 | for seq in 1024
25 | do
26 | ##model_type=Aquila2
27 | #model_type=opt
28 | #model_type=Mistral
29 | #model_type=gpt-j
30 | #model_type=falcon
31 | model_type=$3
32 |
33 |
34 |
35 | # data_types=( "mix" )
36 | # for bit in 8
37 | # do
38 | # for data_type in "${data_types[@]}"
39 | # do
40 | # model=${model_type}
41 | # echo ${model}
42 | # rm -r ${quantpath}${bit}/${model}/model.safetensors
43 | # CUDA_VISIBLE_DEVICES=$1 ${CMD} benchflops.py --model_type ${data_type} --model_path \
44 | # ${quantpath}${bit}/${model} \
45 | # --quant_file ${quantpath}${bit}/${model} \
46 | # --batch_size ${batch} --bit ${bit} --dataset_path /home/chenyidong/checkpoint/dataset
47 | # done
48 | # done
49 |
50 |
51 | # data_types=( "quik" )
52 | # for bit in 4
53 | # do
54 | # for data_type in "${data_types[@]}"
55 | # do
56 | # model=${model_type}
57 | # echo ${model}
58 | # CUDA_VISIBLE_DEVICES=$1 ${CMD} benchflops.py --model_type ${data_type} --model_path \
59 | # ${quantpath}quik${bit}/${model} \
60 | # --quant_file ${quantpath}quik${bit}/${model} \
61 | # --batch_size ${batch} --bit ${bit} --dataset_path /home/chenyidong/checkpoint/dataset
62 | # done
63 | # done
64 | data_types=( "fp16" "bitsandbytes" )
65 | for data_type in "${data_types[@]}"
66 | do
67 | model=${model_type}
68 |
69 | echo ${model}
70 | CUDA_VISIBLE_DEVICES=$1 ${CMD} benchflops.py --model_type ${data_type} --model_path \
71 | ${modelpath}/${model} \
72 | --quant_file ${modelpath}/${model} --batch_size ${batch} --dataset_path /home/chenyidong/checkpoint/dataset
73 |
74 | done
75 | # data_types=( "awq" )
76 | # for data_type in "${data_types[@]}"
77 | # do
78 | # for model in "${models[@]}"
79 | # do
80 | # echo ${model}
81 | # CUDA_VISIBLE_DEVICES=$1 ${CMD} benchflops.py --model_type ${data_type} --model_path \
82 | # ${quantpath}/awqquant/${model} \
83 | # --quant_file ${quantpath}t/awqquant/${model} --batch_size ${batch}
84 | # done
85 | # done
86 |
87 |
88 |
89 | done
90 | done
91 |
--------------------------------------------------------------------------------
/utils/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .perplexity_utils import Perplexity
--------------------------------------------------------------------------------
/utils/utils/exllama_utils.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import torch
3 |
4 | def exllama_set_max_input_length(model, max_input_length: int):
5 | """
6 | This method does not necessarily require `model` to inherit from BaseGPTQForCausalLM.
7 |
8 | When using the exllama backend with act-order, it is necessary to initialize a buffer that depends on the maximum expected input length. In case the
9 | default used (EXLLAMA_DEFAULT_MAX_INPUT_LENGTH) is too short, this method can be called to extend the buffer size without reloading the whole model.
10 | """
11 |
12 | # The import is set here to avoid a global import. Arguably this is quite ugly, it would be better to have lazy loading.
13 | from exllama_kernels import prepare_buffers, cleanup_buffers_cuda
14 |
15 | if not model.quantize_config.desc_act:
16 | raise ValueError("The method exllama_set_max_input_length should be called only when using the exllama backend **with act-order**.")
17 |
18 | device_to_buffers_size = {}
19 | for device, buffers in model.device_to_buffers.items():
20 | device_to_buffers_size[device] = {"max_dq_buffer_size": buffers["max_dq_buffer_size"], "max_inner_outer_dim": buffers["max_inner_outer_dim"]}
21 |
22 | # For an unknown reason calling just `del model.device_to_buffers` raises an AttributeError.
23 | for key in list(model.device_to_buffers.keys()):
24 | del model.device_to_buffers[key]
25 | model.device_to_buffers = None
26 | del model.device_to_buffers
27 |
28 | gc.collect()
29 | torch.cuda.empty_cache()
30 | cleanup_buffers_cuda()
31 |
32 | device_to_buffers = {}
33 | for device, buffers_size in device_to_buffers_size.items():
34 | # The temp_state buffer is required to reorder X in the act-order case.
35 | # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
36 | device_to_buffers[device] = {
37 | "temp_state": torch.zeros((max_input_length, buffers_size["max_inner_outer_dim"]), dtype=torch.float16, device=device),
38 | "temp_dq": torch.zeros((1, buffers_size["max_dq_buffer_size"]), dtype=torch.float16, device=device),
39 | "max_dq_buffer_size": buffers_size["max_dq_buffer_size"],
40 | "max_inner_outer_dim": buffers_size["max_inner_outer_dim"],
41 | }
42 |
43 | prepare_buffers(device, device_to_buffers[device]["temp_state"], device_to_buffers[device]["temp_dq"])
44 |
45 | # Buffers need to be persistent to avoid any bug.
46 | model.device_to_buffers = device_to_buffers
47 |
48 | return model
49 |
--------------------------------------------------------------------------------
/utils/utils/import_utils.py:
--------------------------------------------------------------------------------
1 | from packaging.version import parse as parse_version
2 | from logging import getLogger
3 | import torch
4 |
5 | try:
6 | import triton
7 |
8 | TRITON_AVAILABLE = True
9 | except ImportError:
10 | TRITON_AVAILABLE = False
11 |
12 | try:
13 | import autogptq_cuda_256
14 | import autogptq_cuda_64
15 |
16 | AUTOGPTQ_CUDA_AVAILABLE = True
17 | except:
18 | AUTOGPTQ_CUDA_AVAILABLE = False
19 |
20 |
21 | try:
22 | import exllama_kernels
23 |
24 | EXLLAMA_KERNELS_AVAILABLE = True
25 | except:
26 | EXLLAMA_KERNELS_AVAILABLE = False
27 |
28 | try:
29 | import exllamav2_kernels
30 |
31 | EXLLAMAV2_KERNELS_AVAILABLE = True
32 | except:
33 | EXLLAMAV2_KERNELS_AVAILABLE = False
34 |
35 | try:
36 | import cQIGen as qinfer
37 |
38 | QIGEN_AVAILABLE = True
39 | except:
40 | QIGEN_AVAILABLE = False
41 |
42 | logger = getLogger(__name__)
43 |
44 |
45 | def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool = True, disable_exllamav2:bool = False, use_qigen: bool = False):
46 | if use_qigen:
47 | from ..nn_modules.qlinear.qlinear_qigen import QuantLinear
48 | else:
49 | if use_triton:
50 | if torch.version.hip:
51 | logger.warning("Running GPTQ triton version on AMD GPUs is untested and may result in errors or wrong predictions. Please use use_triton=False.")
52 |
53 | from ..nn_modules.qlinear.qlinear_triton import QuantLinear
54 | else:
55 | if bits == 4 and not disable_exllamav2 and EXLLAMAV2_KERNELS_AVAILABLE:
56 | from ..nn_modules.qlinear.qlinear_exllamav2 import QuantLinear
57 | elif bits == 4 and not disable_exllama and EXLLAMA_KERNELS_AVAILABLE:
58 | from ..nn_modules.qlinear.qlinear_exllama import QuantLinear
59 | elif not desc_act or group_size == -1:
60 | from ..nn_modules.qlinear.qlinear_cuda_old import QuantLinear
61 | else:
62 | from ..nn_modules.qlinear.qlinear_cuda import QuantLinear
63 |
64 | return QuantLinear
65 |
66 |
67 | def compare_transformers_version(
68 | version: str = "v4.28.0",
69 | op: str = "eq"
70 | ):
71 | assert op in ["eq", "lt", "le", "gt", "ge"]
72 |
73 | from transformers import __version__
74 |
75 | return getattr(parse_version(__version__), f"__{op}__")(parse_version(version))
76 |
77 |
78 | def compare_pytorch_version(
79 | version: str = "v2.0.0",
80 | op: str = "eq"
81 | ):
82 | assert op in ["eq", "lt", "le", "gt", "ge"]
83 |
84 | from torch import __version__
85 |
86 | return getattr(parse_version(__version__), f"__{op}__")(parse_version(version))
87 |
--------------------------------------------------------------------------------
/utils/utils/perplexity_utils.py:
--------------------------------------------------------------------------------
1 | # this code is taken from GPTQ:
2 | import sys
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 | from datasets import load_dataset
7 | from transformers import AutoTokenizer, AutoModelForCausalLM
8 | import os
9 |
10 | class Perplexity:
11 | """
12 | A class for calculating the perplexity of a language model.
13 | """
14 |
15 | def __init__(self, model, tokenizer, dataset_path='wikitext', dataset_name=None,
16 | split='test', text_column='text',
17 | eval_accuracy = True):
18 | """
19 | Calculate perplexity using the same method as seen in llama.cpp.
20 |
21 | Parameters
22 | ----------
23 | model : AutoModelForCausalLM
24 | The language model for which the perplexity is calculated.
25 | tokenizer : AutoTokenizer
26 | The tokenizer corresponding to the model.
27 | device : str, optional
28 | The device to run the calculations on. If auto, the device that your model uses
29 | will be the device used for these calculations. Default is 'auto'.
30 | dataset_path : str, optional
31 | The path to the dataset on the Hugging Face dataset hub. Default is 'wikitext'.
32 | dataset_name : str, optional
33 | The name of the dataset. Default is None.
34 | split : str, optional
35 | The split of the dataset to use. Default is 'test'.
36 | text_column : str, optional
37 | The name of the column in the dataset that contains the text data. Default is 'text'.
38 | """
39 | self._model = model
40 | self._tokenizer = tokenizer
41 | self._dataset_path = dataset_path
42 | self._dataset_name = dataset_name
43 | self._split = split
44 | self._text_column = text_column
45 | self._text = self._prepare_data()
46 | self.eval_accuracy = eval_accuracy
47 |
48 |
49 |
50 | def _get_device(self):
51 | if torch.backends.mps.is_available():
52 | return 'mps'
53 | elif torch.cuda.is_available():
54 | return 'cuda:0'
55 | else:
56 | return 'cpu'
57 |
58 | def _prepare_data(self):
59 | """
60 | Prepares the dataset by loading and formatting.
61 |
62 | Returns
63 | -------
64 | str
65 | The formatted dataset as a single string.
66 | """
67 | _dataset_name = 'wikitext-2-raw-v1'
68 | _dataset_path = self._dataset_path
69 | if _dataset_path == 'wikitext':
70 | _dataset_name = 'wikitext-2-raw-v1'
71 | if _dataset_path == 'c4':
72 | _dataset_name = 'realnewslike'
73 | # Load the dataset
74 | print(_dataset_path)
75 |
76 | #data = load_dataset(data_dir=_dataset_path, name='wikitext-2-raw-v1', split=_split)
77 | data = load_dataset(os.path.join(_dataset_path,"wikitext"), name='wikitext-2-raw-v1',
78 | cache_dir=".cache", split=self._split)
79 | # Format the text column of the dataset
80 | text_list = [' \n' if s == '' else s for s in data[self._text_column]]
81 | return ''.join(text_list)
82 |
83 | @staticmethod
84 | def softmax(logits):
85 | """
86 | Static method for applying the softmax function.
87 |
88 | Parameters
89 | ----------
90 | logits : np.ndarray
91 | The input to the softmax function.
92 |
93 | Returns
94 | -------
95 | np.ndarray
96 | The output of the softmax function.
97 | """
98 | e_x = np.exp(logits - np.max(logits))
99 | return e_x / e_x.sum(axis=0)
100 |
101 | def calculate_perplexity(self, n_ctx=512, n_batch=512):
102 | """
103 | Calculates the perplexity of the language model.
104 |
105 | Parameters
106 | ----------
107 | n_ctx : int
108 | The context size.
109 | n_batch : int
110 | The batch size.
111 |
112 | Returns
113 | -------
114 | list
115 | The list of perplexity scores calculated.
116 | """
117 | # Tokenize the text
118 | self._tokenizer.model_max_length = sys.maxsize
119 | tokens = self._tokenizer(self._text, truncation=False, return_tensors='pt').input_ids.to('cuda')
120 |
121 | nll = 0.0 # Negative log likelihood
122 | count = 0 # Counter for processed tokens
123 | curr_ppl = 0
124 | all_perplexity = []
125 |
126 | with tqdm(range(len(tokens[0]) // n_ctx), desc="Perplexity: - ") as progress:
127 | for i in progress:
128 | # Process each batch of tokens
129 | nll, count = self._process_batch(i, n_ctx, n_batch, tokens, nll, count)
130 |
131 |
132 | # Calculate and display the current perplexity
133 | if self.eval_accuracy is True:
134 | curr_ppl = np.exp(nll / count)
135 | else:
136 | curr_ppl = 0
137 | all_perplexity.append(curr_ppl)
138 | progress.set_description(f"Perplexity: {curr_ppl:.4f}")
139 | #print("-----------------------")
140 |
141 |
142 | return all_perplexity
143 |
144 | def _process_batch(self, i, n_ctx, n_batch, tokens, nll, count):
145 | """
146 | Processes each batch of tokens.
147 |
148 | Parameters
149 | ----------
150 | i : int
151 | The batch index.
152 | n_ctx : int
153 | The context size.
154 | n_batch : int
155 | The batch size.
156 | tokens : torch.Tensor
157 | The tokenized text.
158 | nll : float
159 | The current negative log likelihood.
160 | count : int
161 | The current count of processed tokens.
162 |
163 | Returns
164 | -------
165 | float
166 | The updated negative log likelihood.
167 | int
168 | The updated count of processed tokens.
169 | """
170 | start = i * n_ctx
171 | end = start + n_ctx
172 |
173 | num_batches = (n_ctx + n_batch - 1) // n_batch
174 |
175 | logits = []
176 | if self._tokenizer.bos_token_id is None:
177 | self._tokenizer.bos_token_id = 11
178 | for j in range(num_batches):
179 | batch_start = start + j * n_batch
180 | batch_size = min(end - batch_start, n_batch)
181 |
182 | token_org = tokens[0][batch_start].item()
183 |
184 | if j == 0:
185 | # Replace the first token with the BOS token
186 | tokens[0][batch_start] = self._tokenizer.bos_token_id
187 |
188 | # Compute the logits for the current batch of tokens
189 | batch_logits = self._compute_batch_logits(tokens, batch_start, batch_size)
190 |
191 | tokens[0][batch_start] = token_org
192 |
193 | logits.append(batch_logits)
194 |
195 | # We rely on the fact that attention in the forward pass only looks at previous
196 | # tokens here, so the logits returned for each token are an accurate representation
197 | # of what the model would have predicted at that point.
198 | #
199 | # Example, we have a context window of 512, we will compute perplexity for each of the
200 | # last 256 tokens. Then, we split the input up into context window size chunks to
201 | # process the entire prompt.
202 | if self.eval_accuracy is True:
203 | for j in range(min(512, n_ctx // 2), n_ctx - 1):
204 | tok_logits = logits[0][0][j]
205 | # Compute the probability of the next token
206 | prob = torch.softmax(tok_logits,dim=-1)[tokens[0][start + j + 1]]
207 | prob = prob.to(torch.float16).cpu().numpy()
208 | # Update the negative log likelihood and the count of processed tokens
209 | nll += -np.log(prob, where=prob>0)
210 | count += 1
211 |
212 | return nll, count
213 |
214 | def _compute_batch_logits(self, tokens, batch_start, batch_size):
215 | """
216 | Computes the logits for a batch of tokens.
217 |
218 | Parameters
219 | ----------
220 | tokens : torch.Tensor
221 | The tokenized text.
222 | batch_start : int
223 | The start index of the batch.
224 | batch_size : int
225 | The size of the batch.
226 |
227 | Returns
228 | -------
229 | torch.Tensor
230 | The logits for the batch of tokens.
231 | """
232 | # Compute the logits without keeping track of gradients
233 | with torch.no_grad():
234 | outputs = self._model(tokens[:, batch_start:batch_start+batch_size])
235 | return outputs.logits.detach()
236 |
--------------------------------------------------------------------------------