├── flexible_quant ├── flexible_quant │ ├── __init__.py │ └── vanilla_quantizer.py └── setup.py ├── .gitmodules ├── .gitignore ├── helper_scripts ├── preset_parser.py ├── sh_gen_gaokaobench.py ├── sh_gen_baseline.py ├── sh_gen_lmeval.py ├── create_table.py └── sh_gen_presets.py ├── calibration_presets.old ├── Qwen2.5-7B-Instruct_KVTuner4_0.yaml ├── Qwen2.5-7B-Instruct_KVTuner4_1.yaml ├── Qwen2.5-7B-Instruct_KVTuner6_0.yaml ├── Qwen2.5-7B-Instruct_KVTuner6_1.yaml ├── Mistral-7B-Instruct-v0.3_KVTuner4_0.yaml ├── Mistral-7B-Instruct-v0.3_KVTuner4_1.yaml ├── Mistral-7B-Instruct-v0.3_KVTuner6_0.yaml ├── Mistral-7B-Instruct-v0.3_KVTuner6_1.yaml ├── Meta-Llama-3.1-8B-Instruct_KVTuner4_0.yaml ├── Meta-Llama-3.1-8B-Instruct_KVTuner4_1.yaml ├── Meta-Llama-3.1-8B-Instruct_KVTuner6_0.yaml ├── Meta-Llama-3.1-8B-Instruct_KVTuner6_1.yaml ├── Qwen2.5-3B-Instruct-AWQ_KVTuner4_0.yaml ├── Qwen2.5-3B-Instruct-AWQ_KVTuner4_1.yaml ├── Qwen2.5-3B-Instruct-AWQ_KVTuner6_0.yaml └── Qwen2.5-3B-Instruct-AWQ_KVTuner6_1.yaml ├── calibration_presets ├── Qwen2.5-7B-Instruct_kivi_KVTuner4_0.yaml ├── Qwen2.5-7B-Instruct_kivi_KVTuner4_1.yaml ├── Qwen2.5-7B-Instruct_kivi_KVTuner6_0.yaml ├── Qwen2.5-7B-Instruct_kivi_KVTuner6_1.yaml ├── Qwen2.5-7B-Instruct_pertoken_KVTuner4_0.yaml ├── Qwen2.5-7B-Instruct_pertoken_KVTuner4_1.yaml ├── Qwen2.5-7B-Instruct_pertoken_KVTuner6_0.yaml ├── Qwen2.5-7B-Instruct_pertoken_KVTuner6_1.yaml ├── Meta-Llama-3.1-8B-Instruct_kivi_KVTuner4_0.yaml ├── Meta-Llama-3.1-8B-Instruct_kivi_KVTuner4_1.yaml ├── Meta-Llama-3.1-8B-Instruct_kivi_KVTuner6_0.yaml ├── Meta-Llama-3.1-8B-Instruct_kivi_KVTuner6_1.yaml ├── Mistral-7B-Instruct-v0.3_kivi_KVTuner4_0.yaml ├── Mistral-7B-Instruct-v0.3_kivi_KVTuner4_1.yaml ├── Mistral-7B-Instruct-v0.3_kivi_KVTuner6_0.yaml ├── Mistral-7B-Instruct-v0.3_kivi_KVTuner6_1.yaml ├── Mistral-7B-Instruct-v0.3_pertoken_KVTuner4_0.yaml ├── Mistral-7B-Instruct-v0.3_pertoken_KVTuner4_1.yaml ├── Mistral-7B-Instruct-v0.3_pertoken_KVTuner6_0.yaml ├── Mistral-7B-Instruct-v0.3_pertoken_KVTuner6_1.yaml ├── Meta-Llama-3.1-8B-Instruct_pertoken_KVTuner4_0.yaml ├── Meta-Llama-3.1-8B-Instruct_pertoken_KVTuner4_1.yaml ├── Meta-Llama-3.1-8B-Instruct_pertoken_KVTuner6_0.yaml ├── Meta-Llama-3.1-8B-Instruct_pertoken_KVTuner6_1.yaml ├── Qwen2.5-3B-Instruct_kivi_KVTuner4_0.yaml ├── Qwen2.5-3B-Instruct_kivi_KVTuner4_1.yaml ├── Qwen2.5-3B-Instruct_kivi_KVTuner6_0.yaml ├── Qwen2.5-3B-Instruct_kivi_KVTuner6_1.yaml ├── Qwen2.5-3B-Instruct_pertoken_KVTuner4_0.yaml ├── Qwen2.5-3B-Instruct_pertoken_KVTuner4_1.yaml ├── Qwen2.5-3B-Instruct_pertoken_KVTuner6_0.yaml └── Qwen2.5-3B-Instruct_pertoken_KVTuner6_1.yaml ├── config ├── meta-llama_Meta-Llama-3-8B-Instruct_k8_v4._per_layer.yaml └── meta-llama_Meta-Llama-3-8B-Instruct_k8_v4._per_head.yaml ├── flexible_quant_example.py ├── benckmarks ├── example_gsm8k_cot_manyshot.py ├── evals │ └── gsm8k_utils.py ├── gaokao_bench_obj.py ├── test_gaokaobench.py └── pred_longbench.py ├── search_optuna_vanilla.py ├── search_brute_force.py ├── README.md └── search_optuna_adaptive.py /flexible_quant/flexible_quant/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | vanilla_quantizer, 3 | flexible_quantized_cache, 4 | ) -------------------------------------------------------------------------------- /flexible_quant/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="flexible_quant", 5 | packages=find_packages(), 6 | install_requires=["transformers"], 7 | ) -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "lm-evaluation-harness-X"] 2 | path = lm-evaluation-harness-X 3 | url = git@github.com:cmd2001/lm-evaluation-harness-X.git 4 | [submodule "GAOKAO-Bench"] 5 | path = GAOKAO-Bench 6 | url = git@github.com:OpenLMLab/GAOKAO-Bench.git 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | outputs 3 | logs 4 | cached_models 5 | *.egg-info 6 | build 7 | *.so 8 | out_* 9 | *.sh 10 | *.jsonl 11 | *.pt 12 | notebook 13 | dataset 14 | output_samples 15 | *.pdf 16 | *.json 17 | *.jsonl 18 | third_party 19 | speed_logs 20 | pred 21 | pred_e 22 | .locks/ 23 | models_storage 24 | -------------------------------------------------------------------------------- /helper_scripts/preset_parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | 4 | def merge_kv(nbits_key, nbits_value): 5 | if nbits_key != nbits_value: 6 | return f"K{nbits_key}V{nbits_value}" 7 | return f"KV{nbits_key}" 8 | 9 | kv_to_layer = { 10 | } 11 | def get_precision(filename: str): 12 | with open(filename, 'r') as f: 13 | data = yaml.load(f, Loader=yaml.FullLoader) 14 | ret = 0 15 | for layer_id, v in data.items(): 16 | ret += v['nbits_key'] + v['nbits_value'] 17 | if merge_kv(v['nbits_key'], v['nbits_value']) not in kv_to_layer: 18 | kv_to_layer[merge_kv(v['nbits_key'], v['nbits_value'])] = [] 19 | kv_to_layer[merge_kv(v['nbits_key'], v['nbits_value'])].append(layer_id) 20 | ret /= len(data) * 2 21 | return ret 22 | 23 | calibration_presets = os.listdir('./calibration_presets') 24 | 25 | for preset in calibration_presets: 26 | kv_to_layer = {} 27 | if 'Mistral' in preset: 28 | continue 29 | print(f'Precision for {preset}: {get_precision(os.path.join("./calibration_presets", preset))}') 30 | print(kv_to_layer) -------------------------------------------------------------------------------- /calibration_presets.old/Qwen2.5-7B-Instruct_KVTuner4_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 2 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 8 12 | nbits_value: 2 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 2 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 8 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 2 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 2 85 | -------------------------------------------------------------------------------- /calibration_presets.old/Qwen2.5-7B-Instruct_KVTuner4_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 2 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 8 12 | nbits_value: 2 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 2 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 8 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 2 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 2 85 | -------------------------------------------------------------------------------- /calibration_presets.old/Qwen2.5-7B-Instruct_KVTuner6_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 2 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 8 12 | nbits_value: 2 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 2 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 2 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 8 30 | nbits_value: 8 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 8 39 | nbits_value: 8 40 | 13: 41 | nbits_key: 8 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 8 51 | nbits_value: 8 52 | 17: 53 | nbits_key: 8 54 | nbits_value: 8 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 8 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 8 66 | nbits_value: 8 67 | 22: 68 | nbits_key: 8 69 | nbits_value: 8 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 2 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 8 81 | nbits_value: 8 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 2 85 | -------------------------------------------------------------------------------- /calibration_presets.old/Qwen2.5-7B-Instruct_KVTuner6_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 2 4 | 1: 5 | nbits_key: 8 6 | nbits_value: 8 7 | 2: 8 | nbits_key: 8 9 | nbits_value: 8 10 | 3: 11 | nbits_key: 8 12 | nbits_value: 2 13 | 4: 14 | nbits_key: 8 15 | nbits_value: 8 16 | 5: 17 | nbits_key: 8 18 | nbits_value: 8 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 8 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 8 27 | nbits_value: 8 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 8 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 8 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 8 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 8 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 8 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 8 75 | nbits_value: 8 76 | 25: 77 | nbits_key: 8 78 | nbits_value: 8 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 2 85 | -------------------------------------------------------------------------------- /calibration_presets/Qwen2.5-7B-Instruct_kivi_KVTuner4_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 2 6 | nbits_value: 2 7 | 2: 8 | nbits_key: 8 9 | nbits_value: 8 10 | 3: 11 | nbits_key: 2 12 | nbits_value: 2 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 8 21 | nbits_value: 8 22 | 7: 23 | nbits_key: 2 24 | nbits_value: 2 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 2 30 | nbits_value: 2 31 | 10: 32 | nbits_key: 2 33 | nbits_value: 2 34 | 11: 35 | nbits_key: 8 36 | nbits_value: 8 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 2 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 2 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 8 48 | nbits_value: 8 49 | 16: 50 | nbits_key: 2 51 | nbits_value: 2 52 | 17: 53 | nbits_key: 8 54 | nbits_value: 8 55 | 18: 56 | nbits_key: 2 57 | nbits_value: 2 58 | 19: 59 | nbits_key: 2 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 2 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 2 66 | nbits_value: 2 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 4 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 2 84 | nbits_value: 2 85 | -------------------------------------------------------------------------------- /calibration_presets/Qwen2.5-7B-Instruct_kivi_KVTuner4_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 2 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 8 9 | nbits_value: 8 10 | 3: 11 | nbits_key: 2 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 2 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 2 19 | 6: 20 | nbits_key: 8 21 | nbits_value: 8 22 | 7: 23 | nbits_key: 2 24 | nbits_value: 2 25 | 8: 26 | nbits_key: 2 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 2 30 | nbits_value: 2 31 | 10: 32 | nbits_key: 2 33 | nbits_value: 2 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 2 40 | 13: 41 | nbits_key: 2 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 2 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 2 51 | nbits_value: 2 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 2 57 | nbits_value: 2 58 | 19: 59 | nbits_key: 2 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 2 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 2 66 | nbits_value: 2 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 2 79 | 26: 80 | nbits_key: 2 81 | nbits_value: 2 82 | 27: 83 | nbits_key: 2 84 | nbits_value: 2 85 | -------------------------------------------------------------------------------- /calibration_presets/Qwen2.5-7B-Instruct_kivi_KVTuner6_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 2 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 8 9 | nbits_value: 8 10 | 3: 11 | nbits_key: 2 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 8 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 8 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 2 21 | nbits_value: 2 22 | 7: 23 | nbits_key: 8 24 | nbits_value: 8 25 | 8: 26 | nbits_key: 2 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 8 30 | nbits_value: 8 31 | 10: 32 | nbits_key: 8 33 | nbits_value: 8 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 2 37 | 12: 38 | nbits_key: 8 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 8 42 | nbits_value: 8 43 | 14: 44 | nbits_key: 8 45 | nbits_value: 8 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 2 49 | 16: 50 | nbits_key: 8 51 | nbits_value: 8 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 2 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 8 58 | 19: 59 | nbits_key: 8 60 | nbits_value: 8 61 | 20: 62 | nbits_key: 8 63 | nbits_value: 8 64 | 21: 65 | nbits_key: 8 66 | nbits_value: 8 67 | 22: 68 | nbits_key: 8 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 8 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 8 75 | nbits_value: 4 76 | 25: 77 | nbits_key: 8 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 2 81 | nbits_value: 2 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 8 85 | -------------------------------------------------------------------------------- /calibration_presets/Qwen2.5-7B-Instruct_kivi_KVTuner6_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 2 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 8 9 | nbits_value: 8 10 | 3: 11 | nbits_key: 2 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 2 15 | nbits_value: 2 16 | 5: 17 | nbits_key: 2 18 | nbits_value: 2 19 | 6: 20 | nbits_key: 8 21 | nbits_value: 8 22 | 7: 23 | nbits_key: 8 24 | nbits_value: 8 25 | 8: 26 | nbits_key: 2 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 8 30 | nbits_value: 8 31 | 10: 32 | nbits_key: 8 33 | nbits_value: 8 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 2 37 | 12: 38 | nbits_key: 2 39 | nbits_value: 2 40 | 13: 41 | nbits_key: 8 42 | nbits_value: 8 43 | 14: 44 | nbits_key: 8 45 | nbits_value: 8 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 2 49 | 16: 50 | nbits_key: 8 51 | nbits_value: 8 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 2 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 8 58 | 19: 59 | nbits_key: 8 60 | nbits_value: 8 61 | 20: 62 | nbits_key: 8 63 | nbits_value: 8 64 | 21: 65 | nbits_key: 8 66 | nbits_value: 8 67 | 22: 68 | nbits_key: 2 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 2 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 2 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 2 78 | nbits_value: 2 79 | 26: 80 | nbits_key: 2 81 | nbits_value: 2 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 8 85 | -------------------------------------------------------------------------------- /calibration_presets/Qwen2.5-7B-Instruct_pertoken_KVTuner4_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 2 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 2 10 | 3: 11 | nbits_key: 8 12 | nbits_value: 2 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 2 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 2 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 8 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 2 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 2 85 | -------------------------------------------------------------------------------- /calibration_presets/Qwen2.5-7B-Instruct_pertoken_KVTuner4_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 2 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 2 10 | 3: 11 | nbits_key: 8 12 | nbits_value: 8 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 2 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 2 19 | 6: 20 | nbits_key: 8 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 2 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 2 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 2 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 2 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 2 40 | 13: 41 | nbits_key: 8 42 | nbits_value: 8 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 2 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 2 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 2 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 2 58 | 19: 59 | nbits_key: 8 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 2 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 2 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 2 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 8 85 | -------------------------------------------------------------------------------- /calibration_presets/Qwen2.5-7B-Instruct_pertoken_KVTuner6_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 2 22 | 7: 23 | nbits_key: 8 24 | nbits_value: 8 25 | 8: 26 | nbits_key: 2 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 8 30 | nbits_value: 8 31 | 10: 32 | nbits_key: 8 33 | nbits_value: 8 34 | 11: 35 | nbits_key: 8 36 | nbits_value: 8 37 | 12: 38 | nbits_key: 8 39 | nbits_value: 8 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 4 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 8 48 | nbits_value: 8 49 | 16: 50 | nbits_key: 8 51 | nbits_value: 8 52 | 17: 53 | nbits_key: 8 54 | nbits_value: 8 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 8 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 8 66 | nbits_value: 8 67 | 22: 68 | nbits_key: 8 69 | nbits_value: 8 70 | 23: 71 | nbits_key: 8 72 | nbits_value: 8 73 | 24: 74 | nbits_key: 2 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 8 81 | nbits_value: 8 82 | 27: 83 | nbits_key: 4 84 | nbits_value: 4 85 | -------------------------------------------------------------------------------- /calibration_presets/Qwen2.5-7B-Instruct_pertoken_KVTuner6_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 2 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 8 12 | nbits_value: 2 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 8 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 8 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 8 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 8 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 8 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 8 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 8 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 8 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 8 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 8 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 8 75 | nbits_value: 4 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 8 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 2 85 | -------------------------------------------------------------------------------- /config/meta-llama_Meta-Llama-3-8B-Instruct_k8_v4._per_layer.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 4 4 | 1: 5 | nbits_key: 8 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 8 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 8 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 8 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 8 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 8 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 8 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 8 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 8 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 8 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 8 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 8 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 8 42 | nbits_value: 4 43 | 14: 44 | nbits_key: 8 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 8 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 8 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 8 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 8 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 8 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 8 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 8 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 8 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 8 75 | nbits_value: 4 76 | 25: 77 | nbits_key: 8 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 8 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 4 85 | 28: 86 | nbits_key: 8 87 | nbits_value: 4 88 | 29: 89 | nbits_key: 8 90 | nbits_value: 4 91 | 30: 92 | nbits_key: 8 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 8 96 | nbits_value: 4 -------------------------------------------------------------------------------- /calibration_presets.old/Mistral-7B-Instruct-v0.3_KVTuner4_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 2 3 | nbits_value: 4 4 | 1: 5 | nbits_key: 8 6 | nbits_value: 8 7 | 2: 8 | nbits_key: 8 9 | nbits_value: 8 10 | 3: 11 | nbits_key: 8 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 8 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 4 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 8 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 4 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 4 84 | nbits_value: 4 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 4 88 | 29: 89 | nbits_key: 4 90 | nbits_value: 4 91 | 30: 92 | nbits_key: 4 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 8 96 | nbits_value: 4 97 | -------------------------------------------------------------------------------- /calibration_presets.old/Mistral-7B-Instruct-v0.3_KVTuner4_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 4 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 2 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 2 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 2 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 2 16 | 5: 17 | nbits_key: 8 18 | nbits_value: 8 19 | 6: 20 | nbits_key: 8 21 | nbits_value: 8 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 2 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 2 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 2 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 2 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 2 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 2 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 2 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 2 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 2 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 2 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 2 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 2 82 | 27: 83 | nbits_key: 4 84 | nbits_value: 2 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 2 88 | 29: 89 | nbits_key: 4 90 | nbits_value: 2 91 | 30: 92 | nbits_key: 4 93 | nbits_value: 2 94 | 31: 95 | nbits_key: 4 96 | nbits_value: 2 97 | -------------------------------------------------------------------------------- /calibration_presets.old/Mistral-7B-Instruct-v0.3_KVTuner6_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 4 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 8 6 | nbits_value: 8 7 | 2: 8 | nbits_key: 8 9 | nbits_value: 8 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 8 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 8 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 8 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 8 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 8 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 8 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 8 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 8 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 8 42 | nbits_value: 4 43 | 14: 44 | nbits_key: 8 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 8 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 8 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 8 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 8 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 8 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 8 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 8 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 8 75 | nbits_value: 4 76 | 25: 77 | nbits_key: 8 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 8 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 4 85 | 28: 86 | nbits_key: 8 87 | nbits_value: 4 88 | 29: 89 | nbits_key: 8 90 | nbits_value: 4 91 | 30: 92 | nbits_key: 8 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 4 96 | nbits_value: 4 97 | -------------------------------------------------------------------------------- /calibration_presets.old/Mistral-7B-Instruct-v0.3_KVTuner6_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 2 3 | nbits_value: 4 4 | 1: 5 | nbits_key: 8 6 | nbits_value: 8 7 | 2: 8 | nbits_key: 8 9 | nbits_value: 8 10 | 3: 11 | nbits_key: 8 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 8 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 4 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 8 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 4 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 4 84 | nbits_value: 4 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 4 88 | 29: 89 | nbits_key: 4 90 | nbits_value: 4 91 | 30: 92 | nbits_key: 4 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 8 96 | nbits_value: 4 97 | -------------------------------------------------------------------------------- /calibration_presets.old/Meta-Llama-3.1-8B-Instruct_KVTuner4_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 4 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 2 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 2 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 2 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 2 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 2 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 2 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 2 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 4 84 | nbits_value: 2 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 4 88 | 29: 89 | nbits_key: 4 90 | nbits_value: 2 91 | 30: 92 | nbits_key: 4 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 4 96 | nbits_value: 2 97 | -------------------------------------------------------------------------------- /calibration_presets.old/Meta-Llama-3.1-8B-Instruct_KVTuner4_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 4 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 2 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 2 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 2 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 4 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 2 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 2 67 | 22: 68 | nbits_key: 2 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 2 82 | 27: 83 | nbits_key: 4 84 | nbits_value: 4 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 2 88 | 29: 89 | nbits_key: 4 90 | nbits_value: 2 91 | 30: 92 | nbits_key: 4 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 4 96 | nbits_value: 4 97 | -------------------------------------------------------------------------------- /calibration_presets.old/Meta-Llama-3.1-8B-Instruct_KVTuner6_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 4 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 8 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 8 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 8 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 8 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 8 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 8 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 8 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 8 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 8 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 8 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 8 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 8 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 8 42 | nbits_value: 4 43 | 14: 44 | nbits_key: 8 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 8 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 8 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 8 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 8 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 8 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 8 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 8 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 4 85 | 28: 86 | nbits_key: 8 87 | nbits_value: 4 88 | 29: 89 | nbits_key: 4 90 | nbits_value: 2 91 | 30: 92 | nbits_key: 8 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 8 96 | nbits_value: 4 97 | -------------------------------------------------------------------------------- /calibration_presets.old/Meta-Llama-3.1-8B-Instruct_KVTuner6_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 4 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 8 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 8 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 8 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 8 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 8 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 8 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 8 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 8 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 8 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 8 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 8 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 8 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 8 42 | nbits_value: 4 43 | 14: 44 | nbits_key: 8 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 8 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 8 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 8 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 8 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 8 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 4 76 | 25: 77 | nbits_key: 8 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 8 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 4 85 | 28: 86 | nbits_key: 8 87 | nbits_value: 4 88 | 29: 89 | nbits_key: 4 90 | nbits_value: 4 91 | 30: 92 | nbits_key: 8 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 8 96 | nbits_value: 4 97 | -------------------------------------------------------------------------------- /calibration_presets/Meta-Llama-3.1-8B-Instruct_kivi_KVTuner4_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 2 15 | nbits_value: 2 16 | 5: 17 | nbits_key: 2 18 | nbits_value: 2 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 2 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 2 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 2 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 2 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 2 40 | 13: 41 | nbits_key: 2 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 2 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 2 52 | 17: 53 | nbits_key: 2 54 | nbits_value: 2 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 2 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 2 66 | nbits_value: 2 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 2 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 2 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 2 78 | nbits_value: 2 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 2 82 | 27: 83 | nbits_key: 2 84 | nbits_value: 2 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 2 88 | 29: 89 | nbits_key: 4 90 | nbits_value: 4 91 | 30: 92 | nbits_key: 4 93 | nbits_value: 2 94 | 31: 95 | nbits_key: 4 96 | nbits_value: 4 97 | -------------------------------------------------------------------------------- /calibration_presets/Meta-Llama-3.1-8B-Instruct_kivi_KVTuner4_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 2 3 | nbits_value: 2 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 2 15 | nbits_value: 2 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 2 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 2 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 2 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 2 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 2 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 2 40 | 13: 41 | nbits_key: 8 42 | nbits_value: 4 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 2 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 2 52 | 17: 53 | nbits_key: 8 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 2 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 2 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 2 78 | nbits_value: 2 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 2 82 | 27: 83 | nbits_key: 2 84 | nbits_value: 2 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 2 88 | 29: 89 | nbits_key: 4 90 | nbits_value: 4 91 | 30: 92 | nbits_key: 4 93 | nbits_value: 2 94 | 31: 95 | nbits_key: 4 96 | nbits_value: 4 97 | -------------------------------------------------------------------------------- /calibration_presets/Meta-Llama-3.1-8B-Instruct_kivi_KVTuner6_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 2 3 | nbits_value: 2 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 8 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 8 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 8 13 | 4: 14 | nbits_key: 8 15 | nbits_value: 8 16 | 5: 17 | nbits_key: 8 18 | nbits_value: 8 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 8 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 2 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 2 54 | nbits_value: 2 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 8 66 | nbits_value: 8 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 8 72 | nbits_value: 8 73 | 24: 74 | nbits_key: 8 75 | nbits_value: 8 76 | 25: 77 | nbits_key: 8 78 | nbits_value: 8 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 8 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 4 88 | 29: 89 | nbits_key: 4 90 | nbits_value: 8 91 | 30: 92 | nbits_key: 4 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 4 96 | nbits_value: 8 97 | -------------------------------------------------------------------------------- /calibration_presets/Meta-Llama-3.1-8B-Instruct_kivi_KVTuner6_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 2 3 | nbits_value: 4 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 8 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 2 19 | 6: 20 | nbits_key: 8 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 8 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 8 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 8 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 8 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 8 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 2 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 8 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 8 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 8 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 2 54 | nbits_value: 2 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 8 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 8 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 2 67 | 22: 68 | nbits_key: 8 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 8 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 8 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 4 85 | 28: 86 | nbits_key: 8 87 | nbits_value: 4 88 | 29: 89 | nbits_key: 4 90 | nbits_value: 4 91 | 30: 92 | nbits_key: 8 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 4 96 | nbits_value: 4 97 | -------------------------------------------------------------------------------- /calibration_presets/Mistral-7B-Instruct-v0.3_kivi_KVTuner4_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 2 3 | nbits_value: 2 4 | 1: 5 | nbits_key: 2 6 | nbits_value: 2 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 2 19 | 6: 20 | nbits_key: 2 21 | nbits_value: 2 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 8 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 8 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 2 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 8 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 2 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 2 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 2 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 2 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 8 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 2 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 2 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 2 82 | 27: 83 | nbits_key: 2 84 | nbits_value: 2 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 2 88 | 29: 89 | nbits_key: 2 90 | nbits_value: 2 91 | 30: 92 | nbits_key: 4 93 | nbits_value: 2 94 | 31: 95 | nbits_key: 2 96 | nbits_value: 2 97 | -------------------------------------------------------------------------------- /calibration_presets/Mistral-7B-Instruct-v0.3_kivi_KVTuner4_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 4 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 8 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 2 18 | nbits_value: 2 19 | 6: 20 | nbits_key: 2 21 | nbits_value: 2 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 2 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 2 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 2 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 2 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 2 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 2 66 | nbits_value: 2 67 | 22: 68 | nbits_key: 2 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 2 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 2 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 2 78 | nbits_value: 2 79 | 26: 80 | nbits_key: 2 81 | nbits_value: 2 82 | 27: 83 | nbits_key: 2 84 | nbits_value: 2 85 | 28: 86 | nbits_key: 2 87 | nbits_value: 2 88 | 29: 89 | nbits_key: 2 90 | nbits_value: 2 91 | 30: 92 | nbits_key: 2 93 | nbits_value: 2 94 | 31: 95 | nbits_key: 4 96 | nbits_value: 8 97 | -------------------------------------------------------------------------------- /calibration_presets/Mistral-7B-Instruct-v0.3_kivi_KVTuner6_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 2 3 | nbits_value: 2 4 | 1: 5 | nbits_key: 2 6 | nbits_value: 2 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 8 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 8 21 | nbits_value: 8 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 8 36 | nbits_value: 8 37 | 12: 38 | nbits_key: 8 39 | nbits_value: 8 40 | 13: 41 | nbits_key: 8 42 | nbits_value: 8 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 8 48 | nbits_value: 8 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 8 54 | nbits_value: 8 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 8 60 | nbits_value: 8 61 | 20: 62 | nbits_key: 8 63 | nbits_value: 8 64 | 21: 65 | nbits_key: 8 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 8 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 8 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 8 75 | nbits_value: 4 76 | 25: 77 | nbits_key: 8 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 8 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 8 85 | 28: 86 | nbits_key: 8 87 | nbits_value: 4 88 | 29: 89 | nbits_key: 8 90 | nbits_value: 8 91 | 30: 92 | nbits_key: 8 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 2 96 | nbits_value: 2 97 | -------------------------------------------------------------------------------- /calibration_presets/Mistral-7B-Instruct-v0.3_kivi_KVTuner6_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 4 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 8 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 8 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 8 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 8 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 2 21 | nbits_value: 2 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 8 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 8 28 | 9: 29 | nbits_key: 2 30 | nbits_value: 2 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 8 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 2 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 2 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 2 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 2 49 | 16: 50 | nbits_key: 8 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 2 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 8 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 4 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 2 84 | nbits_value: 2 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 4 88 | 29: 89 | nbits_key: 2 90 | nbits_value: 2 91 | 30: 92 | nbits_key: 4 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 4 96 | nbits_value: 8 97 | -------------------------------------------------------------------------------- /calibration_presets/Mistral-7B-Instruct-v0.3_pertoken_KVTuner4_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 4 3 | nbits_value: 4 4 | 1: 5 | nbits_key: 8 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 8 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 8 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 8 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 2 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 2 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 2 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 2 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 2 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 2 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 2 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 2 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 2 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 2 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 8 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 2 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 2 82 | 27: 83 | nbits_key: 4 84 | nbits_value: 2 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 2 88 | 29: 89 | nbits_key: 4 90 | nbits_value: 2 91 | 30: 92 | nbits_key: 4 93 | nbits_value: 2 94 | 31: 95 | nbits_key: 8 96 | nbits_value: 4 97 | -------------------------------------------------------------------------------- /calibration_presets/Mistral-7B-Instruct-v0.3_pertoken_KVTuner4_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 4 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 2 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 2 10 | 3: 11 | nbits_key: 8 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 8 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 8 18 | nbits_value: 8 19 | 6: 20 | nbits_key: 8 21 | nbits_value: 8 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 2 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 2 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 2 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 2 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 2 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 2 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 2 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 2 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 2 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 2 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 8 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 2 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 2 82 | 27: 83 | nbits_key: 4 84 | nbits_value: 2 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 2 88 | 29: 89 | nbits_key: 4 90 | nbits_value: 2 91 | 30: 92 | nbits_key: 4 93 | nbits_value: 2 94 | 31: 95 | nbits_key: 8 96 | nbits_value: 4 97 | -------------------------------------------------------------------------------- /calibration_presets/Mistral-7B-Instruct-v0.3_pertoken_KVTuner6_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 2 3 | nbits_value: 4 4 | 1: 5 | nbits_key: 8 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 8 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 8 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 8 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 8 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 8 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 8 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 8 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 8 42 | nbits_value: 4 43 | 14: 44 | nbits_key: 8 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 8 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 8 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 8 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 8 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 8 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 8 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 8 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 8 75 | nbits_value: 4 76 | 25: 77 | nbits_key: 8 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 8 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 4 85 | 28: 86 | nbits_key: 8 87 | nbits_value: 4 88 | 29: 89 | nbits_key: 8 90 | nbits_value: 4 91 | 30: 92 | nbits_key: 8 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 4 96 | nbits_value: 4 97 | -------------------------------------------------------------------------------- /calibration_presets/Mistral-7B-Instruct-v0.3_pertoken_KVTuner6_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 4 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 2 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 2 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 2 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 2 22 | 7: 23 | nbits_key: 8 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 8 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 8 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 8 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 8 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 8 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 8 42 | nbits_value: 4 43 | 14: 44 | nbits_key: 8 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 8 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 8 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 8 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 8 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 8 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 8 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 8 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 8 75 | nbits_value: 4 76 | 25: 77 | nbits_key: 8 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 8 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 4 85 | 28: 86 | nbits_key: 8 87 | nbits_value: 4 88 | 29: 89 | nbits_key: 8 90 | nbits_value: 4 91 | 30: 92 | nbits_key: 8 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 4 96 | nbits_value: 4 97 | -------------------------------------------------------------------------------- /calibration_presets/Meta-Llama-3.1-8B-Instruct_pertoken_KVTuner4_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 4 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 2 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 2 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 2 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 2 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 2 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 2 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 2 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 4 84 | nbits_value: 2 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 4 88 | 29: 89 | nbits_key: 4 90 | nbits_value: 2 91 | 30: 92 | nbits_key: 4 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 4 96 | nbits_value: 2 97 | -------------------------------------------------------------------------------- /calibration_presets/Meta-Llama-3.1-8B-Instruct_pertoken_KVTuner4_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 2 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 2 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 2 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 2 16 | 5: 17 | nbits_key: 8 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 8 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 2 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 8 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 2 58 | 19: 59 | nbits_key: 2 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 8 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 2 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 2 79 | 26: 80 | nbits_key: 8 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 4 84 | nbits_value: 2 85 | 28: 86 | nbits_key: 8 87 | nbits_value: 4 88 | 29: 89 | nbits_key: 4 90 | nbits_value: 2 91 | 30: 92 | nbits_key: 4 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 4 96 | nbits_value: 2 97 | -------------------------------------------------------------------------------- /calibration_presets/Meta-Llama-3.1-8B-Instruct_pertoken_KVTuner6_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 4 3 | nbits_value: 2 4 | 1: 5 | nbits_key: 8 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 8 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 8 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 8 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 8 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 8 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 8 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 8 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 8 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 8 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 8 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 8 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 8 42 | nbits_value: 4 43 | 14: 44 | nbits_key: 8 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 8 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 8 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 8 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 2 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 8 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 8 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 2 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 8 72 | nbits_value: 8 73 | 24: 74 | nbits_key: 8 75 | nbits_value: 8 76 | 25: 77 | nbits_key: 8 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 8 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 4 85 | 28: 86 | nbits_key: 8 87 | nbits_value: 4 88 | 29: 89 | nbits_key: 8 90 | nbits_value: 8 91 | 30: 92 | nbits_key: 8 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 8 96 | nbits_value: 4 97 | -------------------------------------------------------------------------------- /calibration_presets/Meta-Llama-3.1-8B-Instruct_pertoken_KVTuner6_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 4 3 | nbits_value: 4 4 | 1: 5 | nbits_key: 8 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 8 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 8 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 8 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 8 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 8 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 8 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 8 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 8 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 8 42 | nbits_value: 4 43 | 14: 44 | nbits_key: 8 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 8 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 8 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 8 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 8 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 8 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 8 75 | nbits_value: 4 76 | 25: 77 | nbits_key: 8 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 4 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 4 88 | 29: 89 | nbits_key: 8 90 | nbits_value: 4 91 | 30: 92 | nbits_key: 8 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 8 96 | nbits_value: 4 97 | -------------------------------------------------------------------------------- /calibration_presets.old/Qwen2.5-3B-Instruct-AWQ_KVTuner4_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 2 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 2 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 2 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 2 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 2 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 2 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 2 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 2 31 | 10: 32 | nbits_key: 8 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 2 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 2 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 2 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 2 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 2 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 2 58 | 19: 59 | nbits_key: 8 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 8 75 | nbits_value: 4 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 2 79 | 26: 80 | nbits_key: 8 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 2 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 2 88 | 29: 89 | nbits_key: 8 90 | nbits_value: 2 91 | 30: 92 | nbits_key: 4 93 | nbits_value: 2 94 | 31: 95 | nbits_key: 4 96 | nbits_value: 2 97 | 32: 98 | nbits_key: 4 99 | nbits_value: 2 100 | 33: 101 | nbits_key: 8 102 | nbits_value: 4 103 | 34: 104 | nbits_key: 4 105 | nbits_value: 2 106 | 35: 107 | nbits_key: 4 108 | nbits_value: 4 109 | -------------------------------------------------------------------------------- /calibration_presets.old/Qwen2.5-3B-Instruct-AWQ_KVTuner4_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 2 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 2 7 | 2: 8 | nbits_key: 8 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 2 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 2 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 2 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 2 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 2 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 2 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 2 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 2 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 8 45 | nbits_value: 4 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 2 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 2 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 2 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 8 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 4 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 2 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 2 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 2 88 | 29: 89 | nbits_key: 8 90 | nbits_value: 2 91 | 30: 92 | nbits_key: 4 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 4 96 | nbits_value: 4 97 | 32: 98 | nbits_key: 4 99 | nbits_value: 2 100 | 33: 101 | nbits_key: 4 102 | nbits_value: 4 103 | 34: 104 | nbits_key: 4 105 | nbits_value: 4 106 | 35: 107 | nbits_key: 8 108 | nbits_value: 4 109 | -------------------------------------------------------------------------------- /calibration_presets.old/Qwen2.5-3B-Instruct-AWQ_KVTuner6_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 2 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 2 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 4 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 8 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 2 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 4 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 2 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 4 88 | 29: 89 | nbits_key: 8 90 | nbits_value: 2 91 | 30: 92 | nbits_key: 8 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 8 96 | nbits_value: 4 97 | 32: 98 | nbits_key: 4 99 | nbits_value: 4 100 | 33: 101 | nbits_key: 4 102 | nbits_value: 4 103 | 34: 104 | nbits_key: 8 105 | nbits_value: 4 106 | 35: 107 | nbits_key: 4 108 | nbits_value: 2 109 | -------------------------------------------------------------------------------- /calibration_presets.old/Qwen2.5-3B-Instruct-AWQ_KVTuner6_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 2 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 2 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 2 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 2 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 2 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 2 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 2 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 2 31 | 10: 32 | nbits_key: 8 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 2 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 2 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 8 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 2 58 | 19: 59 | nbits_key: 8 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 8 66 | nbits_value: 8 67 | 22: 68 | nbits_key: 8 69 | nbits_value: 8 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 8 75 | nbits_value: 4 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 8 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 2 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 4 88 | 29: 89 | nbits_key: 8 90 | nbits_value: 2 91 | 30: 92 | nbits_key: 8 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 8 96 | nbits_value: 4 97 | 32: 98 | nbits_key: 4 99 | nbits_value: 4 100 | 33: 101 | nbits_key: 8 102 | nbits_value: 4 103 | 34: 104 | nbits_key: 8 105 | nbits_value: 4 106 | 35: 107 | nbits_key: 4 108 | nbits_value: 2 109 | -------------------------------------------------------------------------------- /calibration_presets/Qwen2.5-3B-Instruct_kivi_KVTuner4_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 4 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 8 7 | 2: 8 | nbits_key: 2 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 2 13 | 4: 14 | nbits_key: 2 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 2 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 2 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 2 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 2 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 2 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 2 37 | 12: 38 | nbits_key: 2 39 | nbits_value: 2 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 2 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 2 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 2 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 2 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 2 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 2 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 2 82 | 27: 83 | nbits_key: 4 84 | nbits_value: 2 85 | 28: 86 | nbits_key: 2 87 | nbits_value: 2 88 | 29: 89 | nbits_key: 4 90 | nbits_value: 2 91 | 30: 92 | nbits_key: 4 93 | nbits_value: 2 94 | 31: 95 | nbits_key: 4 96 | nbits_value: 2 97 | 32: 98 | nbits_key: 4 99 | nbits_value: 2 100 | 33: 101 | nbits_key: 4 102 | nbits_value: 2 103 | 34: 104 | nbits_key: 4 105 | nbits_value: 4 106 | 35: 107 | nbits_key: 4 108 | nbits_value: 4 109 | -------------------------------------------------------------------------------- /calibration_presets/Qwen2.5-3B-Instruct_kivi_KVTuner4_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 8 6 | nbits_value: 8 7 | 2: 8 | nbits_key: 2 9 | nbits_value: 2 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 2 15 | nbits_value: 2 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 2 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 2 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 2 39 | nbits_value: 2 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 4 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 2 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 2 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 2 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 2 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 2 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 4 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 2 82 | 27: 83 | nbits_key: 4 84 | nbits_value: 2 85 | 28: 86 | nbits_key: 2 87 | nbits_value: 2 88 | 29: 89 | nbits_key: 4 90 | nbits_value: 2 91 | 30: 92 | nbits_key: 4 93 | nbits_value: 2 94 | 31: 95 | nbits_key: 4 96 | nbits_value: 2 97 | 32: 98 | nbits_key: 4 99 | nbits_value: 4 100 | 33: 101 | nbits_key: 4 102 | nbits_value: 4 103 | 34: 104 | nbits_key: 2 105 | nbits_value: 4 106 | 35: 107 | nbits_key: 2 108 | nbits_value: 4 109 | -------------------------------------------------------------------------------- /calibration_presets/Qwen2.5-3B-Instruct_kivi_KVTuner6_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 4 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 8 7 | 2: 8 | nbits_key: 2 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 2 13 | 4: 14 | nbits_key: 2 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 2 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 2 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 2 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 2 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 2 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 2 37 | 12: 38 | nbits_key: 2 39 | nbits_value: 2 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 2 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 2 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 2 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 2 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 2 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 2 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 2 82 | 27: 83 | nbits_key: 4 84 | nbits_value: 2 85 | 28: 86 | nbits_key: 2 87 | nbits_value: 2 88 | 29: 89 | nbits_key: 4 90 | nbits_value: 2 91 | 30: 92 | nbits_key: 4 93 | nbits_value: 2 94 | 31: 95 | nbits_key: 4 96 | nbits_value: 2 97 | 32: 98 | nbits_key: 4 99 | nbits_value: 2 100 | 33: 101 | nbits_key: 4 102 | nbits_value: 2 103 | 34: 104 | nbits_key: 4 105 | nbits_value: 4 106 | 35: 107 | nbits_key: 4 108 | nbits_value: 4 109 | -------------------------------------------------------------------------------- /calibration_presets/Qwen2.5-3B-Instruct_kivi_KVTuner6_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 8 6 | nbits_value: 8 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 4 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 8 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 8 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 8 22 | 7: 23 | nbits_key: 8 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 2 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 2 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 8 37 | 12: 38 | nbits_key: 2 39 | nbits_value: 2 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 8 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 2 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 2 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 2 55 | 18: 56 | nbits_key: 4 57 | nbits_value: 2 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 2 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 8 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 8 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 2 82 | 27: 83 | nbits_key: 4 84 | nbits_value: 2 85 | 28: 86 | nbits_key: 2 87 | nbits_value: 2 88 | 29: 89 | nbits_key: 4 90 | nbits_value: 2 91 | 30: 92 | nbits_key: 4 93 | nbits_value: 2 94 | 31: 95 | nbits_key: 4 96 | nbits_value: 2 97 | 32: 98 | nbits_key: 8 99 | nbits_value: 4 100 | 33: 101 | nbits_key: 8 102 | nbits_value: 4 103 | 34: 104 | nbits_key: 2 105 | nbits_value: 4 106 | 35: 107 | nbits_key: 2 108 | nbits_value: 4 109 | -------------------------------------------------------------------------------- /calibration_presets/Qwen2.5-3B-Instruct_pertoken_KVTuner4_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 4 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 2 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 2 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 2 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 2 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 2 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 2 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 2 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 2 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 2 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 2 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 2 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 2 52 | 17: 53 | nbits_key: 4 54 | nbits_value: 2 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 8 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 2 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 4 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 2 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 8 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 2 88 | 29: 89 | nbits_key: 8 90 | nbits_value: 8 91 | 30: 92 | nbits_key: 4 93 | nbits_value: 2 94 | 31: 95 | nbits_key: 4 96 | nbits_value: 2 97 | 32: 98 | nbits_key: 4 99 | nbits_value: 2 100 | 33: 101 | nbits_key: 4 102 | nbits_value: 4 103 | 34: 104 | nbits_key: 4 105 | nbits_value: 2 106 | 35: 107 | nbits_key: 4 108 | nbits_value: 2 109 | -------------------------------------------------------------------------------- /calibration_presets/Qwen2.5-3B-Instruct_pertoken_KVTuner4_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 2 4 | 1: 5 | nbits_key: 4 6 | nbits_value: 2 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 2 10 | 3: 11 | nbits_key: 4 12 | nbits_value: 2 13 | 4: 14 | nbits_key: 4 15 | nbits_value: 2 16 | 5: 17 | nbits_key: 4 18 | nbits_value: 2 19 | 6: 20 | nbits_key: 4 21 | nbits_value: 2 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 4 25 | 8: 26 | nbits_key: 4 27 | nbits_value: 2 28 | 9: 29 | nbits_key: 4 30 | nbits_value: 2 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 4 37 | 12: 38 | nbits_key: 4 39 | nbits_value: 2 40 | 13: 41 | nbits_key: 4 42 | nbits_value: 2 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 4 48 | nbits_value: 2 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 4 52 | 17: 53 | nbits_key: 8 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 4 63 | nbits_value: 2 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 4 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 4 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 4 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 4 88 | 29: 89 | nbits_key: 8 90 | nbits_value: 4 91 | 30: 92 | nbits_key: 8 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 8 96 | nbits_value: 4 97 | 32: 98 | nbits_key: 4 99 | nbits_value: 4 100 | 33: 101 | nbits_key: 4 102 | nbits_value: 4 103 | 34: 104 | nbits_key: 8 105 | nbits_value: 4 106 | 35: 107 | nbits_key: 4 108 | nbits_value: 2 109 | -------------------------------------------------------------------------------- /calibration_presets/Qwen2.5-3B-Instruct_pertoken_KVTuner6_0.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 8 4 | 1: 5 | nbits_key: 8 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 2 10 | 3: 11 | nbits_key: 8 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 8 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 8 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 8 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 2 25 | 8: 26 | nbits_key: 8 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 8 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 8 33 | nbits_value: 4 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 2 37 | 12: 38 | nbits_key: 8 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 8 42 | nbits_value: 4 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 8 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 2 52 | 17: 53 | nbits_key: 8 54 | nbits_value: 4 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 4 58 | 19: 59 | nbits_key: 8 60 | nbits_value: 4 61 | 20: 62 | nbits_key: 8 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 4 66 | nbits_value: 2 67 | 22: 68 | nbits_key: 4 69 | nbits_value: 2 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 8 75 | nbits_value: 4 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 2 79 | 26: 80 | nbits_key: 8 81 | nbits_value: 4 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 4 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 2 88 | 29: 89 | nbits_key: 8 90 | nbits_value: 4 91 | 30: 92 | nbits_key: 8 93 | nbits_value: 4 94 | 31: 95 | nbits_key: 8 96 | nbits_value: 4 97 | 32: 98 | nbits_key: 4 99 | nbits_value: 2 100 | 33: 101 | nbits_key: 8 102 | nbits_value: 4 103 | 34: 104 | nbits_key: 8 105 | nbits_value: 4 106 | 35: 107 | nbits_key: 4 108 | nbits_value: 2 109 | -------------------------------------------------------------------------------- /calibration_presets/Qwen2.5-3B-Instruct_pertoken_KVTuner6_1.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | nbits_key: 8 3 | nbits_value: 4 4 | 1: 5 | nbits_key: 8 6 | nbits_value: 4 7 | 2: 8 | nbits_key: 4 9 | nbits_value: 2 10 | 3: 11 | nbits_key: 8 12 | nbits_value: 4 13 | 4: 14 | nbits_key: 8 15 | nbits_value: 4 16 | 5: 17 | nbits_key: 8 18 | nbits_value: 4 19 | 6: 20 | nbits_key: 8 21 | nbits_value: 4 22 | 7: 23 | nbits_key: 4 24 | nbits_value: 2 25 | 8: 26 | nbits_key: 8 27 | nbits_value: 4 28 | 9: 29 | nbits_key: 8 30 | nbits_value: 4 31 | 10: 32 | nbits_key: 4 33 | nbits_value: 2 34 | 11: 35 | nbits_key: 4 36 | nbits_value: 2 37 | 12: 38 | nbits_key: 8 39 | nbits_value: 4 40 | 13: 41 | nbits_key: 8 42 | nbits_value: 4 43 | 14: 44 | nbits_key: 4 45 | nbits_value: 2 46 | 15: 47 | nbits_key: 8 48 | nbits_value: 4 49 | 16: 50 | nbits_key: 4 51 | nbits_value: 2 52 | 17: 53 | nbits_key: 8 54 | nbits_value: 8 55 | 18: 56 | nbits_key: 8 57 | nbits_value: 8 58 | 19: 59 | nbits_key: 4 60 | nbits_value: 2 61 | 20: 62 | nbits_key: 8 63 | nbits_value: 4 64 | 21: 65 | nbits_key: 8 66 | nbits_value: 4 67 | 22: 68 | nbits_key: 8 69 | nbits_value: 4 70 | 23: 71 | nbits_key: 4 72 | nbits_value: 2 73 | 24: 74 | nbits_key: 4 75 | nbits_value: 2 76 | 25: 77 | nbits_key: 4 78 | nbits_value: 2 79 | 26: 80 | nbits_key: 4 81 | nbits_value: 2 82 | 27: 83 | nbits_key: 8 84 | nbits_value: 8 85 | 28: 86 | nbits_key: 4 87 | nbits_value: 2 88 | 29: 89 | nbits_key: 8 90 | nbits_value: 8 91 | 30: 92 | nbits_key: 8 93 | nbits_value: 8 94 | 31: 95 | nbits_key: 8 96 | nbits_value: 8 97 | 32: 98 | nbits_key: 4 99 | nbits_value: 2 100 | 33: 101 | nbits_key: 4 102 | nbits_value: 2 103 | 34: 104 | nbits_key: 8 105 | nbits_value: 8 106 | 35: 107 | nbits_key: 4 108 | nbits_value: 2 109 | -------------------------------------------------------------------------------- /flexible_quant_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from flexible_quant.flexible_quantized_cache import FlexibleQuantizedCacheConfig, FlexibleHQQQuantizedCache, FlexibleVanillaQuantizedCache 3 | from transformers import AutoTokenizer, AutoModelForCausalLM, QuantizedCacheConfig, HQQQuantizedCache, QuantoQuantizedCache 4 | from datasets import load_dataset 5 | 6 | # CACHE_DIR = "./models_storage" 7 | # model_name = 'Qwen/Qwen2.5-3B-Instruct-AWQ' 8 | # model_name = 'Qwen/Qwen2.5-7B-Instruct' 9 | model_name = 'meta-llama/Meta-Llama-3-8B' 10 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).cuda() 11 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=True) 12 | 13 | # Quanto from huggingface is not working at all 14 | # ValueError("shift must be specified for qtypes lower than 8-bit") 15 | 16 | # cache_config = FlexibleQuantizedCacheConfig(nbits_key=4, nbits_value=4, asym=True, axis_key=0, axis_value=0, device='cuda', per_layer_config=True, per_layer_config_path='config/meta-llama_Meta-Llama-3-8B-Instruct_k8_v4_per_layer.yaml') 17 | cache_config = FlexibleQuantizedCacheConfig(nbits_key=4, nbits_value=4, asym=True, axis_key=0, axis_value=0, device='cuda', q_group_size=-1) 18 | # past_key_values = FlexibleHQQQuantizedCache(cache_config=cache_config) # it seems in HQQ, 0 for per-token and 1 for per-channel 19 | past_key_values = FlexibleVanillaQuantizedCache(cache_config=cache_config) 20 | 21 | # cache_config = QuantizedCacheConfig(nbits=4, axis_key=0, axis_value=0, device='cuda') 22 | # past_key_values = QuantoQuantizedCache(cache_config=cache_config) 23 | 24 | dataset = load_dataset('gsm8k', 'main') 25 | 26 | prompt = '' 27 | for i in range(5): 28 | prompt += 'Question: ' + dataset['train'][i]['question'] + '\nAnswer: ' + dataset['train'][i]['answer'] + '\n' 29 | prompt += "Question: John takes care of 10 dogs. Each dog takes .5 hours a day to walk and take care of their business. How many hours a week does he spend taking care of dogs?" 30 | inputs = tokenizer(prompt, return_tensors="pt").input_ids.cuda() 31 | print('======') 32 | 33 | outputs = model.generate(inputs, past_key_values=past_key_values, use_cache=True, max_new_tokens=256) 34 | 35 | # config_str = f"# prompt tokens: {inputs.shape[1]}, K bit: {config.k_bits}, v_bits: {config.v_bits}, group_size: {config.group_size}, residual_length: {config.residual_length}" 36 | config_str = f"# prompt tokens: {inputs.shape[1]}" 37 | 38 | print(prompt + "\n" + "=" * 10 + f'\n{config_str}\n' + "=" * 10 + "\nExample Output:") 39 | print(tokenizer.decode(outputs[0].tolist()[inputs.shape[1]:], skip_special_tokens=True)) 40 | -------------------------------------------------------------------------------- /helper_scripts/sh_gen_gaokaobench.py: -------------------------------------------------------------------------------- 1 | # tasks in total: ceval-valid,mmlu,triviaqa,race,truthfulqa,gsm8k 2 | # models in total: 3 | # Qwen/Qwen2.5-3B-Instruct-AWQ 4 | # meta-llama/Meta-Llama-3-8B-Instruct,Qwen/Qwen2.5-7B-Instruct 5 | # mistralai/Mistral-7B-v0.3,Qwen/Qwen2.5-Math-7B-Instruct 6 | # mistralai/Mistral-7B-v0.3,Qwen/Qwen2.5-3B-Instruct-AWQ 7 | 8 | # only test: meta-llama/Meta-Llama-3-8B-Instruct, Qwen/Qwen2.5-7B-Instruct, mistralai/Mistral-7B-v0.3 9 | 10 | 11 | command_template_vanliia = 'python3 gaokao_bench_obj.py --device cuda:1 --model_name {0} --k_bits {1} --v_bits {2} --residual_length 32 --group_size 32 --quantizer Vanilla --axis_key 1 --axis_value 0' 12 | command_template_hqq = 'python3 gaokao_bench_obj.py --device cuda:0 --model_name {0} --k_bits {1} --v_bits {2} --residual_length 32 --group_size 32 --quantizer HQQ --axis_key 1 --axis_value 0' 13 | 14 | log_filename = "GAOKAO-Bench_{0}_Q_{1}_k{2}_v{3}.log" 15 | 16 | kv_config = [ 17 | [8, 8], 18 | [8, 4], 19 | [8, 2], 20 | [4, 8], 21 | [4, 4], 22 | [4, 2], 23 | [2, 4], 24 | [2, 2], 25 | ] 26 | nshots = [0, 4, 8, 16] 27 | import argparse 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--models', type=str, required=True) 31 | parser.add_argument('--filename', type=str, required=True) 32 | args = parser.parse_args() 33 | 34 | models = args.models.split(',') 35 | out_filename = args.filename 36 | 37 | out_filename_gpu0 = out_filename.replace('.sh', '_gpu0.sh') 38 | out_filename_gpu1 = out_filename.replace('.sh', '_gpu1.sh') 39 | 40 | with open(out_filename_gpu0, 'w+') as f0, open(out_filename_gpu1, 'w+') as f1: 41 | f0, f1 = f1, f0 42 | # run vanilla on gpu0, hqq on gpu1 43 | f0.write("export NCCL_IB_DISABLE=1\nexport NCCL_P2P_DISABLE=1\n\n") 44 | f1.write("export NCCL_IB_DISABLE=1\nexport NCCL_P2P_DISABLE=1\n\n") 45 | for model in models: 46 | for kv in kv_config: 47 | nbits_key, nbits_value = kv 48 | command_vanilla = command_template_vanliia.format(model, nbits_key, nbits_value) 49 | command_hqq = command_template_hqq.format(model, nbits_key, nbits_value) 50 | logfile_vanilla = log_filename.format(model.replace('/', '_'), 'Vanilla', nbits_key, nbits_value) 51 | logfile_hqq = log_filename.format(model.replace('/', '_'), 'HQQ', nbits_key, nbits_value) 52 | f0.write(command_vanilla + ' | tee ' + logfile_vanilla) 53 | f0.write('\n') 54 | f0.write('\n') 55 | f1.write(command_hqq + ' | tee ' + logfile_hqq) 56 | f1.write('\n') 57 | f1.write('\n') 58 | f0.close() 59 | f1.close() 60 | 61 | import os 62 | os.system('chmod +x {}'.format(out_filename_gpu0)) 63 | os.system('chmod +x {}'.format(out_filename_gpu1)) 64 | -------------------------------------------------------------------------------- /flexible_quant/flexible_quant/vanilla_quantizer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | 5 | def quant_sym(x: torch.tensor, scaling: torch.tensor, nbits: int): 6 | q_max, q_min = 2 ** (nbits - 1) - 1, -2 ** (nbits - 1) 7 | return torch.round(x / scaling.unsqueeze(1)).clip(q_min, q_max).to(torch.int8) 8 | 9 | def dequant_sym(x: torch.tensor, scaling: torch.tensor, target_dtype: torch.dtype): 10 | return x * scaling.unsqueeze(1).to(target_dtype) 11 | 12 | def quant_asym(x: torch.tensor, scaling: torch.tensor, zeros: torch.tensor, nbits: int): 13 | q_max, q_min = 2 ** (nbits - 1) - 1, -2 ** (nbits - 1) 14 | return (torch.round(x / scaling.unsqueeze(1) - zeros.unsqueeze(1))).clip(q_min, q_max).to(torch.int8) 15 | 16 | def dequant_asym(x: torch.tensor, scaling: torch.tensor, zeros: torch.tensor, target_dtype: torch.dtype): 17 | return (x + zeros.unsqueeze(1)) * scaling.unsqueeze(1).to(target_dtype) 18 | 19 | 20 | class VanillaQuantizeMeta: 21 | def __init__(self, nbits, asym, compute_dtype): 22 | self.nbits = nbits 23 | # self.group_size = group_size 24 | # self.axis = axis # 1 for per-channel, 0 for per-token 25 | self.asym = asym 26 | self.compute_dtype = compute_dtype 27 | 28 | 29 | class VanillaQuantizedTensor: 30 | def __init__(self, tensor, scaling, zeros, original_shape, axis, meta: VanillaQuantizeMeta): 31 | self.tensor = tensor 32 | self.scaling = scaling 33 | self.zeros = zeros 34 | self.original_shape = original_shape 35 | self.axis = axis 36 | self.meta = meta 37 | 38 | def dequantize(self): 39 | if self.meta.asym: 40 | dequant = dequant_asym(self.tensor, self.scaling, self.zeros, self.meta.compute_dtype) 41 | else: 42 | dequant = dequant_sym(self.tensor, self.scaling, self.meta.compute_dtype) 43 | dequant = dequant.view(self.original_shape) 44 | if self.axis == 1: 45 | max_dim = len(self.original_shape) - 1 46 | dequant = dequant.transpose(max_dim - 1, max_dim) 47 | return dequant 48 | 49 | class VanillaQuantizer: 50 | def __init__(self, nbits, asym, compute_dtype): 51 | self.meta = VanillaQuantizeMeta(nbits, asym, compute_dtype) 52 | 53 | def quantize(self, tensor, q_group_size, axis): 54 | if axis == 1: 55 | max_dim = len(tensor.shape) - 1 56 | tensor = tensor.transpose(max_dim - 1, max_dim) 57 | if q_group_size == -1: 58 | assert axis == 0 # must be per-token 59 | q_group_size = tensor.shape[-1] # take the last dimension 60 | rs = tensor.reshape(-1, q_group_size) 61 | 62 | q_max, q_min = 2 ** (self.meta.nbits - 1) - 1, -2 ** (self.meta.nbits - 1) 63 | 64 | if self.meta.asym: 65 | _max, _min = rs.max(dim=1).values, rs.min(dim=1).values 66 | scale = (_max - _min).clamp(min=1e-5).div(q_max - q_min) 67 | zeros = (_min / scale).round() - q_min 68 | quant = quant_asym(rs, scale, zeros, self.meta.nbits) 69 | else: 70 | scale = rs.abs().max(dim=1).values.clamp(min=1e-5).div(q_max) 71 | zeros = None 72 | quant = quant_sym(rs, scale, self.meta.nbits) 73 | 74 | return VanillaQuantizedTensor(quant, scale, zeros, tensor.shape, axis, self.meta) 75 | -------------------------------------------------------------------------------- /helper_scripts/sh_gen_baseline.py: -------------------------------------------------------------------------------- 1 | model_args_template_pertoken = "pretrained={},nbits_key={},nbits_value={},residual_length=0,q_group_size=-1,axis_key=0,axis_value=0,trust_remote_code=True,dtype=bfloat16,force_quant=False,quantilizer=vanilla" 2 | model_args_template_kivi = "pretrained={},nbits_key={},nbits_value={},residual_length=32,q_group_size=32,axis_key=1,axis_value=0,trust_remote_code=True,dtype=bfloat16,force_quant=False,quantilizer=vanilla" 3 | 4 | # per_layer_config_path is yaml file path 5 | 6 | command_fewshot_template = '''accelerate launch -m lm_eval --model hf-quant \\ 7 | --model_args {} \\ 8 | --tasks {} \\ 9 | --batch_size 1 \\ 10 | --num_fewshot {} \\ 11 | --limit 200 \\ 12 | --output_path lmeval_results/{} \\ 13 | | tee {}.log''' 14 | 15 | TASKS = [ 16 | { 17 | 'filename': 'gsm8k', 18 | 'tasks': ['gsm8k'], 19 | 'nshots': [4], 20 | }, 21 | ] 22 | 23 | STANDARD_KV_CONFIG = ['kv8', 'k8v4', 'k4v8', 'k8v2', 'kv4', 'k4v2', 'k2v4', 'kv2'] 24 | 25 | 26 | def extract_kv_config(config_str: str): 27 | if len(config_str) == 3: 28 | return int(config_str[2]), int(config_str[2]) 29 | return int(config_str[1]), int(config_str[3]) 30 | 31 | import argparse 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--models', type=str, required=True) 35 | parser.add_argument('--filename', type=str, required=False, default='run.sh') 36 | parser.add_argument('--quant_scheme', type=str, required=False, default='pertoken') 37 | 38 | args = parser.parse_args() 39 | quant_scheme = args.quant_scheme 40 | model_args_template = model_args_template_pertoken if quant_scheme == 'pertoken' else model_args_template_kivi 41 | 42 | models = args.models.split(',') 43 | out_filename = args.filename 44 | 45 | # naming scheme: {model}_pertoken_{task_filename}_{nshot}_{kv_config or bf16 or id}. 46 | 47 | def get_filename(model, quant_scheme, task_filename, nshots, kvconfig_or_bf16_or_id): 48 | return f'{model}_{quant_scheme}_{task_filename}_{nshots}_{kvconfig_or_bf16_or_id}' 49 | 50 | tot_commands = 0 51 | tot_time = 0 52 | with open(out_filename, 'w+') as f: 53 | f.write("export NCCL_IB_DISABLE=1\nexport NCCL_P2P_DISABLE=1\nexport HF_ALLOW_CODE_EVAL=1\nexport TRANSFORMERS_CACHE=./models_storage\n\n") 54 | for model in models: 55 | filename_model = model.replace('/', '_') + f'_{quant_scheme}_baseline_limit200' 56 | f.write(f'# ======== {model} standard kv configs ========\n') 57 | for kv_config in STANDARD_KV_CONFIG: 58 | for task_preset in TASKS: 59 | for nshot in task_preset['nshots']: 60 | nbits_key, nbits_value = extract_kv_config(kv_config) 61 | filename = get_filename(filename_model, quant_scheme, task_preset['filename'], nshot, kv_config) 62 | model_arg = model_args_template.format(model, nbits_key, nbits_value) 63 | command = command_fewshot_template.format(model_arg, ','.join(task_preset['tasks']), nshot, filename, filename) 64 | f.write(command) 65 | tot_commands += 1 66 | tot_time += 10 67 | f.write('\n') 68 | f.write('\n') 69 | 70 | import os 71 | os.system('chmod +x {}'.format(out_filename)) 72 | 73 | print(f'Generated {tot_commands} commands in {out_filename}.') 74 | print(f'Estimated total running time (on dual RTX 4090): {tot_time} minutes. aka {tot_time/60} hours.') -------------------------------------------------------------------------------- /helper_scripts/sh_gen_lmeval.py: -------------------------------------------------------------------------------- 1 | # tasks in total: ceval-valid,mmlu,triviaqa,race,truthfulqa,gsm8k 2 | # models in total: 3 | # meta-llama/Llama-2-7b-chat-hf,Qwen/Qwen2.5-3B-Instruct-AWQ 4 | # meta-llama/Meta-Llama-3-8B-Instruct,Qwen/Qwen2.5-7B-Instruct 5 | # mistralai/Mistral-7B-v0.3,Qwen/Qwen2.5-Math-7B-Instruct 6 | # Qwen/Qwen2.5-7B-Instruct,Qwen/Qwen2.5-Math-7B-Instruct 7 | 8 | # FOR WS-9: meta-llama/Meta-Llama-3-8B-Instruct,mistralai/Mistral-7B-v0.3 9 | # FOR WS-13: Qwen/Qwen2.5-7B-Instruct,Qwen/Qwen2.5-Math-7B-Instruct,Qwen/Qwen2.5-3B-Instruct-AWQ 10 | 11 | model_args_template = "{},nbits_key={},nbits_value={},residual_length=32,q_group_size=32,axis_key=0,axis_value=1,trust_remote_code=True,dtype=bfloat16,force_quant=True,quantilizer=vanilla" 12 | 13 | # model_args_template = "{},nbits_key={},nbits_value={},residual_length=0,q_group_size=-1,axis_key=0,axis_value=0,trust_remote_code=True,dtype=bfloat16,force_quant=True,quantilizer=vanilla" 14 | 15 | model_args_template_bf16 = "{},trust_remote_code=True,dtype=bfloat16" 16 | 17 | command_template = '''accelerate launch -m lm_eval --model hf-quant \\ 18 | --model_args pretrained={} \\ 19 | --tasks {} \\ 20 | --batch_size {} \\ 21 | --output_path lmeval_results/{} \\ 22 | | tee {}.log''' 23 | 24 | command_fewshot_template = '''accelerate launch -m lm_eval --model hf-quant \\ 25 | --model_args pretrained={} \\ 26 | --tasks {} \\ 27 | --batch_size {} \\ 28 | --num_fewshot {} \\ 29 | --output_path lmeval_results/{} \\ 30 | | tee {}.log''' 31 | 32 | 33 | kv_config = [ 34 | [8, 8], 35 | [8, 4], 36 | [8, 2], 37 | [4, 8], 38 | [4, 4], 39 | [4, 2], 40 | [2, 4], 41 | [2, 2], 42 | ] 43 | nshots = [0, 4, 8, 16] 44 | import argparse 45 | 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument('--tasks', type=str, required=False, default="ceval-valid,mmlu,triviaqa,race,truthfulqa,gsm8k") 48 | parser.add_argument('--models', type=str, required=True) 49 | parser.add_argument('--filename', type=str, required=False, default='run.sh') 50 | parser.add_argument('--bf16', action='store_true', default=False) 51 | 52 | args = parser.parse_args() 53 | bf16 = args.bf16 54 | 55 | tasks = args.tasks.split(',') 56 | models = args.models.split(',') 57 | out_filename = args.filename 58 | 59 | with open(out_filename, 'w+') as f: 60 | f.write("export NCCL_IB_DISABLE=1\nexport NCCL_P2P_DISABLE=1\n\n") 61 | for model in models: 62 | # batch_size = 16 63 | # if '7B' in model or '8B' in model or '7b' in model or '8b' in model: 64 | # batch_size = 4 65 | batch_size = 1 66 | filename_model = model.replace('/', '_') 67 | for kv in kv_config: 68 | nbits_key, nbits_value = kv 69 | model_args = model_args_template.format(model, nbits_key, nbits_value) 70 | if bf16: 71 | model_args = model_args_template_bf16.format(model) 72 | task_fewshots = [task for task in tasks if task == 'gsm8k'] 73 | task_others = [task for task in tasks if task != 'gsm8k'] 74 | if task_others: 75 | task_others_str = ','.join(task_others) 76 | filename = f'{filename_model}_others_k{nbits_key}_v{nbits_value}' 77 | if bf16: 78 | filename = f'{filename_model}_others_bf16' 79 | command = command_template.format(model_args, task_others_str, batch_size, filename, filename) 80 | if bf16: 81 | command = command.replace('--model hf-quant', '--model hf') 82 | f.write(command) 83 | f.write('\n') 84 | f.write('\n') 85 | for task in task_fewshots: 86 | for nshot in nshots: 87 | filename = f'{filename_model}_{task}_k{nbits_key}_v{nbits_value}_n{nshot}' 88 | if bf16: 89 | filename = f'{filename_model}_{task}_bf16_n{nshot}' 90 | # command = command_template.format(model_args + (',n-shot={}'.format(nshot)), task, batch_size, filename, filename) 91 | command = command_fewshot_template.format(model_args, task, batch_size, nshot, filename, filename) 92 | if bf16: 93 | command = command.replace('--model hf-quant', '--model hf') 94 | f.write(command) 95 | f.write('\n') 96 | f.write('\n') 97 | if bf16: 98 | break 99 | 100 | import os 101 | os.system('chmod +x {}'.format(out_filename)) 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /benckmarks/example_gsm8k_cot_manyshot.py: -------------------------------------------------------------------------------- 1 | # LLaMA model with KIVI 2 | import warnings 3 | warnings.filterwarnings("ignore") 4 | import torch 5 | import random 6 | import argparse 7 | import torch 8 | from flexible_quant.flexible_quantized_cache import FlexibleQuantizedCacheConfig, FlexibleHQQQuantizedCache, FlexibleVanillaQuantizedCache 9 | from transformers import AutoTokenizer, AutoModelForCausalLM, QuantizedCacheConfig, HQQQuantizedCache, QuantoQuantizedCache 10 | from datasets import load_dataset 11 | from transformers import LlamaConfig, AutoTokenizer, LlamaForCausalLM 12 | from datasets import load_dataset 13 | from evals.gsm8k_utils import * 14 | 15 | # For reproducibility 16 | random.seed(0) 17 | torch.manual_seed(0) 18 | 19 | def parse_args(args=None): 20 | parser = argparse.ArgumentParser() 21 | # parser.add_argument('--model_name', type=str, default="meta-llama/Llama-2-7b-hf") 22 | # parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-3B-Instruct-AWQ") 23 | # parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-7B-Instruct") 24 | parser.add_argument('--model_name', type=str, default="") 25 | parser.add_argument('--nshots', type=int, default=5) 26 | parser.add_argument('--k_bits', type=int, default=8) 27 | parser.add_argument('--v_bits', type=int, default=8) 28 | parser.add_argument('--residual_length', type=int, default=128) 29 | parser.add_argument('--group_size', type=int, default=64) 30 | parser.add_argument('--asym', type=bool, default=True) 31 | # in HQQ, 0 for per-channel, 1 for per-token 32 | parser.add_argument('--axis_key', type=int, default=0) 33 | parser.add_argument('--axis_value', type=int, default=1) 34 | return parser.parse_args(args) 35 | 36 | def args_to_str(args): 37 | ret = "" 38 | for arg in vars(args): 39 | ret += f"{arg}: {getattr(args, arg)}\n" 40 | return ret 41 | 42 | 43 | if __name__ == "__main__": 44 | args = parse_args() 45 | model_name = args.model_name 46 | num_cot = args.nshots 47 | 48 | # asym only works for VanillaQuantizedCache 49 | cache_config = FlexibleQuantizedCacheConfig(nbits_key=args.k_bits, nbits_value=args.v_bits, residual_length=args.residual_length, q_group_size=args.group_size, 50 | asym=args.asym, axis_key=args.axis_key, axis_value=args.axis_value, device='cuda') 51 | 52 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).cuda() 53 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=True) 54 | 55 | dataset = load_dataset('gsm8k', 'main') 56 | 57 | answers = [] 58 | num_testcases = len(dataset['test']) 59 | for idx, _question_answer in enumerate(dataset['test']): 60 | # past_key_values = FlexibleVanillaQuantizedCache(cache_config=cache_config) 61 | past_key_values = FlexibleVanillaQuantizedCache(cache_config=cache_config) 62 | 63 | prompt = build_prompt_from_trainset(dataset['train'], _question_answer["question"], num_cot, COT_FLAG) 64 | 65 | inputs = tokenizer(prompt, return_tensors="pt").input_ids.cuda()\ 66 | 67 | output = model.generate(inputs, past_key_values=past_key_values, use_cache=True, max_new_tokens=256) 68 | # config_str = f"# prompt tokens: {inputs.shape[1]}, K bit: {cache_config.nbits_key}, v_bits: {cache_config.nbits_value}, num_cot: {num_cot} group_size: {cache_config.q_group_size}, residual_length: {cache_config.residual_length}" 69 | config_str = args_to_str(args) 70 | model_completion = tokenizer.decode(output[0].tolist()[inputs.shape[1]:], skip_special_tokens=True) 71 | 72 | model_answer = clean_answer(model_completion) 73 | is_cor = is_correct(model_answer, _question_answer["answer"]) 74 | answers.append(is_cor) 75 | 76 | print("\n\n" + "=" * 88 + "\n\t\t{} / {}-th testcase".format(idx, num_testcases)) 77 | 78 | if idx % 50 == 0: 79 | print(prompt + "\n\n\n" + "=" * 10 + f'\n{config_str}\n' + f"model_name : {args.model_name}\n" + "=" * 10 + "\nExample Output:") 80 | else: 81 | print(_question_answer["question"] + "\n\n\n" + "=" * 10 + f'\n{config_str}\n' + f"model_name : {args.model_name}\n" + "=" * 10 + "\nExample Output:") 82 | print(model_completion) 83 | print("\nTarget answer: {}".format(_question_answer["answer"])) 84 | print("\n=== Is correct: {}".format(is_cor)) 85 | 86 | print( 87 | f"Num of total question: {len(answers)}, " 88 | f"Correct num: {sum(answers)}, " 89 | f"Accuracy: {float(sum(answers))/len(answers)}." 90 | ) 91 | 92 | print("Final result summary:\n") 93 | print( 94 | f"Num of total question: {len(answers)}, " 95 | f"Correct num: {sum(answers)}, " 96 | f"Accuracy: {float(sum(answers))/len(answers)}." 97 | ) 98 | 99 | for i in range(5): 100 | str_ = 'Question: ' + dataset['test'][i]['question'] + '\nAnswer: ' + dataset['test'][i]['answer'] + "\n" 101 | print("TEST [{}]: \n{}".format(i, str_)) -------------------------------------------------------------------------------- /helper_scripts/create_table.py: -------------------------------------------------------------------------------- 1 | # Quant. method & Precision & CEVAL-VALID & MMLU & TriviaQA & RACE & TruthfulQA & GSM8K & GSM8K 4-shot & GSM8K 8-shot & GSM8K 16-shot \\ \hline 2 | # \multicolumn{11}{c}{\textbf{Mistral-7B-Instruct-v0.3}} \\ \hline 3 | # \multirow{5}{*}{KIVI} 4 | # & KV8 & 0.4368 & 0.5904 & 0.3246 & 0.4622 & 0.5435 & - & - & - & - \\ 5 | # & K8V4 & 0.4368 & 0.5904 & 0.3243 & 0.4622 & 0.5483 & - & - & - & - \\ 6 | # & K8V2 & 0.4368 & 0.5904 & 0.3208 & 0.4622 & 0.5459 & - & - & - & - \\ 7 | # & K4V8 & 0.4368 & 0.5904 & 0.3242 & 0.4622 & 0.5349 & - & - & - & - \\ 8 | # & KV4 & 0.4368 & 0.5904 & 0.3245 & 0.4622 & 0.5373 & - & - & - & - \\ 9 | # & K4V2 & 0.4368 & 0.5904 & 0.3199 & 0.4622 & 0.5398 & - & - & - & - \\ 10 | # & K2V4 & 0.4368 & 0.5904 & 0.3231 & 0.4622 & 0.5471 & - & - & - & - \\ 11 | # & KV2 & 0.4368 & 0.5904 & 0.3190 & 0.4622 & 0.5300 & - & - & - & - \\ \hline 12 | 13 | # CEVAL-VALID & MMLU & TriviaQA & RACE & TruthfulQA & GSM8K & GSM8K 4-shot & GSM8K 8-shot & GSM8K 16-shot 14 | datasets = ['CEVAL-VALID', 'MMLU', 'TriviaQA', 'RACE', 'TruthfulQA', 'GSM8K', 'GSM8K 4-shot', 'GSM8K 8-shot', 'GSM8K 16-shot'] 15 | # KV config 16 | kv_configs = [[8, 8], [8, 4], [8, 2], [4, 8], [4, 4], [4, 2], [2, 4], [2, 2]]\ 17 | 18 | def KV_config_str(k_nbit, v_nbit): 19 | if k_nbit == v_nbit: 20 | return f'KV{k_nbit}' 21 | else: 22 | return f'K{k_nbit}V{v_nbit}' 23 | 24 | # |ceval-valid | 2|none | |acc |↑ |0.2585|± |0.0119| 25 | def extrace_value(lines: list, start_str: str): 26 | for line in lines: 27 | if line.startswith(start_str): 28 | rem = line.split(start_str)[1] 29 | return rem.split('|')[0].strip() 30 | return '-' 31 | 32 | import argparse 33 | import os 34 | # add param: model_name 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('--model_name', type=str, required=True) 37 | args = parser.parse_args() 38 | 39 | model_name = args.model_name.replace('/', '_') 40 | display_model_name = args.model_name.split('/')[-1] 41 | 42 | # create empty table 43 | table = {} 44 | for kv_config in kv_configs: 45 | k_nbit, v_nbit = kv_config 46 | table[KV_config_str(k_nbit, v_nbit)] = {dataset: '-' for dataset in datasets} 47 | for dataset in datasets: 48 | if not dataset.startswith('GSM8K'): 49 | log_filename = f'{model_name}_others_k{k_nbit}_v{v_nbit}.log' 50 | # test if file exists 51 | if not os.path.exists(log_filename): 52 | continue 53 | with open(log_filename, 'r') as f: 54 | lines = f.readlines() 55 | if dataset == 'CEVAL-VALID': 56 | table[KV_config_str(k_nbit, v_nbit)][dataset] = extrace_value(lines, '|ceval-valid | 2|none | |acc |↑ |') 57 | elif dataset == 'MMLU': 58 | table[KV_config_str(k_nbit, v_nbit)][dataset] = extrace_value(lines, '|mmlu | 2|none | |acc |↑ |') 59 | elif dataset == 'TriviaQA': 60 | table[KV_config_str(k_nbit, v_nbit)][dataset] = extrace_value(lines, '|triviaqa | 3|remove_whitespace| 0|exact_match|↑ |') 61 | elif dataset == 'RACE': 62 | table[KV_config_str(k_nbit, v_nbit)][dataset] = extrace_value(lines, '|race | 2|none | 0|acc |↑ |') 63 | elif dataset == 'TruthfulQA': 64 | table[KV_config_str(k_nbit, v_nbit)][dataset] = extrace_value(lines, '|truthfulqa_gen | 3|none | 0|bleu_acc |↑ |') 65 | else: 66 | n_shot = 0 67 | if dataset == 'GSM8K 4-shot': 68 | n_shot = 4 69 | elif dataset == 'GSM8K 8-shot': 70 | n_shot = 8 71 | elif dataset == 'GSM8K 16-shot': 72 | n_shot = 16 73 | log_filename = f'{model_name}_gsm8k_k{k_nbit}_v{v_nbit}_n{n_shot}.log' 74 | if not os.path.exists(log_filename): 75 | continue 76 | with open(log_filename, 'r') as f: 77 | lines = f.readlines() 78 | if n_shot != 16: 79 | table[KV_config_str(k_nbit, v_nbit)][dataset] = extrace_value(lines, f'|gsm8k| 3|flexible-extract| {n_shot}|exact_match|↑ |') 80 | else: 81 | table[KV_config_str(k_nbit, v_nbit)][dataset] = extrace_value(lines, f'|gsm8k| 3|flexible-extract| {n_shot}|exact_match|↑ |') 82 | # print(table) 83 | 84 | print('\\multicolumn{11}{c}{\\textbf{' + display_model_name + '}} \\\\ \\hline') 85 | print('\\multirow{5}{*}{KIVI}') 86 | for kv_config in kv_configs: 87 | k_nbit, v_nbit = kv_config 88 | print('& ' + KV_config_str(k_nbit, v_nbit), end=' ') 89 | for dataset in datasets: 90 | print('&', table[KV_config_str(k_nbit, v_nbit)][dataset], end=' ') 91 | print('\\\\', end='') 92 | if kv_config == kv_configs[-1]: 93 | print('\\hline') 94 | else: 95 | print() 96 | -------------------------------------------------------------------------------- /search_optuna_vanilla.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | import torch 4 | import random 5 | import argparse 6 | import torch 7 | import optuna 8 | import lm_eval 9 | from lm_eval.models.huggingface_quant import HFLM_Quant 10 | import logging 11 | import sys 12 | 13 | # For reproducibility 14 | random.seed(0) 15 | torch.manual_seed(0) 16 | CACHE_DIR = "./models_storage" 17 | 18 | TEMPLATE_KV_QUANT_CONFIG = [ 19 | {'nbits_key': 8, 'nbits_value': 8}, 20 | {'nbits_key': 8, 'nbits_value': 4}, 21 | {'nbits_key': 4, 'nbits_value': 4}, 22 | {'nbits_key': 4, 'nbits_value': 2}, 23 | {'nbits_key': 2, 'nbits_value': 2}, 24 | ] 25 | 26 | # THIS IS FOR PER_TOKEN QUANTIZATION 27 | LLAMA3_IMPORTANT_LAYERS = [0, 3, 5, 7, 12, 15, 22, 26, 30, 31] 28 | LLAMA3_MEDIUM_LAYERS = [6, 8, 9, 10, 11, 13, 14, 25, 27, 28, 29] 29 | 30 | QWEN_IMPORTANT_LAYERS = [0, 18, 20, 27, 29, 35] 31 | QWEN_MEDIUM_LAYERS = [3, 4, 5] 32 | 33 | global_args = {} 34 | model = None 35 | tokenizer = None 36 | dataset = None 37 | 38 | def parse_args(args=None): 39 | parser = argparse.ArgumentParser() 40 | # parser.add_argument('--model_name', type=str, default="meta-llama/Llama-2-7b-hf") 41 | # parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-3B-Instruct-AWQ") 42 | # parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-7B-Instruct") 43 | parser.add_argument('--model_name', type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct") 44 | parser.add_argument('--residual_length', type=int, default=0) 45 | parser.add_argument('--group_size', type=int, default=-1) 46 | parser.add_argument('--asym', type=bool, default=True) 47 | # in Vanilla, 0 for per-token, 1 for per-channel, we have to use per-channel there as residual_length is 0 48 | parser.add_argument('--axis_key', type=int, default=0) 49 | parser.add_argument('--axis_value', type=int, default=0) 50 | parser.add_argument('--limit', type=int, default=20) 51 | parser.add_argument('--num_fewshots', type=int, default=4) 52 | parser.add_argument('--max_per_layer_scale', type=int, default=8) 53 | parser.add_argument('--n_trials', type=int, default=100) 54 | parser.add_argument('--device', type=str, default="cuda") 55 | return parser.parse_args(args) 56 | 57 | 58 | 59 | def run_gsm8k(residual_length: int, group_size: int, asym: bool, axis_key: int, axis_value: int, per_layer_config: dict, model_name: str, num_fewshots: int, limit: int, device: str): 60 | results = lm_eval.simple_evaluate( 61 | model='hf-quant', 62 | model_args={ 63 | 'pretrained': model_name, 64 | 'nbits_key': -1, 65 | 'nbits_value': -1, 66 | 'residual_length': residual_length, 67 | 'q_group_size': group_size, 68 | 'asym': asym, 69 | 'axis_key': axis_key, 70 | 'axis_value': axis_value, 71 | 'dtype': torch.bfloat16, 72 | 'force_quant': True, 73 | 'per_layer_quant': True, 74 | 'per_layer_config': per_layer_config, 75 | 'quantilizer': 'vanilla', 76 | }, 77 | tasks=["gsm8k"], 78 | num_fewshot=num_fewshots, 79 | limit=limit, 80 | device=device 81 | ) 82 | print(results['results']['gsm8k']['exact_match,flexible-extract']) 83 | return float(results['results']['gsm8k']['exact_match,flexible-extract']) 84 | 85 | def objective(trial): 86 | tot_layers = 32 if 'llama' in model.lower() else 36 87 | 88 | per_layer_config = {} 89 | tot_scale = 0 90 | 91 | for layer in range(0, tot_layers): 92 | config_current_layer = trial.suggest_int('layer_{}'.format(layer), 0, len(TEMPLATE_KV_QUANT_CONFIG) - 1) 93 | per_layer_config[layer] = TEMPLATE_KV_QUANT_CONFIG[config_current_layer] 94 | tot_scale += per_layer_config[layer]['nbits_key'] + per_layer_config[layer]['nbits_value'] 95 | 96 | # Constraints which are considered feasible if less than or equal to zero. 97 | tot_scale /= tot_layers * 2 98 | c = tot_scale - global_args['max_per_layer_scale'] 99 | 100 | print('constraints:', c) 101 | 102 | trial.set_user_attr('constraints', (c, )) 103 | 104 | accuracy = run_gsm8k(global_args['residual_length'], global_args['group_size'], global_args['asym'], global_args['axis_key'], global_args['axis_value'], per_layer_config, 105 | global_args['model_name'], global_args['num_fewshots'], global_args['limit'], global_args['device']) 106 | 107 | return accuracy, tot_scale 108 | 109 | def constraints(trial): 110 | return trial.user_attrs["constraints"] 111 | 112 | if __name__ == "__main__": 113 | args = parse_args() 114 | model = args.model_name 115 | 116 | global_args['model_name'] = args.model_name 117 | global_args['residual_length'] = args.residual_length 118 | global_args['group_size'] = args.group_size 119 | global_args['asym'] = args.asym 120 | global_args['axis_key'] = args.axis_key 121 | global_args['axis_value'] = args.axis_value 122 | global_args['limit'] = args.limit 123 | global_args['num_fewshots'] = args.num_fewshots 124 | global_args['device'] = args.device 125 | global_args['max_per_layer_scale'] = args.max_per_layer_scale 126 | 127 | optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout)) 128 | study_name = "{}_gsm8k_l{}_search_{}_m{}_brute_force_{}".format(model.replace("/", "_"), args.limit, args.device.replace(":", ""), args.max_per_layer_scale, 'per_token' if args.group_size else 'kivi') 129 | storage_name = "sqlite:///{}.db".format(study_name) 130 | sampler = optuna.samplers.NSGAIISampler(constraints_func=constraints) 131 | study = optuna.create_study(directions=["maximize", "minimize"], study_name=study_name, storage=storage_name, sampler=sampler) 132 | study.optimize(objective, n_trials=args.n_trials) 133 | 134 | # print(study.best_params) 135 | # print(study.best_value) 136 | -------------------------------------------------------------------------------- /search_brute_force.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | import torch 4 | import random 5 | import argparse 6 | import torch 7 | import lm_eval 8 | from lm_eval.models.huggingface_quant import HFLM_Quant 9 | 10 | # For reproducibility 11 | random.seed(0) 12 | torch.manual_seed(0) 13 | CACHE_DIR = "./models_storage" 14 | 15 | TEMPLATE_KV_QUANT_CONFIG = [ 16 | {'nbits_key': 8, 'nbits_value': 8}, 17 | {'nbits_key': 8, 'nbits_value': 4}, 18 | {'nbits_key': 4, 'nbits_value': 4}, 19 | {'nbits_key': 4, 'nbits_value': 2}, 20 | {'nbits_key': 2, 'nbits_value': 2}, 21 | ] 22 | 23 | LLAMA3_IMPORTANT_LAYERS = [0, 3, 5, 7, 12, 15, 22, 26, 30, 31] 24 | LLAMA3_MEDIUM_LAYERS = [6, 8, 9, 10, 11, 13, 14, 25, 27, 28, 29] 25 | 26 | QWEN_IMPORTANT_LAYERS = [0, 18, 20, 27, 29, 35] 27 | QWEN_MEDIUM_LAYERS = [3, 4, 5] 28 | 29 | global_args = {} 30 | model = None 31 | tokenizer = None 32 | dataset = None 33 | 34 | def parse_args(args=None): 35 | parser = argparse.ArgumentParser() 36 | # parser.add_argument('--model_name', type=str, default="meta-llama/Llama-2-7b-hf") 37 | parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-3B-Instruct-AWQ") 38 | # parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-7B-Instruct") 39 | # parser.add_argument('--model_name', type=str, default="meta-llama/Meta-Llama-3-8B-Instruct") 40 | parser.add_argument('--residual_length', type=int, default=32) 41 | parser.add_argument('--group_size', type=int, default=32) 42 | parser.add_argument('--asym', type=bool, default=True) 43 | # in Vanilla, 0 for per-token, 1 for per-channel, we have to use per-channel there as residual_length is 0 44 | parser.add_argument('--axis_key', type=int, default=1) 45 | parser.add_argument('--axis_value', type=int, default=0) 46 | parser.add_argument('--limit', type=int, default=200) 47 | parser.add_argument('--num_fewshots', type=int, default=0) 48 | parser.add_argument('--device', type=str, default="cuda:0") 49 | return parser.parse_args(args) 50 | 51 | 52 | 53 | def run_gsm8k(residual_length: int, group_size: int, asym: bool, axis_key: int, axis_value: int, per_layer_config: dict, model_name: str, num_fewshots: int, limit: int, device: str): 54 | if limit != -1: 55 | results = lm_eval.simple_evaluate( 56 | model='hf-quant', 57 | model_args={ 58 | 'pretrained': model_name, 59 | 'nbits_key': -1, 60 | 'nbits_value': -1, 61 | 'residual_length': residual_length, 62 | 'q_group_size': group_size, 63 | 'asym': asym, 64 | 'axis_key': axis_key, 65 | 'axis_value': axis_value, 66 | 'dtype': torch.bfloat16, 67 | 'force_quant': True, 68 | 'per_layer_quant': True, 69 | 'per_layer_config': per_layer_config, 70 | 'quantilizer': 'vanilla', 71 | }, 72 | tasks=["gsm8k"], 73 | num_fewshot=num_fewshots, 74 | limit=limit, 75 | device=device 76 | ) 77 | else: 78 | results = lm_eval.simple_evaluate( 79 | model='hf-quant', 80 | model_args={ 81 | 'pretrained': model_name, 82 | 'nbits_key': -1, 83 | 'nbits_value': -1, 84 | 'residual_length': residual_length, 85 | 'q_group_size': group_size, 86 | 'asym': asym, 87 | 'axis_key': axis_key, 88 | 'axis_value': axis_value, 89 | 'dtype': torch.bfloat16, 90 | 'force_quant': True, 91 | 'per_layer_quant': True, 92 | 'per_layer_config': per_layer_config, 93 | 'quantilizer': 'vanilla', 94 | }, 95 | tasks=["gsm8k"], 96 | num_fewshot=num_fewshots, 97 | device=device 98 | ) 99 | print(results['results']['gsm8k']['exact_match,flexible-extract']) 100 | return float(results['results']['gsm8k']['exact_match,flexible-extract']) 101 | 102 | def build_per_layer_config(model: str, config_high: int, config_medium: int, config_low: int): 103 | important_layers = [] 104 | if 'llama' in model.lower(): 105 | important_layers = LLAMA3_IMPORTANT_LAYERS 106 | medium_layers = LLAMA3_MEDIUM_LAYERS 107 | if 'qwen' in model.lower(): 108 | important_layers = QWEN_IMPORTANT_LAYERS 109 | medium_layers = QWEN_MEDIUM_LAYERS 110 | per_layer_config = {} 111 | tot_scale = 0 112 | tot_layers = 32 if 'llama' in model.lower() else 36 113 | for layer in range(0, tot_layers): 114 | if layer in important_layers: 115 | per_layer_config[layer] = TEMPLATE_KV_QUANT_CONFIG[config_high] 116 | elif layer in medium_layers: 117 | per_layer_config[layer] = TEMPLATE_KV_QUANT_CONFIG[config_medium] 118 | else: 119 | per_layer_config[layer] = TEMPLATE_KV_QUANT_CONFIG[config_low] 120 | tot_scale += per_layer_config[layer]['nbits_key'] + per_layer_config[layer]['nbits_value'] 121 | tot_scale /= tot_layers * 2 122 | return per_layer_config, tot_scale 123 | 124 | 125 | if __name__ == "__main__": 126 | args = parse_args() 127 | model_name = args.model_name 128 | 129 | global_args['model_name'] = model_name 130 | global_args['residual_length'] = args.residual_length 131 | global_args['group_size'] = args.group_size 132 | global_args['asym'] = args.asym 133 | global_args['axis_key'] = args.axis_key 134 | global_args['axis_value'] = args.axis_value 135 | global_args['limit'] = args.limit 136 | global_args['num_fewshots'] = args.num_fewshots 137 | global_args['device'] = args.device 138 | 139 | print(global_args) 140 | valid_params = [] 141 | for profile_high in range(5): 142 | for profile_medium in range(profile_high, 5): 143 | for profile_low in range(profile_medium, 5): 144 | valid_params.append((profile_high, profile_medium, profile_low)) 145 | 146 | for profile_high, profile_medium, profile_low in valid_params: 147 | per_layer_config, tot_scale = build_per_layer_config(args.model_name, profile_high, profile_medium, profile_low) 148 | accuracy = run_gsm8k(global_args['residual_length'], global_args['group_size'], global_args['asym'], global_args['axis_key'], global_args['axis_value'], per_layer_config, global_args['model_name'], global_args['num_fewshots'], global_args['limit'], global_args['device']) 149 | print(f"Profile: {profile_high}, {profile_medium}, {profile_low}, Accuracy: {accuracy}, Scale: {tot_scale}") 150 | print("") -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KVTuner: Sensitivity-Aware Layer-wise Mixed Precision KV Cache Quantization for Efficient and Nearly Lossless LLM Inference 2 | 3 | Official implementation of the ICML25 paper: KVTuner: Sensitivity-Aware Layer-wise Mixed Precision KV Cache Quantization for Efficient and Nearly Lossless LLM Inference 4 | 5 | ## Installation 6 | ```sh 7 | cd flexible_quant 8 | pip install -e . 9 | ``` 10 | 11 | ## Run Example codes 12 | ```sh 13 | python3 flexible_quant_example.py 14 | ``` 15 | Then you will run a simple example from `GSM8K` with `meta-llama/Meta-Llama-3-8B` and `KV4` quantization. 16 | 17 | Change line 17 in `flexible_quant_example.py` to run different quantization methods. 18 | 19 | ### Run GSM8K 20 | ```bash 21 | cd benchmarks 22 | # GSM8K K8V4 with KiVi quantization scheme 23 | python3 example_gsm8k_cot_manyshot.py --model_name="mistralai/Mistral-7B-Instruct-v0.2" --k_bits=8 --v_bits=4 --residual_length=32 --group_size=32 --axis_key=1 --axis_value=0 24 | # GSM8K K8V4 with Per-Token quantization scheme 25 | python3 example_gsm8k_cot_manyshot.py --model_name="mistralai/Mistral-7B-Instruct-v0.2" --k_bits=8 --v_bits=4 --residual_length=0 --group_size=-1 --axis_key=0 --axis_value=0 26 | ``` 27 | 28 | ##### Parameters 29 | 30 | - `model_name`: the model name from Hugging Face model hub. 31 | - `nshots`: the number of shots for the few-shot inference. 32 | - `k_bits`: the precision for the key. 33 | - `v_bits`: the precision for the value. 34 | - `asym`: whether to use asymmetric quantization. 35 | - `residual_length`: the length of the residual tokens which are not quantized. must be a multiple of `group_size`, use 0 for per-token quantization. 36 | - `group_size`: the size of the group for quantization, use -1 for per-token quantization. 37 | - `axis_key`: the axis for key quantization, 0 for per-token quantization, 1 for per-channel quantization. 38 | - `axis_value`: the axis for value quantization, 0 for per-token quantization, 1 for per-channel quantization. 39 | 40 | 41 | #### Run LongBench 42 | ```sh 43 | cd benchmarks 44 | PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python3 pred_longbench.py 45 | ``` 46 | 47 | Same parameters as GSM8K. 48 | 49 | #### Run lm-eval 50 | 51 | This repo provides a modified version of `lm-eval` to support the quantization evaluation. 52 | 53 | Refer to `lm-evaluation-harness-X/lm_eval/models/huggingface_quant.py` 54 | 55 | ## Use `FlexibleQuantizedCache` in your code 56 | ```python 57 | # Define your model 58 | from transformers import AutoTokenizer, AutoModelForCausalLM, 59 | model_name = 'meta-llama/Meta-Llama-3-8B' 60 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).cuda() 61 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=True) 62 | 63 | # Define the cache 64 | from flexible_quant.flexible_quantized_cache import FlexibleQuantizedCacheConfig, FlexibleVanillaQuantizedCache 65 | cache_config = FlexibleQuantizedCacheConfig(nbits_key=4, nbits_value=4, asym=True, axis_key=0, axis_value=0, device='cuda', q_group_size=-1) 66 | # By default we use FlexibleVanillaQuantizedCache, you can switch to FlexibleHQQQuantizedCache and FlexibleQuantoQuantizedCache 67 | past_key_values = FlexibleVanillaQuantizedCache(cache_config=cache_config) 68 | 69 | # Prompt and generate 70 | prompt = '''The quick brown fox jumps over the lazy dog.''' 71 | inputs = tokenizer(prompt, return_tensors="pt").input_ids.cuda() 72 | outputs = model.generate(inputs, past_key_values=past_key_values, use_cache=True, max_new_tokens=256) 73 | ``` 74 | 75 | 76 | ## FlexibleQuantizedCacheConfig 77 | 78 | ```python 79 | """ 80 | Configuration for flexible quantized cache. 81 | 82 | Attributes: 83 | backend (str): Backend for quantization. Options: "quanto", "hqq", "vanilla". 84 | nbits (Optional[int]): Precision for both key and value. Used if `nbits_key` and `nbits_value` are not set. 85 | For per-layer or per-head quantization, set `nbits` to -1. 86 | nbits_key (Optional[int]): Precision for key quantization. For per-layer or per-head quantization, set to -1. 87 | nbits_value (Optional[int]): Precision for value quantization. For per-layer or per-head quantization, set to -1. 88 | axis_key (Optional[int]): Axis for key quantization. In Vanilla mode: 89 | - 0: Per-token quantization 90 | - 1: Per-channel quantization 91 | axis_value (Optional[int]): Axis for value quantization. In Vanilla mode: 92 | - 0: Per-token quantization 93 | - 1: Per-channel quantization 94 | asym (Optional[bool]): Whether to use asymmetric quantization. Works only for Vanilla mode. 95 | q_group_size (Optional[int]): Group size for quantization. Use -1 for per-token quantization. 96 | residual_length (Optional[int]): Length of residual tokens that are not quantized. 97 | Must be a multiple of `q_group_size`. Use 0 for per-token quantization. 98 | compute_dtype (Optional[torch.dtype]): Compute dtype for the model. Default: `torch.float16`. 99 | device (Optional[str]): Device for the cache. Default: `"cpu"`. 100 | force_quant (Optional[bool]): Whether to quantize the cache during the pre-filling stage. 101 | per_layer_quant (Optional[bool]): Whether to use per-layer quantization. 102 | per_layer_config (Optional[Dict[str, Any]]): If `per_layer_quant` is True, provides the quantization config 103 | for each layer. Alternatively, use `per_layer_config_path`. 104 | per_layer_config_path (Optional[str]): Path to the quantization config for each layer. 105 | Used if `per_layer_quant` is True. 106 | per_head_quant (Optional[bool]): Whether to use per-head quantization. 107 | per_head_config (Optional[Dict[str, Any]]): If `per_head_quant` is True, provides the quantization config 108 | for each head. Alternatively, use `per_head_config_path`. 109 | per_head_config_path (Optional[str]): Path to the quantization config for each head. 110 | Used if `per_head_quant` is True. 111 | """ 112 | ``` 113 | 114 | ### Example for per_layer_config 115 | ```python 116 | per_layer_config = { 117 | {n_layer}: { 118 | 'nbits_key': 4, 119 | 'nbits_value': 4, 120 | }, 121 | # ... 122 | ``` 123 | 124 | ### Example for per_head_config 125 | ```python 126 | per_head_config = { 127 | {n_layer}: { 128 | {head_idx}: { 129 | 'nbits_key': 4, 130 | 'nbits_value': 4, 131 | }, 132 | # ... 133 | }, 134 | # ... 135 | ``` 136 | 137 | ## Citation 138 | ```bibtex 139 | @misc{li2025kvtunersensitivityawarelayerwisemixed, 140 | title={KVTuner: Sensitivity-Aware Layer-wise Mixed Precision KV Cache Quantization for Efficient and Nearly Lossless LLM Inference}, 141 | author={Xing Li and Zeyu Xing and Yiming Li and Linping Qu and Hui-Ling Zhen and Wulong Liu and Yiwu Yao and Sinno Jialin Pan and Mingxuan Yuan}, 142 | year={2025}, 143 | eprint={2502.04420}, 144 | archivePrefix={arXiv}, 145 | primaryClass={cs.LG}, 146 | url={https://arxiv.org/abs/2502.04420}, 147 | } 148 | ``` 149 | -------------------------------------------------------------------------------- /benckmarks/evals/gsm8k_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import random 3 | 4 | ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)") 5 | INVALID_ANS = "[invalid]" 6 | 7 | N_SHOT = 8 8 | COT_FLAG = True 9 | DEBUG = False 10 | # ANSWER_TRIGGER = "The answer is" 11 | ANSWER_TRIGGER = "#### " 12 | 13 | 14 | def extract_answer_from_output(completion): 15 | match = ANS_RE.search(completion) 16 | if match: 17 | match_str = match.group(1).strip() 18 | match_str = match_str.replace(",", "") 19 | return match_str 20 | else: 21 | return INVALID_ANS 22 | 23 | 24 | def is_correct(model_answer, answer): 25 | gt_answer = extract_answer_from_output(answer) 26 | assert gt_answer != INVALID_ANS 27 | return model_answer == gt_answer 28 | 29 | 30 | def create_demo_text(n_shot=8, cot_flag=True): 31 | question, chain, answer = [], [], [] 32 | question.append( 33 | "There are 15 trees in the grove. " 34 | "Grove workers will plant trees in the grove today. " 35 | "After they are done, there will be 21 trees. " 36 | "How many trees did the grove workers plant today?" 37 | ) 38 | chain.append( 39 | "There are 15 trees originally. " 40 | "Then there were 21 trees after some more were planted. " 41 | "So there must have been 21 - 15 = 6." 42 | ) 43 | answer.append("6") 44 | 45 | question.append( 46 | "If there are 3 cars in the parking lot and 2 more cars arrive, " 47 | "how many cars are in the parking lot?" 48 | ) 49 | chain.append("There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5.") 50 | answer.append("5") 51 | 52 | question.append( 53 | "Leah had 32 chocolates and her sister had 42. If they ate 35, " 54 | "how many pieces do they have left in total?" 55 | ) 56 | chain.append( 57 | "Originally, Leah had 32 chocolates. " 58 | "Her sister had 42. So in total they had 32 + 42 = 74. " 59 | "After eating 35, they had 74 - 35 = 39." 60 | ) 61 | answer.append("39") 62 | 63 | question.append( 64 | "Jason had 20 lollipops. He gave Denny some lollipops. Now Jason " 65 | "has 12 lollipops. How many lollipops did Jason give to Denny?" 66 | ) 67 | chain.append( 68 | "Jason started with 20 lollipops. Then he had 12 after giving some " 69 | "to Denny. So he gave Denny 20 - 12 = 8." 70 | ) 71 | answer.append("8") 72 | 73 | question.append( 74 | "Shawn has five toys. For Christmas, he got two toys each from his " 75 | "mom and dad. How many toys does he have now?" 76 | ) 77 | chain.append( 78 | "Shawn started with 5 toys. If he got 2 toys each from his mom and " 79 | "dad, then that is 4 more toys. 5 + 4 = 9." 80 | ) 81 | answer.append("9") 82 | 83 | question.append( 84 | "There were nine computers in the server room. Five more computers " 85 | "were installed each day, from monday to thursday. " 86 | "How many computers are now in the server room?" 87 | ) 88 | chain.append( 89 | "There were originally 9 computers. For each of 4 days, 5 more " 90 | "computers were added. So 5 * 4 = 20 computers were added. " 91 | "9 + 20 is 29." 92 | ) 93 | answer.append("29") 94 | 95 | question.append( 96 | "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On " 97 | "wednesday, he lost 2 more. " 98 | "How many golf balls did he have at the end of wednesday?" 99 | ) 100 | chain.append( 101 | "Michael started with 58 golf balls. After losing 23 on tuesday, " 102 | "he had 58 - 23 = 35. After losing 2 more, " 103 | "he had 35 - 2 = 33 golf balls." 104 | ) 105 | answer.append("33") 106 | 107 | question.append( 108 | "Olivia has $23. She bought five bagels for $3 each. " 109 | "How much money does she have left?" 110 | ) 111 | chain.append( 112 | "Olivia had 23 dollars. " 113 | "5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. " 114 | "So she has 23 - 15 dollars left. 23 - 15 is 8." 115 | ) 116 | answer.append("8") 117 | 118 | # randomize order of the examples ... 119 | index_list = list(range(len(question))) 120 | random.shuffle(index_list) 121 | 122 | # Concatenate demonstration examples ... 123 | demo_text = "" 124 | for i in index_list[:n_shot]: 125 | if cot_flag: 126 | demo_text += ( 127 | "Q: " 128 | + question[i] 129 | + "\nA: " 130 | + chain[i] 131 | + " " 132 | + ANSWER_TRIGGER 133 | + " " 134 | + answer[i] 135 | + ".\n\n" 136 | ) 137 | else: 138 | demo_text += ( 139 | "Question: " 140 | + question[i] 141 | + "\nAnswer: " 142 | + ANSWER_TRIGGER 143 | + " " 144 | + answer[i] 145 | + ".\n\n" 146 | ) 147 | return demo_text 148 | 149 | 150 | def create_demo_text_from_trainset(trainset, n_shot, cot_flag): 151 | question, chain, answer = [], [], [] 152 | for idx in range(n_shot): 153 | question.append(trainset[idx]["question"]) 154 | chain.append(trainset[idx]["answer"]) 155 | 156 | # Concatenate demonstration examples ... 157 | demo_text = "" 158 | for i in range(n_shot): 159 | if cot_flag: 160 | demo_text += ( 161 | "Question: " 162 | + question[i] 163 | + "\nAnswer: " 164 | + chain[i] 165 | + ".\n\n" 166 | ) 167 | else: 168 | demo_text += ( 169 | "Question: " 170 | + question[i] 171 | + "\nAnswer: " 172 | + ".\n\n" 173 | ) 174 | return demo_text 175 | 176 | 177 | def build_prompt(input_text, n_shot, cot_flag): 178 | demo = create_demo_text(n_shot, cot_flag) 179 | input_text_prompt = demo + "Q: " + input_text + "\n" + "A:" 180 | return input_text_prompt 181 | 182 | 183 | def build_prompt_from_trainset(trainset, input_text, n_shot, cot_flag): 184 | print("n_shot: {}, cot_flag: {}".format(n_shot, cot_flag)) 185 | demo = create_demo_text_from_trainset(trainset, n_shot, cot_flag) 186 | input_text_prompt = demo + "Question: " + input_text + "\n" + "Answer:" 187 | return input_text_prompt 188 | 189 | 190 | def clean_answer(model_pred): 191 | model_pred = model_pred.lower() 192 | preds = model_pred.split(ANSWER_TRIGGER.lower()) 193 | answer_flag = True if len(preds) > 1 else False 194 | if answer_flag: 195 | # Pick first answer with flag 196 | pred = preds[1] 197 | else: 198 | # Pick last number without flag 199 | pred = preds[-1] 200 | 201 | pred = pred.replace(",", "") 202 | pred = [s for s in re.findall(r"-?\d+\.?\d*", pred)] 203 | 204 | if len(pred) == 0: 205 | return INVALID_ANS 206 | 207 | if answer_flag: 208 | # choose the first element in list 209 | pred = pred[0] 210 | else: 211 | # choose the last element in list 212 | pred = pred[-1] 213 | 214 | # (For arithmetic tasks) if a word ends with period, it will be omitted ... 215 | if pred[-1] == ".": 216 | pred = pred[:-1] 217 | 218 | return pred 219 | -------------------------------------------------------------------------------- /benckmarks/gaokao_bench_obj.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | import os 4 | import tqdm 5 | import argparse 6 | from flexible_quant.flexible_quantized_cache import FlexibleQuantizedCacheConfig, FlexibleHQQQuantizedCache, FlexibleVanillaQuantizedCache 7 | from transformers import AutoTokenizer, AutoModelForCausalLM, QuantizedCacheConfig, HQQQuantizedCache, QuantoQuantizedCache 8 | from datasets import load_dataset 9 | from transformers import LlamaConfig, AutoTokenizer, LlamaForCausalLM 10 | import torch 11 | import random 12 | # from accelerate import Accelerator 13 | 14 | # accelerator = Accelerator() 15 | # device = accelerator.device 16 | device = None 17 | 18 | import importlib 19 | bench_function = importlib.import_module("GAOKAO-Bench.Bench.bench_function") 20 | 21 | # For reproducibility 22 | random.seed(0) 23 | torch.manual_seed(0) 24 | CACHE_DIR = "./models_storage" 25 | 26 | def get_dtype(str): 27 | if str == "bfloat16": 28 | return torch.bfloat16 29 | elif str == "float16": 30 | return torch.float16 31 | elif str == "float32": 32 | return torch.float32 33 | else: 34 | raise ValueError(f"Unsupported dtype {str}") 35 | 36 | def parse_args(args=None): 37 | parser = argparse.ArgumentParser() 38 | # parser.add_argument('--model_name', type=str, default="meta-llama/Llama-2-7b-hf") 39 | # parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-3B-Instruct-AWQ") 40 | # parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-7B-Instruct") 41 | parser.add_argument('--model_name', type=str, default="meta-llama/Meta-Llama-3-8B-Instruct") 42 | # parser.add_argument('--model_name', type=str, default="") 43 | # parser.add_argument('--nshots', type=int, default=5) 44 | parser.add_argument('--dtype', type=str, default="bfloat16") 45 | parser.add_argument('--k_bits', type=int, default=8) 46 | parser.add_argument('--v_bits', type=int, default=8) 47 | parser.add_argument('--residual_length', type=int, default=128) 48 | parser.add_argument('--group_size', type=int, default=64) 49 | parser.add_argument('--asym', type=bool, default=True) 50 | parser.add_argument('--quantizer', type=str, default="Vanilla") 51 | # in HQQ, 0 for per-channel, 1 for per-token 52 | # in Vanilla, 0 for per-token, 1 for per-channel 53 | parser.add_argument('--axis_key', type=int, default=0) 54 | parser.add_argument('--axis_value', type=int, default=0) 55 | parser.add_argument('--per_layer_quant', type=bool, default=False) 56 | parser.add_argument('--per_layer_config_path', type=str, default="") 57 | parser.add_argument('--limit', type=int, default=-1) 58 | parser.add_argument('--device', type=str, default="cuda") 59 | return parser.parse_args(args) 60 | 61 | tests_all = [] 62 | 63 | if __name__ == "__main__": 64 | args = parse_args() 65 | print(args) 66 | 67 | device = torch.device(args.device) 68 | 69 | model = AutoModelForCausalLM.from_pretrained(args.model_name, cache_dir=CACHE_DIR, torch_dtype=get_dtype(args.dtype)).to(device) 70 | tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False, trust_remote_code=True) 71 | 72 | with open("GAOKAO-Bench/Bench/Obj_Prompt.json", "r") as f: 73 | examples = json.load(f)['examples'] 74 | f.close() 75 | 76 | for example in examples: 77 | directory = "GAOKAO-Bench/Data/Objective_Questions" 78 | 79 | keyword = example['keyword'] 80 | question_type = example['type'] 81 | zero_shot_prompt_text = example['prefix_prompt'] 82 | print('Building data for keyword:', keyword) 83 | 84 | filepath = os.path.join(directory, f"{keyword}.json") 85 | with open(filepath, "r") as f: 86 | data = json.load(f) 87 | f.close() 88 | 89 | data = data['example'] 90 | example_num = len(data) 91 | 92 | for i in tqdm.tqdm(range(example_num)): 93 | if question_type in ["single_choice", "five_out_of_seven", "multi_question_choice", "multi_choice"]: 94 | index = data[i]['index'] 95 | question = data[i]['question'].strip() + '\n' 96 | year = data[i]['year'] 97 | category = data[i]['category'] 98 | score = data[i]['score'] 99 | standard_answer = data[i]['answer'] 100 | answer_length = len(standard_answer) 101 | analysis = data[i]['analysis'] 102 | prompt = zero_shot_prompt_text 103 | 104 | current_test_dict = { 105 | 'index': index, 106 | 'type': question_type, 107 | 'year': year, 108 | 'category': category, 109 | 'score': score, 110 | 'question': question, 111 | 'standard_answer': standard_answer, 112 | 'analysis': analysis, 113 | 'prompt': prompt 114 | } 115 | tests_all.append(current_test_dict) 116 | elif question_type in ["subjective", "cloze"]: 117 | raise NotImplementedError('subjective and cloze question types are not supported') 118 | elif question_type == 'correction': 119 | raise NotImplementedError('correction question type is not supported') 120 | # now init LLM 121 | cache_config = FlexibleQuantizedCacheConfig(nbits_key=args.k_bits, nbits_value=args.v_bits, residual_length=args.residual_length, q_group_size=args.group_size, 122 | asym=args.asym, axis_key=args.axis_key, axis_value=args.axis_value, device=device, compute_dtype=get_dtype(args.dtype), 123 | per_layer_quant=args.per_layer_quant, per_layer_config_path=args.per_layer_config_path) 124 | 125 | # tests_all = tests_all[:3] 126 | print('Running tests') 127 | results = [] 128 | # for test in tqdm.tqdm(tests): 129 | if args.limit != -1: 130 | tests_all = tests_all[:args.limit] 131 | idx, correct = 0, 0 132 | for test in tqdm.tqdm(tests_all): 133 | past_key_values = FlexibleVanillaQuantizedCache(cache_config=cache_config) if args.quantizer == 'Vanilla' else FlexibleHQQQuantizedCache(cache_config=cache_config) 134 | prompt = test['prompt'] + test['question'] 135 | inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device) 136 | output = model.generate(inputs, past_key_values=past_key_values, use_cache=True, max_new_tokens=256, pad_token_id=None, eos_token_id=None) 137 | model_completion = tokenizer.decode(output[0].tolist()[inputs.shape[1]:], skip_special_tokens=True) 138 | test['model_completion'] = model_completion 139 | model_answer = bench_function.extract_choice_answer(model_completion, test['type'], len(test['standard_answer'])) 140 | test['model_answer'] = model_answer 141 | test['is_correct'] = model_answer == test['standard_answer'] 142 | results.append(test) 143 | idx += 1 144 | if test['is_correct']: 145 | correct += 1 146 | if idx % 50 == 0: 147 | print(f"Num of total question: {idx}, Correct num: {correct}, Accuracy: {float(correct)/idx}") 148 | if idx % 50 == 1: 149 | print('promot:', prompt) 150 | print('===') 151 | print('model output:', model_answer) 152 | print('===') 153 | print('standard answer:', test['standard_answer']) 154 | print('===') 155 | print('is correct:', test['is_correct']) 156 | print('====================') 157 | 158 | print(f"Num of total question: {idx}, Correct num: {correct}, Accuracy: {float(correct)/idx}") 159 | filename_out = f"GAOKAO-Bench_{args.model_name.replace('/', '_')}_Q_{args.quantizer}_k{args.k_bits}_v{args.v_bits}_r{args.residual_length}_g{args.group_size}.json" 160 | with open(filename_out, 'w') as f: 161 | json.dump(results, f) 162 | f.close() -------------------------------------------------------------------------------- /helper_scripts/sh_gen_presets.py: -------------------------------------------------------------------------------- 1 | model_args_template_pertoken = "pretrained={},nbits_key={},nbits_value={},residual_length=0,q_group_size=-1,axis_key=0,axis_value=0,trust_remote_code=True,dtype=bfloat16,force_quant=False,quantilizer=vanilla" 2 | model_args_template_pertoken_perlayer = "pretrained={},nbits_key=-1,nbits_value=-1,residual_length=0,q_group_size=-1,axis_key=0,axis_value=0,trust_remote_code=True,dtype=bfloat16,force_quant=False,quantilizer=vanilla,per_layer_quant=True,per_layer_config_path={}" 3 | # per_layer_config_path is yaml file path 4 | 5 | model_args_template_kivi = "pretrained={},nbits_key={},nbits_value={},residual_length=32,q_group_size=32,axis_key=1,axis_value=0,trust_remote_code=True,dtype=bfloat16,force_quant=False,quantilizer=vanilla" 6 | model_args_template_kivi_perlayer = "pretrained={},nbits_key=-1,nbits_value=-1,residual_length=32,q_group_size=32,axis_key=1,axis_value=0,trust_remote_code=True,dtype=bfloat16,force_quant=False,quantilizer=vanilla,per_layer_quant=True,per_layer_config_path={}" 7 | 8 | 9 | model_args_template_bf16 = "pretrained={},trust_remote_code=True,dtype=bfloat16" 10 | 11 | command_non_fewshot_template = '''accelerate launch -m lm_eval --model hf-quant \\ 12 | --model_args {} \\ 13 | --tasks {} \\ 14 | --batch_size 1 \\ 15 | --confirm_run_unsafe_code \\ 16 | --output_path lmeval_results/{} \\ 17 | | tee {}.log''' 18 | 19 | command_fewshot_template = '''accelerate launch -m lm_eval --model hf-quant \\ 20 | --model_args {} \\ 21 | --tasks {} \\ 22 | --batch_size 1 \\ 23 | --num_fewshot {} \\ 24 | --output_path lmeval_results/{} \\ 25 | | tee {}.log''' 26 | 27 | command_fewshot_as_multiturn = '''accelerate launch -m lm_eval --model hf-quant \\ 28 | --model_args {} \\ 29 | --tasks {} \\ 30 | --batch_size 1 \\ 31 | --num_fewshot {} \\ 32 | --fewshot_as_multiturn \\ 33 | --apply_chat_template \\ 34 | --output_path lmeval_results/{} \\ 35 | | tee {}.log''' 36 | 37 | TASKS = [ 38 | # { 39 | # 'filename': 'leaderboard_musr', 40 | # 'tasks': ['leaderboard_musr'], 41 | # 'nshots': [0], 42 | # }, 43 | # { 44 | # 'filename': 'gpqa_extended', 45 | # 'tasks': ['gpqa_extended_n_shot', 'gpqa_extended_generative_n_shot'], 46 | # 'nshots': [5, 10, 20], 47 | # }, 48 | # gpqa_extended gets OOM on RTX 4090 49 | { 50 | 'filename': 'gsm8k', 51 | 'tasks': ['gsm8k'], 52 | 'nshots': [4, 8, 16], 53 | }, 54 | { 55 | 'filename': 'gsm8k_multiturn', 56 | 'tasks': ['gsm8k'], 57 | 'nshots': [4, 8, 16], 58 | 'fewshot_as_multiturn': True, 59 | }, 60 | # { 61 | # 'filename': 'humaneval', 62 | # 'tasks': ['humaneval'], 63 | # 'nshots': [-1], 64 | # } 65 | ] 66 | 67 | STANDARD_KV_CONFIG = ['kv8', 'k8v4', 'k4v8', 'kv4', 'k4v2', 'kv2'] 68 | 69 | def get_calibration_filepath(model: str, quant_scheme: str = 'pertoken'): 70 | model_name = model.split('/')[-1] 71 | path = './calibration_presets' 72 | import os 73 | if not os.path.exists(path): 74 | return [] 75 | files = os.listdir(path) 76 | files = [f for f in files if model_name in f] 77 | # filename like: modelname_KVTuner{4/6}_{id}.yaml 78 | if quant_scheme != 'pertoken': 79 | files = [f for f in files if quant_scheme in f] 80 | ret = [] 81 | for f in files: 82 | full_path = os.path.join(path, f) 83 | fid = '_'.join(f.split('_')[-2:]).replace('.yaml', '') 84 | print(f'Found calibration file {f} with id {fid}') 85 | ret.append((full_path, fid)) 86 | return ret 87 | 88 | def extract_kv_config(config_str: str): 89 | if len(config_str) == 3: 90 | return int(config_str[2]), int(config_str[2]) 91 | return int(config_str[1]), int(config_str[3]) 92 | 93 | import argparse 94 | 95 | parser = argparse.ArgumentParser() 96 | parser.add_argument('--models', type=str, required=True) 97 | parser.add_argument('--filename', type=str, required=False, default='run.sh') 98 | parser.add_argument('--quant_scheme', type=str, required=False, default='pertoken', choices=['pertoken', 'kivi']) 99 | parser.add_argument('--baseline_only', action='store_true', required=False, default=False) 100 | parser.add_argument('--kvturner_only', action='store_true', required=False, default=False) 101 | 102 | 103 | args = parser.parse_args() 104 | 105 | models = args.models.split(',') 106 | out_filename = args.filename 107 | quant_scheme = args.quant_scheme 108 | model_args_template = model_args_template_pertoken if quant_scheme == 'pertoken' else model_args_template_kivi 109 | model_args_template_perlayer = model_args_template_pertoken_perlayer if quant_scheme == 'pertoken' else model_args_template_kivi_perlayer 110 | 111 | # naming scheme: {model}_pertoken_{task_filename}_{nshot}_{kv_config or bf16 or id}. 112 | 113 | def get_filename(model, task_filename, nshots, kvconfig_or_bf16_or_id): 114 | return f'{model}_{quant_scheme}_{task_filename}_{nshots}_{kvconfig_or_bf16_or_id}' 115 | 116 | tot_commands = 0 117 | tot_time = 0 118 | with open(out_filename, 'w+') as f: 119 | f.write("export NCCL_IB_DISABLE=1\nexport NCCL_P2P_DISABLE=1\nexport HF_ALLOW_CODE_EVAL=1\nexport TRANSFORMERS_CACHE=./models_storage\n\n") 120 | for model in models: 121 | calibration_files = get_calibration_filepath(model, quant_scheme) 122 | filename_model = model.replace('/', '_') + f'_{quant_scheme}' 123 | # first, run bf16 124 | if not args.kvturner_only and args.quant_scheme == 'pertoken': 125 | f.write(f'# ======== {model} bf16 ========\n') 126 | for task_preset in TASKS: 127 | for nshot in task_preset['nshots']: 128 | filename = get_filename(filename_model, task_preset['filename'], nshot, 'bf16') 129 | model_arg = model_args_template_bf16.format(model) 130 | if task_preset.get('fewshot_as_multiturn', False): 131 | command = command_fewshot_as_multiturn.format(model_arg, ','.join(task_preset['tasks']), nshot, filename, filename) 132 | elif nshot != -1: 133 | command = command_fewshot_template.format(model_arg, ','.join(task_preset['tasks']), nshot, filename, filename) 134 | else: 135 | command = command_non_fewshot_template.format(model_arg, ','.join(task_preset['tasks']), filename, filename) 136 | command = command.replace('hf-quant', 'hf') 137 | f.write(command) 138 | tot_commands += 1 139 | tot_time += 25 if 'gsm8k' in task_preset['filename'] else 10 140 | f.write('\n') 141 | f.write('\n') 142 | f.write('\n\n\n') 143 | if not args.baseline_only: 144 | # then, run kv configs 145 | f.write(f'# ======== {model} kv calibration ========\n') 146 | for (calibration_file, calibration_file_id) in calibration_files: 147 | for task_preset in TASKS: 148 | for nshot in task_preset['nshots']: 149 | filename = get_filename(filename_model, task_preset['filename'], nshot, calibration_file_id) 150 | model_arg = model_args_template_perlayer.format(model, calibration_file) 151 | if task_preset.get('fewshot_as_multiturn', False): 152 | command = command_fewshot_as_multiturn.format(model_arg, ','.join(task_preset['tasks']), nshot, filename, filename) 153 | elif nshot != -1: 154 | command = command_fewshot_template.format(model_arg, ','.join(task_preset['tasks']), nshot, filename, filename) 155 | else: 156 | command = command_non_fewshot_template.format(model_arg, ','.join(task_preset['tasks']), filename, filename) 157 | f.write(command) 158 | tot_commands += 1 159 | tot_time += 25 if 'gsm8k' in task_preset['filename'] else 10 160 | f.write('\n') 161 | f.write('\n') 162 | f.write('\n\n\n') 163 | if not args.kvturner_only: 164 | # standard kv configs 165 | f.write(f'# ======== {model} standard kv configs ========\n') 166 | for kv_config in STANDARD_KV_CONFIG: 167 | for task_preset in TASKS: 168 | for nshot in task_preset['nshots']: 169 | nbits_key, nbits_value = extract_kv_config(kv_config) 170 | filename = get_filename(filename_model, task_preset['filename'], nshot, kv_config) 171 | model_arg = model_args_template.format(model, nbits_key, nbits_value) 172 | if task_preset.get('fewshot_as_multiturn', False): 173 | command = command_fewshot_as_multiturn.format(model_arg, ','.join(task_preset['tasks']), nshot, filename, filename) 174 | elif nshot != -1: 175 | command = command_fewshot_template.format(model_arg, ','.join(task_preset['tasks']), nshot, filename, filename) 176 | else: 177 | command = command_non_fewshot_template.format(model_arg, ','.join(task_preset['tasks']), filename, filename) 178 | f.write(command) 179 | tot_commands += 1 180 | tot_time += 25 if 'gsm8k' in task_preset['filename'] else 10 181 | f.write('\n') 182 | f.write('\n') 183 | 184 | import os 185 | os.system('chmod +x {}'.format(out_filename)) 186 | 187 | print(f'Generated {tot_commands} commands in {out_filename}.') 188 | print(f'Estimated total running time (on dual RTX 4090): {tot_time} minutes. aka {tot_time/60} hours.') 189 | 190 | 191 | 192 | -------------------------------------------------------------------------------- /benckmarks/test_gaokaobench.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import tqdm 4 | import argparse 5 | from flexible_quant.flexible_quantized_cache import FlexibleQuantizedCacheConfig, FlexibleHQQQuantizedCache, FlexibleVanillaQuantizedCache 6 | from transformers import AutoTokenizer, AutoModelForCausalLM 7 | from transformers import LlamaConfig, AutoTokenizer, LlamaForCausalLM 8 | import torch 9 | import random 10 | 11 | import importlib 12 | bench_function = importlib.import_module("GAOKAO-Bench.Bench.bench_function") 13 | 14 | 15 | # For reproducibility 16 | random.seed(0) 17 | torch.manual_seed(0) 18 | CACHE_DIR = "./models_storage" 19 | 20 | TEMPLATE_KV_QUANT_CONFIG = [ 21 | {'nbits_key': 8, 'nbits_value': 8}, 22 | {'nbits_key': 8, 'nbits_value': 4}, 23 | {'nbits_key': 4, 'nbits_value': 4}, 24 | {'nbits_key': 4, 'nbits_value': 2}, 25 | {'nbits_key': 2, 'nbits_value': 2}, 26 | ] 27 | 28 | LLAMA3_IMPORTANT_LAYERS = [0, 3, 5, 7, 12, 15, 22, 26, 30, 31] 29 | LLAMA3_MEDIUM_LAYERS = [6, 8, 9, 10, 11, 13, 14, 25, 27, 28, 29] 30 | 31 | QWEN_IMPORTANT_LAYERS = [0, 18, 20, 27, 29, 35] 32 | QWEN_MEDIUM_LAYERS = [3, 4, 5] 33 | 34 | global_args = {} 35 | 36 | def get_dtype(str): 37 | if str == "bfloat16": 38 | return torch.bfloat16 39 | elif str == "float16": 40 | return torch.float16 41 | elif str == "float32": 42 | return torch.float32 43 | else: 44 | raise ValueError(f"Unsupported dtype {str}") 45 | 46 | def parse_args(args=None): 47 | parser = argparse.ArgumentParser() 48 | # parser.add_argument('--model_name', type=str, default="meta-llama/Llama-2-7b-hf") 49 | parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-3B-Instruct-AWQ") 50 | parser.add_argument('--front_filename', type=str, default="front_profiles_qwen2.txt") 51 | # parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-7B-Instruct") 52 | # parser.add_argument('--model_name', type=str, default="meta-llama/Meta-Llama-3-8B-Instruct") 53 | parser.add_argument('--residual_length', type=int, default=32) 54 | parser.add_argument('--group_size', type=int, default=32) 55 | parser.add_argument('--asym', type=bool, default=True) 56 | # in Vanilla, 0 for per-token, 1 for per-channel, we have to use per-channel there as residual_length is 0 57 | parser.add_argument('--axis_key', type=int, default=1) 58 | parser.add_argument('--axis_value', type=int, default=0) 59 | # parser.add_argument('--limit', type=int, default=200) 60 | # parser.add_argument('--num_fewshots', type=int, default=0) 61 | parser.add_argument('--device', type=str, default="cuda:0") 62 | return parser.parse_args(args) 63 | 64 | # results = lm_eval.simple_evaluate( 65 | # model='hf-quant', 66 | # model_args={ 67 | # 'pretrained': model_name, 68 | # 'nbits_key': -1, 69 | # 'nbits_value': -1, 70 | # 'residual_length': residual_length, 71 | # 'q_group_size': group_size, 72 | # 'asym': asym, 73 | # 'axis_key': axis_key, 74 | # 'axis_value': axis_value, 75 | # 'dtype': torch.bfloat16, 76 | # 'force_quant': True, 77 | # 'per_layer_quant': True, 78 | # 'per_layer_config': per_layer_config, 79 | # 'quantilizer': 'vanilla', 80 | # }, 81 | # tasks=["gsm8k"], 82 | # num_fewshot=num_fewshots, 83 | # limit=limit, 84 | # device=device 85 | # ) 86 | # else: 87 | # results = lm_eval.simple_evaluate( 88 | # model='hf-quant', 89 | # model_args={ 90 | # 'pretrained': model_name, 91 | # 'nbits_key': -1, 92 | # 'nbits_value': -1, 93 | # 'residual_length': residual_length, 94 | # 'q_group_size': group_size, 95 | # 'asym': asym, 96 | # 'axis_key': axis_key, 97 | # 'axis_value': axis_value, 98 | # 'dtype': torch.bfloat16, 99 | # 'force_quant': True, 100 | # 'per_layer_quant': True, 101 | # 'per_layer_config': per_layer_config, 102 | # 'quantilizer': 'vanilla', 103 | # }, 104 | # tasks=["gsm8k"], 105 | # num_fewshot=num_fewshots, 106 | # device=device 107 | # ) 108 | # print(results['results']['gsm8k']['exact_match,flexible-extract']) 109 | # return float(results['results']['gsm8k']['exact_match,flexible-extract']) 110 | 111 | def run_gaokaobench(residual_length: int, group_size: int, asym: bool, axis_key: int, axis_value: int, per_layer_config: dict, model_name: str, device: str, dtype: str): 112 | device = torch.device(args.device) 113 | 114 | model = AutoModelForCausalLM.from_pretrained(args.model_name, cache_dir=CACHE_DIR, torch_dtype=get_dtype(dtype)).to(device) 115 | tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False, trust_remote_code=True) 116 | 117 | with open("GAOKAO-Bench/Bench/Obj_Prompt.json", "r") as f: 118 | examples = json.load(f)['examples'] 119 | f.close() 120 | 121 | tests_all = [] 122 | 123 | for example in examples: 124 | directory = "GAOKAO-Bench/Data/Objective_Questions" 125 | 126 | keyword = example['keyword'] 127 | question_type = example['type'] 128 | zero_shot_prompt_text = example['prefix_prompt'] 129 | print('Building data for keyword:', keyword) 130 | 131 | filepath = os.path.join(directory, f"{keyword}.json") 132 | with open(filepath, "r") as f: 133 | data = json.load(f) 134 | f.close() 135 | 136 | data = data['example'] 137 | example_num = len(data) 138 | 139 | 140 | for i in tqdm.tqdm(range(example_num)): 141 | if question_type in ["single_choice", "five_out_of_seven", "multi_question_choice", "multi_choice"]: 142 | index = data[i]['index'] 143 | question = data[i]['question'].strip() + '\n' 144 | year = data[i]['year'] 145 | category = data[i]['category'] 146 | score = data[i]['score'] 147 | standard_answer = data[i]['answer'] 148 | answer_length = len(standard_answer) 149 | analysis = data[i]['analysis'] 150 | prompt = zero_shot_prompt_text 151 | 152 | current_test_dict = { 153 | 'index': index, 154 | 'type': question_type, 155 | 'year': year, 156 | 'category': category, 157 | 'score': score, 158 | 'question': question, 159 | 'standard_answer': standard_answer, 160 | 'analysis': analysis, 161 | 'prompt': prompt 162 | } 163 | tests_all.append(current_test_dict) 164 | elif question_type in ["subjective", "cloze"]: 165 | raise NotImplementedError('subjective and cloze question types are not supported') 166 | elif question_type == 'correction': 167 | raise NotImplementedError('correction question type is not supported') 168 | # now init LLM 169 | cache_config = FlexibleQuantizedCacheConfig(residual_length=residual_length, q_group_size=group_size, 170 | asym=asym, axis_key=axis_key, axis_value=axis_value, device=device, compute_dtype=get_dtype(dtype), 171 | per_layer_quant=True, per_layer_config=per_layer_config) 172 | 173 | # tests_all = tests_all[:3] 174 | print('Running tests') 175 | results = [] 176 | # for test in tqdm.tqdm(tests): 177 | idx, correct = 0, 0 178 | for test in tqdm.tqdm(tests_all): 179 | past_key_values = FlexibleVanillaQuantizedCache(cache_config=cache_config) 180 | prompt = test['prompt'] + test['question'] 181 | inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device) 182 | output = model.generate(inputs, past_key_values=past_key_values, use_cache=True, max_new_tokens=256, pad_token_id=None, eos_token_id=None) 183 | model_completion = tokenizer.decode(output[0].tolist()[inputs.shape[1]:], skip_special_tokens=True) 184 | test['model_completion'] = model_completion 185 | model_answer = bench_function.extract_choice_answer(model_completion, test['type'], len(test['standard_answer'])) 186 | test['model_answer'] = model_answer 187 | test['is_correct'] = model_answer == test['standard_answer'] 188 | results.append(test) 189 | idx += 1 190 | if test['is_correct']: 191 | correct += 1 192 | # if idx % 50 == 0: 193 | # print(f"Num of total question: {idx}, Correct num: {correct}, Accuracy: {float(correct)/idx}") 194 | # if idx % 50 == 1: 195 | # print('promot:', prompt) 196 | # print('===') 197 | # print('model output:', model_answer) 198 | # print('===') 199 | # print('standard answer:', test['standard_answer']) 200 | # print('===') 201 | # print('is correct:', test['is_correct']) 202 | # print('====================') 203 | 204 | print(f"Num of total question: {idx}, Correct num: {correct}, Accuracy: {float(correct)/idx}") 205 | 206 | return float(correct)/idx 207 | # filename_out = f"GAOKAO-Bench_{args.model_name.replace('/', '_')}_Q_{args.quantizer}_k{args.k_bits}_v{args.v_bits}_r{args.residual_length}_g{args.group_size}.json" 208 | # with open(filename_out, 'w') as f: 209 | # json.dump(results, f) 210 | # f.close() 211 | 212 | def build_per_layer_config(model: str, config_high: int, config_medium: int, config_low: int): 213 | important_layers = [] 214 | if 'llama' in model.lower(): 215 | important_layers = LLAMA3_IMPORTANT_LAYERS 216 | medium_layers = LLAMA3_MEDIUM_LAYERS 217 | if 'qwen' in model.lower(): 218 | important_layers = QWEN_IMPORTANT_LAYERS 219 | medium_layers = QWEN_MEDIUM_LAYERS 220 | per_layer_config = {} 221 | tot_scale = 0 222 | tot_layers = 32 if 'llama' in model.lower() else 36 223 | for layer in range(0, tot_layers): 224 | if layer in important_layers: 225 | per_layer_config[layer] = TEMPLATE_KV_QUANT_CONFIG[config_high] 226 | elif layer in medium_layers: 227 | per_layer_config[layer] = TEMPLATE_KV_QUANT_CONFIG[config_medium] 228 | else: 229 | per_layer_config[layer] = TEMPLATE_KV_QUANT_CONFIG[config_low] 230 | tot_scale += per_layer_config[layer]['nbits_key'] + per_layer_config[layer]['nbits_value'] 231 | tot_scale /= tot_layers * 2 232 | return per_layer_config, tot_scale 233 | 234 | 235 | if __name__ == "__main__": 236 | args = parse_args() 237 | model_name = args.model_name 238 | 239 | global_args['model_name'] = model_name 240 | global_args['residual_length'] = args.residual_length 241 | global_args['group_size'] = args.group_size 242 | global_args['asym'] = args.asym 243 | global_args['axis_key'] = args.axis_key 244 | global_args['axis_value'] = args.axis_value 245 | # global_args['limit'] = args.limit 246 | # global_args['num_fewshots'] = args.num_fewshots 247 | global_args['device'] = args.device 248 | 249 | print(global_args) 250 | 251 | with open(args.front_filename, "r") as f: 252 | front_profiles = f.readlines() 253 | f.close() 254 | 255 | for front_profile in front_profiles: 256 | # 2.0 0.01061410159211524 4, 4, 4 257 | print('running profile:', front_profile) 258 | profile_high, profile_medium, profile_low = map(int, front_profile.replace(',', '').split(' ')[-3:]) 259 | print('profile:', profile_high, profile_medium, profile_low) 260 | per_layer_config, tot_scale = build_per_layer_config(args.model_name, profile_high, profile_medium, profile_low) 261 | accuracy = run_gaokaobench(global_args['residual_length'], global_args['group_size'], global_args['asym'], global_args['axis_key'], global_args['axis_value'], per_layer_config, global_args['model_name'], global_args['device'], 'bfloat16') 262 | print(f"Profile: {profile_high}, {profile_medium}, {profile_low}, Accuracy: {accuracy}, Scale: {tot_scale}") 263 | print("") 264 | # valid_params = [] 265 | # for profile_high in range(5): 266 | # for profile_medium in range(profile_high, 5): 267 | # for profile_low in range(profile_medium, 5): 268 | # valid_params.append((profile_high, profile_medium, profile_low)) 269 | 270 | # for profile_high, profile_medium, profile_low in valid_params: 271 | # per_layer_config, tot_scale = build_per_layer_config(args.model_name, profile_high, profile_medium, profile_low) 272 | # accuracy = run_gsm8k(global_args['residual_length'], global_args['group_size'], global_args['asym'], global_args['axis_key'], global_args['axis_value'], per_layer_config, global_args['model_name'], global_args['num_fewshots'], global_args['limit'], global_args['device']) 273 | # print(f"Profile: {profile_high}, {profile_medium}, {profile_low}, Accuracy: {accuracy}, Scale: {tot_scale}") 274 | # print("") -------------------------------------------------------------------------------- /search_optuna_adaptive.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | import torch 4 | import random 5 | import argparse 6 | import torch 7 | import optuna 8 | import lm_eval 9 | from lm_eval.models.huggingface_quant import HFLM_Quant 10 | import logging 11 | import sys 12 | 13 | # For reproducibility 14 | random.seed(0) 15 | torch.manual_seed(0) 16 | CACHE_DIR = "./models_storage" 17 | 18 | 19 | 20 | LAYER_GROUPING_CONFIG = { 21 | 'Meta-Llama-3.1-8B-Instruct': { 22 | 'per-token-asym': [[0], [1, 2, 3, 4, 7, 13, 18, 25, 27, 31], [5, 6, 12, 21, 26, 28], [8, 9, 10, 11, 14, 15, 16, 17, 20, 30], [19, 22], [23, 24, 29]], 23 | 'per-channel-asym': [[0], [1, 2, 3, 7, 29, 31], [4, 25, 27], [5, 21, 23, 24], [6, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20, 22, 26, 28, 30], [13, 17]], 24 | }, 25 | 'Mistral-7B-Instruct-v0.3': { 26 | 'per-token-asym': [[0], [1, 2], [3, 4, 23, 31], [5, 6], [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 24, 25, 26, 27, 28, 29, 30]], 27 | 'per-channel-asym': [[0, 1, 31], [2, 3, 4], [6, 27, 29], [7, 8, 10, 18], [9, 14], [5, 21, 22, 23, 24, 25, 26, 28, 30], [11, 12, 13, 15, 17, 19, 20], [16]], 28 | }, 29 | 'Qwen2.5-3B-Instruct': { 30 | 'per-token-asym': [[0], [1, 3, 4, 5, 6, 8, 9, 12, 13, 15, 20], [2, 14, 23, 35], [7, 11, 16, 25, 28, 32], [10, 19, 24, 26, 33], [17, 30, 31, 34], [21, 22], [18, 27, 29]], 31 | 'per-channel-asym': [[0, 1], [2, 4], [34, 35], [3, 6, 11, 13, 23], [5, 7, 25, 32, 33], [8, 16, 18, 21, 22, 24, 26, 27, 30], [9, 10, 14, 15, 17, 19, 20, 29, 31], [12, 28]], 32 | }, 33 | 'Qwen2.5-7B-Instruct': { 34 | 'per-token-asym': [[0], [1, 2, 4, 5, 25], [6, 19], [7, 10, 11, 15, 23], [8, 24], [9, 12, 16, 17, 18, 21, 22, 26], [14, 20], [3, 13, 27]], 35 | 'per-channel-asym': [[0, 2], [1, 3], [4, 5, 12, 22, 23, 24, 25], [7, 9, 10, 13, 14, 16, 18, 19, 20, 21, 27], [8, 26], [11, 15, 17], [6]], 36 | }, 37 | 'Qwen2.5-14B-Instruct': { 38 | 'per-token-asym': [[0, 1, 2, 6, 11, 12, 19, 23, 24, 25, 41], [3, 4, 5, 8], [7, 10, 15], [9, 13, 14, 31, 38, 39], [16, 17, 18, 20, 21, 27, 28, 30, 32, 33, 34, 35, 36, 37, 40, 42, 43, 44, 46, 47], [22, 26, 29, 45]], 39 | 'per-channel-asym': [[0, 2], [1, 3, 4], [5, 6, 8, 9, 12], [7, 10, 13, 15, 16, 17, 18, 19, 20, 21, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 35, 36, 37, 38, 44, 45, 46, 47], [11, 25, 41, 42], [14, 39, 40, 43], [22, 34]], 40 | }, 41 | 'Qwen2.5-32B-Instruct': { 42 | 'per-token-asym': [[0, 2, 11, 12, 15, 33, 54, 57], [1, 5, 7, 8, 9, 10, 13, 14, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 55, 56, 58, 59, 60, 61, 62, 63], [3, 4], [6, 16]], 43 | 'per-channel-asym': [[0, 1, 2, 3, 4], [11], [5, 6, 7, 8, 9, 10, 12, 14, 16, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 32], [13, 15, 17, 22, 24, 25, 29, 30, 31, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62], [63]] 44 | } 45 | } 46 | 47 | SPECIAL_LAYERS = { 48 | 'Meta-Llama-3.1-8B-Instruct': { 49 | 'per-token-asym': { 50 | (0,): ['KV8', 'K4V8', 'KV4', 'K4V2', 'KV2'], 51 | }, 52 | 'per-channel-asym': { 53 | (0,): ['KV8', 'K4V8', 'KV4', 'K2V4', 'KV2'], 54 | (1, 2, 3, 7, 29, 31): ['KV8', 'K4V8', 'KV4', 'K4V2', 'KV2'], 55 | }, 56 | }, 57 | 'Mistral-7B-Instruct-v0.3': { 58 | 'per-token-asym': { 59 | (0,): ['KV8', 'K4V8', 'KV4', 'K2V4', 'KV2'], 60 | }, 61 | 'per-channel-asym': { 62 | # (0,): ['KV8', 'K4V8', 'KV4', 'K2V4', 'KV2'], 63 | (0, 1, 31): ['KV8', 'K4V8', 'KV4', 'K2V4', 'KV2'], # fix: grouping 64 | (2, 3, 4, 6, 7, 8, 9, 10, 14, 18, 27, 29): ['KV8', 'K4V8', 'KV4', 'K4V2', 'KV2'], 65 | }, 66 | }, 67 | 'Qwen2.5-3B-Instruct': { 68 | 'per-token-asym': { 69 | (0,): ['KV8', 'K8V4', 'K8V2', 'K4V2', 'KV2'], 70 | (18, 27, 29): ['KV8', 'K8V4', 'K8V2', 'KV4', 'K4V2', 'KV2'], 71 | }, 72 | 'per-channel-asym': { 73 | (0, 1, 2, 4, 34, 35): ['KV8', 'K4V8', 'KV4', 'K2V4', 'KV2'], 74 | (3, 6, 11, 13, 23): ['KV8', 'K4V8', 'KV4', 'K4V2', 'KV2'], 75 | }, 76 | }, 77 | 'Qwen2.5-7B-Instruct': { 78 | 'per-token-asym': { 79 | (0,): ['KV8', 'K8V4', 'K8V2', 'K4V2', 'KV2'], 80 | (3, 13, 27): ['KV8', 'K8V4', 'K8V2', 'KV4', 'K4V2', 'KV2'], 81 | }, 82 | 'per-channel-asym': { 83 | (0, 1, 2, 3): ['KV8', 'K4V8', 'KV4', 'K2V4', 'KV2'], 84 | (6,): ['KV8', 'K4V8', 'KV4', 'K4V2', 'KV2'], 85 | }, 86 | }, 87 | 'Qwen2.5-14B-Instruct': { 88 | 'per-token-asym': { 89 | # no layers listed 90 | }, 91 | 'per-channel-asym': { 92 | (0, 1, 2, 3, 4): ['KV8', 'K4V8', 'KV4', 'K2V4', 'KV2'], 93 | (5, 6, 8, 9, 12): ['KV8', 'K4V8', 'KV4', 'K4V2', 'KV2'], 94 | }, 95 | }, 96 | 'Qwen2.5-32B-Instruct': { 97 | 'per-token-asym': { 98 | # no layers listed 99 | }, 100 | 'per-channel-asym': { 101 | (0, 1, 2, 3, 4, 11): ['KV8', 'K4V8', 'KV4', 'K2V4', 'KV2'], 102 | (5, 6, 7, 8, 9, 10, 12, 14, 16, 18, 19, 20, 21, 23, 26, 27, 28, 32): ['KV8', 'K4V8', 'KV4', 'K4V2', 'KV2'], 103 | (63,): ['KV8', 'K8V4', 'KV4', 'K2V4', 'KV2'], 104 | }, 105 | }, 106 | } 107 | 108 | 109 | TOT_LAYER = { 110 | 'Meta-Llama-3.1-8B-Instruct': 32, 111 | 'Mistral-7B-Instruct-v0.3': 32, 112 | 'Qwen2.5-3B-Instruct': 36, 113 | 'Qwen2.5-7B-Instruct': 28, 114 | 'Qwen2.5-14B-Instruct': 48, 115 | 'Qwen2.5-32B-Instruct': 64, 116 | } 117 | 118 | STANDARD_KV_QUANT_CONFIG = ['KV8', 'K8V4', 'KV4', 'K4V2', 'KV2'] 119 | 120 | global_args = {} 121 | model = None 122 | tokenizer = None 123 | dataset = None 124 | 125 | num_fewshots = None 126 | limit = None 127 | device = None 128 | 129 | quant_scheme = None 130 | max_per_layer_scale = None 131 | 132 | current_layer_grouping = [] 133 | current_special_layers = {} 134 | current_grouping_quant_template = [] 135 | current_tot_layers = -1 136 | debug_constraint = False 137 | 138 | def parse_args(args=None): 139 | parser = argparse.ArgumentParser() 140 | # parser.add_argument('--model_name', type=str, default="meta-llama/Llama-2-7b-hf") 141 | # parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-3B-Instruct-AWQ") 142 | # parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-7B-Instruct") 143 | # parser.add_argument('--model_name', type=str, default="mistralai/Mistral-7B-Instruct-v0.3") 144 | parser.add_argument('--model_name', type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct") 145 | # parser.add_argument('--residual_length', type=int, default=0) 146 | # parser.add_argument('--group_size', type=int, default=-1) 147 | parser.add_argument('--quant_scheme', type=str, default="per-token-asym") # per-token-asym or per-channel-asym 148 | parser.add_argument('--asym', type=bool, default=True) 149 | # in Vanilla, 0 for per-token, 1 for per-channel, we have to use per-channel there as residual_length is 0 150 | parser.add_argument('--axis_key', type=int, default=0) 151 | parser.add_argument('--axis_value', type=int, default=0) 152 | parser.add_argument('--limit', type=int, default=20) 153 | parser.add_argument('--num_fewshots', type=int, default=4) 154 | parser.add_argument('--max_per_layer_scale', type=str, default='8') 155 | parser.add_argument('--n_trials', type=int, default=100) 156 | parser.add_argument('--device', type=str, default="cuda") 157 | parser.add_argument('--debug_constraint', default=False, action='store_true') 158 | return parser.parse_args(args) 159 | 160 | 161 | def parse_quant_config(quant_config: str): 162 | if len(quant_config) == 3: 163 | precision = int(quant_config[2]) 164 | return {'nbits_key': precision, 'nbits_value': precision} 165 | precision_key = int(quant_config[1]) 166 | precision_value = int(quant_config[3]) 167 | return {'nbits_key': precision_key, 'nbits_value': precision_value} 168 | 169 | def prepare_layer_grouping_config(model_name: str, quant_scheme: str): 170 | model_name = model_name.split('/')[-1] 171 | model_name = model_name.replace('-AWQ', '') # Qwen2.5-3B-Instruct-AWQ -> Qwen2.5-3B-Instruct 172 | global current_layer_grouping, current_special_layers, current_grouping_quant_template, current_tot_layers 173 | current_layer_grouping = LAYER_GROUPING_CONFIG[model_name][quant_scheme] 174 | current_special_layers = SPECIAL_LAYERS[model_name][quant_scheme] 175 | current_tot_layers = TOT_LAYER[model_name] 176 | # check if current_special_layers breaks the current_layer_grouping 177 | for group in current_layer_grouping: 178 | group_quant_template = STANDARD_KV_QUANT_CONFIG 179 | for layer in group: 180 | for special_layer in current_special_layers.keys(): 181 | if layer in special_layer: 182 | group_quant_template = current_special_layers[special_layer] 183 | for other_layer in group: 184 | if not other_layer in special_layer: 185 | raise ValueError("Special layer {} breaks the layer grouping for model {}, quant scheme {}".format(special_layer, model_name, quant_scheme)) 186 | if debug_constraint: 187 | group_quant_template = [i for i in group_quant_template if i != 'KV2'] # remove KV2 188 | current_grouping_quant_template.append(group_quant_template) 189 | 190 | def run_gsm8k(per_layer_config: dict, model_name: str, num_fewshots: int, limit: int, device: str): 191 | results = lm_eval.simple_evaluate( 192 | model='hf-quant', 193 | model_args={ 194 | 'pretrained': model_name, 195 | 'nbits_key': -1, 196 | 'nbits_value': -1, 197 | 'residual_length': 32 if quant_scheme == 'per-channel-asym' else 0, 198 | 'q_group_size': 32 if quant_scheme == 'per-channel-asym' else -1, 199 | 'asym': True, 200 | 'axis_key': 1 if quant_scheme == 'per-channel-asym' else 0, 201 | 'axis_value': 0, 202 | 'dtype': torch.bfloat16, 203 | 'force_quant': False, 204 | 'per_layer_quant': True, 205 | 'per_layer_config': per_layer_config, 206 | 'quantilizer': 'vanilla', 207 | 'device_map': 'auto', 208 | 'parallelize': True, 209 | }, 210 | tasks=["gsm8k"], 211 | num_fewshot=num_fewshots, 212 | limit=limit, 213 | # device=device 214 | ) 215 | print(results['results']['gsm8k']['exact_match,flexible-extract']) 216 | return float(results['results']['gsm8k']['exact_match,flexible-extract']) 217 | 218 | 219 | def build_per_layer_config(config_list: int): 220 | per_layer_config = {} 221 | tot_scale = 0 222 | for i, config in enumerate(config_list): 223 | layers = current_layer_grouping[i] 224 | quant_config = parse_quant_config(current_grouping_quant_template[i][config]) 225 | for layer in layers: 226 | per_layer_config[layer] = quant_config 227 | tot_scale += (quant_config['nbits_key'] + quant_config['nbits_value']) * len(layers) 228 | tot_scale /= current_tot_layers * 2 229 | return per_layer_config, tot_scale 230 | 231 | 232 | def objective(trial): 233 | config_list = [] 234 | for i in range(0, len(current_layer_grouping)): 235 | config_current_layer = trial.suggest_int('group_{}'.format(i), 0, len(current_grouping_quant_template[i]) - 1) 236 | config_list.append(config_current_layer) 237 | 238 | per_layer_config, tot_scale = build_per_layer_config(config_list) 239 | 240 | # Constraints which are considered feasible if less than or equal to zero. 241 | 242 | c = tot_scale - max_per_layer_scale 243 | print('c = ', c) 244 | 245 | if not debug_constraint: 246 | trial.set_user_attr('constraints', (c, )) 247 | 248 | accuracy = run_gsm8k(per_layer_config, model, num_fewshots, limit, device) 249 | 250 | c2 = 0.6 - accuracy 251 | 252 | if debug_constraint: 253 | print('c2 = ', c2) 254 | trial.set_user_attr('constraints', (c, c2)) 255 | 256 | 257 | return accuracy, tot_scale 258 | 259 | def constraints(trial): 260 | return trial.user_attrs["constraints"] 261 | 262 | if __name__ == "__main__": 263 | args = parse_args() 264 | model = args.model_name 265 | quant_scheme = args.quant_scheme 266 | max_per_layer_scale = float(args.max_per_layer_scale) 267 | num_fewshots = args.num_fewshots 268 | limit = args.limit 269 | device = args.device 270 | debug_constraint = args.debug_constraint 271 | 272 | 273 | optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout)) 274 | study_name = "OPTUNA_SEARCH_ADAPTIVE_{}_GSM8K_FIRST{}_{}SHOTS_MAXSCALE{}_SCHEME{}".format(model.replace("/", "_"), limit, num_fewshots, max_per_layer_scale, quant_scheme) 275 | storage_name = "sqlite:///{}.db".format(study_name) 276 | sampler = optuna.samplers.NSGAIISampler(constraints_func=constraints) 277 | study = optuna.create_study(directions=["maximize", "minimize"], study_name=study_name, storage=storage_name, sampler=sampler) 278 | 279 | print(args) 280 | print('Preparing layer grouping config...') 281 | prepare_layer_grouping_config(model, quant_scheme) 282 | print('Layer grouping: ', current_layer_grouping) 283 | print('Special layers: ', current_special_layers) 284 | print('Grouping quant template: ', current_grouping_quant_template) 285 | print('Total layers: ', current_tot_layers) 286 | 287 | study.optimize(objective, n_trials=args.n_trials) -------------------------------------------------------------------------------- /config/meta-llama_Meta-Llama-3-8B-Instruct_k8_v4._per_head.yaml: -------------------------------------------------------------------------------- 1 | 0: 2 | 0: 3 | nbits_key: 8 4 | nbits_value: 4 5 | 1: 6 | nbits_key: 8 7 | nbits_value: 4 8 | 2: 9 | nbits_key: 8 10 | nbits_value: 4 11 | 3: 12 | nbits_key: 8 13 | nbits_value: 4 14 | 4: 15 | nbits_key: 8 16 | nbits_value: 4 17 | 5: 18 | nbits_key: 8 19 | nbits_value: 4 20 | 6: 21 | nbits_key: 8 22 | nbits_value: 4 23 | 7: 24 | nbits_key: 8 25 | nbits_value: 4 26 | 1: 27 | 0: 28 | nbits_key: 8 29 | nbits_value: 4 30 | 1: 31 | nbits_key: 8 32 | nbits_value: 4 33 | 2: 34 | nbits_key: 8 35 | nbits_value: 4 36 | 3: 37 | nbits_key: 8 38 | nbits_value: 4 39 | 4: 40 | nbits_key: 8 41 | nbits_value: 4 42 | 5: 43 | nbits_key: 8 44 | nbits_value: 4 45 | 6: 46 | nbits_key: 8 47 | nbits_value: 4 48 | 7: 49 | nbits_key: 8 50 | nbits_value: 4 51 | 2: 52 | 0: 53 | nbits_key: 8 54 | nbits_value: 4 55 | 1: 56 | nbits_key: 8 57 | nbits_value: 4 58 | 2: 59 | nbits_key: 8 60 | nbits_value: 4 61 | 3: 62 | nbits_key: 8 63 | nbits_value: 4 64 | 4: 65 | nbits_key: 8 66 | nbits_value: 4 67 | 5: 68 | nbits_key: 8 69 | nbits_value: 4 70 | 6: 71 | nbits_key: 8 72 | nbits_value: 4 73 | 7: 74 | nbits_key: 8 75 | nbits_value: 4 76 | 3: 77 | 0: 78 | nbits_key: 8 79 | nbits_value: 4 80 | 1: 81 | nbits_key: 8 82 | nbits_value: 4 83 | 2: 84 | nbits_key: 8 85 | nbits_value: 4 86 | 3: 87 | nbits_key: 8 88 | nbits_value: 4 89 | 4: 90 | nbits_key: 8 91 | nbits_value: 4 92 | 5: 93 | nbits_key: 8 94 | nbits_value: 4 95 | 6: 96 | nbits_key: 8 97 | nbits_value: 4 98 | 7: 99 | nbits_key: 8 100 | nbits_value: 4 101 | 4: 102 | 0: 103 | nbits_key: 8 104 | nbits_value: 4 105 | 1: 106 | nbits_key: 8 107 | nbits_value: 4 108 | 2: 109 | nbits_key: 8 110 | nbits_value: 4 111 | 3: 112 | nbits_key: 8 113 | nbits_value: 4 114 | 4: 115 | nbits_key: 8 116 | nbits_value: 4 117 | 5: 118 | nbits_key: 8 119 | nbits_value: 4 120 | 6: 121 | nbits_key: 8 122 | nbits_value: 4 123 | 7: 124 | nbits_key: 8 125 | nbits_value: 4 126 | 5: 127 | 0: 128 | nbits_key: 8 129 | nbits_value: 4 130 | 1: 131 | nbits_key: 8 132 | nbits_value: 4 133 | 2: 134 | nbits_key: 8 135 | nbits_value: 4 136 | 3: 137 | nbits_key: 8 138 | nbits_value: 4 139 | 4: 140 | nbits_key: 8 141 | nbits_value: 4 142 | 5: 143 | nbits_key: 8 144 | nbits_value: 4 145 | 6: 146 | nbits_key: 8 147 | nbits_value: 4 148 | 7: 149 | nbits_key: 8 150 | nbits_value: 4 151 | 6: 152 | 0: 153 | nbits_key: 8 154 | nbits_value: 4 155 | 1: 156 | nbits_key: 8 157 | nbits_value: 4 158 | 2: 159 | nbits_key: 8 160 | nbits_value: 4 161 | 3: 162 | nbits_key: 8 163 | nbits_value: 4 164 | 4: 165 | nbits_key: 8 166 | nbits_value: 4 167 | 5: 168 | nbits_key: 8 169 | nbits_value: 4 170 | 6: 171 | nbits_key: 8 172 | nbits_value: 4 173 | 7: 174 | nbits_key: 8 175 | nbits_value: 4 176 | 7: 177 | 0: 178 | nbits_key: 8 179 | nbits_value: 4 180 | 1: 181 | nbits_key: 8 182 | nbits_value: 4 183 | 2: 184 | nbits_key: 8 185 | nbits_value: 4 186 | 3: 187 | nbits_key: 8 188 | nbits_value: 4 189 | 4: 190 | nbits_key: 8 191 | nbits_value: 4 192 | 5: 193 | nbits_key: 8 194 | nbits_value: 4 195 | 6: 196 | nbits_key: 8 197 | nbits_value: 4 198 | 7: 199 | nbits_key: 8 200 | nbits_value: 4 201 | 8: 202 | 0: 203 | nbits_key: 8 204 | nbits_value: 4 205 | 1: 206 | nbits_key: 8 207 | nbits_value: 4 208 | 2: 209 | nbits_key: 8 210 | nbits_value: 4 211 | 3: 212 | nbits_key: 8 213 | nbits_value: 4 214 | 4: 215 | nbits_key: 8 216 | nbits_value: 4 217 | 5: 218 | nbits_key: 8 219 | nbits_value: 4 220 | 6: 221 | nbits_key: 8 222 | nbits_value: 4 223 | 7: 224 | nbits_key: 8 225 | nbits_value: 4 226 | 9: 227 | 0: 228 | nbits_key: 8 229 | nbits_value: 4 230 | 1: 231 | nbits_key: 8 232 | nbits_value: 4 233 | 2: 234 | nbits_key: 8 235 | nbits_value: 4 236 | 3: 237 | nbits_key: 8 238 | nbits_value: 4 239 | 4: 240 | nbits_key: 8 241 | nbits_value: 4 242 | 5: 243 | nbits_key: 8 244 | nbits_value: 4 245 | 6: 246 | nbits_key: 8 247 | nbits_value: 4 248 | 7: 249 | nbits_key: 8 250 | nbits_value: 4 251 | 10: 252 | 0: 253 | nbits_key: 8 254 | nbits_value: 4 255 | 1: 256 | nbits_key: 8 257 | nbits_value: 4 258 | 2: 259 | nbits_key: 8 260 | nbits_value: 4 261 | 3: 262 | nbits_key: 8 263 | nbits_value: 4 264 | 4: 265 | nbits_key: 8 266 | nbits_value: 4 267 | 5: 268 | nbits_key: 8 269 | nbits_value: 4 270 | 6: 271 | nbits_key: 8 272 | nbits_value: 4 273 | 7: 274 | nbits_key: 8 275 | nbits_value: 4 276 | 11: 277 | 0: 278 | nbits_key: 8 279 | nbits_value: 4 280 | 1: 281 | nbits_key: 8 282 | nbits_value: 4 283 | 2: 284 | nbits_key: 8 285 | nbits_value: 4 286 | 3: 287 | nbits_key: 8 288 | nbits_value: 4 289 | 4: 290 | nbits_key: 8 291 | nbits_value: 4 292 | 5: 293 | nbits_key: 8 294 | nbits_value: 4 295 | 6: 296 | nbits_key: 8 297 | nbits_value: 4 298 | 7: 299 | nbits_key: 8 300 | nbits_value: 4 301 | 12: 302 | 0: 303 | nbits_key: 8 304 | nbits_value: 4 305 | 1: 306 | nbits_key: 8 307 | nbits_value: 4 308 | 2: 309 | nbits_key: 8 310 | nbits_value: 4 311 | 3: 312 | nbits_key: 8 313 | nbits_value: 4 314 | 4: 315 | nbits_key: 8 316 | nbits_value: 4 317 | 5: 318 | nbits_key: 8 319 | nbits_value: 4 320 | 6: 321 | nbits_key: 8 322 | nbits_value: 4 323 | 7: 324 | nbits_key: 8 325 | nbits_value: 4 326 | 13: 327 | 0: 328 | nbits_key: 8 329 | nbits_value: 4 330 | 1: 331 | nbits_key: 8 332 | nbits_value: 4 333 | 2: 334 | nbits_key: 8 335 | nbits_value: 4 336 | 3: 337 | nbits_key: 8 338 | nbits_value: 4 339 | 4: 340 | nbits_key: 8 341 | nbits_value: 4 342 | 5: 343 | nbits_key: 8 344 | nbits_value: 4 345 | 6: 346 | nbits_key: 8 347 | nbits_value: 4 348 | 7: 349 | nbits_key: 8 350 | nbits_value: 4 351 | 14: 352 | 0: 353 | nbits_key: 8 354 | nbits_value: 4 355 | 1: 356 | nbits_key: 8 357 | nbits_value: 4 358 | 2: 359 | nbits_key: 8 360 | nbits_value: 4 361 | 3: 362 | nbits_key: 8 363 | nbits_value: 4 364 | 4: 365 | nbits_key: 8 366 | nbits_value: 4 367 | 5: 368 | nbits_key: 8 369 | nbits_value: 4 370 | 6: 371 | nbits_key: 8 372 | nbits_value: 4 373 | 7: 374 | nbits_key: 8 375 | nbits_value: 4 376 | 15: 377 | 0: 378 | nbits_key: 8 379 | nbits_value: 4 380 | 1: 381 | nbits_key: 8 382 | nbits_value: 4 383 | 2: 384 | nbits_key: 8 385 | nbits_value: 4 386 | 3: 387 | nbits_key: 8 388 | nbits_value: 4 389 | 4: 390 | nbits_key: 8 391 | nbits_value: 4 392 | 5: 393 | nbits_key: 8 394 | nbits_value: 4 395 | 6: 396 | nbits_key: 8 397 | nbits_value: 4 398 | 7: 399 | nbits_key: 8 400 | nbits_value: 4 401 | 16: 402 | 0: 403 | nbits_key: 8 404 | nbits_value: 4 405 | 1: 406 | nbits_key: 8 407 | nbits_value: 4 408 | 2: 409 | nbits_key: 8 410 | nbits_value: 4 411 | 3: 412 | nbits_key: 8 413 | nbits_value: 4 414 | 4: 415 | nbits_key: 8 416 | nbits_value: 4 417 | 5: 418 | nbits_key: 8 419 | nbits_value: 4 420 | 6: 421 | nbits_key: 8 422 | nbits_value: 4 423 | 7: 424 | nbits_key: 8 425 | nbits_value: 4 426 | 17: 427 | 0: 428 | nbits_key: 8 429 | nbits_value: 4 430 | 1: 431 | nbits_key: 8 432 | nbits_value: 4 433 | 2: 434 | nbits_key: 8 435 | nbits_value: 4 436 | 3: 437 | nbits_key: 8 438 | nbits_value: 4 439 | 4: 440 | nbits_key: 8 441 | nbits_value: 4 442 | 5: 443 | nbits_key: 8 444 | nbits_value: 4 445 | 6: 446 | nbits_key: 8 447 | nbits_value: 4 448 | 7: 449 | nbits_key: 8 450 | nbits_value: 4 451 | 18: 452 | 0: 453 | nbits_key: 8 454 | nbits_value: 4 455 | 1: 456 | nbits_key: 8 457 | nbits_value: 4 458 | 2: 459 | nbits_key: 8 460 | nbits_value: 4 461 | 3: 462 | nbits_key: 8 463 | nbits_value: 4 464 | 4: 465 | nbits_key: 8 466 | nbits_value: 4 467 | 5: 468 | nbits_key: 8 469 | nbits_value: 4 470 | 6: 471 | nbits_key: 8 472 | nbits_value: 4 473 | 7: 474 | nbits_key: 8 475 | nbits_value: 4 476 | 19: 477 | 0: 478 | nbits_key: 8 479 | nbits_value: 4 480 | 1: 481 | nbits_key: 8 482 | nbits_value: 4 483 | 2: 484 | nbits_key: 8 485 | nbits_value: 4 486 | 3: 487 | nbits_key: 8 488 | nbits_value: 4 489 | 4: 490 | nbits_key: 8 491 | nbits_value: 4 492 | 5: 493 | nbits_key: 8 494 | nbits_value: 4 495 | 6: 496 | nbits_key: 8 497 | nbits_value: 4 498 | 7: 499 | nbits_key: 8 500 | nbits_value: 4 501 | 20: 502 | 0: 503 | nbits_key: 8 504 | nbits_value: 4 505 | 1: 506 | nbits_key: 8 507 | nbits_value: 4 508 | 2: 509 | nbits_key: 8 510 | nbits_value: 4 511 | 3: 512 | nbits_key: 8 513 | nbits_value: 4 514 | 4: 515 | nbits_key: 8 516 | nbits_value: 4 517 | 5: 518 | nbits_key: 8 519 | nbits_value: 4 520 | 6: 521 | nbits_key: 8 522 | nbits_value: 4 523 | 7: 524 | nbits_key: 8 525 | nbits_value: 4 526 | 21: 527 | 0: 528 | nbits_key: 8 529 | nbits_value: 4 530 | 1: 531 | nbits_key: 8 532 | nbits_value: 4 533 | 2: 534 | nbits_key: 8 535 | nbits_value: 4 536 | 3: 537 | nbits_key: 8 538 | nbits_value: 4 539 | 4: 540 | nbits_key: 8 541 | nbits_value: 4 542 | 5: 543 | nbits_key: 8 544 | nbits_value: 4 545 | 6: 546 | nbits_key: 8 547 | nbits_value: 4 548 | 7: 549 | nbits_key: 8 550 | nbits_value: 4 551 | 22: 552 | 0: 553 | nbits_key: 8 554 | nbits_value: 4 555 | 1: 556 | nbits_key: 8 557 | nbits_value: 4 558 | 2: 559 | nbits_key: 8 560 | nbits_value: 4 561 | 3: 562 | nbits_key: 8 563 | nbits_value: 4 564 | 4: 565 | nbits_key: 8 566 | nbits_value: 4 567 | 5: 568 | nbits_key: 8 569 | nbits_value: 4 570 | 6: 571 | nbits_key: 8 572 | nbits_value: 4 573 | 7: 574 | nbits_key: 8 575 | nbits_value: 4 576 | 23: 577 | 0: 578 | nbits_key: 8 579 | nbits_value: 4 580 | 1: 581 | nbits_key: 8 582 | nbits_value: 4 583 | 2: 584 | nbits_key: 8 585 | nbits_value: 4 586 | 3: 587 | nbits_key: 8 588 | nbits_value: 4 589 | 4: 590 | nbits_key: 8 591 | nbits_value: 4 592 | 5: 593 | nbits_key: 8 594 | nbits_value: 4 595 | 6: 596 | nbits_key: 8 597 | nbits_value: 4 598 | 7: 599 | nbits_key: 8 600 | nbits_value: 4 601 | 24: 602 | 0: 603 | nbits_key: 8 604 | nbits_value: 4 605 | 1: 606 | nbits_key: 8 607 | nbits_value: 4 608 | 2: 609 | nbits_key: 8 610 | nbits_value: 4 611 | 3: 612 | nbits_key: 8 613 | nbits_value: 4 614 | 4: 615 | nbits_key: 8 616 | nbits_value: 4 617 | 5: 618 | nbits_key: 8 619 | nbits_value: 4 620 | 6: 621 | nbits_key: 8 622 | nbits_value: 4 623 | 7: 624 | nbits_key: 8 625 | nbits_value: 4 626 | 25: 627 | 0: 628 | nbits_key: 8 629 | nbits_value: 4 630 | 1: 631 | nbits_key: 8 632 | nbits_value: 4 633 | 2: 634 | nbits_key: 8 635 | nbits_value: 4 636 | 3: 637 | nbits_key: 8 638 | nbits_value: 4 639 | 4: 640 | nbits_key: 8 641 | nbits_value: 4 642 | 5: 643 | nbits_key: 8 644 | nbits_value: 4 645 | 6: 646 | nbits_key: 8 647 | nbits_value: 4 648 | 7: 649 | nbits_key: 8 650 | nbits_value: 4 651 | 26: 652 | 0: 653 | nbits_key: 8 654 | nbits_value: 4 655 | 1: 656 | nbits_key: 8 657 | nbits_value: 4 658 | 2: 659 | nbits_key: 8 660 | nbits_value: 4 661 | 3: 662 | nbits_key: 8 663 | nbits_value: 4 664 | 4: 665 | nbits_key: 8 666 | nbits_value: 4 667 | 5: 668 | nbits_key: 8 669 | nbits_value: 4 670 | 6: 671 | nbits_key: 8 672 | nbits_value: 4 673 | 7: 674 | nbits_key: 8 675 | nbits_value: 4 676 | 27: 677 | 0: 678 | nbits_key: 8 679 | nbits_value: 4 680 | 1: 681 | nbits_key: 8 682 | nbits_value: 4 683 | 2: 684 | nbits_key: 8 685 | nbits_value: 4 686 | 3: 687 | nbits_key: 8 688 | nbits_value: 4 689 | 4: 690 | nbits_key: 8 691 | nbits_value: 4 692 | 5: 693 | nbits_key: 8 694 | nbits_value: 4 695 | 6: 696 | nbits_key: 8 697 | nbits_value: 4 698 | 7: 699 | nbits_key: 8 700 | nbits_value: 4 701 | 28: 702 | 0: 703 | nbits_key: 8 704 | nbits_value: 4 705 | 1: 706 | nbits_key: 8 707 | nbits_value: 4 708 | 2: 709 | nbits_key: 8 710 | nbits_value: 4 711 | 3: 712 | nbits_key: 8 713 | nbits_value: 4 714 | 4: 715 | nbits_key: 8 716 | nbits_value: 4 717 | 5: 718 | nbits_key: 8 719 | nbits_value: 4 720 | 6: 721 | nbits_key: 8 722 | nbits_value: 4 723 | 7: 724 | nbits_key: 8 725 | nbits_value: 4 726 | 29: 727 | 0: 728 | nbits_key: 8 729 | nbits_value: 4 730 | 1: 731 | nbits_key: 8 732 | nbits_value: 4 733 | 2: 734 | nbits_key: 8 735 | nbits_value: 4 736 | 3: 737 | nbits_key: 8 738 | nbits_value: 4 739 | 4: 740 | nbits_key: 8 741 | nbits_value: 4 742 | 5: 743 | nbits_key: 8 744 | nbits_value: 4 745 | 6: 746 | nbits_key: 8 747 | nbits_value: 4 748 | 7: 749 | nbits_key: 8 750 | nbits_value: 4 751 | 30: 752 | 0: 753 | nbits_key: 8 754 | nbits_value: 4 755 | 1: 756 | nbits_key: 8 757 | nbits_value: 4 758 | 2: 759 | nbits_key: 8 760 | nbits_value: 4 761 | 3: 762 | nbits_key: 8 763 | nbits_value: 4 764 | 4: 765 | nbits_key: 8 766 | nbits_value: 4 767 | 5: 768 | nbits_key: 8 769 | nbits_value: 4 770 | 6: 771 | nbits_key: 8 772 | nbits_value: 4 773 | 7: 774 | nbits_key: 8 775 | nbits_value: 4 776 | 31: 777 | 0: 778 | nbits_key: 8 779 | nbits_value: 4 780 | 1: 781 | nbits_key: 8 782 | nbits_value: 4 783 | 2: 784 | nbits_key: 8 785 | nbits_value: 4 786 | 3: 787 | nbits_key: 8 788 | nbits_value: 4 789 | 4: 790 | nbits_key: 8 791 | nbits_value: 4 792 | 5: 793 | nbits_key: 8 794 | nbits_value: 4 795 | 6: 796 | nbits_key: 8 797 | nbits_value: 4 798 | 7: 799 | nbits_key: 8 800 | nbits_value: 4 -------------------------------------------------------------------------------- /benckmarks/pred_longbench.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datasets import load_dataset 3 | import torch 4 | import json 5 | from transformers import AutoTokenizer, LlamaTokenizer, LlamaForCausalLM, AutoModelForCausalLM 6 | from tqdm import tqdm 7 | import numpy as np 8 | import random 9 | import argparse 10 | import torch.distributed as dist 11 | import torch.multiprocessing as mp 12 | from flexible_quant.flexible_quantized_cache import FlexibleQuantizedCacheConfig, FlexibleHQQQuantizedCache, FlexibleVanillaQuantizedCache 13 | 14 | dataset2prompt = { 15 | "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", 16 | "qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", 17 | "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 18 | "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:", 19 | "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 20 | "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 21 | "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 22 | "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:", 23 | "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:", 24 | "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:", 25 | "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:", 26 | "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:", 27 | "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}", 28 | "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}", 29 | "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}", 30 | "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}", 31 | "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ", 32 | "passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ", 33 | "passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:", 34 | "lcc": "Please complete the code given below. \n{context}Next line of code:\n", 35 | "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n" 36 | } 37 | 38 | dataset2maxlen = { 39 | "narrativeqa": 128, 40 | "qasper": 128, 41 | "multifieldqa_en": 64, 42 | "multifieldqa_zh": 64, 43 | "hotpotqa": 32, 44 | "2wikimqa": 32, 45 | "musique": 32, 46 | "dureader": 128, 47 | "gov_report": 512, 48 | "qmsum": 512, 49 | "multi_news": 512, 50 | "vcsum": 512, 51 | "trec": 64, 52 | "triviaqa": 32, 53 | "samsum": 128, 54 | "lsht": 64, 55 | "passage_count": 32, 56 | "passage_retrieval_en": 32, 57 | "passage_retrieval_zh": 32, 58 | "lcc": 64, 59 | "repobench-p": 64 60 | } 61 | 62 | CACHE_DIR = "./models_storage" 63 | 64 | def parse_args(args=None): 65 | parser = argparse.ArgumentParser() 66 | # parser.add_argument('--model', type=str, default=None, choices=["llama2-7b-chat-4k", "longchat-v1.5-7b-32k", "xgen-7b-8k", "internlm-7b-8k", "chatglm2-6b", "chatglm2-6b-32k", "chatglm3-6b-32k", "vicuna-v1.5-7b-16k"]) 67 | # parser.add_argument('--model', type=str, default="Qwen/Qwen2.5-3B-Instruct-AWQ") 68 | parser.add_argument('--model', type=str, default="Qwen/Qwen2.5-7B-Instruct") 69 | parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E") 70 | parser.add_argument('--k_bits', type=int, default=8) 71 | parser.add_argument('--v_bits', type=int, default=8) 72 | parser.add_argument('--residual_length', type=int, default=128) 73 | parser.add_argument('--group_size', type=int, default=64) 74 | parser.add_argument('--asym', type=bool, default=True) 75 | # in HQQ, 0 for per-channel, 1 for per-token 76 | parser.add_argument('--axis_key', type=int, default=0) 77 | parser.add_argument('--axis_value', type=int, default=1) 78 | parser.add_argument('--max_length', type=int, default=7500) 79 | return parser.parse_args(args) 80 | 81 | # This is the customized building prompt for chat models 82 | def build_chat(tokenizer, prompt, model_name): 83 | # if "chatglm3" in model_name: 84 | # prompt = tokenizer.build_chat_input(prompt) 85 | # elif "chatglm" in model_name: 86 | # prompt = tokenizer.build_prompt(prompt) 87 | # elif "longchat" in model_name or "vicuna" in model_name: 88 | # from fastchat.model import get_conversation_template 89 | # conv = get_conversation_template("vicuna") 90 | # conv.append_message(conv.roles[0], prompt) 91 | # conv.append_message(conv.roles[1], None) 92 | # prompt = conv.get_prompt() 93 | # elif "llama2" in model_name: 94 | # prompt = f"[INST]{prompt}[/INST]" 95 | # elif "xgen" in model_name: 96 | # header = ( 97 | # "A chat between a curious human and an artificial intelligence assistant. " 98 | # "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n" 99 | # ) 100 | # prompt = header + f" ### Human: {prompt}\n###" 101 | # elif "internlm" in model_name: 102 | # prompt = f"<|User|>:{prompt}\n<|Bot|>:" 103 | return prompt 104 | 105 | def post_process(response, model_name): 106 | # if "xgen" in model_name: 107 | # response = response.strip().replace("Assistant:", "") 108 | # elif "internlm" in model_name: 109 | # response = response.split("")[0] 110 | return response 111 | 112 | def get_pred(rank, world_size, data, max_length, max_gen, prompt_format, dataset, device, model_name, out_path, cache_config): 113 | device = torch.device(f'cuda:{rank}') 114 | model, tokenizer = load_model_and_tokenizer(model_name, device) 115 | for json_obj in tqdm(data): 116 | prompt = prompt_format.format(**json_obj) 117 | # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions) 118 | tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0] 119 | # if "chatglm3" in model_name: 120 | # tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt", add_special_tokens=False).input_ids[0] 121 | # if len(tokenized_prompt) > max_length: 122 | # half = int(max_length/2) 123 | # prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) 124 | # if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks 125 | # prompt = build_chat(tokenizer, prompt, model_name) 126 | # if "chatglm3" in model_name: 127 | # if dataset in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: 128 | # input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device) 129 | # else: 130 | # input = prompt.to(device) 131 | # else: 132 | input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device) 133 | context_length = input.input_ids.shape[-1] 134 | past_key_values = FlexibleVanillaQuantizedCache(cache_config=cache_config) 135 | if dataset == "samsum": # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue 136 | output = model.generate( 137 | **input, 138 | max_new_tokens=max_gen, 139 | # num_beams=1, 140 | # do_sample=False, 141 | # temperature=1.0, 142 | min_length=context_length+1, 143 | eos_token_id=[tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[-1]], 144 | past_key_values=past_key_values, 145 | use_cache=True 146 | )[0] 147 | else: 148 | output = model.generate( 149 | **input, 150 | max_new_tokens=max_gen, 151 | # num_beams=1, 152 | # do_sample=False, 153 | # temperature=1.0, 154 | past_key_values=past_key_values, 155 | use_cache=True 156 | )[0] 157 | pred = tokenizer.decode(output[context_length:], skip_special_tokens=True) 158 | pred = post_process(pred, model_name) 159 | with open(out_path, "a", encoding="utf-8") as f: 160 | json.dump({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], "length": json_obj["length"]}, f, ensure_ascii=False) 161 | f.write('\n') 162 | # dist.destroy_process_group() 163 | 164 | def seed_everything(seed): 165 | torch.manual_seed(seed) 166 | torch.cuda.manual_seed(seed) 167 | np.random.seed(seed) 168 | random.seed(seed) 169 | torch.backends.cudnn.benchmark = False 170 | torch.backends.cudnn.deterministic = True 171 | torch.cuda.manual_seed_all(seed) 172 | 173 | def load_model_and_tokenizer(model_name, device): 174 | # if "chatglm" in model_name or "internlm" in model_name or "xgen" in model_name: 175 | # tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) 176 | # model = AutoModelForCausalLM.from_pretrained(path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device) 177 | # elif "llama2" in model_name: 178 | # replace_llama_attn_with_flash_attn() 179 | # tokenizer = LlamaTokenizer.from_pretrained(path) 180 | # model = LlamaForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to(device) 181 | # elif "longchat" in model_name or "vicuna" in model_name: 182 | # from fastchat.model import load_model 183 | # replace_llama_attn_with_flash_attn() 184 | # model, _ = load_model( 185 | # path, 186 | # device='cpu', 187 | # num_gpus=0, 188 | # load_8bit=False, 189 | # cpu_offloading=False, 190 | # debug=False, 191 | # ) 192 | # model = model.to(device) 193 | # model = model.bfloat16() 194 | # tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False) 195 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=True) 196 | model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=CACHE_DIR, torch_dtype=torch.float16).to(device) 197 | model = model.eval() 198 | return model, tokenizer 199 | 200 | if __name__ == '__main__': 201 | seed_everything(42) 202 | args = parse_args() 203 | cache_config = FlexibleQuantizedCacheConfig(nbits_key=args.k_bits, nbits_value=args.v_bits, residual_length=args.residual_length, q_group_size=args.group_size, 204 | asym=args.asym, axis_key=args.axis_key, axis_value=args.axis_value, device='cuda') 205 | world_size = torch.cuda.device_count() 206 | mp.set_start_method('spawn', force=True) 207 | 208 | # model2path = json.load(open("config/model2path.json", "r")) 209 | # model2maxlen = json.load(open("config/model2maxlen.json", "r")) 210 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 211 | model_name = args.model 212 | # define your model 213 | # max_length = model2maxlen[model_name] 214 | max_length = args.max_length 215 | if args.e: 216 | datasets = ["qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report", "multi_news", \ 217 | "trec", "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"] 218 | else: 219 | datasets = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa", "2wikimqa", "musique", \ 220 | "dureader", "gov_report", "qmsum", "multi_news", "vcsum", "trec", "triviaqa", "samsum", "lsht", \ 221 | "passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"] 222 | # we design specific prompt format and max generation length for each task, feel free to modify them to optimize model output 223 | # dataset2prompt = json.load(open("config/dataset2prompt.json", "r")) 224 | # dataset2maxlen = json.load(open("config/dataset2maxlen.json", "r")) 225 | # predict on each dataset 226 | for dataset in datasets: 227 | if args.e: 228 | data = load_dataset('THUDM/LongBench', f"{dataset}_e", split='test') 229 | if not os.path.exists("pred_e"): 230 | os.makedirs("pred_e") 231 | if not os.path.exists(f"pred_e/{model_name}"): 232 | os.makedirs(f"pred_e/{model_name}") 233 | out_path = f"pred_e/{model_name}/{dataset}.jsonl" 234 | else: 235 | data = load_dataset('THUDM/LongBench', dataset, split='test') 236 | if not os.path.exists("pred"): 237 | os.makedirs("pred") 238 | if not os.path.exists(f"pred/{model_name}"): 239 | os.makedirs(f"pred/{model_name}") 240 | out_path = f"pred/{model_name}/{dataset}.jsonl" 241 | prompt_format = dataset2prompt[dataset] 242 | max_gen = dataset2maxlen[dataset] 243 | data_all = [data_sample for data_sample in data] 244 | data_subsets = [data_all[i::world_size] for i in range(world_size)] 245 | assert world_size == 1 246 | get_pred(0, world_size, data_subsets[0], max_length, max_gen, prompt_format, dataset, device, model_name, out_path, cache_config) 247 | # processes = [] 248 | # for rank in range(world_size): 249 | # p = mp.Process(target=get_pred, args=(rank, world_size, data_subsets[rank], max_length, \ 250 | # max_gen, prompt_format, dataset, device, model_name, model2path, out_path)) 251 | # p.start() 252 | # processes.append(p) 253 | # for p in processes: 254 | # p.join() 255 | --------------------------------------------------------------------------------