├── .dockerignore ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── any_precision ├── __init__.py ├── analyzer │ ├── __init__.py │ ├── analyzer.py │ ├── architectures │ │ ├── llama.yaml │ │ ├── mistral.yaml │ │ ├── opt.yaml │ │ └── phi.yaml │ └── utils.py ├── evaluate │ ├── __init__.py │ ├── eval.py │ └── helpers │ │ ├── __init__.py │ │ ├── dataloader.py │ │ └── utils.py ├── modules │ ├── AnyPrecisionForCausalLM.py │ ├── AnyPrecisionLinear.py │ ├── __init__.py │ └── kernels │ │ ├── dequant.cuh │ │ ├── main.cu │ │ ├── matmul.cuh │ │ └── setup.py └── quantization │ ├── __init__.py │ ├── config.py │ ├── datautils.py │ ├── dense_and_sparse.py │ ├── gradients.py │ ├── main.py │ ├── pack.py │ ├── quantize.py │ └── utils.py ├── demo.py ├── evaluate.sh ├── fake_pack.py ├── figures ├── incremental_upscaling.png └── software_engine.png ├── quantize.py ├── requirements.txt ├── run_eval.py └── setup.py /.dockerignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SNU-ARC/any-precision-llm/baa9d0272510d6342fef562b5200c3f9454f9070/.dockerignore -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | # Cache folder 163 | cache/ 164 | 165 | # JSON files 166 | *.json 167 | 168 | # Evaluate directories 169 | any_precision/evaluate/input_tokens_cache/ 170 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.1.1-devel-ubuntu22.04 2 | MAINTAINER SangLyul Cho 3 | 4 | RUN apt update && apt upgrade -y && apt install -y pip ninja-build vim 5 | 6 | WORKDIR /home/any-precision-llm 7 | COPY ./requirements.txt ./requirements.txt 8 | RUN pip install -r requirements.txt 9 | 10 | WORKDIR any_precision/modules/kernels 11 | COPY any_precision/modules/kernels . 12 | # TODO support sm60 and sm61 (Pascal) 13 | RUN TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0" python3 setup.py sdist bdist_wheel 14 | RUN pip install dist/any_precision_ext-0.0.0-cp310-cp310-linux_x86_64.whl 15 | 16 | WORKDIR ../../.. 17 | COPY . . 18 | RUN python3 setup.py sdist bdist_wheel 19 | RUN pip install dist/any_precision_llm-0.0.0-py3-none-any.whl 20 | RUN mv demo.py quantize.py .. 21 | 22 | WORKDIR .. 23 | RUN rm -rf any-precision-llm 24 | 25 | CMD /bin/bash 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 SNU ARC Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Any-Precision LLM: Low-Cost Deployment of Multiple, Different-Sized LLMs [[Paper](http://www.arxiv.org/pdf/2402.10517)] 2 | 3 | ## Overview 4 | 5 | Any-precision LLM is a memory-efficient and cost-effective solution for deployment of multiple, different sized LLMs. 6 | Specifically, any-precision LLM redues the memory cost of deplying multiple, different-sized LLMs by overlaying LLMs 7 | quantized to varying bit-widths, such as 3, 4, ..., n bits, into a memory footprint comparable to a single n-bit LLM. 8 | This includes a lightweight any-precision quantization technique for LLMs called incremental upscaling, and a 9 | specialized software engine for efficient serving, which is equipped with a custom CUDA kernel supporting bitplane-based 10 | weight representation. 11 |
12 |

13 | 14 | 15 | Illustration of incremental upscaling scheme 16 |

17 | 18 |

19 | 20 |

21 | 22 | Illustration of specialized software engine for any-precision LLM 23 |
24 | 25 | ## Prerequisites 26 | 27 | - Python 3.11 28 | - CUDA Toolkit 12 or higher 29 | - gcc-9 or higher 30 | 31 | ## Setup 32 | 33 | 1. Clone this repository. 34 | 2. Install the required Python packages. 35 | We recommend using Python 3.11 with either Virtualenv or Conda to manage the dependencies. 36 | ```bash 37 | pip install -r requirements.txt 38 | ``` 39 | 40 | 3. Install the Any-Precision CUDA kernels. 41 | 42 | ```bash 43 | cd any_precision/modules/kernels 44 | pip install . 45 | ``` 46 | You will need to have CUDA toolkit 12 or higher installed on your machine to build the CUDA kernels. You will also 47 | need gcc-9 or higher. 48 | 49 | ## Quantization 50 | 51 | To quantize a model to any-precision, run `quantize.py` with relevant arguments: 52 | 53 | ```bash 54 | python quantize.py [options] 55 | ``` 56 | 57 | The quantization process will automatically download the model from the Hugging Face repository if it is not already 58 | present. 59 | 60 | The quantization process consists of the following steps: 61 | 62 | 1. Gradient calculation: The gradients of the model are calculated, for a sensitivity metric for quantization. 63 | 2. Upscaling: The model is quantized to the seed precision, and then incrementally upscaled to the parent precision. 64 | 3. Packing: The quantized model is packed into a single file, with the weights stored in bitplanes. 65 | 66 | All steps are performed automatically. 67 | 68 | A `cache` directory is created to store both intermediate files and the final quantized model. 69 | The final quantized model can be found under `cache/packed`. 70 | 71 | **Notes**: 72 | - We have tested the quantization on Llama, OPT and Mistral models. Other models can be automatically quantized, 73 | but we do not guarantee the correctness of the quantization. 74 | - You need free space of approximately 2x the fp16 model size in RAM, VRAM, and disk to quantize the model. 75 | For Llama 2 7B, this is approximately 28 GB. 76 | - The quantization process utilizes both CPU and GPU resources. As the main quantization process is CPU-bound, 77 | use a machine with powerful multicore performance for faster quantization. However, our quantization pipeline is 78 | highly optimized, and Llama 2 7B can be quantized in under 2 minutes on an i9-13900K machine. 79 | On lower-end machines this will be a few times slower. 80 | 81 | ### Required Argument 82 | 83 | - `model`: The model to quantize. Should be a Hugging Face repository, or a local path to a model. 84 | e.g. `meta-llama/Llama-2-7b-chat-hf`. 85 | 86 | ### Optional Arguments 87 | 88 | - `--seed_precision`: The seed model precision. Default is 3. 89 | - `--parent_precision`: The parent model precision. Default is 8. 90 | (The final quantized model will support precisions in the range `[seed_precision, parent_precision]`.) 91 | - `--mode`: The mode of operation. Valid options are `gradients`, `quantize`, and `pack`, which are the three steps of 92 | quantization. 93 | The quantization process will abort after the specified operation. Default is `pack`, which completes the entire 94 | quantization process. 95 | - `--yaml_path`: The path to the architecture config yaml file. When not specified, the model architecture is inferred 96 | automatically. 97 | - `--cache_dir`: The directory to cache results in. Default is `cache`. 98 | - `--dataset`: The dataset to use for gradient calculation. Default is `c4`. 99 | - `--seq_len`: The sequence length to use for gradient calculation. Default is 512. 100 | - `--num_examples`: The number of examples to use for gradient calculation. Default is 100. 101 | - `--cpu_count`: The number of cores to use for gradient calculation. Default is the number of available cores. 102 | - `--random_state`: The random state to use for reproducibility. When not set, the random state is not fixed. Use an 103 | integer. 104 | 105 | ### Flags 106 | 107 | - `--overwrite_gradients`: Whether to overwrite the gradients stored to disk. When not set, the gradients are 108 | loaded from disk if available. 109 | - `--overwrite_quantize`: Whether to overwrite the parent model stored to disk. When not set, the parent model is 110 | loaded from disk if available. 111 | - `--overwrite_pack`: Whether to overwrite the packed model stored to disk. When not set, the packed model will not be 112 | overwritten if it already exists. 113 | 114 | ### Example Command 115 | 116 | ```bash 117 | python quantize.py meta-llama/Llama-2-7b-chat-hf 118 | ``` 119 | 120 | ## Supporting Different Models 121 | 122 | Our quantization pipeline is designed to support any model that can be loaded from the Hugging Face repository, 123 | by automatically detecting the linear layers for quantization. However, for better reproducibility, 124 | we have preconfigured YAML files for the Llama, OPT, and Mistral models under the 125 | path `any_precision_llm/analyzer/architectures`. 126 | 127 | This is what the YAML file for the Llama model looks like: 128 | 129 | ```yaml 130 | architecture: "LlamaForCausalLM" 131 | arch_config: 132 | model_name: "model" 133 | layers_name: "layers" 134 | module_names: 135 | - "self_attn.q_proj" 136 | - "self_attn.k_proj" 137 | - "self_attn.v_proj" 138 | - "self_attn.o_proj" 139 | - "mlp.gate_proj" 140 | - "mlp.up_proj" 141 | - "mlp.down_proj" 142 | ``` 143 | 144 | The `architecture` field specifies what model class the YAML file is for. 145 | Under `arch_config`, the `model_name` field specifies the name of the model attribute that contains the model. 146 | The `layers_name` field specifies the name of the attribute that contains the layers of the model, under the model. 147 | The `module_names` field specifies the names of the linear layers to quantize. 148 | 149 | For models with no corresponding YAML file under `any_precision_llm/analyzer/architectures`, the quantization process 150 | will attempt to automatically detect the linear layers to quantize. This is not guaranteed to work, and may result in 151 | incorrect quantization. 152 | 153 | If you wish to experiment with different model types, you can create your own YAML file under the same directory, 154 | or specify the `yaml_path` argument to point to your custom YAML file, in which case the `architecture` field is 155 | unnecessary. 156 | 157 | ## Inference 158 | 159 | To use the quantized model for inference, you can use the `AnyPrecisionForCausalLM` class from the `any_precision` 160 | module. 161 | Below is an example of how to load the quantized model and perform inference: 162 | 163 | ```python 164 | from any_precision import AnyPrecisionForCausalLM 165 | from transformers import AutoTokenizer 166 | 167 | quanitized_model_path = "./cache/packed/anyprec-(Llama-2-7b-chat-hf)-w8_orig3-gc1-c4_s100_blk512" 168 | 169 | model = AnyPrecisionForCausalLM.from_quantized( 170 | quanitized_model_path, 171 | trust_remote_code=True, 172 | fuse_layers=False, # Will be supported in the future 173 | precisions=[3, 4, 5, 6, 7, 8] # You may optionally specify a subset of supported precisions to load 174 | ) 175 | tokenizer = AutoTokenizer.from_pretrained(quanitized_model_path) 176 | 177 | # The following methods are supported by the quantized model, and work similarly to the original huggingface model. 178 | # Note that you can specify the precision to use for each method. 179 | model.forward(..., precision=3) 180 | model.generate(..., precision=5) 181 | 182 | # Or you can specify the precision like this: 183 | model.set_precision(8) 184 | ``` 185 | 186 | ## Demo 187 | 188 | We have provided a demo script to showcase the dynamic inference latency of the quantized model. 189 | To run the demo, execute the following command: 190 | 191 | ```bash 192 | python demo.py 193 | ``` 194 | 195 | Note that the demo script requires the quantized `Llama-2-7b-chat-hf` model to be present in the cache directory. 196 | Other models can be used by changing the `model_path` and `original_model_path` variables in the script. 197 | 198 | The demo script will load the quantized model, and perform inference on a custom prompt, using specified precisions. 199 | Include 16 to measure the latency of the original model in fp16. 200 | The latency at each precision will be measured and displayed. 201 | 202 | The demo will look like this when run properly: 203 | 204 | ![AnyPrec Latency Demo](https://github.com/SNU-ARC/any-precision-llm/assets/48833786/75a42bea-979a-489f-aee8-89697c55411a) 205 | 206 | Please note that this demo serves as a proof-of-concept. 207 | Further optimizations in the inference pipeline are needed to achieve the best performance of our engine. 208 | 209 | ## Evaluation 210 | 211 | We have provided a script to evaluate the perplexity of the quantized model on various datasets. 212 | To run the evaluation, execute the following command: 213 | 214 | ```bash 215 | python run_eval.py 216 | ``` 217 | 218 | You can specify which datasets to evaluate on by changing the `datasets` variable within the script. 219 | The results are automatically appended to a JSON file. (By default, `results.json`) 220 | 221 | ## Citation 222 | 223 | Please cite our paper if you find our work useful: 224 | 225 | ``` 226 | @inproceedings{park2024anyprecision, 227 | title={Any-Precision LLM: Low-Cost Deployment of Multiple, Different-Sized LLMs}, 228 | author={Yeonhong Park and Jake Hyun and SangLyul Cho and Bonggeun Sim and Jae W. Lee}, 229 | year={2024}, 230 | booktitle={Proceedings of the 41st International Conference on Machine Learning}, 231 | } 232 | ``` 233 | 234 | -------------------------------------------------------------------------------- /any_precision/__init__.py: -------------------------------------------------------------------------------- 1 | from . import modules 2 | from . import quantization 3 | from .modules import AnyPrecisionForCausalLM 4 | -------------------------------------------------------------------------------- /any_precision/analyzer/__init__.py: -------------------------------------------------------------------------------- 1 | from .analyzer import get_analyzer 2 | -------------------------------------------------------------------------------- /any_precision/analyzer/analyzer.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, PreTrainedModel 2 | import torch 3 | import yaml 4 | import os 5 | import logging 6 | from .utils import load_model, load_tokenizer 7 | 8 | 9 | def get_analyzer(model, yaml_path=None, include_tokenizer=False): 10 | # Load model from string if necessary 11 | model = load_model(model) 12 | 13 | # Anyprecision quantized model 14 | if hasattr(model.config, 'anyprec'): 15 | return ModelAnalyzer.from_arch_config(model, model.config.anyprec['arch_config']) 16 | 17 | # Unspecified model quantization config 18 | if yaml_path is None: 19 | dirpath = os.path.dirname(os.path.realpath(__file__)) 20 | yaml_dir = os.path.join(dirpath, f'./architectures/') 21 | assert len(model.config.architectures) == 1, "Model has multiple architectures" 22 | # Check if there is a yaml file for the model architecture 23 | for file in os.listdir(yaml_dir): 24 | if file.endswith(".yaml"): 25 | with open(os.path.join(yaml_dir, file)) as f: 26 | yaml_contents = yaml.safe_load(f) 27 | if model.config.architectures[0] == yaml_contents['architecture']: 28 | return ModelAnalyzer.from_arch_config(model, yaml_contents['arch_config'], include_tokenizer) 29 | else: 30 | # If no yaml file is found, use AutoQuantConfig 31 | logging.warning((f"Attempting to use AutoArchConfig for architecture:" 32 | f" {model.config.architectures[0]}")) 33 | logging.warning("This may not work as expected!") 34 | return ModelAnalyzer.from_autoconfig(model, include_tokenizer=include_tokenizer) 35 | 36 | # Specified model quantization config 37 | else: 38 | if not os.path.exists(yaml_path): 39 | raise FileNotFoundError(f"Specified yaml file does not exist: {yaml_path}") 40 | with open(yaml_path) as f: 41 | yaml_contents = yaml.safe_load(f) 42 | return ModelAnalyzer.from_arch_config(model, yaml_contents['arch_config'], include_tokenizer=include_tokenizer) 43 | 44 | 45 | class ModelAnalyzer: 46 | """ModelAnalyzer is a class that provides an interface to access relevant model information for quantization. 47 | 48 | This class is intended to work for any model type, and the model-specific information should be passed in the 49 | constructor. Alternatively, you can instantiate from a yaml file using the from_yaml method. 50 | """ 51 | 52 | def __init__(self, model: AutoModelForCausalLM, module_names, model_name, layers_name, include_tokenizer=False): 53 | self.model = model 54 | self.module_names = module_names 55 | self.model_name = model_name 56 | self.layers_name = layers_name 57 | self.config = model.config 58 | self.state_dict = model.state_dict() 59 | self.dropped_original_weights = False 60 | self.num_layers = len(self.get_layers()) 61 | self.tokenizer = None 62 | if include_tokenizer: 63 | self.tokenizer = load_tokenizer(model) 64 | self._model_weights = {} 65 | 66 | @classmethod 67 | def from_arch_config(cls, model: AutoModelForCausalLM, quant_config: dict, include_tokenizer=False): 68 | return cls(model, **quant_config, include_tokenizer=include_tokenizer) 69 | 70 | def get_arch_config(self): 71 | quant_config = { 72 | "module_names": self.module_names, 73 | "model_name": self.model_name, 74 | "layers_name": self.layers_name, 75 | } 76 | return quant_config 77 | 78 | @classmethod 79 | def from_autoconfig(cls, model: AutoModelForCausalLM, include_tokenizer=False): 80 | """Instantiate a ModelAnalyzer from an AutoConfig.""" 81 | auto_config = AutoArchConfig(model) 82 | return cls(model, **auto_config.to_dict(), include_tokenizer=include_tokenizer) 83 | 84 | def get_layers(self): 85 | """Return the layers of the model.""" 86 | if self.dropped_original_weights: 87 | raise ValueError("Original weights have been dropped") 88 | module = self.get_model() 89 | for attrib_name in self.layers_name.split('.'): 90 | module = getattr(module, attrib_name) 91 | return module 92 | 93 | def get_modules(self, layer): 94 | """Return the relevant modules of the layer.""" 95 | modules = {} 96 | for module_name in self.module_names: 97 | module = layer 98 | for attrib_name in module_name.split('.'): 99 | module = getattr(module, attrib_name) 100 | modules[module_name] = module 101 | return modules 102 | 103 | def get_layer_weights(self, layer_idx): 104 | """Return the relevant weights of the model.""" 105 | if self.dropped_original_weights: 106 | raise ValueError("Original weights have been dropped") 107 | if layer_idx in self._model_weights: 108 | return self._model_weights[layer_idx] 109 | layers = self.get_layers() 110 | layer_data = {} 111 | modules = self.get_modules(layers[layer_idx]) 112 | for name, module in modules.items(): 113 | layer_data[name] = module.weight.data.cpu() 114 | self._model_weights[layer_idx] = layer_data 115 | return layer_data 116 | 117 | def get_model(self): 118 | """Return the model.""" 119 | if self.dropped_original_weights: 120 | raise ValueError("Original weights have been dropped") 121 | module = self.model 122 | for attrib_name in self.model_name.split('.'): 123 | module = getattr(module, attrib_name) 124 | return module 125 | 126 | def drop_original_weights(self): 127 | weight_key_prefixes = [f'{self.model_name}.{self.layers_name}.{i}' for i in range(self.num_layers)] 128 | weight_key_postfix = 'weight' 129 | for prefix in weight_key_prefixes: 130 | for module_name in self.module_names: 131 | key = f"{prefix}.{module_name}.{weight_key_postfix}" 132 | self.state_dict.pop(key) 133 | 134 | self.model = None 135 | self._model_weights.clear() 136 | self.dropped_original_weights = True 137 | 138 | 139 | class AutoArchConfig: 140 | def __init__(self, model): 141 | self.model = model 142 | 143 | def to_dict(self): 144 | return { 145 | "module_names": self.get_module_names(), 146 | "model_name": self.get_model()[0], 147 | "layers_name": self.get_layers()[0], 148 | } 149 | 150 | def get_module_names(self): 151 | layers_name, layers = self.get_layers() 152 | first_layer = next(layers.children()) 153 | # find all linear layers 154 | module_names = [] 155 | for name, module in first_layer.named_modules(): 156 | if isinstance(module, torch.nn.Linear): 157 | module_names.append(name) 158 | return module_names 159 | 160 | def get_model(self): 161 | for name, module in self.model.named_modules(): 162 | if module is not self.model and isinstance(module, PreTrainedModel): 163 | return name, module 164 | else: 165 | raise ValueError("Model not found") 166 | 167 | def get_layers(self): 168 | model_name, model = self.get_model() 169 | for name, module in model.named_children(): 170 | if isinstance(module, torch.nn.ModuleList): 171 | return name, module 172 | else: 173 | raise ValueError("Model layers not found") 174 | -------------------------------------------------------------------------------- /any_precision/analyzer/architectures/llama.yaml: -------------------------------------------------------------------------------- 1 | architecture: "LlamaForCausalLM" 2 | arch_config: 3 | model_name: "model" 4 | layers_name: "layers" 5 | module_names: 6 | - "self_attn.q_proj" 7 | - "self_attn.k_proj" 8 | - "self_attn.v_proj" 9 | - "self_attn.o_proj" 10 | - "mlp.gate_proj" 11 | - "mlp.up_proj" 12 | - "mlp.down_proj" 13 | -------------------------------------------------------------------------------- /any_precision/analyzer/architectures/mistral.yaml: -------------------------------------------------------------------------------- 1 | architecture: "MistralForCausalLM" 2 | arch_config: 3 | model_name: "model" 4 | layers_name: "layers" 5 | module_names: 6 | - "self_attn.q_proj" 7 | - "self_attn.k_proj" 8 | - "self_attn.v_proj" 9 | - "self_attn.o_proj" 10 | - "mlp.gate_proj" 11 | - "mlp.up_proj" 12 | - "mlp.down_proj" 13 | -------------------------------------------------------------------------------- /any_precision/analyzer/architectures/opt.yaml: -------------------------------------------------------------------------------- 1 | architecture: "OPTForCausalLM" 2 | arch_config: 3 | model_name: "model.decoder" 4 | layers_name: "layers" 5 | module_names: 6 | - "self_attn.q_proj" 7 | - "self_attn.k_proj" 8 | - "self_attn.v_proj" 9 | - "self_attn.out_proj" 10 | - "fc1" 11 | - "fc2" 12 | -------------------------------------------------------------------------------- /any_precision/analyzer/architectures/phi.yaml: -------------------------------------------------------------------------------- 1 | architecture: "PhiForCausalLM" 2 | arch_config: 3 | model_name: "model" 4 | layers_name: "layers" 5 | module_names: 6 | - "self_attn.q_proj" 7 | - "self_attn.k_proj" 8 | - "self_attn.v_proj" 9 | - "self_attn.dense" 10 | - "mlp.fc1" 11 | - "mlp.fc2" 12 | -------------------------------------------------------------------------------- /any_precision/analyzer/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, PreTrainedModel, AutoTokenizer, PreTrainedTokenizerBase 3 | 4 | 5 | def load_model(model_str_or_model, dtype=torch.float16): 6 | """Returns a model from a string or a model object. If a string is passed, it will be loaded from the HuggingFace""" 7 | if isinstance(model_str_or_model, str): 8 | model = AutoModelForCausalLM.from_pretrained( 9 | model_str_or_model, 10 | trust_remote_code=True, 11 | torch_dtype=dtype, 12 | device_map='auto', 13 | ) 14 | else: 15 | assert isinstance(model_str_or_model, PreTrainedModel), "model must be a string or a PreTrainedModel" 16 | model = model_str_or_model 17 | return model 18 | 19 | 20 | def load_tokenizer(model_str_or_model_or_tokenizer): 21 | """Returns a tokenizer from the model string or model object or tokenizer object""" 22 | if isinstance(model_str_or_model_or_tokenizer, str): 23 | model_str = model_str_or_model_or_tokenizer 24 | return AutoTokenizer.from_pretrained(model_str, trust_remote_code=True) 25 | elif isinstance(model_str_or_model_or_tokenizer, PreTrainedModel): 26 | model_str = model_str_or_model_or_tokenizer.name_or_path 27 | return AutoTokenizer.from_pretrained(model_str, trust_remote_code=True) 28 | else: 29 | assert isinstance(model_str_or_model_or_tokenizer, PreTrainedTokenizerBase), \ 30 | f"Unsupported type for model_str_or_model_or_tokenizer: {type(model_str_or_model_or_tokenizer)}" 31 | return model_str_or_model_or_tokenizer 32 | -------------------------------------------------------------------------------- /any_precision/evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | from . import helpers 2 | -------------------------------------------------------------------------------- /any_precision/evaluate/eval.py: -------------------------------------------------------------------------------- 1 | from .helpers import dataloader 2 | from tqdm import tqdm 3 | import torch 4 | from .helpers.utils import vprint, logprint, get_tokenizer_type, name_splitter, base_model_name_to_hf_repo_name 5 | from transformers import AutoModelForCausalLM, AutoTokenizer 6 | from ..modules import AnyPrecisionForCausalLM 7 | import os 8 | import json 9 | import lm_eval 10 | 11 | current_dir = os.path.dirname(os.path.realpath(__file__)) 12 | 13 | 14 | def fake_pack(parent_path, verbose=True): 15 | # Load from non-packed parent model to simulate quantization 16 | # WARNING: This is for PPL research only, and should not be used for any other purpose 17 | import re 18 | logprint(verbose, f"Simulating Any-Precision model from non-packed parent model at {parent_path}") 19 | 20 | if os.path.isdir('./cache/fake_packed'): 21 | for file in os.listdir('./cache/fake_packed'): 22 | if parent_path.split("/")[-1] in file: 23 | logprint(verbose, f"Faked packed model already exists for {parent_path.split('/')[-1]}. Skipping...") 24 | return 25 | 26 | # Check if D&S quantization is used 27 | dns = parent_path.split("/")[-1].startswith("dns") 28 | 29 | fields = name_splitter(parent_path) 30 | # get the field wrapped in () 31 | for field in fields: 32 | if field.startswith('(') and field.endswith(')'): 33 | base_model_name = field[1:-1] 34 | break 35 | else: 36 | raise ValueError(f"Could not find base model name in {parent_path}") 37 | original_model_repo = base_model_name_to_hf_repo_name(base_model_name) 38 | tokenizer = AutoTokenizer.from_pretrained(original_model_repo) 39 | 40 | logprint(verbose, f"Loading original model from {original_model_repo}") 41 | # Load the model from the original model repo 42 | model = AutoModelForCausalLM.from_pretrained(original_model_repo, torch_dtype=torch.float16, 43 | trust_remote_code=True) 44 | 45 | logprint(verbose, f"Loading quantized weights from {parent_path}") 46 | # Load the qweights 47 | files = os.listdir(parent_path + '/weights') 48 | layer_count = len(files) # this should suffice 49 | qweights = [None] * layer_count 50 | for file in tqdm(files, desc="Loading qweights", disable=not verbose): 51 | # filenames should be 'l0.pt' 52 | l = int(re.match(r'l(\d+).pt', file).group(1)) 53 | qweights[l] = torch.load(parent_path + '/weights/' + file) 54 | 55 | logprint(verbose, f"Loading LUTs from {parent_path}") 56 | # get a list of directories in the model_path 57 | dirs = os.listdir(parent_path) 58 | dirs.remove('weights') 59 | if dns: 60 | dirs.remove('sparse') 61 | luts = {} 62 | # Only the LUT directories should remain 63 | for lut_dir in dirs: 64 | # example: lut_3 65 | bit = int(re.match(r'lut_(\d+)', lut_dir).group(1)) 66 | for file in tqdm(os.listdir(parent_path + '/' + lut_dir), desc=f"Loading {bit}-bit LUTs", 67 | disable=not verbose): 68 | # example: l0.pt 69 | l = int(re.match(r'l(\d+).pt', file).group(1)) 70 | if bit not in luts: 71 | luts[bit] = [None] * layer_count 72 | luts[bit][l] = torch.load(parent_path + '/' + lut_dir + '/' + file) 73 | 74 | # Load D&S sparse weights if they exist 75 | sparse_model_weights = [] 76 | if dns: 77 | logprint(verbose, f"D&S quantization detected. Loading sparse weights...") 78 | for l in range(layer_count): 79 | sparse_weights = torch.load(parent_path + f'/sparse/l{l}.pt') 80 | sparse_model_weights.append(sparse_weights) 81 | 82 | logprint(verbose, f"Replacing qweights with centroids from LUTs...") 83 | 84 | max_bit = max(luts.keys()) 85 | 86 | for bit in luts: 87 | state_dict = model.state_dict() 88 | for l in tqdm(range(layer_count), desc=f"Replacing qweights with {bit}-bit centroids", ): 89 | qweight = qweights[l] 90 | lut = luts[bit][l] 91 | for module_name in qweight: 92 | full_param_name_suffix = f".{l}.{module_name}.weight" 93 | matching_keys = [key for key in state_dict.keys() if key.endswith(full_param_name_suffix)] 94 | assert len(matching_keys) == 1, f"Expected 1 matching key, got {len(matching_keys)}" 95 | matching_key = matching_keys[0] 96 | 97 | module_qweight = qweight[module_name] 98 | module_lut = lut[module_name] 99 | module_weights = [] 100 | for row_idx in range(module_qweight.shape[0]): 101 | row_weights = [] 102 | for group_idx in range(module_qweight.shape[1]): 103 | # fetch weights from the LUT 104 | group_weights = module_lut[row_idx][group_idx][ 105 | module_qweight[row_idx][group_idx] >> (max_bit - bit)] 106 | row_weights.append(torch.from_numpy(group_weights)) 107 | # join the group weights 108 | row_weights = torch.cat(row_weights, dim=0) 109 | module_weights.append(row_weights) 110 | module_weights = torch.stack(module_weights) 111 | # Add the sparse weights if they exist 112 | if dns: 113 | sparse_weights = sparse_model_weights[l][module_name] 114 | # get the indices of the sparse weights 115 | sparse_indices = sparse_weights.indices() 116 | # replace the weights with the sparse weights 117 | module_weights[sparse_indices[0], sparse_indices[1]] = sparse_weights.values() 118 | state_dict[matching_key] = module_weights 119 | 120 | save_path = f'./cache/fake_packed/fake_anyprec-p{bit}-{parent_path.split("/")[-1]}' 121 | os.makedirs(save_path, exist_ok=True) 122 | torch.save(state_dict, save_path + '/pytorch_model.bin') 123 | tokenizer.save_pretrained(save_path) 124 | model.config.save_pretrained(save_path) 125 | logprint(verbose, f"{bit}-bit model saved to {save_path}") 126 | 127 | 128 | @torch.no_grad() 129 | def auto_model_load(model_path, device='cuda', dtype=torch.float16, verbose=True): 130 | """ 131 | Args: 132 | model_path: path of the model to evaluate 133 | device: the device to use for evaluation, either 'cuda' or 'cpu' 134 | dtype: the dtype to use for evaluation, either torch.float16 or torch.float32 135 | verbose: whether to print progress 136 | 137 | Returns: 138 | (tokenizer, model) tuple loaded from the given path, with the given device and dtype. 139 | """ 140 | logprint(verbose, "Loading tokenizer and model...") 141 | 142 | if os.path.basename(model_path).startswith("anyprec-"): 143 | tokenizer = AutoTokenizer.from_pretrained(model_path) 144 | model = AnyPrecisionForCausalLM.from_quantized(model_path).to(device) 145 | else: 146 | tokenizer = AutoTokenizer.from_pretrained(model_path) 147 | model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dtype, 148 | trust_remote_code=True).to(device) 149 | 150 | logprint(verbose, f"{model.__class__.__name__} model loaded to device: {model.device}") 151 | 152 | tokenizer_type = get_tokenizer_type(model_path) 153 | 154 | if tokenizer_type is None: 155 | logprint(verbose, f"Unknown tokenizer type for {model_path}. Cannot use cached input tokens.") 156 | 157 | return tokenizer_type, tokenizer, model 158 | 159 | 160 | @torch.no_grad() 161 | def evaluate_ppl(model, tokenizer, testcases, verbose=True, chunk_size=2048, tokenizer_type=None): 162 | """ 163 | Args: 164 | model: model to evaluate 165 | tokenizer: tokenizer to use 166 | testcases: testcases names to evaluate on, passed on to dataloader.get_loaders 167 | verbose: whether to print progress 168 | chunk_size: the size of the chunks into which the test set is split 169 | tokenizer_type: set to llama, llama-2, or opt to use cached input tokens 170 | for the corresponding test set 171 | 172 | Returns: 173 | A dictionary of perplexity scores, with keys being the testcases names and values being the perplexity scores. 174 | 175 | Note that the perplexity scores are calculated over non-overlapping chunks of the test set. 176 | """ 177 | 178 | if isinstance(model, AnyPrecisionForCausalLM): 179 | is_anyprec = True 180 | else: 181 | is_anyprec = False 182 | 183 | model.eval() 184 | 185 | results = {} 186 | 187 | supported_bits = model.precisions if is_anyprec else [None] 188 | 189 | for bit in supported_bits: 190 | if is_anyprec: 191 | logprint(verbose, f"<<<< Setting model precision to {bit}-bit... >>>>") 192 | model.set_precision(bit) 193 | 194 | for testcase_name in testcases: 195 | vprint(verbose, f"---------------------- {testcase_name} ----------------------") 196 | 197 | input_tokens = _load_input_tokens(tokenizer_type, testcase_name, tokenizer, verbose) 198 | 199 | input_tokens.to(model.device) 200 | 201 | logprint(verbose, "Calculating perplexity...") 202 | 203 | seq_len = input_tokens.input_ids.size(1) 204 | nsamples = seq_len // chunk_size # floor(seq_len / chunk_size) 205 | 206 | neg_log_likelihoods = [] 207 | for i in tqdm(range(nsamples), disable=not verbose): 208 | begin_loc = i * chunk_size 209 | 210 | input_ids = input_tokens.input_ids[:, begin_loc:begin_loc + chunk_size] 211 | 212 | # add BOS token for Gemma-7B 213 | # https://github.com/huggingface/transformers/issues/29250 214 | if 'gemma' in model.config.architectures[0].lower(): 215 | # Mostly harmless to other models, but a slight drop in ppl is observed 216 | # Hence, we only add the BOS token for Gemma models for now 217 | input_ids[:, 0] = tokenizer.bos_token_id 218 | 219 | with torch.no_grad(): 220 | outputs = model(input_ids, labels=input_ids) 221 | neg_log_likelihood = outputs.loss 222 | neg_log_likelihoods.append(neg_log_likelihood) 223 | 224 | ppl = torch.exp(torch.stack(neg_log_likelihoods).mean()) 225 | logprint(verbose, f"Perplexity: {ppl.item()}") 226 | 227 | results[f"{testcase_name}:{bit}-bit"] = ppl.item() 228 | 229 | if not is_anyprec: 230 | break 231 | 232 | return results 233 | 234 | 235 | @torch.no_grad() 236 | def run_lm_eval(tokenizer, model, tasks, verbose=True): 237 | """ Run lm-eval on the given model and tasks and return the results. 238 | 239 | Receives an already initialized hf model, and a list of task names. 240 | """ 241 | if isinstance(model, AnyPrecisionForCausalLM): 242 | is_anyprec = True 243 | else: 244 | is_anyprec = False 245 | 246 | model.eval() 247 | 248 | results = {} 249 | 250 | supported_bits = model.precisions if is_anyprec else [None] 251 | 252 | for bit in supported_bits: 253 | if is_anyprec: 254 | logprint(verbose, f"<<<< Setting model precision to {bit}-bit... >>>>") 255 | model.set_precision(bit) 256 | 257 | model_lm = lm_eval.models.huggingface.HFLM(pretrained=model, tokenizer=tokenizer) 258 | eval_results = lm_eval.simple_evaluate(model=model_lm, tasks=tasks) 259 | 260 | if verbose: 261 | logprint(verbose, json.dumps(eval_results['results'], indent=4)) 262 | 263 | for task in tasks: 264 | results[f"{task}:{bit}-bit"] = eval_results['results'][task] 265 | 266 | if not is_anyprec: 267 | break 268 | 269 | return results 270 | 271 | 272 | def _load_input_tokens(tokenizer_type, testcase_name, tokenizer, verbose): 273 | """ Load input tokens from cache if available, otherwise load from dataloader and save to cache. """ 274 | input_tokens_cache_path = f"{current_dir}/input_tokens_cache/dataloader-{tokenizer_type}-{testcase_name}-test.pt" 275 | if tokenizer_type and os.path.exists(input_tokens_cache_path): 276 | logprint(verbose, f"Loading cached input tokens from {input_tokens_cache_path}...") 277 | input_tokens = torch.load(input_tokens_cache_path) 278 | else: 279 | logprint(verbose, "Loading test set...") 280 | 281 | raw_text = dataloader.get_loaders(testcase_name) 282 | 283 | logprint(verbose, "Tokenizing test set...") 284 | 285 | input_tokens = tokenizer(raw_text, return_tensors='pt') 286 | # save input_tokens to cache 287 | if tokenizer_type: 288 | logprint(verbose, f"Caching input tokens to {input_tokens_cache_path}...") 289 | # we must create the directory if it doesn't exist 290 | os.makedirs(os.path.dirname(input_tokens_cache_path), exist_ok=True) 291 | torch.save(input_tokens, input_tokens_cache_path) 292 | 293 | return input_tokens 294 | -------------------------------------------------------------------------------- /any_precision/evaluate/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SNU-ARC/any-precision-llm/baa9d0272510d6342fef562b5200c3f9454f9070/any_precision/evaluate/helpers/__init__.py -------------------------------------------------------------------------------- /any_precision/evaluate/helpers/dataloader.py: -------------------------------------------------------------------------------- 1 | # Originally from https://github.com/IST-DASLab/gptq/blob/main/datautils.py 2 | # Modified to: 3 | # - Only return the test set 4 | # - Skip the tokenization step (return the datasets as-is) 5 | 6 | import numpy as np 7 | import torch 8 | from datasets import load_dataset 9 | 10 | 11 | def set_seed(seed): 12 | np.random.seed(seed) 13 | torch.random.manual_seed(seed) 14 | 15 | 16 | def get_wikitext2(): 17 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 18 | return "\n\n".join(testdata['text']) 19 | 20 | 21 | def get_ptb(): 22 | valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation') 23 | return "\n\n".join(valdata['sentence']) 24 | 25 | 26 | def get_c4(): 27 | raise NotImplementedError("Only C4-new has been refactored to use the new dataset API") 28 | 29 | 30 | def get_ptb_new(): 31 | testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') 32 | return " ".join(testdata['sentence']) 33 | 34 | 35 | def get_ptb_new_sliced(): 36 | raw_text = get_ptb_new() 37 | sliced = raw_text.replace('', '< u n k >') 38 | return sliced 39 | 40 | 41 | def get_c4_new(): 42 | valdata = load_dataset( 43 | 'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, 44 | split='validation' 45 | ) 46 | # The original datautils from the GPTQ paper had two filters: 47 | # 1. get only the first 1100 examples 48 | # 2. tokenize, then get the first seqlen * 256 tokens, where seqlen defaulted to 2048 49 | # This resulted in 524288 tokens, which in turn decode back into 2088532 characters. 50 | 51 | # However in my version, I am only returning the text, and leaving the tokenization to the caller. 52 | # Therefore, I replace the second filter of tokens into an equivalent filter of characters. 53 | return " ".join(valdata[:1100]['text'])[:2088528] 54 | 55 | 56 | def get_loaders(name): 57 | if 'wikitext2' in name: 58 | return get_wikitext2() 59 | if 'ptb' in name: 60 | if 'new' in name: 61 | if 'sliced' in name: 62 | return get_ptb_new_sliced() 63 | else: 64 | return get_ptb_new() 65 | return get_ptb() 66 | if 'c4' in name: 67 | if 'new' in name: 68 | return get_c4_new() 69 | return get_c4() 70 | 71 | raise ValueError(f"Unknown dataset {name}") 72 | -------------------------------------------------------------------------------- /any_precision/evaluate/helpers/utils.py: -------------------------------------------------------------------------------- 1 | """ Utility functions """ 2 | 3 | import datetime 4 | import os 5 | 6 | 7 | def get_timestamp(): 8 | """ Get the current timestamp for prefixing log entries """ 9 | return datetime.datetime.now().strftime("%H:%M:%S") 10 | 11 | 12 | def logprint(verbose, *args, **kwargs): 13 | """ Print if verbose is True, and prefix with timestamp """ 14 | assert isinstance(verbose, bool), "The first argument `verbose` must be a boolean." 15 | if verbose: 16 | print(f"[{get_timestamp()}]", end=" ") 17 | print(*args, **kwargs) 18 | 19 | 20 | def vprint(verbose, *args, **kwargs): 21 | """ Print if verbose is True """ 22 | assert isinstance(verbose, bool), "The first argument `verbose` must be a boolean." 23 | if verbose: 24 | print(*args, **kwargs) 25 | 26 | 27 | def get_subdirs(path): 28 | if not os.path.exists(path): 29 | return [] 30 | return [os.path.join(path, o) for o in sorted(os.listdir(path)) 31 | if os.path.isdir(os.path.join(path, o))] 32 | 33 | 34 | def get_files(path): 35 | if not os.path.exists(path): 36 | return [] 37 | return [os.path.join(path, o) for o in sorted(os.listdir(path)) 38 | if os.path.isfile(os.path.join(path, o))] 39 | 40 | 41 | def get_base_models(include_prequant=False, relevant_models_only=False): 42 | """ Get the repo names of all base models """ 43 | repo_names = [ 44 | 'meta-llama/Llama-2-7b-hf', 45 | 'mistralai/Mistral-7B-v0.1', 46 | 'facebook/opt-1.3b', 47 | 'facebook/opt-2.7b', 48 | 'facebook/opt-6.7b', 49 | ] 50 | if not relevant_models_only: 51 | repo_names.append('huggyllama/llama-7b') 52 | repo_names.append('microsoft/phi-2') 53 | 54 | if include_prequant: 55 | repo_names += ['TheBloke/Llama-2-7B-AWQ', 'TheBloke/Llama-2-7B-GPTQ', 'TheBloke/Mistral-7B-v0.1-AWQ'] 56 | return repo_names 57 | 58 | 59 | def base_model_name_to_hf_repo_name(base_model_name): 60 | """ Convert a base model name to the full HF repository name """ 61 | if base_model_name == 'Llama-2-7b-hf': 62 | return 'meta-llama/' + base_model_name 63 | elif base_model_name == 'llama-7b': 64 | return 'huggyllama/' + base_model_name 65 | elif 'opt' in base_model_name: 66 | return 'facebook/' + base_model_name 67 | elif base_model_name == 'Mistral-7B-v0.1': 68 | return 'mistralai/' + base_model_name 69 | elif base_model_name == 'phi-2': 70 | return 'microsoft/' + base_model_name 71 | else: 72 | raise ValueError(f"Unknown base model name {base_model_name}") 73 | 74 | 75 | def find_matching_paren(string, start): 76 | """ Find the matching parenthesis for the parenthesis at index start """ 77 | assert string[start] == '(' 78 | count = 1 79 | for i in range(start + 1, len(string)): 80 | if string[i] == '(': 81 | count += 1 82 | elif string[i] == ')': 83 | count -= 1 84 | if count == 0: 85 | return i 86 | return -1 87 | 88 | 89 | def name_splitter(full_model_name): 90 | """ Split a model name into its components """ 91 | model_name = full_model_name.split('/')[-1] 92 | 93 | # Find the indices of the separators, skipping over parentheses 94 | separator_indexes = [] 95 | i = 0 96 | while i < len(model_name): 97 | if model_name[i] == '-': 98 | separator_indexes.append(i) 99 | elif model_name[i] == '(': 100 | i = find_matching_paren(model_name, i) 101 | i += 1 102 | 103 | # Split the model name into its components, based on previously found separators 104 | fields = [] 105 | start = 0 106 | for end in separator_indexes: 107 | fields.append(model_name[start:end]) 108 | start = end + 1 109 | fields.append(model_name[start:]) 110 | 111 | return fields 112 | 113 | 114 | def get_tokenizer_type(model_path): 115 | if 'llama-2' in model_path.lower(): 116 | tokenizer_type = 'llama-2' 117 | elif 'llama' in model_path.lower(): 118 | tokenizer_type = 'llama' 119 | elif 'opt' in model_path.lower(): 120 | tokenizer_type = 'opt' 121 | elif 'mistral' in model_path.lower(): 122 | tokenizer_type = 'mistral' 123 | elif 'phi-2' in model_path.lower(): 124 | tokenizer_type = 'phi-2' 125 | elif 'gemma' in model_path.lower(): 126 | tokenizer_type = 'gemma' 127 | else: 128 | tokenizer_type = None 129 | 130 | return tokenizer_type 131 | -------------------------------------------------------------------------------- /any_precision/modules/AnyPrecisionForCausalLM.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import torch 3 | import torch.nn as nn 4 | from tqdm import tqdm 5 | from transformers import ( 6 | PreTrainedModel, 7 | PretrainedConfig, 8 | AutoConfig, 9 | AutoModelForCausalLM, 10 | ) 11 | from accelerate.big_modeling import ( 12 | init_empty_weights, 13 | load_checkpoint_and_dispatch, 14 | ) 15 | 16 | from .AnyPrecisionLinear import AnyPrecisionLinear 17 | from any_precision.analyzer.analyzer import get_analyzer 18 | 19 | 20 | def replace_module_by_name(layer, module_name, new_module): 21 | levels = module_name.split('.') 22 | module = layer 23 | for level in levels[:-1]: 24 | module = getattr(module, level) if not level.isdigit() else module[int(level)] 25 | setattr(module, levels[-1], new_module) 26 | 27 | 28 | class AnyPrecisionForCausalLM(nn.Module): 29 | def __init__( 30 | self, 31 | model_path, 32 | config, 33 | precisions=None, 34 | torch_dtype=torch.float16, 35 | fuse_layers=False, 36 | trust_remote_code=True, 37 | ): 38 | super().__init__() 39 | 40 | self.config = config 41 | 42 | self.supported_bits = list(range(self.config.anyprec['seed_precision'], 43 | self.config.anyprec['parent_precision'] + 1)) 44 | if precisions is None: 45 | self.precisions = self.supported_bits 46 | else: 47 | assert len(precisions) == len(set(precisions)), "Precisions must be unique" 48 | assert all(bit in self.supported_bits for bit in precisions), \ 49 | f"Supported bits {precisions} must be a subset of model supported bits {self.supported_bits}" 50 | self.precisions = precisions 51 | 52 | self.precision = max(self.precisions) 53 | 54 | with init_empty_weights(): 55 | self.model = AutoModelForCausalLM.from_config( 56 | config=config, 57 | torch_dtype=torch_dtype, 58 | trust_remote_code=trust_remote_code, 59 | # attn_implementation="flash_attention_2", 60 | ) 61 | 62 | self.analyzer = get_analyzer(self.model) 63 | 64 | self.ap_linears = [] 65 | # Replace to AnyPrecisionLinear layers 66 | self._load_quantized_modules() 67 | 68 | self.tie_weights() 69 | 70 | device_map = {key: 'cpu' for key in self.model.state_dict().keys()} 71 | 72 | # loads the weights into modules and distributes 73 | # across available devices automatically 74 | load_checkpoint_and_dispatch( 75 | self.model, 76 | checkpoint=model_path, 77 | device_map=device_map, 78 | no_split_module_classes=[self.layer_type], 79 | dtype=torch_dtype, 80 | ) 81 | 82 | # Dispath to devices 83 | if fuse_layers: 84 | self.fuse_layers() 85 | 86 | self.prune_precisions() 87 | 88 | def forward(self, *args, **kwargs): 89 | prev_precision = self.precision 90 | if 'precision' in kwargs: 91 | precision = kwargs.pop('precision') 92 | self.set_precision(precision) 93 | 94 | results = self.model.forward(*args, **kwargs) 95 | 96 | self.set_precision(prev_precision) 97 | return results 98 | 99 | def generate(self, *args, **kwargs): 100 | if 'precision' in kwargs: 101 | prev_precision = self.precision 102 | precision = kwargs.pop('precision') 103 | self.set_precision(precision) 104 | else: 105 | prev_precision = self.precision 106 | 107 | with torch.inference_mode(): 108 | results = self.model.generate(*args, **kwargs) 109 | 110 | self.set_precision(prev_precision) 111 | return results 112 | 113 | @staticmethod 114 | def _load_config( 115 | model_path, 116 | trust_remote_code=True, 117 | ): 118 | config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) 119 | return config 120 | 121 | @classmethod 122 | def from_quantized( 123 | cls, 124 | quant_model_path, 125 | trust_remote_code=True, 126 | fuse_layers=False, 127 | precisions=None 128 | ): 129 | config = cls._load_config(quant_model_path, trust_remote_code) 130 | 131 | ap_model = cls( 132 | model_path=quant_model_path, 133 | precisions=precisions, 134 | config=config, 135 | fuse_layers=fuse_layers, 136 | trust_remote_code=trust_remote_code, 137 | ) 138 | 139 | return ap_model 140 | 141 | def _load_quantized_modules(self): 142 | # Get blocks of model 143 | layers = self.analyzer.get_layers() 144 | 145 | for layer in tqdm(layers, desc="Loading AP Layers"): 146 | # Get every linear layer in a block 147 | named_linears = self.analyzer.get_modules(layer) 148 | 149 | # Replace nn.Linear with AnyPrecisionLinear 150 | for name, module in named_linears.items(): 151 | wqlinear = AnyPrecisionLinear( 152 | module.in_features, module.out_features, 153 | self.supported_bits, 154 | bias=module.bias is not None, 155 | precisions=self.precisions, 156 | device=module.weight.device, 157 | ) 158 | self.ap_linears.append(wqlinear) 159 | replace_module_by_name(layer, name, wqlinear) 160 | 161 | torch.cuda.empty_cache() 162 | gc.collect() 163 | 164 | def prune_precisions(self): 165 | for ap_linear in self.ap_linears: 166 | ap_linear.prune_precisions() 167 | 168 | torch.cuda.empty_cache() 169 | gc.collect() 170 | 171 | def set_precision(self, precision): 172 | for ap_linear in self.ap_linears: 173 | ap_linear.set_precision(precision) 174 | self.precision = precision 175 | 176 | def tie_weights(self): 177 | if hasattr(self.model, "tie_weights"): 178 | self.model.tie_weights() 179 | 180 | def get_model_layers(self): 181 | module = self.model 182 | for attrib_name in self.config.anyprec['arch_config']['model_name'].split('.'): 183 | module = getattr(module, attrib_name) 184 | return getattr(module, self.config.anyprec['arch_config']['layers_name']) 185 | 186 | def fuse_layers(self): 187 | if 'fuse_target_layers' not in self.model_config: 188 | raise NotImplementedError("This model does not support layer fusion") 189 | # TODO implement layer fusion 190 | pass 191 | 192 | @property 193 | def layer_type(self): 194 | for layer in self.get_model_layers(): 195 | layer_class_name = layer.__class__.__name__ 196 | if layer_class_name.endswith("DecoderLayer"): 197 | return layer_class_name 198 | return None 199 | 200 | @property 201 | def device(self): 202 | return self.model.device 203 | -------------------------------------------------------------------------------- /any_precision/modules/AnyPrecisionLinear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | try: 5 | from any_precision_ext import matmul_kbit, dequant_kbit 6 | except: 7 | matmul_kbit, dequant_kbit = None, None 8 | 9 | 10 | class AnyPrecisionLinear(nn.Module): 11 | def __init__(self, in_features, out_features, supported_bits, bias=True, precisions=None, device=None, 12 | dtype=None): 13 | super().__init__() 14 | if dequant_kbit is None or matmul_kbit is None: 15 | raise ModuleNotFoundError('Please install any precision CUDA kernel extension from modules/kernels.') 16 | if precisions is None: 17 | precisions = supported_bits 18 | if not isinstance(precisions, list): 19 | raise RuntimeError('supported_bits must be a list of integers.') 20 | if dtype is not None and dtype != torch.float16: 21 | raise RuntimeError('Only float16 is supported for now.') 22 | 23 | self.in_features = in_features 24 | self.out_features = out_features 25 | self.precisions = precisions 26 | self.precision = max(self.precisions) 27 | self.supported_bits = supported_bits 28 | 29 | self.register_buffer( 30 | 'qweight', 31 | torch.empty((max(supported_bits), out_features, in_features // 32), dtype=torch.int32, device=device) 32 | ) 33 | 34 | for bit in supported_bits: 35 | self.register_buffer( 36 | f'lut{bit}', 37 | torch.empty((out_features, 2 ** bit), dtype=dtype, device=device) 38 | ) 39 | 40 | if bias: 41 | self.register_buffer( 42 | "bias", 43 | torch.empty((out_features,), dtype=dtype, device=device) 44 | ) 45 | else: 46 | self.bias = None 47 | 48 | def prune_precisions(self): 49 | self.qweight = self.qweight[:max(self.precisions)] 50 | for bit in self.supported_bits: 51 | if bit not in self.precisions: 52 | delattr(self, f'lut{bit}') 53 | 54 | def forward(self, x, **kwargs): 55 | if 'precision' in kwargs: 56 | w_bits = kwargs['precision'] 57 | else: 58 | w_bits = self.precision 59 | 60 | if x.numel() // x.shape[-1] > 8: 61 | weight = dequant_kbit(self.qweight, self._buffers[f'lut{w_bits}'], w_bits) 62 | x = torch.matmul(x, weight.T) 63 | else: 64 | x = matmul_kbit(x, self.qweight, self._buffers[f'lut{w_bits}'], w_bits) 65 | 66 | if self.bias is not None: 67 | x += self.bias 68 | 69 | return x 70 | 71 | def set_precision(self, precision): 72 | if precision not in self.precisions: 73 | raise RuntimeError(f"{self.precisions}-bit precisions are supported but {precision}-bit was specified.") 74 | 75 | self.precision = precision 76 | 77 | def extra_repr(self) -> str: 78 | return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}' 79 | -------------------------------------------------------------------------------- /any_precision/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .AnyPrecisionForCausalLM import AnyPrecisionForCausalLM 2 | -------------------------------------------------------------------------------- /any_precision/modules/kernels/dequant.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | /* macros */ 6 | 7 | #define num_rows 4 8 | #define DIV_ROUND_UP(x,y) (((x)+(y)-1)/(y)) 9 | 10 | template 11 | __device__ __forceinline__ void dequant(const uint32_t q[], uint32_t q_w[]); 12 | 13 | template <> 14 | __device__ __forceinline__ void dequant<3, true>(const uint32_t q[3], uint32_t q_w[4]) { 15 | constexpr uint32_t mask0 = 0x88888888; 16 | constexpr uint32_t mask1 = 0x44444444; 17 | constexpr uint32_t mask2 = 0x22222222; 18 | constexpr uint32_t mask3 = 0x11111111; 19 | 20 | // fast transpose 21 | q_w[0] = (((q[0]&mask0)) | ((q[1]&mask0) >> 1) | ((q[2]&mask0)>>2))>>1; 22 | q_w[1] = ((q[0]&mask1)) | ((q[1]&mask1) >> 1) | ((q[2]&mask1)>>2); 23 | q_w[2] = ((q[0]&mask2) << 1) | ((q[1]&mask2)) | ((q[2]&mask2)>>1); 24 | q_w[3] = ((q[0]&mask3) << 2) | ((q[1]&mask3) << 1) | ((q[2]&mask3)); 25 | 26 | // table lookup merge 27 | #pragma unroll 28 | for (int i = 0; i < 4; i++) 29 | q_w[i] = (q_w[i] & 0x0f0f0f0f) | ((q_w[i] & 0xf0f0f0f0) >> 1); 30 | } 31 | 32 | template <> 33 | __device__ __forceinline__ void dequant<3, false>(const uint32_t q[3], uint32_t q_w[8]) { 34 | constexpr uint32_t mask0 = 0x88888888; 35 | constexpr uint32_t mask1 = 0x44444444; 36 | constexpr uint32_t mask2 = 0x22222222; 37 | constexpr uint32_t mask3 = 0x11111111; 38 | 39 | q_w[0] = (((q[0]&mask0)) | ((q[1]&mask0) >> 1) | ((q[2]&mask0)>>2))>>1; 40 | q_w[1] = ((q[0]&mask1)) | ((q[1]&mask1) >> 1) | ((q[2]&mask1)>>2); 41 | q_w[2] = ((q[0]&mask2) << 1) | ((q[1]&mask2)) | ((q[2]&mask2)>>1); 42 | q_w[3] = ((q[0]&mask3) << 2) | ((q[1]&mask3) << 1) | ((q[2]&mask3)); 43 | 44 | constexpr uint32_t mask = 0x0f0f0f0f; 45 | q_w[4] = q_w[0] & mask; 46 | q_w[5] = q_w[1] & mask; 47 | q_w[6] = q_w[2] & mask; 48 | q_w[7] = q_w[3] & mask; 49 | 50 | q_w[0] = (q_w[0] >> 4) & mask; 51 | q_w[1] = (q_w[1] >> 4) & mask; 52 | q_w[2] = (q_w[2] >> 4) & mask; 53 | q_w[3] = (q_w[3] >> 4) & mask; 54 | } 55 | 56 | template <> 57 | __device__ __forceinline__ void dequant<4, true>(const uint32_t q[4], uint32_t q_w[4]) { 58 | constexpr uint32_t mask0 = 0x88888888; 59 | constexpr uint32_t mask1 = 0x44444444; 60 | constexpr uint32_t mask2 = 0x22222222; 61 | constexpr uint32_t mask3 = 0x11111111; 62 | 63 | q_w[0] = ((q[0]&mask0)) | ((q[1]&mask0) >> 1) | ((q[2]&mask0)>>2) | ((q[3]&mask0) >> 3); 64 | q_w[1] = ((q[0]&mask1) << 1) | (q[1]&mask1) | ((q[2]&mask1)>>1) | ((q[3]&mask1) >> 2); 65 | q_w[2] = ((q[0]&mask2) << 2) | ((q[1]&mask2) << 1) | (q[2]&mask2) | ((q[3]&mask2) >> 1); 66 | q_w[3] = ((q[0]&mask3) << 3) | ((q[1]&mask3) << 2) | ((q[2]&mask3) << 1) | (q[3]&mask3); 67 | } 68 | 69 | template <> 70 | __device__ __forceinline__ void dequant<4, false>(const uint32_t q[4], uint32_t q_w[8]) { 71 | constexpr uint32_t mask0 = 0x88888888; 72 | constexpr uint32_t mask1 = 0x44444444; 73 | constexpr uint32_t mask2 = 0x22222222; 74 | constexpr uint32_t mask3 = 0x11111111; 75 | 76 | q_w[0] = ((q[0]&mask0)) | ((q[1]&mask0) >> 1) | ((q[2]&mask0)>>2) | ((q[3]&mask0) >> 3); 77 | q_w[1] = ((q[0]&mask1) << 1) | (q[1]&mask1) | ((q[2]&mask1)>>1) | ((q[3]&mask1) >> 2); 78 | q_w[2] = ((q[0]&mask2) << 2) | ((q[1]&mask2) << 1) | (q[2]&mask2) | ((q[3]&mask2) >> 1); 79 | q_w[3] = ((q[0]&mask3) << 3) | ((q[1]&mask3) << 2) | ((q[2]&mask3) << 1) | (q[3]&mask3); 80 | 81 | constexpr uint32_t mask = 0x0f0f0f0f; 82 | q_w[4] = q_w[0] & mask; 83 | q_w[5] = q_w[1] & mask; 84 | q_w[6] = q_w[2] & mask; 85 | q_w[7] = q_w[3] & mask; 86 | 87 | q_w[0] = (q_w[0] >> 4) & mask; 88 | q_w[1] = (q_w[1] >> 4) & mask; 89 | q_w[2] = (q_w[2] >> 4) & mask; 90 | q_w[3] = (q_w[3] >> 4) & mask; 91 | } 92 | 93 | template <> 94 | __device__ __forceinline__ void dequant<8, false>(const uint32_t q[8], uint32_t q_w[8]) { 95 | constexpr uint32_t mask0 = 0x80808080; 96 | constexpr uint32_t mask1 = 0x40404040; 97 | constexpr uint32_t mask2 = 0x20202020; 98 | constexpr uint32_t mask3 = 0x10101010; 99 | constexpr uint32_t mask4 = 0x08080808; 100 | constexpr uint32_t mask5 = 0x04040404; 101 | constexpr uint32_t mask6 = 0x02020202; 102 | constexpr uint32_t mask7 = 0x01010101; 103 | 104 | q_w[0] = ((q[0]&mask0)>>0) | ((q[1]&mask0)>>1) | ((q[2]&mask0)>>2) | ((q[3]&mask0)>>3) | ((q[4]&mask0)>>4) | ((q[5]&mask0)>>5) | ((q[6]&mask0)>>6) | ((q[7]&mask0)>>7); 105 | q_w[1] = ((q[0]&mask1)<<1) | ((q[1]&mask1)>>0) | ((q[2]&mask1)>>1) | ((q[3]&mask1)>>2) | ((q[4]&mask1)>>3) | ((q[5]&mask1)>>4) | ((q[6]&mask1)>>5) | ((q[7]&mask1)>>6); 106 | q_w[2] = ((q[0]&mask2)<<2) | ((q[1]&mask2)<<1) | ((q[2]&mask2)>>0) | ((q[3]&mask2)>>1) | ((q[4]&mask2)>>2) | ((q[5]&mask2)>>3) | ((q[6]&mask2)>>4) | ((q[7]&mask2)>>5); 107 | q_w[3] = ((q[0]&mask3)<<3) | ((q[1]&mask3)<<2) | ((q[2]&mask3)<<1) | ((q[3]&mask3)>>0) | ((q[4]&mask3)>>1) | ((q[5]&mask3)>>2) | ((q[6]&mask3)>>3) | ((q[7]&mask3)>>4); 108 | q_w[4] = ((q[0]&mask4)<<4) | ((q[1]&mask4)<<3) | ((q[2]&mask4)<<2) | ((q[3]&mask4)<<1) | ((q[4]&mask4)>>0) | ((q[5]&mask4)>>1) | ((q[6]&mask4)>>2) | ((q[7]&mask4)>>3); 109 | q_w[5] = ((q[0]&mask5)<<5) | ((q[1]&mask5)<<4) | ((q[2]&mask5)<<3) | ((q[3]&mask5)<<2) | ((q[4]&mask5)<<1) | ((q[5]&mask5)>>0) | ((q[6]&mask5)>>1) | ((q[7]&mask5)>>2); 110 | q_w[6] = ((q[0]&mask6)<<6) | ((q[1]&mask6)<<5) | ((q[2]&mask6)<<4) | ((q[3]&mask6)<<3) | ((q[4]&mask6)<<2) | ((q[5]&mask6)<<1) | ((q[6]&mask6)>>0) | ((q[7]&mask6)>>1); 111 | q_w[7] = ((q[0]&mask7)<<7) | ((q[1]&mask7)<<6) | ((q[2]&mask7)<<5) | ((q[3]&mask7)<<4) | ((q[4]&mask7)<<3) | ((q[5]&mask7)<<2) | ((q[6]&mask7)<<1) | ((q[7]&mask7)>>0); 112 | } 113 | 114 | template <> 115 | __device__ __forceinline__ void dequant<7, false>(const uint32_t q[7], uint32_t q_w[8]) { 116 | constexpr uint32_t mask0 = 0x80808080; 117 | constexpr uint32_t mask1 = 0x40404040; 118 | constexpr uint32_t mask2 = 0x20202020; 119 | constexpr uint32_t mask3 = 0x10101010; 120 | constexpr uint32_t mask4 = 0x08080808; 121 | constexpr uint32_t mask5 = 0x04040404; 122 | constexpr uint32_t mask6 = 0x02020202; 123 | constexpr uint32_t mask7 = 0x01010101; 124 | 125 | q_w[0] = ((q[0]&mask0)>>1) | ((q[1]&mask0)>>2) | ((q[2]&mask0)>>3) | ((q[3]&mask0)>>4) | ((q[4]&mask0)>>5) | ((q[5]&mask0)>>6) | ((q[6]&mask0)>>7); 126 | q_w[1] = ((q[0]&mask1)>>0) | ((q[1]&mask1)>>1) | ((q[2]&mask1)>>2) | ((q[3]&mask1)>>3) | ((q[4]&mask1)>>4) | ((q[5]&mask1)>>5) | ((q[6]&mask1)>>6); 127 | q_w[2] = ((q[0]&mask2)<<1) | ((q[1]&mask2)>>0) | ((q[2]&mask2)>>1) | ((q[3]&mask2)>>2) | ((q[4]&mask2)>>3) | ((q[5]&mask2)>>4) | ((q[6]&mask2)>>5); 128 | q_w[3] = ((q[0]&mask3)<<2) | ((q[1]&mask3)<<1) | ((q[2]&mask3)>>0) | ((q[3]&mask3)>>1) | ((q[4]&mask3)>>2) | ((q[5]&mask3)>>3) | ((q[6]&mask3)>>4); 129 | q_w[4] = ((q[0]&mask4)<<3) | ((q[1]&mask4)<<2) | ((q[2]&mask4)<<1) | ((q[3]&mask4)>>0) | ((q[4]&mask4)>>1) | ((q[5]&mask4)>>2) | ((q[6]&mask4)>>3); 130 | q_w[5] = ((q[0]&mask5)<<4) | ((q[1]&mask5)<<3) | ((q[2]&mask5)<<2) | ((q[3]&mask5)<<1) | ((q[4]&mask5)>>0) | ((q[5]&mask5)>>1) | ((q[6]&mask5)>>2); 131 | q_w[6] = ((q[0]&mask6)<<5) | ((q[1]&mask6)<<4) | ((q[2]&mask6)<<3) | ((q[3]&mask6)<<2) | ((q[4]&mask6)<<1) | ((q[5]&mask6)>>0) | ((q[6]&mask6)>>1); 132 | q_w[7] = ((q[0]&mask7)<<6) | ((q[1]&mask7)<<5) | ((q[2]&mask7)<<4) | ((q[3]&mask7)<<3) | ((q[4]&mask7)<<2) | ((q[5]&mask7)<<1) | ((q[6]&mask7)>>0); 133 | } 134 | 135 | template <> 136 | __device__ __forceinline__ void dequant<6, false>(const uint32_t q[6], uint32_t q_w[8]) { 137 | constexpr uint32_t mask0 = 0x80808080; 138 | constexpr uint32_t mask1 = 0x40404040; 139 | constexpr uint32_t mask2 = 0x20202020; 140 | constexpr uint32_t mask3 = 0x10101010; 141 | constexpr uint32_t mask4 = 0x08080808; 142 | constexpr uint32_t mask5 = 0x04040404; 143 | constexpr uint32_t mask6 = 0x02020202; 144 | constexpr uint32_t mask7 = 0x01010101; 145 | 146 | q_w[0] = ((q[0]&mask0)>>2) | ((q[1]&mask0)>>3) | ((q[2]&mask0)>>4) | ((q[3]&mask0)>>5) | ((q[4]&mask0)>>6) | ((q[5]&mask0)>>7); 147 | q_w[1] = ((q[0]&mask1)>>1) | ((q[1]&mask1)>>2) | ((q[2]&mask1)>>3) | ((q[3]&mask1)>>4) | ((q[4]&mask1)>>5) | ((q[5]&mask1)>>6); 148 | q_w[2] = ((q[0]&mask2)>>0) | ((q[1]&mask2)>>1) | ((q[2]&mask2)>>2) | ((q[3]&mask2)>>3) | ((q[4]&mask2)>>4) | ((q[5]&mask2)>>5); 149 | q_w[3] = ((q[0]&mask3)<<1) | ((q[1]&mask3)>>0) | ((q[2]&mask3)>>1) | ((q[3]&mask3)>>2) | ((q[4]&mask3)>>3) | ((q[5]&mask3)>>4); 150 | q_w[4] = ((q[0]&mask4)<<2) | ((q[1]&mask4)<<1) | ((q[2]&mask4)>>0) | ((q[3]&mask4)>>1) | ((q[4]&mask4)>>2) | ((q[5]&mask4)>>3); 151 | q_w[5] = ((q[0]&mask5)<<3) | ((q[1]&mask5)<<2) | ((q[2]&mask5)<<1) | ((q[3]&mask5)>>0) | ((q[4]&mask5)>>1) | ((q[5]&mask5)>>2); 152 | q_w[6] = ((q[0]&mask6)<<4) | ((q[1]&mask6)<<3) | ((q[2]&mask6)<<2) | ((q[3]&mask6)<<1) | ((q[4]&mask6)>>0) | ((q[5]&mask6)>>1); 153 | q_w[7] = ((q[0]&mask7)<<5) | ((q[1]&mask7)<<4) | ((q[2]&mask7)<<3) | ((q[3]&mask7)<<2) | ((q[4]&mask7)<<1) | ((q[5]&mask7)<<0); 154 | } 155 | 156 | template <> 157 | __device__ __forceinline__ void dequant<5, false>(const uint32_t q[5], uint32_t q_w[8]) { 158 | constexpr uint32_t mask0 = 0x80808080; 159 | constexpr uint32_t mask1 = 0x40404040; 160 | constexpr uint32_t mask2 = 0x20202020; 161 | constexpr uint32_t mask3 = 0x10101010; 162 | constexpr uint32_t mask4 = 0x08080808; 163 | constexpr uint32_t mask5 = 0x04040404; 164 | constexpr uint32_t mask6 = 0x02020202; 165 | constexpr uint32_t mask7 = 0x01010101; 166 | 167 | q_w[0] = ((q[0]&mask0)>>3) | ((q[1]&mask0)>>4) | ((q[2]&mask0)>>5) | ((q[3]&mask0)>>6) | ((q[4]&mask0)>>7); 168 | q_w[1] = ((q[0]&mask1)>>2) | ((q[1]&mask1)>>3) | ((q[2]&mask1)>>4) | ((q[3]&mask1)>>5) | ((q[4]&mask1)>>6); 169 | q_w[2] = ((q[0]&mask2)>>1) | ((q[1]&mask2)>>2) | ((q[2]&mask2)>>3) | ((q[3]&mask2)>>4) | ((q[4]&mask2)>>5); 170 | q_w[3] = ((q[0]&mask3)>>0) | ((q[1]&mask3)>>1) | ((q[2]&mask3)>>2) | ((q[3]&mask3)>>3) | ((q[4]&mask3)>>4); 171 | q_w[4] = ((q[0]&mask4)<<1) | ((q[1]&mask4)>>0) | ((q[2]&mask4)>>1) | ((q[3]&mask4)>>2) | ((q[4]&mask4)>>3); 172 | q_w[5] = ((q[0]&mask5)<<2) | ((q[1]&mask5)<<1) | ((q[2]&mask5)>>0) | ((q[3]&mask5)>>1) | ((q[4]&mask5)>>2); 173 | q_w[6] = ((q[0]&mask6)<<3) | ((q[1]&mask6)<<2) | ((q[2]&mask6)<<1) | ((q[3]&mask6)>>0) | ((q[4]&mask6)>>1); 174 | q_w[7] = ((q[0]&mask7)<<4) | ((q[1]&mask7)<<3) | ((q[2]&mask7)<<2) | ((q[3]&mask7)<<1) | ((q[4]&mask7)>>0); 175 | } 176 | 177 | template 178 | __global__ void dequant_kbit_store( 179 | const uint32_t * W, 180 | const uint32_t N, const uint32_t K, 181 | const __half * C, __half * O 182 | ) { 183 | static_assert(bits >= 3 && bits <= 8); 184 | constexpr int num_centroids = 1 << bits, warp_size = 32; 185 | 186 | const uint32_t row_idx = blockIdx.x * num_rows + threadIdx.y; 187 | const int centroid_idx = threadIdx.y * num_centroids; 188 | 189 | __shared__ __half shC[num_rows * num_centroids]; 190 | 191 | if constexpr (bits < 6) { 192 | if (threadIdx.x < num_centroids) 193 | shC[centroid_idx + threadIdx.x] = C[num_centroids * row_idx + threadIdx.x]; 194 | } else if constexpr (bits == 6) { 195 | ((half2 *)shC)[centroid_idx / 2 + threadIdx.x] = ((half2 *)C)[num_centroids * row_idx / 2 + threadIdx.x]; 196 | } else if constexpr (bits == 7) { 197 | ((float2 *)shC)[centroid_idx / 4 + threadIdx.x] = ((float2 *)C)[num_centroids * row_idx / 4 + threadIdx.x]; 198 | } else if constexpr (bits == 8) { 199 | ((float4 *)shC)[centroid_idx / 8 + threadIdx.x] = ((float4 *)C)[num_centroids * row_idx / 8 + threadIdx.x]; 200 | } 201 | __syncthreads(); 202 | 203 | int eff_warp_size = warp_size; 204 | uint32_t q[bits], q_w[8]; 205 | half2 dq_w[16]; 206 | 207 | const uint32_t maxi = DIV_ROUND_UP(K, 32 * warp_size); 208 | for (int i = 0; i < maxi; i++) { 209 | if (i == K / (32 * warp_size)) { 210 | eff_warp_size = (K % (32 * warp_size)) / 32; 211 | if (threadIdx.x >= eff_warp_size) break; 212 | } 213 | 214 | // load quantized weight 215 | #pragma unroll 216 | for (int j = 0; j < bits; j++) { 217 | const int k = (j * N + row_idx) * (K / 32) + i * 32 + threadIdx.x; 218 | q[j] = W[k]; 219 | } 220 | 221 | // dequantize 222 | dequant(q, q_w); 223 | 224 | // lookup 225 | #pragma unroll 226 | for (int j = 3; j >= 0; j--) { 227 | #pragma unroll 228 | for (int k = 0; k < 4; k++) { 229 | const __half x = shC[centroid_idx | (q_w[k*2+0] & 0xff)]; 230 | const __half y = shC[centroid_idx | (q_w[k*2+1] & 0xff)]; 231 | dq_w[j * 4 + k] = make_half2(x, y); 232 | } 233 | #pragma unroll 234 | for (int k = 0; k < 8; k++) 235 | q_w[k] >>= 8; 236 | } 237 | 238 | #pragma unroll 239 | for (int j = 0; j < 4; j++) 240 | ((float4 *)O)[(row_idx*K + 8*eff_warp_size*j + i*warp_size*32 + 8*threadIdx.x)/8] = ((float4 *)dq_w)[j]; 241 | } 242 | } 243 | -------------------------------------------------------------------------------- /any_precision/modules/kernels/main.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "matmul.cuh" 4 | 5 | void cudaError(cudaError_t errCode, const char * filename, int linenum) { 6 | if(errCode != cudaSuccess) { 7 | printf("Error : %s (%s : %d)\n", cudaGetErrorString(errCode), filename, linenum); 8 | exit(EXIT_FAILURE); 9 | } 10 | } 11 | 12 | #define HANDLE_ERROR(err) (cudaError(err, __FILE__, __LINE__)) 13 | 14 | typedef void (* matmul_func) ( 15 | const __half *, const uint32_t *, 16 | const uint32_t, const uint32_t, const uint32_t, 17 | const __half *, __half * 18 | ); 19 | 20 | template 21 | struct get_matmul_func { 22 | void operator()(matmul_func func[][9][2]) const { 23 | if constexpr (s <= e) { 24 | func[s][1][0] = matmul_kbit_32<1, s, false>; 25 | func[s][1][1] = matmul_kbit_32<1, s, true>; 26 | func[s][2][0] = matmul_kbit_32<2, s, false>; 27 | func[s][3][0] = matmul_kbit_32<3, s, false>; 28 | func[s][4][0] = matmul_kbit_32<4, s, false>; 29 | func[s][5][0] = matmul_kbit_32<5, s, false>; 30 | func[s][6][0] = matmul_kbit_32<6, s, false>; 31 | func[s][7][0] = matmul_kbit_32<7, s, false>; 32 | func[s][8][0] = matmul_kbit_32<8, s, false>; 33 | get_matmul_func()(func); 34 | } 35 | } 36 | }; 37 | 38 | typedef void (* dequant_func) ( 39 | const uint32_t *, 40 | const uint32_t, const uint32_t, 41 | const __half *, __half * 42 | ); 43 | 44 | template 45 | struct get_dequant_func { 46 | void operator()(dequant_func func[]) const { 47 | if constexpr (s <= e) { 48 | func[s] = dequant_kbit_store; 49 | get_dequant_func()(func); 50 | } 51 | } 52 | }; 53 | 54 | bool dequant_initalized = false; 55 | bool matmul_initialized = false; 56 | bool is_orin = false; 57 | matmul_func matmul_functions[9][9][2] = {NULL, }; 58 | dequant_func dequant_functions[9] = {NULL, }; 59 | 60 | torch::Tensor dequant_kbit( 61 | torch::Tensor qweight, 62 | torch::Tensor lut, 63 | int w_bits 64 | ) { 65 | // Set correct device 66 | HANDLE_ERROR(cudaSetDevice(qweight.device().index())); 67 | 68 | assert(qweight.ndimension() == 3 && qweight.dtype() == torch::kInt && lut.dtype() == torch::kHalf); 69 | assert(qweight.device() == lut.device() && qweight.is_cuda()); 70 | assert(w_bits >= 3 && w_bits <= 8); 71 | const int N = qweight.size(1); 72 | const int K = qweight.size(2) * 32; 73 | 74 | if (!dequant_initalized) { 75 | get_dequant_func<3, 8>()(dequant_functions); 76 | dequant_initalized = true; 77 | } 78 | 79 | auto options = torch::TensorOptions().dtype(torch::kHalf).device(qweight.device()); 80 | at::Tensor weight = torch::empty({N, K}, options); 81 | 82 | dim3 grid(N/num_rows), block(32, num_rows); 83 | dequant_functions[w_bits]<<>>( 84 | (uint32_t *)qweight.data_ptr(), 85 | N, K, 86 | (__half *)lut.data_ptr(), 87 | (__half *)weight.data_ptr() 88 | ); 89 | 90 | return weight; 91 | } 92 | 93 | torch::Tensor matmul_kbit( 94 | torch::Tensor in, 95 | torch::Tensor qweight, 96 | torch::Tensor lut, 97 | int w_bits 98 | ) { 99 | // Set correct device 100 | HANDLE_ERROR(cudaSetDevice(qweight.device().index())); 101 | 102 | const int N = qweight.size(1); 103 | const int K = qweight.size(2) * 32; 104 | int64_t in_ndim = in.ndimension(); 105 | const int M = in.numel() / K; 106 | 107 | // TODO assert with size or dtype 108 | assert(M >= 1 && M <= 8 && w_bits >= 3 && w_bits <= 8); 109 | assert(in.device() == qweight.device() && in.device() == lut.device() && in.is_cuda()); 110 | assert(qweight.ndimension() == 3 && qweight.dtype() == torch::kInt && lut.dtype() == torch::kHalf); 111 | assert(in.dtype() == torch::kHalf); 112 | 113 | if (!matmul_initialized) { 114 | int device; 115 | HANDLE_ERROR(cudaGetDevice(&device)); 116 | cudaDeviceProp prop; 117 | HANDLE_ERROR(cudaGetDeviceProperties(&prop, device)); 118 | is_orin = strcmp(prop.name, "Orin") == 0; 119 | 120 | get_matmul_func<3, 8>()(matmul_functions); 121 | matmul_initialized = true; 122 | } 123 | 124 | auto sizes = in.sizes().vec(); 125 | sizes.at(in_ndim - 1) = N; 126 | auto options = torch::TensorOptions().dtype(torch::kHalf).device(in.device()); 127 | at::Tensor out = torch::empty(sizes, options); 128 | 129 | const int multi_row = (M == 1 ? 1 : 4); 130 | const int use_ksplit = !is_orin && M == 1 && K > 4096 && w_bits >= 7; 131 | const int num_ksplit = (use_ksplit ? DIV_ROUND_UP(K, 4096) : 1); 132 | 133 | dim3 grid(N/(num_rows*multi_row)), block(32, num_rows, num_ksplit); 134 | matmul_functions[w_bits][M][use_ksplit]<<>>( 135 | (__half *)in.data_ptr(), 136 | (uint32_t *)qweight.data_ptr(), 137 | M, N, K, 138 | (__half *)lut.data_ptr(), 139 | (__half *)out.data_ptr() 140 | ); 141 | 142 | return out; 143 | } 144 | 145 | PYBIND11_MODULE(any_precision_ext, m) { 146 | m.def("matmul_kbit", &matmul_kbit, "kbit quantized matmul_function"); 147 | m.def("dequant_kbit", &dequant_kbit, "kbit dequantize function"); 148 | } 149 | -------------------------------------------------------------------------------- /any_precision/modules/kernels/matmul.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "dequant.cuh" 5 | 6 | /* warp-wide sum with tree-reduction */ 7 | __device__ __forceinline__ __half warp_reduce_sum( 8 | __half sum 9 | ) { 10 | #pragma unroll 11 | for (int i = 4; i >= 0; i--) 12 | sum = __hadd(sum, __shfl_down_sync(0xffffffff, sum, 1< 17 | __global__ void matmul_kbit_32( 18 | const __half * I, const uint32_t * W, 19 | const uint32_t M, const uint32_t N, const uint32_t K, 20 | const __half * C, __half * O 21 | ) { 22 | static_assert(maxm >= 1 && bits >= 3 && bits <= 8); 23 | static_assert(!use_ksplit || maxm == 1); 24 | constexpr bool use_half2_centroid = (bits == 3 || (bits == 4 && maxm > 1)); 25 | constexpr int multi_row = (maxm == 1 ? 1 : 4); 26 | 27 | constexpr int num_centroids = 1 << bits, warp_size = 32; 28 | constexpr int shC_siz = (use_half2_centroid ? num_centroids * num_centroids * 2 : num_centroids); 29 | constexpr int q_w_siz = (use_half2_centroid ? 4 : 8); 30 | 31 | const uint32_t row_idx_base = blockIdx.x * num_rows * multi_row + threadIdx.y; 32 | const int centroid_idx_base = threadIdx.y * (use_half2_centroid ? num_centroids * num_centroids : num_centroids); 33 | 34 | __shared__ __half shC[num_rows * multi_row * shC_siz]; 35 | 36 | if (!use_ksplit || threadIdx.z == 0) { 37 | #pragma unroll 38 | for (int h = 0; h < multi_row; h++) { 39 | const uint32_t row_idx = row_idx_base + h * num_rows; 40 | const int centroid_idx = centroid_idx_base + h * num_rows * (use_half2_centroid ? num_centroids * num_centroids : num_centroids); 41 | if constexpr (use_half2_centroid) { 42 | const int xx = threadIdx.x % num_centroids, yy = threadIdx.x / num_centroids; 43 | const __half fragCX = C[row_idx * num_centroids | xx]; 44 | #pragma unroll 45 | for (int i = 0; i < shC_siz / warp_size / 2; i++) { 46 | const int yidx = yy | (i * warp_size / num_centroids); 47 | const __half fragCY = C[row_idx * num_centroids | yidx]; 48 | ((__half2 * )shC)[centroid_idx | (yidx * num_centroids) | xx] = make_half2(fragCY, fragCX); 49 | } 50 | } else if constexpr (bits < 6) { 51 | if (threadIdx.x < num_centroids) 52 | shC[centroid_idx + threadIdx.x] = C[num_centroids * row_idx + threadIdx.x]; 53 | } else if constexpr (bits == 6) { 54 | ((__half2 *)shC)[centroid_idx / 2 + threadIdx.x] = ((__half2 *)C)[num_centroids * row_idx / 2 + threadIdx.x]; 55 | } else if constexpr (bits == 7) { 56 | ((float2 *)shC)[centroid_idx / 4 + threadIdx.x] = ((float2 *)C)[num_centroids * row_idx / 4 + threadIdx.x]; 57 | } else if constexpr (bits == 8) { 58 | ((float4 *)shC)[centroid_idx / 8 + threadIdx.x] = ((float4 *)C)[num_centroids * row_idx / 8 + threadIdx.x]; 59 | } 60 | } 61 | } 62 | __syncthreads(); 63 | 64 | int eff_warp_size = warp_size; 65 | __half partial_sum[maxm * multi_row] = {__float2half(0.0), }; 66 | uint32_t q[bits], q_w[q_w_siz]; 67 | __half2 dq_w[16]; 68 | 69 | int mini = (use_ksplit ? threadIdx.z * 4 : 0); 70 | int maxi = DIV_ROUND_UP(K, 32 * warp_size); 71 | if (use_ksplit && maxi > mini + 4) maxi = mini + 4; 72 | for (int i = mini; i < maxi; i++) { 73 | if (i == K / (32 * warp_size)) { 74 | eff_warp_size = (K % (32 * warp_size)) / 32; 75 | if (threadIdx.x >= eff_warp_size) break; 76 | } 77 | 78 | #pragma unroll 79 | for (int h = 0; h < multi_row; h++) { 80 | const uint32_t row_idx = row_idx_base + h * num_rows; 81 | const int centroid_idx = centroid_idx_base + h * num_rows * (use_half2_centroid ? num_centroids * num_centroids : num_centroids); 82 | 83 | // load quantized weight 84 | #pragma unroll 85 | for (int j = 0; j < bits; j++) { 86 | const int k = (j * N + row_idx) * (K / 32) + i * 32 + threadIdx.x; 87 | q[j] = W[k]; 88 | } 89 | 90 | // dequantize 91 | dequant(q, q_w); 92 | 93 | // lookup 94 | #pragma unroll 95 | for (int j = 3; j >= 0; j--) { 96 | if constexpr (use_half2_centroid) { 97 | #pragma unroll 98 | for (int k = 0; k < 2; k++) { 99 | const __half2 x = ((__half2 *)shC)[centroid_idx | (q_w[k*2+0] & 0xff)]; 100 | const __half2 y = ((__half2 *)shC)[centroid_idx | (q_w[k*2+1] & 0xff)]; 101 | dq_w[j * 4 + k + 0] = make_half2(x.x, y.x); 102 | dq_w[j * 4 + k + 2] = make_half2(x.y, y.y); 103 | } 104 | } else { 105 | #pragma unroll 106 | for (int k = 0; k < 4; k++) { 107 | const __half x = shC[centroid_idx | (q_w[k*2+0] & 0xff)]; 108 | const __half y = shC[centroid_idx | (q_w[k*2+1] & 0xff)]; 109 | dq_w[j * 4 + k] = make_half2(x, y); 110 | } 111 | } 112 | #pragma unroll 113 | for (int k = 0; k < q_w_siz; k++) 114 | q_w[k] >>= 8; 115 | } 116 | 117 | // accumulate 118 | #pragma unroll 119 | for (int l = 0; l < maxm; l++) { 120 | __half2 sum = make_half2(__float2half(0.0), __float2half(0.0)); 121 | #pragma unroll 122 | for (int j = 3; j >= 0; j--) { 123 | const int idx = (l*K/8 + eff_warp_size*j) + i*warp_size*4 + threadIdx.x; 124 | float4 in_buf = ((float4 *)I)[idx]; 125 | __half2 * in_half = (__half2 *)&in_buf; 126 | #pragma unroll 127 | for (int k = 0; k < 4; k++) 128 | sum = __hfma2(dq_w[j * 4 + k], in_half[k], sum); 129 | } 130 | partial_sum[l + h * maxm] = __hadd(partial_sum[l + h * maxm], __hadd(sum.x, sum.y)); 131 | } 132 | } 133 | } 134 | 135 | #pragma unroll 136 | for (int i = 0; i < maxm * multi_row; i++) 137 | partial_sum[i] = warp_reduce_sum(partial_sum[i]); 138 | 139 | if constexpr (use_ksplit) { 140 | __shared__ __half shO[maxm * multi_row * num_rows]; 141 | if (threadIdx.x == 0 && threadIdx.z == 0) 142 | #pragma unroll 143 | for (int j = 0; j < multi_row; j++) 144 | shO[j + threadIdx.y * multi_row] = __float2half(0.0); 145 | __syncthreads(); 146 | if (threadIdx.x == 0) 147 | #pragma unroll 148 | for (int j = 0; j < multi_row; j++) 149 | atomicAdd(shO + j + threadIdx.y * multi_row, partial_sum[j]); 150 | __syncthreads(); 151 | if (threadIdx.x == 0 && threadIdx.z == 0) 152 | #pragma unroll 153 | for (int j = 0; j < multi_row; j++) 154 | partial_sum[j] = shO[j + threadIdx.y * multi_row]; 155 | } 156 | 157 | if (threadIdx.x == 0 && (!use_ksplit || threadIdx.z == 0)) { 158 | #pragma unroll 159 | for (int i = 0; i < maxm; i++) { 160 | #pragma unroll 161 | for (int j = 0; j < multi_row; j++) { 162 | const uint32_t row_idx = row_idx_base + j * num_rows; 163 | O[i * N + row_idx] = partial_sum[i + j * maxm]; 164 | } 165 | } 166 | } 167 | } 168 | -------------------------------------------------------------------------------- /any_precision/modules/kernels/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils import cpp_extension 3 | 4 | setup( 5 | name="any_precision_ext", 6 | ext_modules=[ 7 | cpp_extension.CUDAExtension( 8 | "any_precision_ext", ["main.cu"] 9 | ) 10 | ], 11 | cmdclass={"build_ext": cpp_extension.BuildExtension}, 12 | ) 13 | -------------------------------------------------------------------------------- /any_precision/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import any_precision_quantize 2 | -------------------------------------------------------------------------------- /any_precision/quantization/config.py: -------------------------------------------------------------------------------- 1 | DEFAULT_DATASET = 'c4' 2 | DEFAULT_SEQ_LEN = 512 3 | DEFAULT_NUM_EXAMPLES = 100 4 | DEFAULT_CACHE_DIR = 'cache' 5 | DEFAULT_SEED_PRECISION = 3 6 | DEFAULT_PARENT_PRECISION = 8 7 | -------------------------------------------------------------------------------- /any_precision/quantization/datautils.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import random 3 | import numpy as np 4 | import logging 5 | 6 | 7 | def _get_wikitext2(split): 8 | assert split in ['train', 'validation', 'test'], f"Unknown split {split} for wikitext2" 9 | 10 | data = load_dataset('wikitext', 'wikitext-2-raw-v1', split=split, trust_remote_code=True) 11 | return data['text'] 12 | 13 | 14 | def _get_ptb(split, slice_unk=True): 15 | assert split in ['train', 'validation', 'test'], f"Unknown split {split} for ptb" 16 | 17 | data = load_dataset('ptb_text_only', 'penn_treebank', split=split, 18 | trust_remote_code=True) 19 | data_list = data['sentence'] 20 | 21 | if slice_unk: 22 | data_list = [s.replace('', '< u n k >') for s in data_list] 23 | 24 | return data_list 25 | 26 | 27 | def _get_c4(split): 28 | assert split in ['train', 'validation'], f"Unknown split {split} for c4" 29 | 30 | if split == 'train': 31 | data = load_dataset( 32 | 'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', 33 | trust_remote_code=True 34 | ) 35 | else: 36 | assert split == 'validation' 37 | data = load_dataset( 38 | 'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation', 39 | trust_remote_code=True 40 | ) 41 | 42 | return data['text'] 43 | 44 | 45 | def _get_pileval(split): 46 | if split != 'validation': 47 | logging.warning(f"Pileval only has a validation split, but got split={split}. Using validation split.") 48 | data = load_dataset("mit-han-lab/pile-val-backup", split="validation", trust_remote_code=True) 49 | 50 | return data['text'] 51 | 52 | 53 | def _sample_and_tokenize(texts, tokenizer, seq_len, num_samples, seed=None): 54 | assert num_samples <= len(texts), \ 55 | f"num_samples({num_samples}) should be less than or equal to the number of texts({len(texts)})" 56 | 57 | # this works for None too, effectively setting random seeds 58 | random.seed(seed) 59 | np.random.seed(seed) 60 | 61 | selected_indices = set() 62 | 63 | samples = [] 64 | while len(samples) < num_samples: 65 | idx = random.randint(0, len(texts) - 1) 66 | if idx in selected_indices: # we don't want to sample the same text twice 67 | continue 68 | text = texts[idx] 69 | 70 | tokens = tokenizer(text, return_tensors='pt')['input_ids'][0] 71 | if len(tokens) < seq_len: # if the text is too short, we skip it 72 | continue 73 | 74 | tokens = tokens[:seq_len] 75 | 76 | selected_indices.add(idx) 77 | samples.append(tokens) 78 | 79 | return samples 80 | 81 | 82 | def _get_dataset(dataset_name, split): 83 | if dataset_name == 'wikitext2': 84 | return _get_wikitext2(split) 85 | elif dataset_name == 'ptb': 86 | return _get_ptb(split) 87 | elif dataset_name == 'c4': 88 | return _get_c4(split) 89 | elif dataset_name == 'pileval': 90 | return _get_pileval(split) 91 | else: 92 | raise ValueError(f"Unknown dataset {dataset_name}") 93 | 94 | 95 | def get_tokens(dataset_name, split, tokenizer, seq_len, num_samples, seed=None): 96 | logging.info(f"Fetching dataset: {dataset_name}") 97 | texts = _get_dataset(dataset_name, split) 98 | logging.info(f"Sampling {num_samples} samples of length {seq_len} from {dataset_name}...") 99 | return _sample_and_tokenize(texts, tokenizer, seq_len, num_samples, seed) 100 | -------------------------------------------------------------------------------- /any_precision/quantization/dense_and_sparse.py: -------------------------------------------------------------------------------- 1 | import numba 2 | import logging 3 | import torch 4 | 5 | import numpy as np 6 | from tqdm.contrib.concurrent import process_map 7 | 8 | from tqdm import tqdm 9 | 10 | 11 | @numba.njit(cache=True) 12 | def _module_get_threshold_from_range(weights, trange): 13 | assert len(weights.shape) == 1, "Weights must be 1D" 14 | # Assumes sorted weights, O(1) 15 | q1 = weights[len(weights) // 4] 16 | q3 = weights[3 * len(weights) // 4] 17 | low = q1 - trange * (q3 - q1) 18 | high = q3 + trange * (q3 - q1) 19 | larger_abs = max(abs(low), abs(high)) 20 | return larger_abs 21 | 22 | 23 | @numba.njit(cache=True) 24 | def _module_get_outlier_count_from_threshold(weights, threshold): 25 | # Assumes sorted weights, O(log n) 26 | return np.searchsorted(weights, -threshold) + len(weights) - np.searchsorted(weights, threshold) 27 | 28 | 29 | @numba.njit(cache=True) 30 | def _module_get_outlier_count_from_range(weights, trange): 31 | # Assumes sorted weights, O(log n) 32 | threshold = _module_get_threshold_from_range(weights, trange) 33 | return _module_get_outlier_count_from_threshold(weights, threshold), threshold 34 | 35 | 36 | @numba.njit(cache=True) 37 | def _get_outlier_count_from_range(trange, sorted_flattened_weights): 38 | total_outliers = 0 39 | thresholds = np.empty(len(sorted_flattened_weights), dtype=np.float32) 40 | for i, module_weight in enumerate(sorted_flattened_weights): 41 | num_outliers, threshold = _module_get_outlier_count_from_range(module_weight, trange) 42 | thresholds[i] = threshold 43 | total_outliers += num_outliers 44 | return total_outliers, thresholds 45 | 46 | 47 | def _process_module(module_data): 48 | layer_index, module_name, module_weight = module_data 49 | sorted_weights = np.sort(module_weight.flatten()).astype(np.float32) # fp32 for numba 50 | total_params = module_weight.numel() 51 | return layer_index, module_name, sorted_weights, total_params 52 | 53 | 54 | def _find_thresholds(analyzer, outlier_percent, tolerance=0.0001): 55 | assert outlier_percent < 50, "Outlier ratio must be less than 0.5" 56 | 57 | model_weights = [analyzer.get_layer_weights(l) for l in range(analyzer.num_layers)] 58 | 59 | tasks = [] 60 | for layer_index, model_layer in enumerate(model_weights): 61 | for module_name in analyzer.module_names: 62 | module_weight = model_layer[module_name] 63 | tasks.append((layer_index, module_name, module_weight)) 64 | 65 | sorted_flattened_weights = [] 66 | total_params = 0 67 | 68 | results = process_map(_process_module, tasks, chunksize=1, max_workers=None, desc="Preprocessing weights") 69 | 70 | for layer_index, module_name, sorted_weights, params in results: 71 | sorted_flattened_weights.append(sorted_weights) 72 | total_params += params 73 | 74 | # Find the trange by binary search 75 | low = 0 76 | high = 32 # this seems like an extra overkill upper bound but adjust if necessary 77 | thresholds = None 78 | logging.info(f"Begin trange search for outlier percent {outlier_percent}%") 79 | while low < high: 80 | mid = (low + high) / 2 # Note that this is a float as we are searching for a float threshold 81 | total_outliers, thresholds = _get_outlier_count_from_range(mid, sorted_flattened_weights) 82 | percent = total_outliers / total_params * 100 83 | logging.info(f"Search range: [{low:.5f}, {high:.4f}] Threshold: {mid:.5f}, Outlier ratio: {percent:.5f}%") 84 | if abs(percent - outlier_percent) < tolerance: 85 | logging.info(f"Found threshold: {mid:.5f}, Outlier ratio: {percent:.5f}%") 86 | break 87 | elif percent < outlier_percent: 88 | high = mid 89 | else: 90 | low = mid 91 | 92 | thresholds_by_module = [{} for _ in range(analyzer.num_layers)] 93 | idx = 0 94 | for layer_index in range(analyzer.num_layers): 95 | for module_name in analyzer.module_names: 96 | thresholds_by_module[layer_index][module_name] = thresholds[idx] 97 | idx += 1 98 | 99 | return thresholds_by_module 100 | 101 | 102 | def _remove_outliers_by_threshold(analyzer, thresholds_by_module): 103 | sparse_model_weights = [] 104 | 105 | for l in tqdm(range(analyzer.num_layers), desc="Removing threshold outliers"): 106 | model_layer = analyzer.get_layer_weights(l) 107 | sparse_model_weights_by_layer = {} 108 | for module_name in analyzer.module_names: 109 | module_weight = model_layer[module_name] 110 | threshold = thresholds_by_module[l][module_name] 111 | dense_mask = torch.abs(module_weight) < threshold 112 | sparse_mask = ~dense_mask 113 | # Save the sparse weights 114 | sparse_model_weights_by_layer[module_name] = module_weight * sparse_mask 115 | # Zero out the outliers 116 | module_weight[sparse_mask] = 0 117 | 118 | sparse_model_weights.append(sparse_model_weights_by_layer) 119 | 120 | return sparse_model_weights 121 | 122 | 123 | def _remove_outliers_by_sensitivity(analyzer, gradients, sensitivity_outlier_percent): 124 | sparse_model_weights = [] 125 | 126 | for l in tqdm(range(analyzer.num_layers), desc="Removing sensitivity outliers"): 127 | model_layer = analyzer.get_layer_weights(l) 128 | sparse_model_weights_by_layer = {} 129 | for module_name in analyzer.module_names: 130 | module_weight = model_layer[module_name] 131 | gradient = gradients[l][module_name] # this is a torch tensor 132 | # get the top sensitivity_outlier_percent% of the gradients 133 | topk = int(sensitivity_outlier_percent * gradient.numel() / 100) 134 | _, indices = torch.topk(gradient.abs().flatten(), topk) 135 | sparse_mask = torch.zeros_like(module_weight, dtype=torch.bool) 136 | sparse_mask.view(-1)[indices] = 1 137 | # Save the sparse weights 138 | sparse_model_weights_by_layer[module_name] = module_weight * sparse_mask 139 | # Zero out the outliers 140 | module_weight[sparse_mask] = 0 141 | 142 | sparse_model_weights.append(sparse_model_weights_by_layer) 143 | 144 | return sparse_model_weights 145 | 146 | 147 | def remove_outliers(analyzer, gradients, sensitivity_outlier_percent, threshold_outlier_percent): 148 | # This removes the sensitivity outliers from the weights stored in the analyzer, and returns the removed weights 149 | sparse_model_weights_1 = _remove_outliers_by_sensitivity(analyzer, gradients, sensitivity_outlier_percent) 150 | 151 | # Find the thresholds to achieve the desired outlier percentage 152 | thresholds_by_module = _find_thresholds(analyzer, threshold_outlier_percent) 153 | 154 | # This removes the threshold outliers from the weights stored in the analyzer, and returns the removed weights 155 | sparse_model_weights_2 = _remove_outliers_by_threshold(analyzer, thresholds_by_module) 156 | 157 | # Add the two sets of sparse weights 158 | sparse_model_weights = [] 159 | for l in range(analyzer.num_layers): 160 | sparse_model_weights_by_layer = {} 161 | for module_name in analyzer.module_names: 162 | sparse_model_weights_by_layer[module_name] = (sparse_model_weights_1[l][module_name] + 163 | sparse_model_weights_2[l][module_name]).to_sparse() 164 | sparse_model_weights.append(sparse_model_weights_by_layer) 165 | 166 | return sparse_model_weights 167 | -------------------------------------------------------------------------------- /any_precision/quantization/gradients.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import os 4 | import logging 5 | from .config import * 6 | from any_precision.analyzer.analyzer import get_analyzer 7 | from .datautils import get_tokens 8 | 9 | 10 | def get_gradients( 11 | analyzer, 12 | dataset=DEFAULT_DATASET, 13 | seq_len=DEFAULT_SEQ_LEN, 14 | num_examples=DEFAULT_NUM_EXAMPLES, 15 | save_path=None, 16 | random_state=None, 17 | ): 18 | if save_path is not None and os.path.isfile(save_path): 19 | logging.info(f"Gradients already calculated and saved at {save_path}.") 20 | logging.info(f"Loading gradients...") 21 | return torch.load(save_path) 22 | logging.info(f"Calculating gradients on dataset {dataset} with sequence length {seq_len} and " 23 | f"{num_examples} examples...") 24 | 25 | model = analyzer.model 26 | tokenizer = analyzer.tokenizer 27 | 28 | input_tokens = get_tokens(dataset, 'train', tokenizer, seq_len, num_examples, seed=random_state) 29 | 30 | if analyzer is None: 31 | analyzer = get_analyzer(model) 32 | 33 | model = model.bfloat16() 34 | model.eval() 35 | 36 | if model.device.type != 'cuda': 37 | model.cuda() 38 | 39 | layers = analyzer.get_layers() 40 | 41 | # Register hook to store the square of the gradients 42 | def square_grad_hook(grad): 43 | return grad.pow(2) 44 | 45 | hooks = [] 46 | 47 | for layer in layers: 48 | for module in analyzer.get_modules(layer).values(): 49 | hooks.append(module.weight.register_hook(square_grad_hook)) 50 | 51 | # Calculate gradients through loss.backward() 52 | for tokens in tqdm(input_tokens, desc="Calculating gradients"): 53 | tokens = tokens.to(model.device) 54 | tokens = tokens.unsqueeze(0) 55 | outputs = model(input_ids=tokens, labels=tokens) 56 | loss = outputs.loss 57 | loss.backward() 58 | 59 | # Remove hooks 60 | for hook in hooks: 61 | hook.remove() 62 | 63 | # Move model back to cpu 64 | model.cpu() 65 | 66 | # Harvest the gradients 67 | gradients = [] 68 | for layer in layers: 69 | gradients_per_layer = {} 70 | for module_name, module in analyzer.get_modules(layer).items(): 71 | gradients_per_layer[module_name] = module.weight.grad 72 | gradients.append(gradients_per_layer) 73 | 74 | # Save the gradients to file 75 | # Note that when saving, the gradients are stored as bf16, 76 | # but are converted to np.float32 before returning, for the next steps in the pipeline 77 | if save_path is not None: 78 | logging.info(f"Saving gradients to {save_path}...") 79 | # add file extension if not present 80 | if not save_path.endswith('.pt'): 81 | save_path = save_path + '.pt' 82 | # check if the file already exists 83 | if os.path.exists(save_path): 84 | input(f"[WARNING] File {save_path} already exists. Press enter to overwrite or Ctrl+C to cancel.") 85 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 86 | torch.save(gradients, save_path) 87 | 88 | return gradients 89 | -------------------------------------------------------------------------------- /any_precision/quantization/main.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import os.path 4 | import shutil 5 | import logging 6 | 7 | from .config import * 8 | from ..analyzer import get_analyzer 9 | from .gradients import get_gradients 10 | from .quantize import seed_and_upscale 11 | from .pack import pack 12 | from .dense_and_sparse import remove_outliers 13 | import torch 14 | 15 | # Disable parallelism in tokenizers to prevent warnings when forking in the seed generation step 16 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 17 | 18 | # Logging with time sans date, level name, and message 19 | logging.basicConfig(level=logging.INFO, format='[%(asctime)s | %(levelname)s] %(message)s', datefmt='%H:%M:%S') 20 | 21 | 22 | def any_precision_quantize( 23 | model, 24 | seed_precision=DEFAULT_SEED_PRECISION, 25 | parent_precision=DEFAULT_PARENT_PRECISION, 26 | mode='pack', 27 | yaml_path=None, cache_dir=DEFAULT_CACHE_DIR, 28 | dataset=DEFAULT_DATASET, seq_len=DEFAULT_SEQ_LEN, num_examples=DEFAULT_NUM_EXAMPLES, 29 | cpu_count=os.cpu_count(), 30 | overwrite_gradients=False, 31 | overwrite_quantize=False, 32 | overwrite_pack=False, 33 | random_state=None, 34 | group_count=1, 35 | dns=False, 36 | sensitivity_outlier_percent=0.05, 37 | threshold_outlier_percent=0.40 38 | ): 39 | assert mode in ['gradients', 'quantize', 'pack'], \ 40 | "mode must be one of 'gradients', 'quantize', or 'pack'. Use 'pack' to run the entire pipeline." 41 | 42 | if overwrite_gradients: 43 | if not overwrite_quantize: 44 | logging.warning("Parent model needs to be recalculated if gradients are recalculated. " 45 | "Setting overwrite_quantize to True.") 46 | overwrite_quantize = True 47 | 48 | if overwrite_quantize: 49 | if not overwrite_pack: 50 | logging.warning("Packed model needs to be recalculated if parent model is recalculated. " 51 | "Setting overwrite_pack to True.") 52 | overwrite_pack = True 53 | 54 | if mode == 'gradients': 55 | logging.info("Running: [Gradients]") 56 | elif mode == 'quantize': 57 | logging.info("Running: [Gradients -> Quantize]") 58 | else: 59 | logging.info("Running: [Gradients -> Quantize -> Pack]") 60 | 61 | model_string = model if isinstance(model, str) else model.name_or_path 62 | model_name = model_string.split("/")[-1] 63 | 64 | logging.info(f"Running Any-Precision Quantization on {model_name} with seed precision {seed_precision} and " 65 | f"parent precision {parent_precision} using {dataset} for gradient calculation") 66 | 67 | # ------------------- Load model ------------------- 68 | 69 | analyzer = get_analyzer(model, yaml_path=yaml_path, include_tokenizer=True) 70 | 71 | # ------------------- Set cache paths ------------------- 72 | 73 | gradients_cache_path = (f"{cache_dir}/gradients/" 74 | f"({model_name})-{dataset}_s{num_examples}_blk{seq_len}.pt") 75 | 76 | quantized_cache_path = (f"{cache_dir}/quantized/" 77 | f"{'dns-' if dns else ''}({model_name})-w{parent_precision}_orig{seed_precision}" 78 | f"-gc{group_count}-{dataset}_s{num_examples}_blk{seq_len}") 79 | 80 | model_output_path = (f"{cache_dir}/packed/" 81 | f"anyprec-({model_name})-w{parent_precision}_orig{seed_precision}" 82 | f"-gc{group_count}-{dataset}_s{num_examples}_blk{seq_len}") 83 | 84 | # ------------------- Gradients ------------------- 85 | 86 | logging.info("------------------- Gradients -------------------") 87 | 88 | logging.info("Beginning gradient calculation...") 89 | # Calculate or load gradients 90 | if overwrite_gradients and os.path.exists(gradients_cache_path): 91 | # if the user wants to recalculate the gradients, delete the cached gradients 92 | logging.info(f"Detected cached gradients at {gradients_cache_path}. Will delete and recalculate.") 93 | os.remove(gradients_cache_path) 94 | 95 | # this will load and return the gradients if they exist, or calculate them if they don't 96 | model_gradients = get_gradients( 97 | analyzer=analyzer, 98 | dataset=dataset, 99 | seq_len=seq_len, 100 | num_examples=num_examples, 101 | save_path=gradients_cache_path, 102 | random_state=random_state, 103 | ) 104 | logging.info("Gradient calculation complete.") 105 | 106 | if mode == 'gradients': 107 | return 108 | 109 | # ------------------- Dense & Sparse ------------------- 110 | 111 | if dns: 112 | logging.info("------------------- Dense & Sparse -------------------") 113 | sparse_model_weights = remove_outliers( 114 | analyzer=analyzer, 115 | gradients=model_gradients, 116 | sensitivity_outlier_percent=sensitivity_outlier_percent, 117 | threshold_outlier_percent=threshold_outlier_percent, 118 | ) 119 | 120 | sparse_path = f"{quantized_cache_path}/sparse" 121 | os.makedirs(sparse_path, exist_ok=True) 122 | for l in range(analyzer.num_layers): 123 | torch.save(sparse_model_weights[l], f"{sparse_path}/l{l}.pt") 124 | 125 | del sparse_model_weights 126 | 127 | # ------------------- Quantize: Seed + Upscale ------------------- 128 | 129 | logging.info("------------------- Quantize: Seed + Upscale -------------------") 130 | 131 | # Calculate or load parent 132 | logging.info(f"Beginning {seed_precision}~{parent_precision}-bit Any-Precision Quantization...") 133 | # Note that this saves the seed model to the cache path and must be loaded for the upscale step 134 | if overwrite_quantize and os.path.exists(quantized_cache_path): 135 | # if the user wants to recalculate the seed, delete the cached seed 136 | logging.info(f"Detected cached parent at {quantized_cache_path}. Will delete and recalculate.") 137 | shutil.rmtree(quantized_cache_path) 138 | 139 | # this skips over existing layers in the cache, and doesn't overwrite them 140 | seed_and_upscale( 141 | analyzer=analyzer, 142 | gradients=model_gradients, 143 | output_folder=quantized_cache_path, 144 | seed_precision=seed_precision, 145 | parent_precision=parent_precision, 146 | cpu_count=cpu_count, 147 | random_state=random_state, 148 | group_count=group_count, 149 | ) 150 | 151 | if mode == 'quantize': 152 | return 153 | 154 | del model_gradients # free up memory 155 | analyzer.drop_original_weights() # drop the original weights to save memory 156 | 157 | logging.info("Quantization(Seed + Upscale) complete.") 158 | 159 | # ------------------- Pack ------------------- 160 | logging.info("------------------- Pack -------------------") 161 | 162 | # check for non-empty directory 163 | if os.path.exists(model_output_path) and os.path.isdir(model_output_path) and os.listdir(model_output_path): 164 | if overwrite_pack: 165 | logging.info(f"Model output path {model_output_path} already exists and is not empty. Will delete and " 166 | f"re-pack.") 167 | shutil.rmtree(model_output_path) 168 | else: 169 | # if the user doesn't want to overwrite the pack, but the directory is not empty, skip packing 170 | logging.info(f"Model output path {model_output_path} already exists and is not empty. Will skip packing.") 171 | return 172 | 173 | pack( 174 | analyzer=analyzer, 175 | lut_path=quantized_cache_path, 176 | output_model_path=model_output_path, 177 | seed_precision=seed_precision, 178 | parent_precision=parent_precision, 179 | cpu_count=cpu_count, 180 | group_count=group_count, 181 | dns=dns, 182 | ) 183 | 184 | logging.info("Packing complete.") 185 | -------------------------------------------------------------------------------- /any_precision/quantization/pack.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import os 4 | import torch 5 | import logging 6 | from multiprocessing import Pool 7 | import numba 8 | 9 | _bytes_per_thread = 4 10 | 11 | 12 | @numba.njit(cache=True) 13 | def _permute_bitmaps(bitmaps): 14 | _, _, total_bytes = bitmaps.shape 15 | assert total_bytes % 4 == 0, "Number of bytes must be a multiple of 4" 16 | 17 | threads_per_warp = 32 18 | bytes_per_warp = threads_per_warp * _bytes_per_thread 19 | 20 | # Calculate the number of full warps and the starting index of remaining bytes 21 | full_warps_bytes = (total_bytes // bytes_per_warp) * bytes_per_warp 22 | remaining_bytes_start_idx = full_warps_bytes 23 | 24 | # Create an array of byte indices for full warps 25 | full_warp_byte_indices = np.arange(full_warps_bytes) 26 | # Calculate new indices for full warp bytes 27 | new_full_warp_byte_indices = _calculate_new_indices(full_warp_byte_indices, threads_per_warp) 28 | 29 | remaining_bytes = total_bytes - full_warps_bytes 30 | # Handle remaining bytes 31 | if remaining_bytes: 32 | remaining_byte_indices = np.arange(remaining_bytes) 33 | # Adjust the calculation for remaining bytes, which might not fill a complete warp 34 | adjusted_threads_per_warp = remaining_byte_indices.size // _bytes_per_thread 35 | new_remaining_byte_indices = _calculate_new_indices(remaining_byte_indices, 36 | adjusted_threads_per_warp, 37 | offset=remaining_bytes_start_idx) 38 | 39 | # Combine indices - the choice to not use np.concatenate is for numba compatibility 40 | new_byte_indices = np.empty(total_bytes, dtype=np.int64) 41 | new_byte_indices[:full_warps_bytes] = new_full_warp_byte_indices 42 | new_byte_indices[full_warps_bytes:] = new_remaining_byte_indices 43 | else: 44 | new_byte_indices = new_full_warp_byte_indices 45 | 46 | permuted_bitmaps = bitmaps[:, :, np.argsort(new_byte_indices)] 47 | 48 | return permuted_bitmaps 49 | 50 | 51 | @numba.njit(cache=True) 52 | def _calculate_new_indices(byte_indices, threads_per_warp, offset=0): 53 | """ 54 | Calculate new byte indices for a given array of byte indices. 55 | """ 56 | bytes_per_warp = threads_per_warp * _bytes_per_thread 57 | 58 | warp_idx, byte_offsets_within_warp = np.divmod(byte_indices, bytes_per_warp) 59 | 60 | warp_offsets = warp_idx * bytes_per_warp 61 | thread_indices = byte_indices % threads_per_warp 62 | 63 | # Change endianness within each thread and calculate new byte positions 64 | byte_offsets_within_thread = byte_offsets_within_warp // threads_per_warp 65 | byte_offsets_within_thread ^= 3 # Change endianness 66 | new_byte_indices = warp_offsets + thread_indices * _bytes_per_thread + byte_offsets_within_thread + offset 67 | 68 | return new_byte_indices 69 | 70 | 71 | @numba.njit(cache=True) 72 | def _permute_bitmaps_int32(bitmaps): 73 | """Return a permuted version of the input bitmaps, reshaped to int32.""" 74 | w_bits, N, total_bytes = bitmaps.shape 75 | bitmaps = _permute_bitmaps(bitmaps) 76 | return bitmaps.reshape(-1, 4).view(np.int32).reshape(w_bits, N, total_bytes // 4) 77 | 78 | 79 | def _process_layer_data(args): 80 | layer_idx, lut_path, model_name, layers_name, module_names, parent_precision, seed_precision = args 81 | layer_data = {} 82 | 83 | weightpath = os.path.join(lut_path, 'weights', f'l{layer_idx}.pt') 84 | layer_weights = torch.load(weightpath) 85 | 86 | for i, name in enumerate(module_names): 87 | N, group_count, group_size = layer_weights[name].shape 88 | K = group_count * group_size 89 | 90 | qweight_flattened = layer_weights[name].flatten() 91 | bitarray = np.empty((parent_precision, len(qweight_flattened) // 8), dtype=np.uint8) 92 | mask = 1 << (parent_precision - 1) # MSB first 93 | for bit in range(parent_precision): 94 | curbitpack = np.packbits((qweight_flattened & mask).astype(bool)) 95 | bitarray[bit] = curbitpack 96 | mask >>= 1 97 | 98 | bitarray = bitarray.reshape((parent_precision, N, K // 8)) 99 | weighttensor = _permute_bitmaps_int32(bitarray) 100 | 101 | param_name = f'{model_name}.{layers_name}.{layer_idx}.{name}' 102 | layer_data[param_name + '.qweight'] = weighttensor 103 | 104 | for bit in range(seed_precision, parent_precision + 1): 105 | layer_lut_path = os.path.join(lut_path, f'lut_{bit}', f'l{layer_idx}.pt') 106 | layer_lut = torch.load(layer_lut_path) 107 | 108 | curLUT = np.empty((N, 2 ** bit), dtype=np.float16) 109 | for r_idx in range(N): 110 | curLUT[r_idx] = layer_lut[name][r_idx][0] # the 0 here assumes group_count == 1 111 | 112 | layer_data[param_name + '.lut' + str(bit)] = curLUT 113 | 114 | return layer_idx, layer_data 115 | 116 | 117 | def pack( 118 | analyzer, 119 | lut_path, 120 | output_model_path, 121 | seed_precision, 122 | parent_precision, 123 | group_count=1, 124 | dns=False, 125 | cpu_count=None 126 | ): 127 | 128 | if group_count != 1: 129 | raise NotImplementedError("Group counts other than 1 are not supported yet for packing") 130 | 131 | if dns: 132 | raise NotImplementedError("D&S packing is not supported yet") 133 | 134 | if cpu_count is None: 135 | cpu_count = os.cpu_count() 136 | 137 | # Limit cpu_count to 8 as larger values use too much memory, without much speedup 138 | _max_cpu_count = 8 139 | if cpu_count > _max_cpu_count: 140 | logging.warning(f"cpu_count will be limited to 8 to avoid excessive memory usage. " 141 | f"Original value: {cpu_count}") 142 | cpu_count = _max_cpu_count 143 | 144 | tokenizer = analyzer.tokenizer 145 | 146 | num_layers = analyzer.num_layers 147 | 148 | model_name = analyzer.model_name 149 | layers_name = analyzer.layers_name 150 | module_names = analyzer.module_names 151 | config = analyzer.config # original model config 152 | arch_config = analyzer.get_arch_config() 153 | 154 | state_dict = analyzer.state_dict 155 | 156 | args_list = [(layer_idx, lut_path, model_name, layers_name, module_names, parent_precision, seed_precision) for 157 | layer_idx in range(num_layers)] 158 | 159 | with Pool(cpu_count) as pool: 160 | for layer_idx, layer_data in tqdm(pool.imap(_process_layer_data, args_list), total=num_layers, desc="Packing"): 161 | for key, value in layer_data.items(): 162 | state_dict[key] = torch.from_numpy(value) # Update with modified weights 163 | 164 | # add new config parameters 165 | anyprec_configs = { 166 | 'seed_precision': seed_precision, 167 | 'parent_precision': parent_precision, 168 | 'group_count': group_count, 169 | 'arch_config': arch_config 170 | } 171 | config.anyprec = anyprec_configs 172 | 173 | logging.info(f"Writing model to disk...") 174 | os.makedirs(output_model_path, exist_ok=True) 175 | torch.save(state_dict, os.path.join(output_model_path, 'pytorch_model.bin')) 176 | tokenizer.save_pretrained(output_model_path) 177 | config.save_pretrained(output_model_path) 178 | logging.info(f"Model saved to {output_model_path}") 179 | -------------------------------------------------------------------------------- /any_precision/quantization/quantize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | import numba 7 | from concurrent.futures import ThreadPoolExecutor 8 | import flash1dkmeans 9 | 10 | 11 | @numba.njit(cache=True) 12 | def _upscale_group(orig_centroids, 13 | orig_cluster_borders, weights, 14 | weighted_X_prefix_sum, sample_weight_prefix_sum, 15 | seed_bit, parent_bit): 16 | """WARNING: labels, weights and sample_weight should be sorted by weights in ascending order""" 17 | luts_by_bit = [orig_centroids] 18 | 19 | cluster_borders = orig_cluster_borders 20 | 21 | # Run the upscale 22 | for i in range(seed_bit, parent_bit): 23 | centroids, cluster_borders = _increment_group(luts_by_bit[-1], cluster_borders, weights, 24 | weighted_X_prefix_sum, 25 | sample_weight_prefix_sum, i) 26 | luts_by_bit.append(centroids) 27 | 28 | return luts_by_bit, cluster_borders 29 | 30 | 31 | @numba.njit(cache=True) 32 | def _increment_group(orig_centroids, cluster_borders, weights, weighted_X_prefix_sum, 33 | sample_weight_prefix_sum, 34 | seed_bit): 35 | """WARNING: labels, weights and sample_weight should be sorted by weights in ascending order""" 36 | new_centroids = np.empty(2 ** (seed_bit + 1), dtype=np.float32) 37 | new_cluster_borders = np.empty(2 ** (seed_bit + 1) + 1, dtype=np.int32) 38 | 39 | assert len(orig_centroids) == 2 ** seed_bit, "The number of centroids should be 2^seed_bit" 40 | assert len(cluster_borders) == 2 ** seed_bit + 1, \ 41 | "The number of cluster start indices should be 2^seed_bit + 1" 42 | 43 | for c in range(2 ** seed_bit): 44 | start_idx = cluster_borders[c] 45 | stop_idx = cluster_borders[c + 1] 46 | 47 | if start_idx == stop_idx: 48 | # These are empty clusters, but we still need to save the centroids 49 | new_centroids[c * 2] = orig_centroids[c] 50 | new_centroids[c * 2 + 1] = orig_centroids[c] 51 | # new_cluster_borders still needs to be set 52 | new_cluster_borders[c * 2] = start_idx 53 | new_cluster_borders[c * 2 + 1] = start_idx 54 | continue 55 | 56 | cluster_centers, local_cluster_borders = flash1dkmeans.numba_kmeans_1d_two_cluster( 57 | sorted_X=weights, 58 | weights_prefix_sum=sample_weight_prefix_sum, 59 | weighted_X_prefix_sum=weighted_X_prefix_sum, 60 | start_idx=start_idx, 61 | stop_idx=stop_idx 62 | ) 63 | 64 | # local_cluster_borders is [start_idx, division_point, stop_idx] 65 | 66 | # save the new centroids and labels 67 | new_centroids[c * 2] = cluster_centers[0] 68 | new_centroids[c * 2 + 1] = cluster_centers[1] 69 | new_cluster_borders[c * 2] = start_idx 70 | new_cluster_borders[c * 2 + 1] = local_cluster_borders[1] 71 | 72 | new_cluster_borders[-1] = cluster_borders[-1] # the final border must be set manually 73 | 74 | return new_centroids, new_cluster_borders 75 | 76 | 77 | @numba.njit(parallel=True, cache=True) 78 | def _seed_and_upscale_layer(layer_gradients, layer_modules, seed_bit, parent_bit, group_count, random_state=None): 79 | # The shape of LUTs are different for each module and bit. 80 | # The logical thing to do would be to use a list of lists(for each bit-width) of numpy arrays(for each module). 81 | # However as numba doesn't like nested python lists, we will use a list of numpy arrays instead, 82 | # in such a way that we flatten the list of lists into a single list. 83 | lut_by_bit_by_module = [] 84 | parent_weights_by_modules = [] 85 | 86 | n_cluster = 2 ** seed_bit 87 | 88 | for m_idx in range(len(layer_modules)): 89 | module_gradient = layer_gradients[m_idx] 90 | module_weight = layer_modules[m_idx] 91 | 92 | row_count = module_weight.shape[0] 93 | group_size = module_weight.shape[1] // group_count 94 | 95 | assert group_size * group_count == module_weight.shape[1], \ 96 | f"Group count {group_count} does not divide the number of columns {module_weight.shape[1]}" 97 | 98 | parent_weights = np.empty((row_count, group_count, group_size), dtype=np.float32) 99 | 100 | lut_by_bit = [] 101 | for bit in range(seed_bit, parent_bit + 1): 102 | lut_by_bit.append(np.empty((row_count, group_count, 2 ** bit), dtype=np.float32)) 103 | 104 | for r_idx in numba.prange(module_weight.shape[0]): 105 | for g_idx in range(group_count): 106 | start_col_idx = g_idx * group_size 107 | end_col_idx = (g_idx + 1) * group_size 108 | 109 | weights_np = module_weight[r_idx, start_col_idx:end_col_idx] 110 | 111 | weight_mask = weights_np != 0 112 | sample_weight = module_gradient[r_idx, start_col_idx:end_col_idx] 113 | sample_weight = sample_weight * weight_mask 114 | 115 | # ---------------- Preprocessing ---------------- 116 | 117 | # Use fp64 to avoid precision loss in prefix sum subtraction 118 | X = weights_np 119 | sorted_indices = np.argsort(X) 120 | 121 | sorted_X = X[sorted_indices] 122 | sorted_weights = sample_weight[sorted_indices] 123 | 124 | sorted_X_fp64 = sorted_X.astype(np.float64) 125 | sorted_weights_fp64 = sorted_weights.astype(np.float64) 126 | sorted_weights_prefix_sum = np.cumsum(sorted_weights_fp64) 127 | 128 | if sorted_weights_prefix_sum[-1] == 0: 129 | # If the sum of the sample weights is zero, we act as if the sample weights are all 1 130 | sorted_weights_prefix_sum = np.arange(1, len(sorted_weights_fp64) + 1, dtype=np.float64) 131 | sorted_weighted_X_prefix_sum = np.cumsum(sorted_X_fp64) 132 | sorted_weighted_X_squared_prefix_sum = np.cumsum(sorted_X_fp64 ** 2) 133 | else: 134 | # Else we proceed with the normal prefix sum calculations 135 | sorted_weighted_X_fp64 = sorted_X_fp64 * sorted_weights_fp64 136 | sorted_weighted_X_squared_fp64 = sorted_weighted_X_fp64 * sorted_X_fp64 137 | 138 | sorted_weighted_X_prefix_sum = np.cumsum(sorted_weighted_X_fp64) 139 | sorted_weighted_X_squared_prefix_sum = np.cumsum(sorted_weighted_X_squared_fp64) 140 | 141 | # ---------------- Seed ---------------- 142 | 143 | # Generate the seed weights 144 | 145 | centroids, cluster_borders = flash1dkmeans.numba_kmeans_1d_k_cluster( 146 | sorted_X=sorted_X, 147 | n_clusters=n_cluster, 148 | max_iter=50, 149 | weights_prefix_sum=sorted_weights_prefix_sum, 150 | weighted_X_prefix_sum=sorted_weighted_X_prefix_sum, 151 | weighted_X_squared_prefix_sum=sorted_weighted_X_squared_prefix_sum, 152 | start_idx=0, 153 | stop_idx=len(sorted_X), 154 | random_state=random_state, 155 | ) 156 | 157 | centroids = centroids.astype(np.float32) 158 | 159 | # ---------------- Upscale ---------------- 160 | 161 | # Upscale the seed weights 162 | lut_per_bit, parent_cluster_borders = _upscale_group( 163 | centroids, cluster_borders, sorted_X, 164 | sorted_weighted_X_prefix_sum, sorted_weights_prefix_sum, 165 | seed_bit, parent_bit) 166 | 167 | # ---------------- Postprocessing ---------------- 168 | 169 | # Save the LUTs 170 | for k, bit in enumerate(range(seed_bit, parent_bit + 1)): 171 | lut_by_bit[k][r_idx][g_idx] = lut_per_bit[k] 172 | 173 | # Convert cluster_borders back to labels 174 | labels = np.empty(group_size, dtype=np.uint8) 175 | for i in range(2 ** parent_bit): 176 | labels[parent_cluster_borders[i]:parent_cluster_borders[i + 1]] = i 177 | 178 | # Unsort the labels 179 | labels = labels[np.argsort(sorted_indices)] 180 | 181 | # Save the parent weights 182 | parent_weights[r_idx][g_idx] = labels 183 | 184 | parent_weights_by_modules.append(parent_weights) 185 | lut_by_bit_by_module.append(lut_by_bit) 186 | 187 | return lut_by_bit_by_module, parent_weights_by_modules 188 | 189 | 190 | def _get_layer_loader(analyzer, gradients): 191 | def layer_loader(l): 192 | # Convert from torch.bf16 to np.fp32 for numba processing 193 | # Only converts one layer at a time to avoid excessive memory usage 194 | gradient_layer = [gradients[l][name].float().numpy() for name in analyzer.module_names] 195 | model_layer = [analyzer.get_layer_weights(l)[name].float().numpy() for name in analyzer.module_names] 196 | return gradient_layer, model_layer 197 | 198 | return layer_loader 199 | 200 | 201 | def _save_results(parent_parameters_path, seed_precision, parent_precision, module_names, 202 | luts_by_bit_by_module, parent_weights, l): 203 | # Note that it is important to cast the luts to fp16 before saving them, 204 | # as we previously cast them to fp32 to use with numba 205 | for i, bit in enumerate(range(seed_precision, parent_precision + 1)): 206 | output_lut_file_name = f"{parent_parameters_path}/lut_{bit}/l{l}.pt" 207 | os.makedirs(os.path.dirname(output_lut_file_name), exist_ok=True) 208 | lut_dict = {} 209 | for j in range(len(module_names)): 210 | lut_dict[module_names[j]] = luts_by_bit_by_module[j][i].astype(np.float16) 211 | torch.save(lut_dict, output_lut_file_name) 212 | 213 | parent_weight_dict = {module_names[j]: parent_weights[j].astype(np.uint8) 214 | for j in range(len(module_names))} 215 | 216 | output_weights_layer_file_name = f"{parent_parameters_path}/weights/l{l}.pt" 217 | os.makedirs(os.path.dirname(output_weights_layer_file_name), exist_ok=True) 218 | torch.save(parent_weight_dict, output_weights_layer_file_name) 219 | 220 | 221 | def _get_saver(parent_parameters_path, seed_precision, parent_precision, module_names): 222 | """Returns a function that saves the results for a given layer""" 223 | 224 | def save_results(luts_by_bit_by_module, parent_weights, l): 225 | return _save_results(parent_parameters_path, seed_precision, parent_precision, module_names, 226 | luts_by_bit_by_module, parent_weights, l) 227 | 228 | return save_results 229 | 230 | 231 | def _load_progress(parent_parameters_path, seed_precision, parent_precision, layer_count): 232 | # Check if the layer has already been processed 233 | todo_ran = [] 234 | processed_ran = [] 235 | for l in range(layer_count): 236 | if all([os.path.exists(f"{parent_parameters_path}/lut_{bit}/l{l}.pt") 237 | for bit in range(seed_precision, parent_precision + 1)]) and \ 238 | os.path.exists(f"{parent_parameters_path}/weights/l{l}.pt"): 239 | processed_ran.append(l) 240 | else: 241 | todo_ran.append(l) 242 | return todo_ran, processed_ran 243 | 244 | 245 | def seed_and_upscale( 246 | analyzer, 247 | gradients, 248 | output_folder, 249 | seed_precision, 250 | parent_precision, 251 | cpu_count=None, 252 | random_state=None, 253 | group_count=1, 254 | ): 255 | assert seed_precision <= parent_precision, "Parent precision should be equal or higher than seed precision" 256 | 257 | if cpu_count is None: 258 | cpu_count = os.cpu_count() 259 | # Determine IO and threading settings based on the number of cores 260 | if cpu_count >= 8: 261 | pipelined_io = True 262 | io_workers = 2 if cpu_count >= 64 else 1 263 | numba.set_num_threads(cpu_count - io_workers) 264 | else: 265 | pipelined_io = False 266 | io_workers = 0 # No separate IO workers needed for non-pipelined IO 267 | numba.set_num_threads(cpu_count) 268 | 269 | logging.info(f"Using {cpu_count} cores for parallelization") 270 | 271 | logging.info(f"Seeding & Upscaling from {seed_precision}-bit to {parent_precision}-bit") 272 | 273 | layers_to_process, completed_layers = _load_progress(output_folder, seed_precision, parent_precision, 274 | analyzer.num_layers) 275 | 276 | if completed_layers: 277 | logging.info(f"The following layers will be skipped as they have already been processed:\n{completed_layers}") 278 | logging.info(f"To reprocess these layers, delete the corresponding files in {output_folder}") 279 | 280 | if not layers_to_process: 281 | logging.info("All layers have already been processed. Exiting...") 282 | return 283 | 284 | logging.info(f"Quantizing layers {layers_to_process}") 285 | 286 | layer_loader = _get_layer_loader(analyzer, gradients) 287 | layer_saver = _get_saver(output_folder, seed_precision, parent_precision, analyzer.module_names) 288 | 289 | if pipelined_io: 290 | with ThreadPoolExecutor(max_workers=io_workers) as io_executor: 291 | for l in tqdm(layers_to_process, desc="Quantizing layers..."): 292 | if l == layers_to_process[0]: 293 | future_load = io_executor.submit(layer_loader, l) 294 | 295 | gradient_layer, model_layer = future_load.result() 296 | 297 | if l != layers_to_process[-1]: 298 | future_load = io_executor.submit(layer_loader, l + 1) 299 | 300 | luts_by_bit_by_module, parent_weights = _seed_and_upscale_layer( 301 | gradient_layer, 302 | model_layer, 303 | seed_precision, 304 | parent_precision, 305 | group_count, 306 | random_state=random_state, 307 | ) 308 | 309 | io_executor.submit(layer_saver, luts_by_bit_by_module, parent_weights, l) 310 | logging.info("Waiting for IO to finish...") 311 | else: 312 | for l in tqdm(layers_to_process, desc="Quantizing layers..."): 313 | gradient_layer, model_layer = layer_loader(l) 314 | 315 | luts_by_bit_by_module, parent_weights = _seed_and_upscale_layer( 316 | gradient_layer, 317 | model_layer, 318 | seed_precision, 319 | parent_precision, 320 | group_count, 321 | random_state=random_state 322 | ) 323 | 324 | layer_saver(luts_by_bit_by_module, parent_weights, l) 325 | -------------------------------------------------------------------------------- /any_precision/quantization/utils.py: -------------------------------------------------------------------------------- 1 | import numba 2 | 3 | 4 | @numba.njit(cache=True) 5 | def query_prefix_sum(arr_prefix_sum, start, stop): 6 | """Returns the sum of elements in the range [start, stop) of arr.""" 7 | return arr_prefix_sum[stop - 1] - arr_prefix_sum[start - 1] if start > 0 else arr_prefix_sum[stop - 1] 8 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from any_precision import AnyPrecisionForCausalLM 3 | from transformers import AutoTokenizer, TextStreamer, AutoModelForCausalLM 4 | import logging 5 | import time 6 | from argparse import ArgumentParser 7 | 8 | # Logging with time sans date, level name, and message 9 | logging.basicConfig(level=logging.INFO, format='[%(asctime)s | %(levelname)s] %(message)s', datefmt='%H:%M:%S') 10 | 11 | parser = ArgumentParser() 12 | parser.add_argument('-p', '--precisions', nargs='+', type=int, default=None, 13 | help="The precisions to benchmark. If not specified, all available precisions will be benchmarked." 14 | ) 15 | 16 | args = parser.parse_args() 17 | 18 | if __name__ == '__main__': 19 | model_path = './cache/packed/anyprec-(Llama-2-7b-chat-hf)-w8_orig3-gc1-c4_s100_blk512' 20 | original_model_path = 'meta-llama/Llama-2-7b-chat-hf' 21 | 22 | # Configure the precisions to benchmark 23 | do_fp16 = True 24 | if args.precisions is not None: 25 | precisions = args.precisions 26 | if 16 in precisions: 27 | precisions.remove(16) 28 | else: 29 | do_fp16 = False 30 | else: 31 | precisions = None # Benchmark all available precisions 32 | 33 | # Load model and tokenizer 34 | tokenizer = AutoTokenizer.from_pretrained(model_path) 35 | streamer = TextStreamer(tokenizer) 36 | 37 | model = AnyPrecisionForCausalLM.from_quantized(model_path, precisions=precisions) 38 | model = model.eval().cuda() 39 | 40 | # Warm up CUDA cache for stable performance 41 | print("~~~~~~~ Warming up CUDA cache ~~~~~~~") 42 | input_context = "A CUDA cache warm-up is needed to" 43 | input_ids = tokenizer.encode(input_context, return_tensors="pt").cuda() 44 | output = model.generate( 45 | input_ids, 46 | precision=min(model.precisions), 47 | max_new_tokens=32, 48 | pad_token_id=tokenizer.eos_token_id, 49 | streamer=streamer, 50 | ) 51 | print("~~~~~~~ Warm up complete ~~~~~~~\n") 52 | 53 | # Now begin bit-width benchmarking 54 | input_context = input("Prompt/Context: ") 55 | input_ids = tokenizer.encode(input_context, return_tensors="pt").cuda() 56 | 57 | results = {} 58 | 59 | for precision in model.precisions: 60 | print(f"=============== generation with {precision}-bit precision ===============") 61 | torch.cuda.synchronize() 62 | start_time = time.time() 63 | output = model.generate( 64 | input_ids, 65 | precision=precision, 66 | max_new_tokens=256, 67 | pad_token_id=tokenizer.eos_token_id, 68 | streamer=streamer, 69 | ) 70 | torch.cuda.synchronize() 71 | end_time = time.time() 72 | 73 | # Calculate generation speed 74 | token_count = len(output[0]) - len(input_ids[0]) 75 | tokens_per_second = token_count / (end_time - start_time) 76 | ms_per_token = 1 / tokens_per_second * 1000 77 | 78 | results[precision] = (tokens_per_second, ms_per_token) 79 | 80 | print(f"\n( Generation speed: {tokens_per_second:.1f} tok/s | Latency: {ms_per_token:.2f} ms/tok )\n") 81 | 82 | # Clear memory 83 | del model 84 | torch.cuda.empty_cache() 85 | 86 | if do_fp16: 87 | # Benchmark the original model 88 | print(f"=============== generation with fp16 precision ===============") 89 | model = AutoModelForCausalLM.from_pretrained(original_model_path, torch_dtype=torch.float16).eval().cuda() 90 | torch.cuda.synchronize() 91 | start_time = time.time() 92 | output = model.generate( 93 | input_ids, 94 | max_length=256, 95 | pad_token_id=tokenizer.eos_token_id, 96 | streamer=streamer, 97 | ) 98 | torch.cuda.synchronize() 99 | end_time = time.time() 100 | 101 | # Calculate generation speed 102 | token_count = len(output[0]) - len(input_ids[0]) 103 | tokens_per_second = token_count / (end_time - start_time) 104 | ms_per_token = 1 / tokens_per_second * 1000 105 | 106 | results[16] = (tokens_per_second, ms_per_token) 107 | 108 | print(f"\n( Generation speed: {tokens_per_second:.1f} tok/s | Latency: {ms_per_token:.2f} ms/tok )\n") 109 | 110 | print("=============== Summary ===============") 111 | print(f"\nModel: {model_path}\n") 112 | 113 | for precision, (tokens_per_second, ms_per_token) in results.items(): 114 | print(f"{precision}-bit: {tokens_per_second:.1f} tok/s | {ms_per_token:.2f} ms/tok") 115 | -------------------------------------------------------------------------------- /evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=0 python run_eval.py 4 | -------------------------------------------------------------------------------- /fake_pack.py: -------------------------------------------------------------------------------- 1 | from any_precision.evaluate import eval 2 | from any_precision.evaluate.helpers import utils 3 | 4 | parents = utils.get_subdirs('./cache/upscaled') 5 | 6 | for parent in parents: 7 | eval.fake_pack(parent, verbose=True) 8 | -------------------------------------------------------------------------------- /figures/incremental_upscaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SNU-ARC/any-precision-llm/baa9d0272510d6342fef562b5200c3f9454f9070/figures/incremental_upscaling.png -------------------------------------------------------------------------------- /figures/software_engine.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SNU-ARC/any-precision-llm/baa9d0272510d6342fef562b5200c3f9454f9070/figures/software_engine.png -------------------------------------------------------------------------------- /quantize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from any_precision.quantization import any_precision_quantize 3 | 4 | if __name__ == "__main__": 5 | parser = argparse.ArgumentParser(description="Quantize a model to any precision") 6 | parser.add_argument("model", type=str, help="The model to quantize") 7 | parser.add_argument("--seed_precision", type=int, help="The precision to quantize the seed to") 8 | parser.add_argument("--parent_precision", type=int, help="The precision to quantize the parent to") 9 | parser.add_argument("--mode", type=str, default="pack", help="The mode to run in") 10 | parser.add_argument("--yaml_path", type=str, help="The path to the architecture config yaml file") 11 | parser.add_argument("--cache_dir", type=str, help="The directory to cache results in") 12 | parser.add_argument("--dataset", type=str, help="The dataset to use") 13 | parser.add_argument("--seq_len", type=int, help="The sequence length to use") 14 | parser.add_argument("--num_examples", type=int, help="The number of examples to use") 15 | parser.add_argument("--cpu_count", type=int, help="The number of CPUs to use for parallelization") 16 | parser.add_argument('--overwrite_gradients', action="store_true", 17 | help="Whether to overwrite the gradients stored to disk") 18 | parser.add_argument("--overwrite_quantize", action="store_true", 19 | help="Whether to overwrite the parent model stored to disk") 20 | parser.add_argument("--overwrite_pack", action="store_true", 21 | help="Whether to overwrite the packed model stored to disk") 22 | parser.add_argument("--random_state", type=int, 23 | help="The random state to use for reproducibility\n" 24 | "[WARNING] May not be reproducible across different machines") 25 | parser.add_argument("--group_count", type=int, 26 | help="Experimental: Group count per row - the default is 1") 27 | parser.add_argument("--dns", action="store_true", 28 | help="REALLY Experimental: Whether to run Dense & Sparse quantization") 29 | 30 | args = parser.parse_args() 31 | 32 | # only pass options that are not None 33 | any_precision_quantize(**{k: v for k, v in args.__dict__.items() if v is not None}) 34 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy~=1.26.4 2 | torch~=2.2.2 3 | transformers~=4.39.3 4 | tqdm~=4.66.2 5 | numba~=0.60.0 6 | datasets~=2.17.0 7 | accelerate~=0.29.2 8 | setuptools~=68.2.0 9 | pandas~=2.2.0 10 | safetensors~=0.4.2 11 | threadpoolctl~=3.2.0 12 | pyyaml~=6.0.1 13 | attributedict~=0.3.0 14 | flash1dkmeans==0.2.2 15 | lm-eval==0.4.3 16 | -------------------------------------------------------------------------------- /run_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | from any_precision.evaluate.helpers import utils 6 | from any_precision.evaluate import eval 7 | 8 | print("""This script will evaluate all models in the cache directory by: 9 | 1. Calculating perplexity on specified datasets, and 10 | 2. Evaluating downstream tasks using lm_eval on specified tasks. 11 | 12 | To view and modify the datasets and tasks to be evaluated, please modify this script directly. 13 | Also check the provided command line arguments for more options. 14 | """) 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--output_file', type=str, default='results.json') 18 | parser.add_argument('--redo', action='store_true') 19 | parser.add_argument('--cache_dir', type=str, default='./cache') 20 | parser.add_argument('--downstream', action='store_true') 21 | args = parser.parse_args() 22 | 23 | model_paths = [] 24 | 25 | # Uncomment the line below to run baseline models 26 | # model_paths += utils.get_base_models(include_prequant=False, relevant_models_only=True) 27 | model_paths += utils.get_subdirs(f'{args.cache_dir}/fake_packed') 28 | model_paths += utils.get_subdirs(f'{args.cache_dir}/packed') 29 | 30 | # testcases for perplexity calculation 31 | datasets = ['wikitext2', 'c4_new', 'ptb_new_sliced'] 32 | 33 | # tasks for lm_eval 34 | if args.downstream: 35 | tasks = ['winogrande', 'piqa', 'arc_easy', 'arc_challenge', 'hellaswag'] 36 | else: 37 | tasks = [] 38 | 39 | # read previous results 40 | if os.path.exists(args.output_file): 41 | with open(args.output_file) as f: 42 | all_results = json.load(f) 43 | else: 44 | all_results = {} 45 | 46 | new_results = {} # results that are newly calculated, to be printed at the end 47 | 48 | total_tests_to_run = {} # tasks to be run will be stored here 49 | skipped_models = [] # models that are skipped will be stored here 50 | 51 | # Check which models/testcases need to be run 52 | # This is done first so that we know how many tasks there are in total, 53 | # and thus we can print the progress 54 | for model_path in model_paths: 55 | model_name = os.path.basename(model_path) 56 | model_jobs = {'to_print': [], 'ppl': [], 'lm-eval': []} 57 | 58 | # Check if all results already exist for any bit-width. If so, skip that dataset/task. 59 | datasets_with_results = [testcase for testcase in datasets if 60 | any(testcase == key.split(':')[0] for key in 61 | all_results.get(model_name, {}).get('ppl', {}).keys())] 62 | tasks_with_results = [task for task in tasks if 63 | any(task == key.split(':')[0] for key in 64 | all_results.get(model_name, {}).get('lm-eval', {}).keys())] 65 | if not args.redo: 66 | model_jobs['ppl'] = [testcase for testcase in datasets if testcase not in datasets_with_results] 67 | model_jobs['lm-eval'] = [task for task in tasks if task not in tasks_with_results] 68 | if not model_jobs['ppl'] and not model_jobs['lm-eval']: 69 | # All results of the target model/testcases and model/tasks combination exist, skip 70 | skipped_models.append(model_name) 71 | continue 72 | else: 73 | if datasets_with_results: 74 | model_jobs['to_print'].append(f"Skipping datasets: " 75 | f"{datasets_with_results} because results already exist") 76 | if tasks_with_results: 77 | model_jobs['to_print'].append(f"Skipping tasks: " 78 | f"{tasks_with_results} because results already exist") 79 | else: 80 | if datasets_with_results: 81 | model_jobs['to_print'].append(f"Redoing all datasets, overwriting for {datasets_with_results}") 82 | else: 83 | model_jobs['to_print'].append("No previous ppl results to overwrite.") 84 | if tasks_with_results: 85 | model_jobs['to_print'].append(f"Redoing all tasks, overwriting for {tasks_with_results}") 86 | else: 87 | model_jobs['to_print'].append("No previous task results to overwrite.") 88 | model_jobs['ppl'] = datasets 89 | model_jobs['lm-eval'] = tasks 90 | model_jobs['to_print'].append(f"Running datasets: {model_jobs['ppl']}") 91 | model_jobs['to_print'].append(f"Running tasks: {model_jobs['lm-eval']}") 92 | total_tests_to_run[model_path] = model_jobs 93 | 94 | total_ppl_job_count = sum(len(model_tasks['ppl']) for model_tasks in total_tests_to_run.values()) 95 | total_lm_eval_job_count = sum(len(model_tasks['lm-eval']) for model_tasks in total_tests_to_run.values()) 96 | if skipped_models: 97 | print(f">> {len(skipped_models)} models will be skipped because all dataset results already exist.") 98 | # print('\n'.join(skipped_models) + '\n') 99 | print(f">> Summary: {total_ppl_job_count} ppl jobs and {total_lm_eval_job_count} lm-eval tasks" 100 | f" over {len(total_tests_to_run)} models:") 101 | print('\n'.join(os.path.basename(model_path) for model_path in total_tests_to_run) + '\n') 102 | 103 | 104 | def save_results(results_dict): 105 | def recursive_sort_dict(d): 106 | if isinstance(d, dict): 107 | return {k: recursive_sort_dict(v) for k, v in sorted(d.items())} 108 | return d 109 | 110 | sorted_results = recursive_sort_dict(results_dict) 111 | 112 | with open(args.output_file, 'w') as f: 113 | json.dump(sorted_results, f, indent=2) 114 | 115 | 116 | # Run all tasks 117 | for i, model_path in enumerate(total_tests_to_run): 118 | model_name = os.path.basename(model_path) 119 | model_jobs = total_tests_to_run[model_path] 120 | to_print = model_jobs['to_print'] 121 | datasets_to_evaluate = model_jobs['ppl'] 122 | tasks_to_evaluate = model_jobs['lm-eval'] 123 | print("==================================================") 124 | print(f" Model: {model_name}") 125 | print(f"Progress: {i + 1}/{len(total_tests_to_run)}") 126 | print("==================================================") 127 | datasets_with_results = [testcase for testcase in datasets if testcase in all_results.get(model_name, {})] 128 | 129 | for line in to_print: 130 | print('>> ' + line) 131 | 132 | ppl_results = {} 133 | lm_eval_results = {} 134 | 135 | # Run evaluation 136 | tokenizer_type, tokenizer, model = eval.auto_model_load(model_path) 137 | if datasets_to_evaluate: 138 | ppl_results = eval.evaluate_ppl(model, tokenizer, datasets_to_evaluate, verbose=True, 139 | chunk_size=2048, tokenizer_type=tokenizer_type) 140 | 141 | # Update ppl results 142 | new_results[model_name] = {} 143 | if ppl_results: 144 | new_results[model_name]['ppl'] = ppl_results 145 | all_results.setdefault(model_name, {}).setdefault('ppl', {}).update(ppl_results) 146 | 147 | save_results(all_results) 148 | 149 | # Run lm_eval 150 | if tasks_to_evaluate: 151 | lm_eval_results = eval.run_lm_eval(tokenizer, model, tasks_to_evaluate) 152 | 153 | # Update lm_eval results 154 | if lm_eval_results: 155 | new_results[model_name]['lm-eval'] = lm_eval_results 156 | all_results.setdefault(model_name, {}).setdefault('lm-eval', {}).update(lm_eval_results) 157 | 158 | save_results(all_results) 159 | 160 | print() 161 | 162 | del model # clear memory 163 | 164 | print("---------------------- All Results ----------------------") 165 | # print new results as formatted json 166 | print(json.dumps(new_results, indent=4)) 167 | 168 | if len(total_tests_to_run) == 0: 169 | exit(1) 170 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='any_precision_llm', 5 | version='0.0.0', 6 | packages=find_packages(), 7 | package_data={'any_precision': ['analyzer/architectures/*.yaml']}, 8 | ) 9 | --------------------------------------------------------------------------------