├── .gitignore ├── LICENSE ├── README.md ├── assets ├── deepcompressor.png ├── diffusion │ ├── .gitkeep │ └── svdquant │ │ ├── svdquant.gif │ │ └── teaser.jpg └── llm │ ├── .gitkeep │ └── qoq │ ├── qoq-qserve.png │ └── qoq.png ├── deepcompressor ├── __init__.py ├── app │ ├── __init__.py │ ├── diffusion │ │ ├── __init__.py │ │ ├── cache │ │ │ ├── __init__.py │ │ │ └── config.py │ │ ├── config.py │ │ ├── dataset │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── calib.py │ │ │ ├── collect │ │ │ │ ├── calib.py │ │ │ │ └── utils.py │ │ │ └── data │ │ │ │ ├── COCO │ │ │ │ ├── COCO.py │ │ │ │ └── __init__.py │ │ │ │ ├── DCI │ │ │ │ ├── DCI.py │ │ │ │ └── __init__.py │ │ │ │ ├── MJHQ │ │ │ │ ├── MJHQ.py │ │ │ │ └── __init__.py │ │ │ │ ├── __init__.py │ │ │ │ └── dump.py │ │ ├── eval │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── metrics │ │ │ │ ├── __init__.py │ │ │ │ ├── fid.py │ │ │ │ ├── image_reward.py │ │ │ │ ├── multimodal.py │ │ │ │ ├── run.py │ │ │ │ └── similarity.py │ │ ├── nn │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── patch.py │ │ │ └── struct.py │ │ ├── pipeline │ │ │ ├── __init__.py │ │ │ └── config.py │ │ ├── ptq.py │ │ ├── quant │ │ │ ├── __init__.py │ │ │ ├── activation.py │ │ │ ├── config.py │ │ │ ├── quantizer │ │ │ │ ├── __init__.py │ │ │ │ ├── config.py │ │ │ │ └── quantizer.py │ │ │ ├── rotate.py │ │ │ ├── smooth.py │ │ │ ├── utils.py │ │ │ └── weight.py │ │ └── utils.py │ └── llm │ │ ├── __init__.py │ │ ├── cache │ │ ├── __init__.py │ │ └── config.py │ │ ├── config.py │ │ ├── eval │ │ ├── __init__.py │ │ ├── base.py │ │ ├── config.py │ │ ├── custom.py │ │ ├── lm_eval.py │ │ └── longbench │ │ │ ├── __init__.py │ │ │ ├── eval.py │ │ │ ├── metrics.py │ │ │ └── task2prompt.json │ │ ├── model │ │ ├── __init__.py │ │ └── config.py │ │ ├── nn │ │ ├── __init__.py │ │ ├── patch.py │ │ └── struct.py │ │ ├── ptq.py │ │ └── quant │ │ ├── __init__.py │ │ ├── activation.py │ │ ├── config.py │ │ ├── dataset.py │ │ ├── quantizer │ │ ├── __init__.py │ │ ├── config.py │ │ └── quantizer.py │ │ ├── reorder.py │ │ ├── rotate.py │ │ ├── smooth.py │ │ ├── utils.py │ │ └── weight.py ├── backend │ ├── __init__.py │ ├── nunchaku │ │ ├── __init__.py │ │ ├── convert.py │ │ ├── convert_lora.py │ │ └── utils.py │ ├── qserve │ │ ├── __init__.py │ │ ├── convert.py │ │ └── utils.py │ ├── tinychat │ │ ├── __init__.py │ │ ├── convert.py │ │ ├── csrc │ │ │ ├── load.py │ │ │ ├── pybind.cpp │ │ │ ├── quantization │ │ │ │ ├── dequantize.cuh │ │ │ │ ├── gemm │ │ │ │ │ ├── gemm_cuda.cu │ │ │ │ │ ├── gemm_cuda.h │ │ │ │ │ └── semaphore.h │ │ │ │ └── gemv │ │ │ │ │ ├── gemv_cuda.cu │ │ │ │ │ └── gemv_cuda.h │ │ │ └── utils.cuh │ │ ├── linear.py │ │ └── utils.py │ └── utils.py ├── calib │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ ├── lowrank.py │ │ ├── range.py │ │ ├── reorder.py │ │ ├── rotation.py │ │ ├── search.py │ │ └── smooth.py │ ├── lowrank.py │ ├── metric.py │ ├── range.py │ ├── reorder.py │ ├── rotate.py │ ├── search.py │ └── smooth.py ├── csrc │ ├── load.py │ ├── pybind.cpp │ └── quantize │ │ ├── quantize.cu │ │ └── quantize.h ├── data │ ├── __init__.py │ ├── cache.py │ ├── codebook.py │ ├── common.py │ ├── dtype.py │ ├── range.py │ ├── scale.py │ ├── tensor.py │ ├── utils │ │ ├── __init__.py │ │ ├── dtype.py │ │ ├── reshape.py │ │ ├── scale.py │ │ └── shape.py │ └── zero.py ├── dataset │ ├── __init__.py │ ├── action.py │ ├── cache.py │ └── config.py ├── nn │ ├── __init__.py │ ├── patch │ │ ├── __init__.py │ │ ├── conv.py │ │ ├── linear.py │ │ ├── lowrank.py │ │ └── sdpa.py │ └── struct │ │ ├── __init__.py │ │ ├── attn.py │ │ └── base.py ├── quantizer │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ ├── base.py │ │ ├── kernel.py │ │ └── lowrank.py │ ├── impl │ │ ├── __init__.py │ │ ├── base.py │ │ ├── info.py │ │ ├── scale.py │ │ ├── simple.py │ │ └── ste.py │ ├── kernel │ │ ├── __init__.py │ │ ├── gptq.py │ │ └── rtn.py │ └── processor.py ├── utils │ ├── __init__.py │ ├── common.py │ ├── config │ │ ├── __init__.py │ │ ├── base.py │ │ ├── model.py │ │ ├── output.py │ │ └── path.py │ ├── dataclass.py │ ├── hooks │ │ ├── __init__.py │ │ ├── branch.py │ │ ├── hook.py │ │ ├── packager.py │ │ └── processor.py │ ├── math │ │ ├── __init__.py │ │ ├── functional.py │ │ └── hadamard.py │ ├── patch.py │ └── tools │ │ ├── __init__.py │ │ ├── logging.py │ │ └── sys.py └── version.py ├── environment.yml ├── examples ├── diffusion │ ├── .gitignore │ ├── README.md │ ├── configs │ │ ├── __default__.yaml │ │ ├── collect │ │ │ └── qdiff.yaml │ │ ├── lora │ │ │ ├── __default__.yaml │ │ │ └── flux.1-dev │ │ │ │ ├── anime.yaml │ │ │ │ ├── ghibsky.yaml │ │ │ │ ├── realism.yaml │ │ │ │ ├── sketch.yaml │ │ │ │ └── yarn.yaml │ │ ├── model │ │ │ ├── flux.1-dev.yaml │ │ │ ├── flux.1-schnell.yaml │ │ │ ├── pixart-sigma.yaml │ │ │ └── sana-1.6b.yaml │ │ ├── svdquant │ │ │ ├── __default__.yaml │ │ │ ├── fast.yaml │ │ │ ├── gptq.yaml │ │ │ ├── int4.yaml │ │ │ └── nvfp4.yaml │ │ └── text │ │ │ ├── __default__.yaml │ │ │ └── awq.yaml │ ├── prompts │ │ ├── lora │ │ │ ├── anime.yaml │ │ │ ├── ghibsky.yaml │ │ │ ├── realism.yaml │ │ │ ├── sketch.yaml │ │ │ └── yarn.yaml │ │ └── qdiff.yaml │ └── scripts │ │ └── svdquant.sh └── llm │ ├── .gitignore │ ├── README.md │ ├── configs │ ├── __default__.yaml │ ├── awq.yaml │ ├── gptq.yaml │ ├── ooo.yaml │ ├── qoq-g128.yaml │ ├── qoq-gchn.yaml │ ├── smoothquant-dynamic.yaml │ └── smoothquant-static.yaml │ └── scripts │ ├── awq.sh │ ├── gptq.sh │ ├── qoq.sh │ └── smoothquant.sh └── pyproject.toml /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | # VS Code 163 | .vscode/ 164 | !.vscode/settings.json 165 | 166 | .DS_Store 167 | *.log 168 | *.pt 169 | .tmp/ 170 | runs 171 | exps 172 | runs/ 173 | exps/ 174 | wandb 175 | wandb/ 176 | -------------------------------------------------------------------------------- /assets/deepcompressor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/deepcompressor/69f3473f5e1c1504bae35cc50c7858ef900a9b17/assets/deepcompressor.png -------------------------------------------------------------------------------- /assets/diffusion/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/deepcompressor/69f3473f5e1c1504bae35cc50c7858ef900a9b17/assets/diffusion/.gitkeep -------------------------------------------------------------------------------- /assets/diffusion/svdquant/svdquant.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/deepcompressor/69f3473f5e1c1504bae35cc50c7858ef900a9b17/assets/diffusion/svdquant/svdquant.gif -------------------------------------------------------------------------------- /assets/diffusion/svdquant/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/deepcompressor/69f3473f5e1c1504bae35cc50c7858ef900a9b17/assets/diffusion/svdquant/teaser.jpg -------------------------------------------------------------------------------- /assets/llm/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/deepcompressor/69f3473f5e1c1504bae35cc50c7858ef900a9b17/assets/llm/.gitkeep -------------------------------------------------------------------------------- /assets/llm/qoq/qoq-qserve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/deepcompressor/69f3473f5e1c1504bae35cc50c7858ef900a9b17/assets/llm/qoq/qoq-qserve.png -------------------------------------------------------------------------------- /assets/llm/qoq/qoq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/deepcompressor/69f3473f5e1c1504bae35cc50c7858ef900a9b17/assets/llm/qoq/qoq.png -------------------------------------------------------------------------------- /deepcompressor/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ # noqa: F401 2 | -------------------------------------------------------------------------------- /deepcompressor/app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/deepcompressor/69f3473f5e1c1504bae35cc50c7858ef900a9b17/deepcompressor/app/__init__.py -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/deepcompressor/69f3473f5e1c1504bae35cc50c7858ef900a9b17/deepcompressor/app/diffusion/__init__.py -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/cache/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import DiffusionPtqCacheConfig, DiffusionQuantCacheConfig 2 | -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/cache/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """LLM quantization cache configuration.""" 3 | 4 | import functools 5 | import re 6 | import typing as tp 7 | from dataclasses import dataclass, field 8 | 9 | from omniconfig import configclass 10 | 11 | from deepcompressor.utils.config.path import BasePathConfig 12 | 13 | from ..nn.struct import DiffusionModelStruct 14 | 15 | __all__ = ["DiffusionQuantCacheConfig", "DiffusionPtqCacheConfig"] 16 | 17 | 18 | @dataclass 19 | class DiffusionQuantCacheConfig(BasePathConfig): 20 | """Denoising diffusion model quantization cache path. 21 | 22 | Args: 23 | smooth (`str`, *optional*, default=`""`): 24 | The smoothing scales cache path. 25 | branch (`str`, *optional*, default=`""`): 26 | The low-rank branches cache path. 27 | wgts (`str`, *optional*, default=`""`): 28 | The weight quantizers state dict cache path. 29 | acts (`str`, *optional*, default=`""`): 30 | The activation quantizers state dict cache path 31 | """ 32 | 33 | smooth: str = "" 34 | branch: str = "" 35 | wgts: str = "" 36 | acts: str = "" 37 | 38 | @staticmethod 39 | def simplify_path(path: str, key_map: dict[str, set[str]]) -> str: 40 | """Simplify the cache path.""" 41 | to_replace = {} 42 | # we first extract all the parts matching the pattern "(skip|include).\[[a-zA-Z0-9_\+]+\]" 43 | for part in re.finditer(r"(skip|include)\.\[[a-zA-Z0-9_\+]+\]", path): 44 | # remove the "skip." or "include." prefix 45 | part = part.group(0) 46 | if part[0] == "s": 47 | prefix, keys = part[:4], part[6:-1] 48 | else: 49 | prefix, keys = part[:7], part[9:-1] 50 | # simplify the keys 51 | keys = "+".join( 52 | ( 53 | "".join((s[0] for s in x.split("_"))) 54 | for x in DiffusionModelStruct._simplify_keys(keys.split("+"), key_map=key_map) 55 | ) 56 | ) 57 | to_replace[part] = f"{prefix}.[{keys}]" 58 | # we then replace the parts 59 | for key, value in to_replace.items(): 60 | path = path.replace(key, value) 61 | return path 62 | 63 | def simplify(self, key_map: dict[str, set[str]]) -> tp.Self: 64 | """Simplify the cache paths.""" 65 | return self.apply(functools.partial(self.simplify_path, key_map=key_map)) 66 | 67 | 68 | @configclass 69 | @dataclass 70 | class DiffusionPtqCacheConfig: 71 | root: str 72 | dirpath: DiffusionQuantCacheConfig = field(init=False) 73 | path: DiffusionQuantCacheConfig = field(init=False) 74 | -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .base import DiffusionDataset 4 | from .calib import DiffusionCalibCacheLoader, DiffusionCalibCacheLoaderConfig 5 | -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/dataset/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Dataset for diffusion models.""" 3 | 4 | import os 5 | import random 6 | import typing as tp 7 | 8 | import numpy as np 9 | import torch 10 | import torch.utils.data 11 | from torch.nn import functional as F 12 | 13 | from deepcompressor.utils.common import tree_collate, tree_map 14 | 15 | __all__ = ["DiffusionDataset"] 16 | 17 | 18 | class DiffusionDataset(torch.utils.data.Dataset): 19 | path: str 20 | filenames: list[str] 21 | filepaths: list[str] 22 | 23 | def __init__(self, path: str, num_samples: int = -1, seed: int = 0, ext: str = ".npy") -> None: 24 | if os.path.exists(path): 25 | self.path = path 26 | if "caches" in os.listdir(path): 27 | path = os.path.join(path, "caches") 28 | filenames = [f for f in sorted(os.listdir(path)) if f.endswith(ext)] 29 | if num_samples > 0 and num_samples < len(filenames): 30 | random.Random(seed).shuffle(filenames) 31 | filenames = filenames[:num_samples] 32 | filenames = sorted(filenames) 33 | self.filenames = filenames 34 | self.filepaths = [os.path.join(path, f) for f in filenames] 35 | else: 36 | raise ValueError(f"Invalid data path: {path}") 37 | 38 | def __len__(self) -> int: 39 | return len(self.filepaths) 40 | 41 | def __getitem__(self, idx) -> dict[str, tp.Any]: 42 | data = np.load(self.filepaths[idx], allow_pickle=True).item() 43 | if isinstance(data["input_args"][0], str): 44 | name = data["input_args"][0] 45 | latent = np.load(os.path.join(self.path, "latents", name)) 46 | data["input_args"][0] = latent 47 | if isinstance(data["input_kwargs"]["encoder_hidden_states"], str): 48 | name = data["input_kwargs"]["encoder_hidden_states"] 49 | text_emb = np.load(os.path.join(self.path, "text_embs", name)) 50 | data["input_kwargs"]["encoder_hidden_states"] = text_emb 51 | data = tree_map(lambda x: torch.from_numpy(x), data) 52 | 53 | # Pad encoder_hidden_states to 300 for pixart 54 | if "encoder_attention_mask" in data["input_kwargs"]: 55 | encoder_attention_mask = data["input_kwargs"]["encoder_attention_mask"] 56 | encoder_hidden_states = data["input_kwargs"]["encoder_hidden_states"] 57 | encoder_hidden_states = F.pad( 58 | encoder_hidden_states, 59 | (0, 0, 0, encoder_attention_mask.shape[1] - encoder_hidden_states.shape[1]), 60 | ) 61 | data["input_kwargs"]["encoder_hidden_states"] = encoder_hidden_states 62 | 63 | return data 64 | 65 | def build_loader(self, **kwargs) -> torch.utils.data.DataLoader: 66 | return torch.utils.data.DataLoader(self, collate_fn=tree_collate, **kwargs) 67 | -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/dataset/collect/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Common utilities for collecting data.""" 3 | 4 | import inspect 5 | import typing as tp 6 | 7 | import torch 8 | import torch.nn as nn 9 | from diffusers.models.transformers import ( 10 | FluxTransformer2DModel, 11 | PixArtTransformer2DModel, 12 | SanaTransformer2DModel, 13 | ) 14 | from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel 15 | 16 | from deepcompressor.utils.common import tree_map, tree_split 17 | 18 | __all__ = ["CollectHook"] 19 | 20 | 21 | class CollectHook: 22 | def __init__(self, caches: list[dict[str, tp.Any]] = None, zero_redundancy: bool = False) -> None: 23 | self.caches = [] if caches is None else caches 24 | self.zero_redundancy = zero_redundancy 25 | 26 | def __call__( 27 | self, 28 | module: nn.Module, 29 | input_args: tuple[torch.Tensor, ...], 30 | input_kwargs: dict[str, tp.Any], 31 | output: tuple[torch.Tensor, ...], 32 | ) -> tp.Any: 33 | new_args = [] 34 | signature = inspect.signature(module.forward) 35 | bound_arguments = signature.bind(*input_args, **input_kwargs) 36 | arguments = bound_arguments.arguments 37 | args_to_kwargs = {k: v for k, v in arguments.items() if k not in input_kwargs} 38 | input_kwargs.update(args_to_kwargs) 39 | 40 | if isinstance(module, UNet2DConditionModel): 41 | sample = input_kwargs.pop("sample") 42 | new_args.append(sample) 43 | timestep = input_kwargs["timestep"] 44 | timesteps = timestep 45 | if not torch.is_tensor(timesteps): 46 | is_mps = sample.device.type == "mps" 47 | if isinstance(timestep, float): 48 | dtype = torch.float32 if is_mps else torch.float64 49 | else: 50 | dtype = torch.int32 if is_mps else torch.int64 51 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 52 | elif len(timesteps.shape) == 0: 53 | timesteps = timesteps[None].to(sample.device) 54 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 55 | timesteps = timesteps.expand(sample.shape[0]) 56 | input_kwargs["timestep"] = timesteps 57 | elif isinstance(module, (PixArtTransformer2DModel, SanaTransformer2DModel)): 58 | new_args.append(input_kwargs.pop("hidden_states")) 59 | elif isinstance(module, FluxTransformer2DModel): 60 | new_args.append(input_kwargs.pop("hidden_states")) 61 | else: 62 | raise ValueError(f"Unknown model: {module}") 63 | cache = tree_map(lambda x: x.cpu(), {"input_args": new_args, "input_kwargs": input_kwargs, "outputs": output}) 64 | split_cache = tree_split(cache) 65 | 66 | if isinstance(module, PixArtTransformer2DModel) and self.zero_redundancy: 67 | for cache in split_cache: 68 | cache_kwargs = cache["input_kwargs"] 69 | encoder_hidden_states = cache_kwargs.pop("encoder_hidden_states") 70 | assert encoder_hidden_states.shape[0] == 1 71 | encoder_attention_mask = cache_kwargs.get("encoder_attention_mask", None) 72 | if encoder_attention_mask is not None: 73 | encoder_hidden_states = encoder_hidden_states[:, : max(encoder_attention_mask.sum(), 1)] 74 | cache_kwargs["encoder_hidden_states"] = encoder_hidden_states 75 | 76 | self.caches.extend(split_cache) 77 | -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/dataset/data/COCO/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/deepcompressor/69f3473f5e1c1504bae35cc50c7858ef900a9b17/deepcompressor/app/diffusion/dataset/data/COCO/__init__.py -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/dataset/data/DCI/DCI.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import datasets 5 | import yaml 6 | from PIL import Image 7 | 8 | _CITATION = """\ 9 | @InProceedings{Urbanek_2024_CVPR, 10 | author = {Urbanek, Jack and Bordes, Florian and Astolfi, Pietro and Williamson, Mary and Sharma, Vasu and Romero-Soriano, Adriana}, 11 | title = {A Picture is Worth More Than 77 Text Tokens: Evaluating CLIP-Style Models on Dense Captions}, 12 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 13 | month = {June}, 14 | year = {2024}, 15 | pages = {26700-26709} 16 | } 17 | """ # noqa: E501 18 | 19 | _DESCRIPTION = """\ 20 | The Densely Captioned Images dataset, or DCI, consists of 7805 images from SA-1B, 21 | each with a complete description aiming to capture the full visual detail of what is present in the image. 22 | Much of the description is directly aligned to submasks of the image. 23 | """ 24 | 25 | _HOMEPAGE = "https://github.com/facebookresearch/DCI" 26 | 27 | _LICENSE = "Attribution-NonCommercial 4.0 International (https://github.com/facebookresearch/DCI/blob/main/LICENSE)" 28 | 29 | IMAGE_URL = "https://huggingface.co/datasets/mit-han-lab/svdquant-datasets/resolve/main/sDCI.gz" 30 | 31 | PROMPT_URLS = {"sDCI": "https://huggingface.co/datasets/mit-han-lab/svdquant-datasets/resolve/main/sDCI.yaml"} 32 | 33 | 34 | class DCIConfig(datasets.BuilderConfig): 35 | def __init__(self, max_dataset_size: int = -1, return_gt: bool = False, **kwargs): 36 | super(DCIConfig, self).__init__( 37 | name=kwargs.get("name", "default"), 38 | version=kwargs.get("version", "0.0.0"), 39 | data_dir=kwargs.get("data_dir", None), 40 | data_files=kwargs.get("data_files", None), 41 | description=kwargs.get("description", None), 42 | ) 43 | self.max_dataset_size = max_dataset_size 44 | self.return_gt = return_gt 45 | 46 | 47 | class DCI(datasets.GeneratorBasedBuilder): 48 | VERSION = datasets.Version("0.0.0") 49 | 50 | BUILDER_CONFIG_CLASS = DCIConfig 51 | BUILDER_CONFIGS = [DCIConfig(name="sDCI", version=VERSION, description="sDCI full prompt set")] 52 | DEFAULT_CONFIG_NAME = "sDCI" 53 | 54 | def _info(self): 55 | features = datasets.Features( 56 | { 57 | "filename": datasets.Value("string"), 58 | "image": datasets.Image(), 59 | "prompt": datasets.Value("string"), 60 | "meta_path": datasets.Value("string"), 61 | "image_root": datasets.Value("string"), 62 | "image_path": datasets.Value("string"), 63 | "split": datasets.Value("string"), 64 | } 65 | ) 66 | return datasets.DatasetInfo( 67 | description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION 68 | ) 69 | 70 | def _split_generators(self, dl_manager: datasets.download.DownloadManager): 71 | image_url = IMAGE_URL 72 | meta_url = PROMPT_URLS[self.config.name] 73 | 74 | meta_path = dl_manager.download(meta_url) 75 | image_root = dl_manager.download_and_extract(image_url) 76 | 77 | return [ 78 | datasets.SplitGenerator( 79 | name=datasets.Split.TRAIN, gen_kwargs={"meta_path": meta_path, "image_root": image_root} 80 | ) 81 | ] 82 | 83 | def _generate_examples(self, meta_path: str, image_root: str): 84 | meta = yaml.safe_load(open(meta_path, "r")) 85 | names = list(meta.keys()) 86 | if self.config.max_dataset_size > 0: 87 | random.Random(0).shuffle(names) 88 | names = names[: self.config.max_dataset_size] 89 | names = sorted(names) 90 | 91 | for i, name in enumerate(names): 92 | prompt = meta[name] 93 | image_path = os.path.join(image_root, f"{name}.jpg") 94 | yield ( 95 | i, 96 | { 97 | "filename": name, 98 | "image": Image.open(image_path) if self.config.return_gt else None, 99 | "prompt": prompt, 100 | "meta_path": meta_path, 101 | "image_root": image_root, 102 | "image_path": image_path, 103 | "split": self.config.name, 104 | }, 105 | ) 106 | -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/dataset/data/DCI/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/deepcompressor/69f3473f5e1c1504bae35cc50c7858ef900a9b17/deepcompressor/app/diffusion/dataset/data/DCI/__init__.py -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/dataset/data/MJHQ/MJHQ.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | import datasets 6 | from PIL import Image 7 | 8 | _CITATION = """\ 9 | @misc{li2024playground, 10 | title={Playground v2.5: Three Insights towards Enhancing Aesthetic Quality in Text-to-Image Generation}, 11 | author={Daiqing Li and Aleks Kamko and Ehsan Akhgari and Ali Sabet and Linmiao Xu and Suhail Doshi}, 12 | year={2024}, 13 | eprint={2402.17245}, 14 | archivePrefix={arXiv}, 15 | primaryClass={cs.CV} 16 | } 17 | """ 18 | 19 | _DESCRIPTION = """\ 20 | We introduce a new benchmark, MJHQ-30K, for automatic evaluation of a model’s aesthetic quality. 21 | The benchmark computes FID on a high-quality dataset to gauge aesthetic quality. 22 | """ 23 | 24 | _HOMEPAGE = "https://huggingface.co/datasets/playgroundai/MJHQ-30K" 25 | 26 | _LICENSE = ( 27 | "Playground v2.5 Community License " 28 | "(https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md)" 29 | ) 30 | 31 | IMAGE_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/mjhq30k_imgs.zip" 32 | 33 | META_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/meta_data.json" 34 | 35 | 36 | class MJHQConfig(datasets.BuilderConfig): 37 | def __init__(self, max_dataset_size: int = -1, return_gt: bool = False, **kwargs): 38 | super(MJHQConfig, self).__init__( 39 | name=kwargs.get("name", "default"), 40 | version=kwargs.get("version", "0.0.0"), 41 | data_dir=kwargs.get("data_dir", None), 42 | data_files=kwargs.get("data_files", None), 43 | description=kwargs.get("description", None), 44 | ) 45 | self.max_dataset_size = max_dataset_size 46 | self.return_gt = return_gt 47 | 48 | 49 | class DCI(datasets.GeneratorBasedBuilder): 50 | VERSION = datasets.Version("0.0.0") 51 | 52 | BUILDER_CONFIG_CLASS = MJHQConfig 53 | BUILDER_CONFIGS = [MJHQConfig(name="MJHQ", version=VERSION, description="MJHQ-30K full dataset")] 54 | DEFAULT_CONFIG_NAME = "MJHQ" 55 | 56 | def _info(self): 57 | features = datasets.Features( 58 | { 59 | "filename": datasets.Value("string"), 60 | "category": datasets.Value("string"), 61 | "image": datasets.Image(), 62 | "prompt": datasets.Value("string"), 63 | "prompt_path": datasets.Value("string"), 64 | "image_root": datasets.Value("string"), 65 | "image_path": datasets.Value("string"), 66 | "split": datasets.Value("string"), 67 | } 68 | ) 69 | return datasets.DatasetInfo( 70 | description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION 71 | ) 72 | 73 | def _split_generators(self, dl_manager: datasets.download.DownloadManager): 74 | meta_path = dl_manager.download(META_URL) 75 | image_root = dl_manager.download_and_extract(IMAGE_URL) 76 | return [ 77 | datasets.SplitGenerator( 78 | name=datasets.Split.TRAIN, gen_kwargs={"meta_path": meta_path, "image_root": image_root} 79 | ), 80 | ] 81 | 82 | def _generate_examples(self, meta_path: str, image_root: str): 83 | with open(meta_path, "r") as f: 84 | meta = json.load(f) 85 | 86 | names = list(meta.keys()) 87 | if self.config.max_dataset_size > 0: 88 | random.Random(0).shuffle(names) 89 | names = names[: self.config.max_dataset_size] 90 | names = sorted(names) 91 | 92 | for i, name in enumerate(names): 93 | category = meta[name]["category"] 94 | prompt = meta[name]["prompt"] 95 | image_path = os.path.join(image_root, category, f"{name}.jpg") 96 | yield ( 97 | i, 98 | { 99 | "filename": name, 100 | "category": category, 101 | "image": Image.open(image_path) if self.config.return_gt else None, 102 | "prompt": prompt, 103 | "meta_path": meta_path, 104 | "image_root": image_root, 105 | "image_path": image_path, 106 | "split": self.config.name, 107 | }, 108 | ) 109 | -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/dataset/data/MJHQ/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/deepcompressor/69f3473f5e1c1504bae35cc50c7858ef900a9b17/deepcompressor/app/diffusion/dataset/data/MJHQ/__init__.py -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/dataset/data/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import datasets 5 | import yaml 6 | 7 | __all__ = ["get_dataset"] 8 | 9 | 10 | def load_dataset_yaml(meta_path: str, max_dataset_size: int = -1, repeat: int = 4) -> dict: 11 | meta = yaml.safe_load(open(meta_path, "r")) 12 | names = list(meta.keys()) 13 | if max_dataset_size > 0: 14 | random.Random(0).shuffle(names) 15 | names = names[:max_dataset_size] 16 | names = sorted(names) 17 | 18 | ret = {"filename": [], "prompt": [], "meta_path": []} 19 | idx = 0 20 | for name in names: 21 | prompt = meta[name] 22 | for j in range(repeat): 23 | ret["filename"].append(f"{name}-{j}") 24 | ret["prompt"].append(prompt) 25 | ret["meta_path"].append(meta_path) 26 | idx += 1 27 | return ret 28 | 29 | 30 | def get_dataset( 31 | name: str, 32 | config_name: str | None = None, 33 | split: str = "train", 34 | max_dataset_size: int = -1, 35 | return_gt: bool = False, 36 | repeat: int = 4, 37 | chunk_start: int = 0, 38 | chunk_step: int = 1, 39 | ) -> datasets.Dataset: 40 | prefix = os.path.dirname(__file__) 41 | kwargs = { 42 | "name": config_name, 43 | "split": split, 44 | "trust_remote_code": True, 45 | "token": True, 46 | "max_dataset_size": max_dataset_size, 47 | } 48 | if name.endswith((".yaml", ".yml")): 49 | dataset = datasets.Dataset.from_dict( 50 | load_dataset_yaml(name, max_dataset_size=max_dataset_size, repeat=repeat), 51 | features=datasets.Features( 52 | { 53 | "filename": datasets.Value("string"), 54 | "prompt": datasets.Value("string"), 55 | "meta_path": datasets.Value("string"), 56 | } 57 | ), 58 | ) 59 | else: 60 | path = os.path.join(prefix, f"{name}") 61 | if name == "COCO": 62 | dataset = datasets.load_dataset(path, return_gt=return_gt, **kwargs) 63 | elif name == "DCI": 64 | dataset = datasets.load_dataset(path, return_gt=return_gt, **kwargs) 65 | elif name == "MJHQ": 66 | dataset = datasets.load_dataset(path, return_gt=return_gt, **kwargs) 67 | else: 68 | raise ValueError(f"Unknown dataset name: {name}") 69 | assert not hasattr(dataset, "_unchunk_size") 70 | assert not hasattr(dataset, "_chunk_start") 71 | assert not hasattr(dataset, "_chunk_step") 72 | unchunk_size = len(dataset) 73 | if chunk_step > 1 or chunk_start > 0: 74 | assert 0 <= chunk_start < chunk_step 75 | dataset = dataset.select(range(chunk_start, len(dataset), chunk_step)) 76 | else: 77 | chunk_start, chunk_step = 0, 1 78 | dataset._unchunk_size = unchunk_size 79 | dataset._chunk_start = chunk_start 80 | dataset._chunk_step = chunk_step 81 | return dataset 82 | -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/dataset/data/dump.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import yaml 5 | from tqdm import tqdm 6 | 7 | from ...utils import get_control 8 | from . import get_dataset 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--benchmarks", type=str, nargs="*", default=["COCO", "DCI", "MJHQ"]) 13 | parser.add_argument("--max-dataset-size", type=int, default=-1) 14 | parser.add_argument("--dump-root", type=str, default="benchmarks") 15 | parser.add_argument("--copy-images", action="store_true") 16 | parser.add_argument("--prompts-only", action="store_true") 17 | parser.add_argument("--controls", type=str, nargs="*", default=["canny-to-image", "depth-to-image", "inpainting"]) 18 | parser.add_argument("--chunk-start", type=int, default=0) 19 | parser.add_argument("--chunk-step", type=int, default=1) 20 | args = parser.parse_args() 21 | 22 | if "depth-to-image" in args.controls: 23 | from image_gen_aux import DepthPreprocessor 24 | 25 | processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf").to("cuda") 26 | 27 | for benchmark in args.benchmarks: 28 | dataset = get_dataset( 29 | benchmark, 30 | max_dataset_size=args.max_dataset_size, 31 | return_gt=True, 32 | chunk_start=args.chunk_start, 33 | chunk_step=args.chunk_step, 34 | ) 35 | prompts = {} 36 | benchmark_root = os.path.join(args.dump_root, benchmark, f"{dataset.config_name}-{dataset._unchunk_size}") 37 | for row in tqdm(dataset, desc=f"Dumping {dataset.config_name}"): 38 | prompts[row["filename"]] = row["prompt"] 39 | if not args.prompts_only: 40 | image = row.get("image", None) 41 | if image is not None: 42 | image_root = os.path.join(benchmark_root, "images") 43 | os.makedirs(image_root, exist_ok=True) 44 | if args.copy_images: 45 | image.save(os.path.join(image_root, row["filename"] + ".png")) 46 | else: 47 | ext = os.path.basename(row["image_path"]).split(".")[-1] 48 | os.symlink( 49 | os.path.abspath(os.path.expanduser(row["image_path"])), 50 | os.path.abspath(os.path.expanduser(os.path.join(image_root, row["filename"] + f".{ext}"))), 51 | ) 52 | if "canny-to-image" in args.controls: 53 | canny_root = os.path.join(benchmark_root, "canny_images") 54 | os.makedirs(canny_root, exist_ok=True) 55 | canny = get_control("canny-to-image", image) 56 | canny.save(os.path.join(canny_root, row["filename"] + ".png")) 57 | if "depth-to-image" in args.controls: 58 | depth_root = os.path.join(benchmark_root, "depth_images") 59 | os.makedirs(depth_root, exist_ok=True) 60 | depth = get_control("depth-to-image", image, processor=processor) 61 | depth.save(os.path.join(depth_root, row["filename"] + ".png")) 62 | if "inpainting" in args.controls: 63 | mask_root = os.path.join(benchmark_root, "mask_images") 64 | cropped_image_root = os.path.join(benchmark_root, "cropped_images") 65 | os.makedirs(mask_root, exist_ok=True) 66 | os.makedirs(cropped_image_root, exist_ok=True) 67 | cropped_image, mask_image = get_control("inpainting", image, names=row["filename"]) 68 | cropped_image.save(os.path.join(cropped_image_root, row["filename"] + ".png")) 69 | mask_image.save(os.path.join(mask_root, row["filename"] + ".png")) 70 | 71 | if args.chunk_step == 1: 72 | os.makedirs(benchmark_root, exist_ok=True) 73 | with open(os.path.join(benchmark_root, "prompts.yaml"), "w") as f: 74 | yaml.dump(prompts, f) 75 | -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/eval/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .config import DiffusionEvalConfig 4 | -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/eval/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from deepcompressor.app.diffusion.dataset.data import get_dataset 5 | 6 | from .fid import compute_fid 7 | from .image_reward import compute_image_reward 8 | from .multimodal import compute_image_multimodal_metrics 9 | from .similarity import compute_image_similarity_metrics 10 | 11 | logging.getLogger("PIL").setLevel(logging.WARNING) 12 | 13 | __all__ = ["compute_image_metrics"] 14 | 15 | 16 | def compute_image_metrics( 17 | gen_root: str, 18 | benchmarks: str | tuple[str, ...] = ("DCI", "GenAIBench", "GenEval", "MJHQ", "T2ICompBench"), 19 | max_dataset_size: int = -1, 20 | chunk_start: int = 0, 21 | chunk_step: int = 1, 22 | chunk_only: bool = False, 23 | ref_root: str = "", 24 | gt_stats_root: str = "", 25 | gt_metrics: tuple[str, ...] = ("clip_iqa", "clip_score", "image_reward", "fid"), 26 | ref_metrics: tuple[str, ...] = ("psnr", "lpips", "ssim", "fid"), 27 | ) -> dict: 28 | if chunk_start == 0 and chunk_step == 1: 29 | chunk_only = False 30 | assert chunk_start == 0 and chunk_step == 1, "Chunking is not supported for image data." 31 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 32 | if isinstance(benchmarks, str): 33 | benchmarks = (benchmarks,) 34 | gt_multimodal_metrics, gt_similarity_metrics, gt_other_metrics = categorize_metrics(gt_metrics) 35 | _, ref_similarity_metrics, ref_other_metrics = categorize_metrics(ref_metrics) 36 | results = {} 37 | for benchmark in benchmarks: 38 | benchmark_results = {} 39 | dataset = get_dataset(benchmark, max_dataset_size=max_dataset_size, return_gt=True) 40 | dirname = f"{dataset.config_name}-{dataset._unchunk_size}" 41 | if dataset._chunk_start == 0 and dataset._chunk_step == 1: 42 | filename = f"{dirname}.npz" 43 | else: 44 | filename = os.path.join(dirname, f"{dataset._chunk_start}-{dataset._chunk_step}.npz") 45 | if chunk_only: 46 | dirname += f".{dataset._chunk_start}.{dataset._chunk_step}" 47 | gen_dirpath = os.path.join(gen_root, "samples", benchmark, dirname) 48 | if gt_metrics: 49 | gt_results = compute_image_multimodal_metrics(dataset, gen_dirpath, metrics=gt_multimodal_metrics) 50 | if "image_reward" in gt_other_metrics: 51 | gt_results.update(compute_image_reward(dataset, gen_dirpath)) 52 | if benchmark in ("COCO", "DCI", "MJHQ"): 53 | gt_results.update(compute_image_similarity_metrics(dataset, gen_dirpath, metrics=gt_similarity_metrics)) 54 | if "fid" in gt_other_metrics: 55 | gt_results["fid"] = compute_fid( 56 | dataset, 57 | gen_dirpath, 58 | ref_cache_path=(os.path.join(gt_stats_root, benchmark, filename) if gt_stats_root else None), 59 | gen_cache_path=os.path.join(gen_root, "fid_stats", benchmark, filename), 60 | ) 61 | benchmark_results["with_gt"] = gt_results 62 | if ref_root and ref_metrics: 63 | assert os.path.exists(ref_root), f"Reference root directory {ref_root} does not exist." 64 | ref_dirpath = os.path.join(ref_root, "samples", benchmark, dirname) 65 | ref_results = compute_image_similarity_metrics(ref_dirpath, gen_dirpath, metrics=ref_similarity_metrics) 66 | if "fid" in ref_other_metrics: 67 | ref_results["fid"] = compute_fid( 68 | ref_dirpath, 69 | gen_dirpath, 70 | ref_cache_path=os.path.join(ref_root, "fid_stats", benchmark, filename), 71 | gen_cache_path=os.path.join(gen_root, "fid_stats", benchmark, filename), 72 | ) 73 | benchmark_results["with_orig"] = ref_results 74 | print(f"{dirname} results:") 75 | print(benchmark_results) 76 | results[dirname] = benchmark_results 77 | return results 78 | 79 | 80 | def categorize_metrics(metrics: tuple[str, ...]) -> tuple[list[str], list[str], list[str]]: 81 | """ 82 | Categorize metrics into multimodal, similarity, and other metrics. 83 | 84 | Args: 85 | metrics (tuple[str, ...]): List of metrics. 86 | 87 | Returns: 88 | tuple[list[str], list[str], list[str]]: Tuple of multimodal, similarity, and other metrics. 89 | """ 90 | metrics = tuple(set(metrics)) 91 | multimodal_metrics, similarity_metrics, other_metrics = [], [], [] 92 | for metric in metrics: 93 | if metric in ("clip_iqa", "clip_score"): 94 | multimodal_metrics.append(metric) 95 | elif metric in ("psnr", "lpips", "ssim"): 96 | similarity_metrics.append(metric) 97 | else: 98 | other_metrics.append(metric) 99 | return multimodal_metrics, similarity_metrics, other_metrics 100 | -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/eval/metrics/fid.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | import numpy as np 5 | import torch 6 | import torchvision 7 | from cleanfid import fid 8 | from cleanfid.resize import build_resizer 9 | from datasets import Dataset 10 | from tqdm import tqdm 11 | 12 | __all__ = ["compute_fid"] 13 | 14 | 15 | def get_dataset_features( 16 | dataset: Dataset, 17 | model, 18 | mode: str = "clean", 19 | batch_size: int = 128, 20 | device: str | torch.device = "cuda", 21 | ) -> np.ndarray: 22 | to_tensor = torchvision.transforms.ToTensor() 23 | fn_resize = build_resizer(mode) 24 | np_feats = [] 25 | for batch in tqdm( 26 | dataset.iter(batch_size=batch_size, drop_last_batch=False), 27 | desc=f"Extracting {dataset.config_name} features", 28 | total=(len(dataset) + batch_size - 1) // batch_size, 29 | ): 30 | resized_images = [fn_resize(np.array(image.convert("RGB"))) for image in batch["image"]] 31 | image_tensors = [] 32 | for resized_image in resized_images: 33 | if resized_image.dtype == "uint8": 34 | image_tensor = to_tensor(resized_image) * 255 35 | else: 36 | image_tensor = to_tensor(resized_image) 37 | image_tensors.append(image_tensor) 38 | image_tensors = torch.stack(image_tensors, dim=0) 39 | np_feats.append(fid.get_batch_features(image_tensors, model, device)) 40 | np_feats = np.concatenate(np_feats, axis=0) 41 | return np_feats 42 | 43 | 44 | def get_fid_features( 45 | dataset_or_folder: str | Dataset | None = None, 46 | cache_path: str | None = None, 47 | num: int | None = None, 48 | mode: str = "clean", 49 | num_workers: int = 8, 50 | batch_size: int = 64, 51 | device: str | torch.device = "cuda", 52 | force_overwrite: bool = False, 53 | verbose: bool = True, 54 | ) -> tuple[np.ndarray, np.ndarray]: 55 | if cache_path is not None and os.path.exists(cache_path) and not force_overwrite: 56 | npz = np.load(cache_path) 57 | mu, sigma = npz["mu"], npz["sigma"] 58 | else: 59 | feat_model = fid.build_feature_extractor(mode, device) 60 | if isinstance(dataset_or_folder, str): 61 | np_feats = fid.get_folder_features( 62 | dataset_or_folder, 63 | feat_model, 64 | num_workers=num_workers, 65 | num=num, 66 | batch_size=batch_size, 67 | device=device, 68 | verbose=verbose, 69 | mode=mode, 70 | description=f"Extracting {dataset_or_folder} features", 71 | ) 72 | else: 73 | assert isinstance(dataset_or_folder, Dataset) 74 | np_feats = get_dataset_features( 75 | dataset_or_folder, model=feat_model, mode=mode, batch_size=batch_size, device=device 76 | ) 77 | 78 | mu = np.mean(np_feats, axis=0) 79 | sigma = np.cov(np_feats, rowvar=False) 80 | if cache_path is not None: 81 | os.makedirs(os.path.abspath(os.path.dirname(cache_path)), exist_ok=True) 82 | np.savez(cache_path, mu=mu, sigma=sigma) 83 | 84 | return mu, sigma 85 | 86 | 87 | def compute_fid( 88 | ref_dirpath_or_dataset: str | Dataset, 89 | gen_dirpath: str, 90 | ref_cache_path: str | None = None, 91 | gen_cache_path: str | None = None, 92 | use_symlink: bool = True, 93 | timestamp: str | None = None, 94 | ) -> float: 95 | sym_ref_dirpath, sym_gen_dirpath = None, None 96 | if use_symlink: 97 | if timestamp is None: 98 | timestamp = datetime.now().strftime("%y%m%d.%H%M%S") 99 | 100 | os.makedirs(".tmp", exist_ok=True) 101 | 102 | if isinstance(ref_dirpath_or_dataset, str): 103 | sym_ref_dirpath = os.path.join(".tmp", f"ref-{hash(str(ref_dirpath_or_dataset))}-{timestamp}") 104 | os.symlink(os.path.abspath(ref_dirpath_or_dataset), os.path.abspath(sym_ref_dirpath)) 105 | ref_dirpath_or_dataset = sym_ref_dirpath 106 | 107 | sym_gen_dirpath = os.path.join(".tmp", f"gen-{hash(str(gen_dirpath))}-{timestamp}") 108 | os.symlink(os.path.abspath(gen_dirpath), os.path.abspath(sym_gen_dirpath)) 109 | gen_dirpath = sym_gen_dirpath 110 | mu1, sigma1 = get_fid_features(dataset_or_folder=ref_dirpath_or_dataset, cache_path=ref_cache_path) 111 | mu2, sigma2 = get_fid_features(dataset_or_folder=gen_dirpath, cache_path=gen_cache_path) 112 | fid_score = fid.frechet_distance(mu1, sigma1, mu2, sigma2) 113 | fid_score = float(fid_score) 114 | if use_symlink: 115 | if sym_ref_dirpath is not None: 116 | os.remove(sym_ref_dirpath) 117 | os.remove(sym_gen_dirpath) 118 | return fid_score 119 | -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/eval/metrics/image_reward.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import datasets 4 | import torch 5 | from tqdm import tqdm 6 | 7 | __all__ = ["compute_image_reward"] 8 | 9 | 10 | def compute_image_reward( 11 | ref_dataset: datasets.Dataset, 12 | gen_dirpath: str, 13 | ) -> dict[str, float]: 14 | # import here to remove dependency on `ImageReward` git repo 15 | import ImageReward as RM 16 | 17 | scores = [] 18 | model = RM.load("ImageReward-v1.0") 19 | for batch in tqdm( 20 | ref_dataset.iter(batch_size=1, drop_last_batch=False), 21 | desc=f"{ref_dataset.config_name} image reward", 22 | total=len(ref_dataset), 23 | dynamic_ncols=True, 24 | ): 25 | filename = batch["filename"][0] 26 | path = os.path.join(gen_dirpath, f"{filename}.png") 27 | prompt = batch["prompt"][0] 28 | with torch.inference_mode(): 29 | score = model.score(prompt, path) 30 | scores.append(score) 31 | result = {"image_reward": sum(scores) / len(scores)} 32 | return result 33 | -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/eval/metrics/multimodal.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import datasets 4 | import numpy as np 5 | import torch 6 | import torchmetrics 7 | import torchvision 8 | from PIL import Image 9 | from torch.utils import data 10 | from torchmetrics.multimodal import CLIPImageQualityAssessment, CLIPScore 11 | from tqdm import tqdm 12 | 13 | __all__ = ["compute_image_multimodal_metrics"] 14 | 15 | 16 | class PromptImageDataset(data.Dataset): 17 | def __init__(self, ref_dataset: datasets.Dataset, gen_dirpath: str): 18 | super(data.Dataset, self).__init__() 19 | self.ref_dataset, self.gen_dirpath = ref_dataset, gen_dirpath 20 | self.transform = torchvision.transforms.ToTensor() 21 | 22 | def __len__(self): 23 | return len(self.ref_dataset) 24 | 25 | def __getitem__(self, idx: int): 26 | row = self.ref_dataset[idx] 27 | gen_image = Image.open(os.path.join(self.gen_dirpath, row["filename"] + ".png")).convert("RGB") 28 | gen_tensor = torch.from_numpy(np.array(gen_image)).permute(2, 0, 1) 29 | prompt = row["prompt"] 30 | return [gen_tensor, prompt] 31 | 32 | 33 | def compute_image_multimodal_metrics( 34 | ref_dataset: datasets.Dataset, 35 | gen_dirpath: str, 36 | metrics: tuple[str, ...] = ("clip_iqa", "clip_score"), 37 | batch_size: int = 64, 38 | num_workers: int = 8, 39 | device: str | torch.device = "cuda", 40 | ) -> dict[str, float]: 41 | if len(metrics) == 0: 42 | return {} 43 | metric_names = metrics 44 | metrics: dict[str, torchmetrics.Metric] = {} 45 | for metric_name in metric_names: 46 | if metric_name == "clip_iqa": 47 | metric = CLIPImageQualityAssessment(model_name_or_path="openai/clip-vit-large-patch14").to(device) 48 | elif metric_name == "clip_score": 49 | metric = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14").to(device) 50 | else: 51 | raise NotImplementedError(f"Metric {metric_name} is not implemented") 52 | metrics[metric_name] = metric 53 | dataset = PromptImageDataset(ref_dataset, gen_dirpath) 54 | dataloader = data.DataLoader( 55 | dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, drop_last=False 56 | ) 57 | with torch.no_grad(): 58 | for batch in tqdm(dataloader, desc=f"{ref_dataset.config_name} multimodal metrics"): 59 | batch[0] = batch[0].to(device) 60 | for metric_name, metric in metrics.items(): 61 | if metric_name == "clip_iqa": 62 | metric.update(batch[0].to(torch.float32)) 63 | else: 64 | prompts = list(batch[1]) 65 | metric.update(batch[0], prompts) 66 | result = {metric_name: metric.compute().mean().item() for metric_name, metric in metrics.items()} 67 | return result 68 | -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/eval/metrics/run.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Evaluate generated images or videos using the specified metrics.""" 3 | 4 | import json 5 | import os 6 | 7 | from ...config import DiffusionPtqRunConfig 8 | 9 | if __name__ == "__main__": 10 | config, _, unused_cfgs, unused_args, unknown_args = DiffusionPtqRunConfig.get_parser().parse_known_args() 11 | assert len(unknown_args) == 0, f"Unknown arguments: {unknown_args}" 12 | assert len(unused_cfgs) == 0, f"Unused configurations: {unused_cfgs}" 13 | assert unused_args is None, f"Unused arguments: {unused_args}" 14 | assert isinstance(config, DiffusionPtqRunConfig) 15 | results = config.eval.evaluate(pipeline=None, skip_gen=True, task=config.pipeline.task) 16 | save_path = os.path.join(config.eval.gen_root, f"results-{config.output.timestamp}.json") 17 | os.makedirs(os.path.abspath(os.path.dirname(save_path)), exist_ok=True) 18 | with open(save_path, "w") as f: 19 | json.dump(results, f, indent=2, sort_keys=True) 20 | print(results) 21 | -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/eval/metrics/similarity.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import datasets 4 | import torch 5 | import torchmetrics 6 | import torchvision 7 | from PIL import Image 8 | from torch.utils import data 9 | from torchmetrics.image import ( 10 | LearnedPerceptualImagePatchSimilarity, 11 | PeakSignalNoiseRatio, 12 | StructuralSimilarityIndexMeasure, 13 | ) 14 | from tqdm import tqdm 15 | 16 | __all__ = ["compute_image_similarity_metrics"] 17 | 18 | 19 | class MultiImageDataset(data.Dataset): 20 | def __init__(self, gen_dirpath: str, ref_dirpath_or_dataset: str | datasets.Dataset): 21 | super(data.Dataset, self).__init__() 22 | self.gen_names = sorted( 23 | [name for name in os.listdir(gen_dirpath) if name.endswith(".png") or name.endswith(".jpg")] 24 | ) 25 | self.gen_dirpath, self.ref_dirpath_or_dataset = gen_dirpath, ref_dirpath_or_dataset 26 | if isinstance(ref_dirpath_or_dataset, str): 27 | self.ref_names = sorted( 28 | [name for name in os.listdir(ref_dirpath_or_dataset) if name.endswith(".png") or name.endswith(".jpg")] 29 | ) 30 | assert len(self.ref_names) == len(self.gen_names) 31 | else: 32 | assert isinstance(ref_dirpath_or_dataset, datasets.Dataset) 33 | self.ref_names = self.gen_names 34 | assert len(ref_dirpath_or_dataset) == len(self.gen_names) 35 | self.transform = torchvision.transforms.ToTensor() 36 | 37 | def __len__(self): 38 | return len(self.ref_names) 39 | 40 | def __getitem__(self, idx: int): 41 | if isinstance(self.ref_dirpath_or_dataset, str): 42 | name = self.ref_names[idx] 43 | assert name == self.gen_names[idx] 44 | ref_image = Image.open(os.path.join(self.ref_dirpath_or_dataset, name)).convert("RGB") 45 | else: 46 | row = self.ref_dirpath_or_dataset[idx] 47 | ref_image = row["image"].convert("RGB") 48 | name = row["filename"] + ".png" 49 | gen_image = Image.open(os.path.join(self.gen_dirpath, name)).convert("RGB") 50 | gen_size = gen_image.size 51 | ref_size = ref_image.size 52 | if ref_size != gen_size: 53 | ref_image = ref_image.resize(gen_size, Image.Resampling.BICUBIC) 54 | gen_tensor = self.transform(gen_image) 55 | ref_tensor = self.transform(ref_image) 56 | return [gen_tensor, ref_tensor] 57 | 58 | 59 | def compute_image_similarity_metrics( 60 | ref_dirpath_or_dataset: str | datasets.Dataset, 61 | gen_dirpath: str, 62 | metrics: tuple[str, ...] = ("psnr", "lpips", "ssim"), 63 | batch_size: int = 64, 64 | num_workers: int = 8, 65 | device: str | torch.device = "cuda", 66 | ) -> dict[str, float]: 67 | if len(metrics) == 0: 68 | return {} 69 | metric_names = metrics 70 | metrics: dict[str, torchmetrics.Metric] = {} 71 | for metric_name in metric_names: 72 | if metric_name == "psnr": 73 | metric = PeakSignalNoiseRatio(data_range=(0, 1), reduction="elementwise_mean", dim=(1, 2, 3)).to(device) 74 | elif metric_name == "lpips": 75 | metric = LearnedPerceptualImagePatchSimilarity(normalize=True).to(device) 76 | elif metric_name == "ssim": 77 | metric = StructuralSimilarityIndexMeasure(data_range=(0, 1)).to(device) 78 | else: 79 | raise NotImplementedError(f"Metric {metric_name} is not implemented") 80 | metrics[metric_name] = metric 81 | dataset = MultiImageDataset(gen_dirpath, ref_dirpath_or_dataset) 82 | dataloader = data.DataLoader( 83 | dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, drop_last=False 84 | ) 85 | with torch.no_grad(): 86 | desc = ( 87 | ref_dirpath_or_dataset.config_name 88 | if isinstance(ref_dirpath_or_dataset, datasets.Dataset) 89 | else os.path.basename(ref_dirpath_or_dataset) 90 | ) + " similarity metrics" 91 | for batch in tqdm(dataloader, desc=desc): 92 | batch = [tensor.to(device) for tensor in batch] 93 | for metric in metrics.values(): 94 | metric.update(batch[0], batch[1]) 95 | result = {metric_name: metric.compute().item() for metric_name, metric in metrics.items()} 96 | return result 97 | -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .config import DiffusionPipelineConfig 4 | -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/quant/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .activation import quantize_diffusion_activations 4 | from .config import DiffusionQuantCacheConfig, DiffusionQuantConfig 5 | from .quantizer import DiffusionActivationQuantizer, DiffusionWeightQuantizer 6 | from .rotate import rotate_diffusion 7 | from .smooth import smooth_diffusion 8 | from .weight import load_diffusion_weights_state_dict, quantize_diffusion_weights 9 | -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/quant/quantizer/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .config import DiffusionModuleQuantizerConfig 4 | from .quantizer import DiffusionActivationQuantizer, DiffusionWeightQuantizer 5 | -------------------------------------------------------------------------------- /deepcompressor/app/diffusion/quant/utils.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ..nn.struct import DiffusionAttentionStruct, DiffusionFeedForwardStruct, DiffusionModelStruct 7 | from .config import DiffusionQuantConfig 8 | 9 | __all__ = ["get_needs_inputs_fn", "get_needs_outputs_fn", "wrap_joint_attn"] 10 | 11 | 12 | def wrap_joint_attn(attn: nn.Module, /, *, indexes: int | tuple[int, ...] = 1) -> tp.Callable: 13 | if isinstance(indexes, int): 14 | 15 | def eval(*args, **kwargs) -> torch.Tensor: 16 | return attn(*args, **kwargs)[indexes] 17 | 18 | else: 19 | 20 | def eval(*args, **kwargs) -> tuple[torch.Tensor, ...]: 21 | tensors = attn(*args, **kwargs) 22 | result = torch.concat([tensors[i] for i in indexes], dim=-2) 23 | return result 24 | 25 | return eval 26 | 27 | 28 | def get_needs_inputs_fn( 29 | model: DiffusionModelStruct, config: DiffusionQuantConfig 30 | ) -> tp.Callable[[str, nn.Module], bool]: 31 | """Get function that checks whether the module needs to cache inputs. 32 | 33 | Args: 34 | model (`DiffusionModelStruct`): 35 | The diffused model. 36 | config (`DiffusionQuantConfig`): 37 | The quantization configuration. 38 | 39 | Returns: 40 | `Callable[[str, nn.Module], bool]`: 41 | The function that checks whether the module needs to cache inputs. 42 | """ 43 | 44 | needs_inputs_names = set() 45 | for module_key, module_name, _, parent, field_name in model.named_key_modules(): 46 | if (config.enabled_wgts and config.wgts.is_enabled_for(module_key)) or ( 47 | config.enabled_ipts and config.ipts.is_enabled_for(module_key) 48 | ): 49 | if isinstance(parent, DiffusionAttentionStruct): 50 | if field_name.endswith("o_proj"): 51 | needs_inputs_names.add(module_name) 52 | elif field_name in ("q_proj", "k_proj", "v_proj"): 53 | needs_inputs_names.add(parent.q_proj_name) 54 | if parent.parent.parallel and parent.idx == 0: 55 | needs_inputs_names.add(parent.parent.name) 56 | else: 57 | needs_inputs_names.add(parent.name) 58 | elif field_name in ("add_q_proj", "add_k_proj", "add_v_proj"): 59 | needs_inputs_names.add(parent.add_k_proj_name) 60 | if parent.parent.parallel and parent.idx == 0: 61 | needs_inputs_names.add(parent.parent.name) 62 | else: 63 | needs_inputs_names.add(parent.name) 64 | else: 65 | raise RuntimeError(f"Unknown field name: {field_name}") 66 | elif isinstance(parent, DiffusionFeedForwardStruct): 67 | if field_name == "up_proj": 68 | needs_inputs_names.update(parent.up_proj_names[: parent.config.num_experts]) 69 | elif field_name == "down_proj": 70 | needs_inputs_names.update(parent.down_proj_names[: parent.config.num_experts]) 71 | else: 72 | raise RuntimeError(f"Unknown field name: {field_name}") 73 | else: 74 | needs_inputs_names.add(module_name) 75 | 76 | def needs_inputs(name: str, module: nn.Module) -> bool: 77 | return name in needs_inputs_names 78 | 79 | return needs_inputs 80 | 81 | 82 | def get_needs_outputs_fn( 83 | model: DiffusionModelStruct, config: DiffusionQuantConfig 84 | ) -> tp.Callable[[str, nn.Module], bool]: 85 | """Get function that checks whether the module needs to cache outputs. 86 | 87 | Args: 88 | model (`DiffusionModelStruct`): 89 | The diffused model. 90 | config (`DiffusionQuantConfig`): 91 | The quantization configuration. 92 | 93 | Returns: 94 | `Callable[[str, nn.Module], bool]`: 95 | The function that checks whether the module needs to cache outputs. 96 | """ 97 | 98 | # TODO: Implement the function that checks whether the module needs to cache outputs. 99 | 100 | def needs_outputs(name: str, module: nn.Module) -> bool: 101 | return False 102 | 103 | return needs_outputs 104 | -------------------------------------------------------------------------------- /deepcompressor/app/llm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /deepcompressor/app/llm/cache/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /deepcompressor/app/llm/cache/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """LLM quantization cache configuration.""" 3 | 4 | from dataclasses import dataclass, field 5 | 6 | from omniconfig import configclass 7 | 8 | from deepcompressor.utils.config.path import BasePathConfig 9 | 10 | __all__ = ["LlmQuantCacheConfig", "LlmCacheConfig"] 11 | 12 | 13 | @configclass 14 | @dataclass 15 | class LlmQuantCacheConfig(BasePathConfig): 16 | """Large language model quantization cache path. 17 | 18 | Args: 19 | rotation (`str`, *optional*, default=`""`): 20 | The rotation matrix cache path. 21 | reorder (`str`, *optional*, default=`""`): 22 | The reorder channel indexes cache path. 23 | smooth (`str`, *optional*, default=`""`): 24 | The smoothing scales cache path. 25 | wgts (`str`, *optional*, default=`""`): 26 | The weight quantizers state dict cache path. 27 | acts (`str`, *optional*, default=`""`): 28 | The activation quantizers state dict cache path. 29 | """ 30 | 31 | rotation: str = "" 32 | reorder: str = "" 33 | smooth: str = "" 34 | wgts: str = "" 35 | acts: str = "" 36 | 37 | 38 | @configclass 39 | @dataclass 40 | class LlmCacheConfig: 41 | """LLM quantization cache configuration. 42 | 43 | Attributes: 44 | root (`str`, *optional*, default=`""`): 45 | The root directory path for the cache. 46 | dirpath (`LlmQuantCacheConfig`, *optional*, default=`LlmQuantCacheConfig()`): 47 | The directory paths for the cache. 48 | path (`LlmQuantCacheConfig`, *optional*, default=`LlmQuantCacheConfig()`): 49 | The file paths for the cache. 50 | """ 51 | 52 | root: str = field(default="") 53 | dirpath: LlmQuantCacheConfig = field(init=False, default_factory=LlmQuantCacheConfig) 54 | path: LlmQuantCacheConfig = field(default_factory=LlmQuantCacheConfig) 55 | -------------------------------------------------------------------------------- /deepcompressor/app/llm/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configurations for evaluating a large language model.""" 3 | 4 | import os 5 | import random 6 | from dataclasses import dataclass, field 7 | 8 | import numpy as np 9 | import omniconfig 10 | import torch 11 | from omniconfig import ConfigParser, configclass 12 | 13 | from deepcompressor.data.utils import ScaleUtils 14 | from deepcompressor.utils.config.output import OutputConfig 15 | 16 | from .cache.config import LlmCacheConfig, LlmQuantCacheConfig 17 | from .eval.config import LlmEvalConfig 18 | from .model.config import LlmModelConfig 19 | from .quant.config import LlmQuantConfig 20 | 21 | __all__ = [ 22 | "LlmPtqRunConfig", 23 | "LlmCacheConfig", 24 | "LlmQuantCacheConfig", 25 | "LlmEvalConfig", 26 | "LlmModelConfig", 27 | "LlmQuantConfig", 28 | ] 29 | 30 | 31 | @configclass 32 | @dataclass 33 | class LlmPtqRunConfig: 34 | """Top-level config of post-training quantization for a large language model. 35 | 36 | Args: 37 | cache (`LlmCacheConfig`): 38 | Large language model quantization cache path configuration. 39 | output (`OutputConfig`): 40 | Output directory configuration. 41 | model (`LlmModelConfig`): 42 | Large language model configuration. 43 | eval (`LlmEvalConfig`): 44 | Large language model evaluation configuration. 45 | quant (`LlmQuantConfig`): 46 | Large language model quantization configuration. 47 | seed (`int`, *optional*, defaults to `12345`): 48 | Random seed. 49 | skip_eval (`bool`, *optional*, defaults to `False`): 50 | Whether to skip evaluation. 51 | load_model (`str`, *optional*, defaults to `""`): 52 | Directory path to load the model checkpoint. 53 | save_model (`str`, *optional*, defaults to `""`): 54 | Directory path to save the model checkpoint. 55 | copy_on_save (`bool`, *optional*, defaults to `False`): 56 | Whether to copy the quantization cache on save. 57 | """ 58 | 59 | cache: LlmCacheConfig 60 | output: OutputConfig 61 | model: LlmModelConfig 62 | eval: LlmEvalConfig 63 | quant: LlmQuantConfig = field(metadata={omniconfig.ARGPARSE_KWARGS: {"prefix": ""}}) 64 | seed: int = 12345 65 | skip_eval: bool = False 66 | load_from: str = "" 67 | save_model: str = "" 68 | copy_on_save: bool = False 69 | 70 | def __post_init__(self): # noqa: C901 71 | # region set scale default dtype 72 | if self.quant.enabled_wgts: 73 | self.quant.wgts.scale_dtypes = tuple( 74 | ScaleUtils.infer_scale_dtypes(self.quant.wgts.scale_dtypes, default_dtype=self.model.dtype) 75 | ) 76 | if self.quant.enabled_ipts: 77 | self.quant.ipts.scale_dtypes = tuple( 78 | ScaleUtils.infer_scale_dtypes(self.quant.ipts.scale_dtypes, default_dtype=self.model.dtype) 79 | ) 80 | if self.quant.enabled_opts: 81 | self.quant.opts.scale_dtypes = tuple( 82 | ScaleUtils.infer_scale_dtypes(self.quant.opts.scale_dtypes, default_dtype=self.model.dtype) 83 | ) 84 | # endregion 85 | # region set num_gpus and batch_size for auto parallelism of large models 86 | self.eval.num_gpus = min(torch.cuda.device_count(), self.eval.num_gpus) 87 | if self.model.size < 50: 88 | self.eval.batch_size = min(8, self.eval.batch_size) 89 | elif self.model.size < 100: 90 | self.eval.batch_size = min(4, self.eval.batch_size) 91 | else: 92 | self.eval.batch_size = min(1, self.eval.batch_size) 93 | # endregion 94 | if self.quant.is_enabled(): 95 | if self.cache.path.is_all_empty(): 96 | self.cache.dirpath = self.quant.generate_cache_dirpath( 97 | root=self.cache.root, seed=self.seed, default_dtype=self.model.dtype 98 | ) 99 | self.cache.path = self.cache.dirpath.clone().add_children(f"{self.model.name}.pt") 100 | else: 101 | self.cache.dirpath = self.cache.path.clone().to_dirpath() 102 | if self.output.dirname == "default": 103 | self.output.dirname = self.quant.generate_default_dirname() 104 | self.output.dirpath = os.path.join( 105 | self.output.root, 106 | "llm", 107 | self.model.family, 108 | self.model.name, 109 | *self.quant.generate_dirnames(default_dtype=self.model.dtype)[:-1], 110 | self.quant.generate_calib_dirname(), 111 | self.output.dirname, 112 | ) 113 | random.seed(self.seed) 114 | torch.manual_seed(self.seed) 115 | torch.cuda.manual_seed_all(self.seed) 116 | np.random.seed(self.seed) 117 | 118 | @classmethod 119 | def get_parser(cls) -> ConfigParser: 120 | """Get a parser for evaluating a large language model. 121 | 122 | Returns: 123 | `ConfigParser`: A parser for evaluating a large language model. 124 | """ 125 | parser = ConfigParser("Evaluate a large language model") 126 | parser.add_config(cls) 127 | return parser 128 | -------------------------------------------------------------------------------- /deepcompressor/app/llm/eval/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /deepcompressor/app/llm/eval/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Language model evaluator base.""" 3 | 4 | from abc import ABC, abstractmethod 5 | 6 | from transformers import PreTrainedModel, PreTrainedTokenizer 7 | 8 | __all__ = ["LlmEvaluatorBase"] 9 | 10 | 11 | class LlmEvaluatorBase(ABC): 12 | def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer): 13 | self.model, self.tokenizer = model, tokenizer 14 | 15 | @abstractmethod 16 | def filter_tasks(self, tasks: list[str]) -> list[str]: 17 | """Filter the tasks to only include supported tasks.""" 18 | ... 19 | 20 | @abstractmethod 21 | def evaluate(self, tasks: list[str], **kwargs) -> dict[str, dict[str, dict[str, float]]]: 22 | """Evaluate the model on the given tasks.""" 23 | ... 24 | -------------------------------------------------------------------------------- /deepcompressor/app/llm/eval/custom.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Language model customized evaluator.""" 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | from datasets import load_dataset 9 | from tqdm import tqdm 10 | from transformers import PreTrainedModel, PreTrainedTokenizer 11 | 12 | from .base import LlmEvaluatorBase 13 | 14 | __all__ = ["LlmCustomEvaluator"] 15 | 16 | 17 | class LlmCustomEvaluator(LlmEvaluatorBase): 18 | def filter_tasks(self, tasks: list[str]) -> list[str]: 19 | """Filter the tasks to only include supported tasks.""" 20 | return [task for task in tasks if task.startswith(("wikitext", "pile"))] 21 | 22 | def evaluate( 23 | self, tasks: list[str], max_length: int | None = None, **kwargs 24 | ) -> dict[str, dict[str, dict[str, float]]]: 25 | """Evaluate the model on the given tasks. 26 | 27 | Args: 28 | tasks (`list[str]`): List of tasks to evaluate on. 29 | max_length (`int`, optional, defaults to `None`): Maximum length for the model. 30 | 31 | Returns: 32 | dict[str, dict[str, dict[str, float]]]: Evaluation results `{"results": {"task": {"metric": score}}}`. 33 | """ 34 | result = {"results": {}, "versions": {}} 35 | for task in tasks: 36 | result["results"][task] = { 37 | "word_perplexity": _eval_ppl_with_gptq_evaluator( 38 | self.model, self.tokenizer, task=task, seq_length=max_length 39 | ) 40 | } 41 | result["versions"][task] = 1 42 | return result 43 | 44 | 45 | def _eval_ppl_with_gptq_evaluator( 46 | model: PreTrainedModel, 47 | /, 48 | tokenizer: PreTrainedTokenizer, 49 | task: str, 50 | seq_length: int = 2048, 51 | max_num_samples: int = -1, 52 | ) -> float: 53 | """Evaluate the perplexity of a model on a task using GPTQ style evaluation. 54 | 55 | Args: 56 | model (`PreTrainedModel`): 57 | The model. 58 | tokenizer (`PreTrainedTokenizer`): 59 | The tokenizer. 60 | task (`str`): 61 | The task name. 62 | seq_length (`int`, *optional*, defaults to `2048`): 63 | The sequence length. 64 | max_num_samples (`int`, *optional*, defaults to `-1`): 65 | The maximum number of samples to evaluate. 66 | 67 | Returns: 68 | float: The perplexity. 69 | """ 70 | assert seq_length > 0, "seq_length must be positive" 71 | if task.startswith("wikitext"): 72 | test_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") 73 | test_dataset = tokenizer("\n\n".join(test_dataset["text"]), return_tensors="pt") 74 | elif task.startswith("pile"): 75 | test_dataset = load_dataset("pile", task, split="test") 76 | test_dataset = tokenizer("\n\n".join(test_dataset["text"]), return_tensors="pt") 77 | else: 78 | raise ValueError(f"Invalid task: {task}") 79 | 80 | test_dataset = test_dataset.input_ids.to(model.device) 81 | num_samples = test_dataset.numel() // seq_length 82 | if max_num_samples > 0: 83 | num_samples = min(num_samples, max_num_samples) 84 | model = model.eval() 85 | 86 | nlls = [] 87 | for i in tqdm(range(num_samples), desc=f"evaluating on {task} with seq_length {seq_length}", dynamic_ncols=True): 88 | batch = test_dataset[:, (i * seq_length) : ((i + 1) * seq_length)] 89 | with torch.inference_mode(): 90 | shift_logits = model(batch.to(model.device)).logits[:, :-1, :].contiguous().float() 91 | shift_labels = batch[:, 1:] 92 | loss = nn.CrossEntropyLoss()(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 93 | neg_log_likelihood = loss.float() * seq_length 94 | nlls.append(neg_log_likelihood) 95 | return math.exp(sum(nlls) / (num_samples * seq_length)) 96 | -------------------------------------------------------------------------------- /deepcompressor/app/llm/eval/lm_eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Language model evaluator using lm_eval.""" 3 | 4 | import lm_eval 5 | import lm_eval.models 6 | from transformers import PreTrainedModel, PreTrainedTokenizer 7 | 8 | from .base import LlmEvaluatorBase 9 | 10 | __all__ = ["LmevalEvaluator"] 11 | 12 | 13 | class LmevalEvaluator(LlmEvaluatorBase): 14 | def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, batch_size: int = 1): 15 | super().__init__(model=model, tokenizer=tokenizer) 16 | self.lm = lm_eval.models.huggingface.HFLM(pretrained=model, tokenizer=tokenizer, batch_size=batch_size) 17 | 18 | def filter_tasks(self, tasks: list[str]) -> list[str]: 19 | """Filter the tasks to only include supported tasks.""" 20 | return tasks 21 | 22 | def evaluate( 23 | self, 24 | tasks: list[str], 25 | max_length: int | None = None, 26 | num_shot: int | None = None, 27 | fewshot_as_multiturn: bool = False, 28 | apply_chat_template: bool = False, 29 | **kwargs, 30 | ) -> dict[str, dict[str, dict[str, float]]]: 31 | """Evaluate the model on the given tasks. 32 | 33 | Args: 34 | tasks (`list[str]`): List of tasks to evaluate on. 35 | max_length (`int`, optional, defaults to `None`): Maximum length for the model. 36 | 37 | Returns: 38 | dict[str, dict[str, dict[str, float]]]: Evaluation results `{"results": {"task": {"metric": score}}}`. 39 | """ 40 | self.lm._max_length = max_length 41 | result = lm_eval.evaluator.simple_evaluate( 42 | model=self.lm, 43 | tasks=tasks, 44 | verbosity="ERROR", 45 | num_fewshot=num_shot, 46 | fewshot_as_multiturn=fewshot_as_multiturn, 47 | apply_chat_template=apply_chat_template, 48 | **kwargs, 49 | ) 50 | self.lm._max_length = None 51 | result.pop("samples", None) 52 | result.pop("config", None) 53 | return result 54 | -------------------------------------------------------------------------------- /deepcompressor/app/llm/eval/longbench/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval import LongbenchEvaluator, LongbenchScorer 2 | -------------------------------------------------------------------------------- /deepcompressor/app/llm/eval/longbench/task2prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", 3 | "qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", 4 | "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 5 | "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:", 6 | "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 7 | "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 8 | "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 9 | "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:", 10 | "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:", 11 | "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:", 12 | "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:", 13 | "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:", 14 | "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}", 15 | "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}", 16 | "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}", 17 | "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}", 18 | "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ", 19 | "passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ", 20 | "passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:", 21 | "lcc": "Please complete the code given below. \n{context}Next line of code:\n", 22 | "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n" 23 | } -------------------------------------------------------------------------------- /deepcompressor/app/llm/model/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /deepcompressor/app/llm/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .struct import LlmModelStruct, LlmTransformerBlockStruct, LlmTransformerStruct 4 | -------------------------------------------------------------------------------- /deepcompressor/app/llm/quant/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .activation import quantize_llm_activations 4 | from .config import LlmQuantCacheConfig, LlmQuantConfig 5 | from .quantizer import LlmActivationQuantizer, LlmWeightQuantizer 6 | from .reorder import reorder_llm 7 | from .rotate import rotate_llm 8 | from .smooth import smooth_llm 9 | from .weight import quantize_llm_weights 10 | -------------------------------------------------------------------------------- /deepcompressor/app/llm/quant/quantizer/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .config import LlmModuleQuantizerConfig 4 | from .quantizer import LlmActivationQuantizer, LlmWeightQuantizer 5 | -------------------------------------------------------------------------------- /deepcompressor/app/llm/quant/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """LLM quantization utils module.""" 3 | 4 | import typing as tp 5 | 6 | import torch.nn as nn 7 | 8 | from ..nn.struct import LlmModelStruct 9 | from .quantizer.config import LlmModuleQuantizerConfig 10 | 11 | __all__ = ["get_needs_inputs_fn", "get_needs_outputs_fn"] 12 | 13 | 14 | def get_needs_inputs_fn(model: LlmModelStruct, config: LlmModuleQuantizerConfig) -> tp.Callable[[str, nn.Module], bool]: 15 | """Get function that checks if the module needs to cache the inputs. 16 | 17 | Args: 18 | model (`LlmStruct`): 19 | Model struct. 20 | config (`LlmModuleQuantizerConfig`): 21 | Module quantization config. 22 | 23 | Returns: 24 | `Callable[[str, nn.Module], bool]`: 25 | Function to check if the module needs to cache the inputs. 26 | """ 27 | 28 | needs_inputs_names = set() 29 | 30 | example_layer = model.backbone_struct.layer_structs[0] 31 | attn, ffn = example_layer.attn_struct, example_layer.ffn_struct 32 | if (config.enabled_wgts and config.wgts.is_enabled_for(attn.qkv_proj_key)) or ( 33 | config.enabled_ipts and config.ipts.is_enabled_for(attn.qkv_proj_key) 34 | ): 35 | needs_inputs_names.add(attn.rname) 36 | needs_inputs_names.add(attn.v_proj_rname) 37 | if (config.enabled_wgts and config.wgts.is_enabled_for(attn.out_proj_key)) or ( 38 | config.enabled_ipts and config.ipts.is_enabled_for(attn.out_proj_key) 39 | ): 40 | needs_inputs_names.add(attn.o_proj_rname) 41 | if (config.enabled_wgts and config.wgts.is_enabled_for(ffn.up_proj_key)) or ( 42 | config.enabled_ipts and config.ipts.is_enabled_for(ffn.up_proj_key) 43 | ): 44 | needs_inputs_names.add(ffn.rname) 45 | needs_inputs_names.add(ffn.up_proj_rnames[0]) 46 | if (config.enabled_wgts and config.wgts.is_enabled_for(ffn.down_proj_key)) or ( 47 | config.enabled_ipts and config.ipts.is_enabled_for(ffn.down_proj_key) 48 | ): 49 | needs_inputs_names.add(ffn.down_proj_rnames[0]) 50 | if config.enabled_opts: 51 | needs_inputs_names.add(attn.rname) 52 | 53 | needs_inputs_names = tuple(needs_inputs_names) 54 | 55 | def needs_inputs(name: str, module: nn.Module) -> bool: 56 | return name.endswith(needs_inputs_names) 57 | 58 | return needs_inputs 59 | 60 | 61 | def get_needs_outputs_fn( 62 | model: LlmModelStruct, config: LlmModuleQuantizerConfig 63 | ) -> tp.Callable[[str, nn.Module], bool]: 64 | """Get function that checks if the module needs to cache the outputs. 65 | 66 | Args: 67 | model (`LlmStruct`): 68 | Model struct. 69 | config (`LlmModuleQuantizerConfig`): 70 | Module quantization config. 71 | 72 | Returns: 73 | `Callable[[str, nn.Module], bool]`: 74 | Function to check if the module needs to cache the outputs. 75 | """ 76 | 77 | attn = model.backbone_struct.layer_structs[0].attn_struct 78 | needs_outputs_names = set() 79 | if config.enabled_opts: 80 | needs_outputs_names.add(attn.q_rname) 81 | needs_outputs_names.add(attn.k_rname) 82 | needs_outputs_names.add(attn.v_rname) 83 | needs_outputs_names = tuple(needs_outputs_names) 84 | 85 | def needs_outputs(name: str, module: nn.Module) -> bool: 86 | return name.endswith(needs_outputs_names) 87 | 88 | return needs_outputs 89 | -------------------------------------------------------------------------------- /deepcompressor/backend/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/deepcompressor/69f3473f5e1c1504bae35cc50c7858ef900a9b17/deepcompressor/backend/__init__.py -------------------------------------------------------------------------------- /deepcompressor/backend/nunchaku/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/deepcompressor/69f3473f5e1c1504bae35cc50c7858ef900a9b17/deepcompressor/backend/nunchaku/__init__.py -------------------------------------------------------------------------------- /deepcompressor/backend/qserve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/deepcompressor/69f3473f5e1c1504bae35cc50c7858ef900a9b17/deepcompressor/backend/qserve/__init__.py -------------------------------------------------------------------------------- /deepcompressor/backend/tinychat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/deepcompressor/69f3473f5e1c1504bae35cc50c7858ef900a9b17/deepcompressor/backend/tinychat/__init__.py -------------------------------------------------------------------------------- /deepcompressor/backend/tinychat/csrc/load.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """TinyChat Extension.""" 3 | 4 | import os 5 | 6 | from torch.utils.cpp_extension import load 7 | 8 | __all__ = ["_C"] 9 | 10 | dirpath = os.path.dirname(__file__) 11 | 12 | _C = load( 13 | name="deepcompressor_tinychat_C", 14 | sources=[ 15 | f"{dirpath}/pybind.cpp", 16 | f"{dirpath}/quantization/gemv/gemv_cuda.cu", 17 | f"{dirpath}/quantization/gemm/gemm_cuda.cu", 18 | ], 19 | extra_cflags=["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++20"], 20 | extra_cuda_cflags=[ 21 | "-O3", 22 | "-std=c++20", 23 | "-U__CUDA_NO_HALF_OPERATORS__", 24 | "-U__CUDA_NO_HALF_CONVERSIONS__", 25 | "-U__CUDA_NO_HALF2_OPERATORS__", 26 | "-U__CUDA_NO_HALF2_CONVERSIONS__", 27 | "-U__CUDA_NO_BFLOAT16_OPERATORS__", 28 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 29 | "-U__CUDA_NO_BFLOAT162_OPERATORS__", 30 | "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", 31 | "--expt-relaxed-constexpr", 32 | "--expt-extended-lambda", 33 | "--use_fast_math", 34 | "--ptxas-options=--allow-expensive-optimizations=true", 35 | "--threads=8", 36 | ], 37 | ) 38 | -------------------------------------------------------------------------------- /deepcompressor/backend/tinychat/csrc/pybind.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "quantization/gemm/gemm_cuda.h" 4 | #include "quantization/gemv/gemv_cuda.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 7 | { 8 | m.def("awq_gemm_forward_cuda", &awq_gemm_forward_cuda, "AWQ quantized GEMM kernel."); 9 | m.def("awq_gemv_forward_cuda", &awq_gemv_forward_cuda, "AWQ quantized GEMV kernel."); 10 | } 11 | -------------------------------------------------------------------------------- /deepcompressor/backend/tinychat/csrc/quantization/gemm/gemm_cuda.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | torch::Tensor awq_gemm_forward_cuda( 4 | torch::Tensor _in_feats, 5 | torch::Tensor _kernel, 6 | torch::Tensor _scales, 7 | torch::Tensor _zeros); 8 | -------------------------------------------------------------------------------- /deepcompressor/backend/tinychat/csrc/quantization/gemm/semaphore.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | * 30 | **************************************************************************************************/ 31 | /*! \file 32 | \brief Implementation of a CTA-wide semaphore for inter-CTA synchronization. 33 | */ 34 | 35 | #pragma once 36 | 37 | ///////////////////////////////////////////////////////////////////////////////////////////////// 38 | 39 | // namespace cutlass { 40 | 41 | ///////////////////////////////////////////////////////////////////////////////////////////////// 42 | 43 | /// CTA-wide semaphore for inter-CTA synchronization. 44 | class Semaphore 45 | { 46 | public: 47 | int *lock; 48 | bool wait_thread; 49 | int state; 50 | 51 | public: 52 | /// Implements a semaphore to wait for a flag to reach a given value 53 | __host__ __device__ Semaphore(int *lock_, int thread_id) : lock(lock_), 54 | wait_thread(thread_id < 0 || thread_id == 0), 55 | state(-1) 56 | { 57 | } 58 | 59 | /// Permit fetching the synchronization mechanism early 60 | __device__ void fetch() 61 | { 62 | if (wait_thread) 63 | { 64 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 65 | asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); 66 | #else 67 | asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); 68 | #endif 69 | } 70 | } 71 | 72 | /// Gets the internal state 73 | __device__ int get_state() const 74 | { 75 | return state; 76 | } 77 | 78 | /// Waits until the semaphore is equal to the given value 79 | __device__ void wait(int status = 0) 80 | { 81 | while (__syncthreads_and(state != status)) 82 | { 83 | fetch(); 84 | } 85 | 86 | __syncthreads(); 87 | } 88 | 89 | /// Updates the lock with the given result 90 | __device__ void release(int status = 0) 91 | { 92 | __syncthreads(); 93 | 94 | if (wait_thread) 95 | { 96 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 97 | asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); 98 | #else 99 | asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); 100 | #endif 101 | } 102 | } 103 | }; 104 | 105 | ///////////////////////////////////////////////////////////////////////////////////////////////// 106 | 107 | // } // namespace cutlass 108 | 109 | ///////////////////////////////////////////////////////////////////////////////////////////////// 110 | -------------------------------------------------------------------------------- /deepcompressor/backend/tinychat/csrc/quantization/gemv/gemv_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | torch::Tensor awq_gemv_forward_cuda( 5 | torch::Tensor _in_feats, 6 | torch::Tensor _kernel, 7 | torch::Tensor _scaling_factors, 8 | torch::Tensor _zeros, 9 | int m, 10 | int n, 11 | int k, 12 | int group_size); 13 | -------------------------------------------------------------------------------- /deepcompressor/backend/tinychat/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """TinyChat backend utilities.""" 3 | 4 | import torch 5 | 6 | from ..utils import ceil_divide 7 | 8 | __all__ = ["ceil_num_groups", "convert_to_tinychat_w4x16y16_linear_weight"] 9 | 10 | 11 | def ceil_num_groups(in_features: int, group_size: int, weight_bits: int = 4) -> int: 12 | """Calculate the ceiling number of quantization groups. 13 | 14 | Args: 15 | in_features (`int`): 16 | input channel size. 17 | group_size (`int`): 18 | quantization group size. 19 | weight_bits (`int`, *optional*, defaults to `4`): 20 | quantized weight bits. 21 | 22 | Returns: 23 | `int`: 24 | ceiling number of quantization groups. 25 | """ 26 | assert in_features % group_size == 0, "input channel size should be divisible by group size." 27 | num_groups = in_features // group_size 28 | assert weight_bits in (4, 2, 1), "weight bits should be 4, 2, or 1." 29 | pack_size = 32 // weight_bits # one INT32 contains `pack_size` elements of weights 30 | num_packs = ceil_divide(num_groups, pack_size) 31 | if group_size >= 128: 32 | num_packs_factor = 1 33 | elif group_size == 64: 34 | num_packs_factor = 2 35 | elif group_size == 32: 36 | num_packs_factor = 4 37 | else: 38 | raise NotImplementedError 39 | # make sure num_packs is a multiple of num_packs_factor 40 | num_packs = ceil_divide(num_packs, num_packs_factor) * num_packs_factor 41 | num_groups = num_packs * pack_size 42 | return num_groups 43 | 44 | 45 | def pack_w4(weight: torch.Tensor) -> torch.Tensor: 46 | assert weight.dtype == torch.int32, f"quantized weight should be torch.int32, but got {weight.dtype}." 47 | oc, ic = weight.shape 48 | assert ic % 32 == 0, "input channel size should be divisible by 32." 49 | # [0, 1, ..., 31] -> [0, 8, 16, 24, 1, 9, 17, 25, ..., 7, 15, 23, 31] 50 | weight = weight.view(-1, 4, 8) 51 | weight = weight[:, 0] | (weight[:, 1] << 4) | (weight[:, 2] << 8) | (weight[:, 3] << 12) 52 | weight = weight.view(oc // 4, 4, ic // 64, 16).permute(0, 2, 1, 3).reshape(oc // 4, ic) 53 | return weight.to(torch.int16) 54 | 55 | 56 | def convert_to_tinychat_w4x16y16_linear_weight( 57 | weight: torch.Tensor, 58 | scale: torch.Tensor, 59 | zero: torch.Tensor, 60 | zero_pre_scaled: bool = False, 61 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 62 | """Convert a weight tensor to TinyChat W4-X16-Y16 linear weight format. 63 | 64 | Args: 65 | weight (`torch.Tensor`): 66 | weight tensor to be converted. 67 | scale (`torch.Tensor`): 68 | scale tensor for the weight tensor. 69 | zero (`torch.Tensor`): 70 | zero point tensor for the weight tensor. 71 | zero_pre_scaled (`bool`, *optional*, defaults to `False`): 72 | whether zero point tensor is pre-scaled. 73 | 74 | Returns: 75 | `tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: 76 | packed quantized weight tensor, scale tensor, and zero point tensor. 77 | """ 78 | dtype, device = weight.dtype, weight.device 79 | assert dtype in (torch.float16, torch.bfloat16), "currently tinychat only supports fp16 and bf16." 80 | assert scale is not None, "scale tensor is required for quantization." 81 | assert zero is not None, "zero point tensor is required for quantization." 82 | weight = weight.to(dtype=torch.float32) 83 | scale = scale.to(dtype=torch.float32, device=device) 84 | zero = zero.to(dtype=torch.float32, device=device) 85 | if zero_pre_scaled: 86 | zero = zero * scale 87 | oc, ic = weight.shape 88 | if scale.numel() == 1: 89 | scale = scale.view(1, 1).expand(oc, 1) 90 | ng, gs = 1, ic 91 | else: 92 | ng = scale.numel() // oc 93 | gs = ic // ng 94 | scale = scale.reshape(oc, ng).contiguous().view(oc, ng, 1) 95 | assert ic == gs * ng, "input channel size should be equal to group size times number of groups." 96 | if zero.numel() == 1: 97 | zero = zero.view(1, 1).expand(oc, ng) 98 | zero = zero.reshape(oc, ng).contiguous().view(oc, ng, 1) 99 | weight = weight.view(oc, ng, -1).add_(zero).div_(scale).round_().view(oc, ic) 100 | assert weight.min() >= 0 and weight.max() <= 15, "quantized weight should be in [0, 15]." 101 | _weight = pack_w4(weight.to(torch.int32)) 102 | _ng = ceil_num_groups(ic, gs, weight_bits=4) 103 | _scale = torch.zeros((_ng, oc), dtype=dtype, device=device) 104 | _zero = torch.zeros((_ng, oc), dtype=dtype, device=device) 105 | _scale[:ng] = scale.view(oc, ng).t().to(dtype=dtype) 106 | _zero[:ng] = zero.view(oc, ng).t().to(dtype=dtype).neg_() 107 | return _weight, _scale, _zero 108 | -------------------------------------------------------------------------------- /deepcompressor/calib/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /deepcompressor/calib/config/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .lowrank import QuantLowRankCalibConfig, SkipBasedQuantLowRankCalibConfig 4 | from .range import DynamicRangeCalibConfig, SkipBasedDynamicRangeCalibConfig 5 | from .reorder import ChannelOrderCalibConfig, SkipBasedChannelOrderConfig 6 | from .rotation import QuantRotationConfig 7 | from .search import ( 8 | SearchBasedCalibConfig, 9 | SearchBasedCalibGranularity, 10 | SearchBasedCalibObjective, 11 | SearchBasedCalibStrategy, 12 | ) 13 | from .smooth import SkipBasedSmoothCalibConfig, SmoothCalibConfig, SmoothSpanMode, SmoothTransfomerConfig 14 | -------------------------------------------------------------------------------- /deepcompressor/calib/config/lowrank.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Quantization SVD calibration configuration.""" 3 | 4 | from dataclasses import dataclass, field 5 | 6 | from omniconfig import configclass 7 | 8 | from ...quantizer.config import QuantLowRankConfig 9 | from ...utils.common import num2str 10 | from ...utils.config import SkipBasedConfig 11 | from .search import SearchBasedCalibConfig, SearchBasedCalibGranularity, SearchBasedCalibStrategy 12 | 13 | __all__ = ["QuantLowRankCalibConfig", "SkipBasedQuantLowRankCalibConfig"] 14 | 15 | 16 | @configclass 17 | @dataclass 18 | class QuantLowRankCalibConfig(SearchBasedCalibConfig, QuantLowRankConfig): 19 | """Configuration for quantization low-rank branch calibration. 20 | 21 | Args: 22 | rank (`int`, *optional*, defaults to `32`): 23 | The rank of the low-rank branch. 24 | exclusive (`bool`, *optional*, defaults to `False`): 25 | Whether to use exclusive low-rank branch for each weight sharing the inputs. 26 | compensate (`bool`, *optional*, defaults to `False`): 27 | Whether the low-rank branch compensates the quantization error. 28 | degree (`int`, *optional*, default=`2`): 29 | The power degree for the quantization error. Defaults to `2`. 30 | objective (`SearchBasedCalibObjective`, *optional*, default=`SearchBasedCalibObjective.OutputsError`): 31 | The objective for quantization calibration. 32 | sample_batch_size (`int`, *optional*, default=`-1`): 33 | The samples batch size for calibration. 34 | sample_size (`int`, *optional*, default=`-1`): 35 | The calibration sample size. 36 | outputs_device (`str`, *optional*, default=`"cpu"`): 37 | The device to store the precomputed outputs of the module. 38 | num_iters (`int`, *optional*, default=`1`): 39 | The number of iterations. 40 | early_stop (`bool`, *optional*, default=`False`): 41 | Whether to stop the calibration early. 42 | """ 43 | 44 | granularity: SearchBasedCalibGranularity = field(init=False, default=SearchBasedCalibGranularity.Layer) 45 | element_batch_size: int = field(init=False, default=-1) 46 | element_size: int = field(init=False, default=-1) 47 | pre_reshape: bool = field(init=False, default=True) 48 | num_iters: int = 1 49 | early_stop: bool = False 50 | 51 | def __post_init__(self): 52 | if self.strategy != SearchBasedCalibStrategy.Manual: 53 | self.strategy = SearchBasedCalibStrategy.GridSearch 54 | if self.compensate and self.num_iters <= 1: 55 | self.exclusive = True 56 | super().__post_init__() 57 | 58 | def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]: 59 | """Generate the directory names of the configuration. 60 | 61 | Returns: 62 | list[str]: The directory names. 63 | """ 64 | names = super().generate_dirnames(**kwargs) 65 | name = f"i{num2str(self.num_iters)}.r{num2str(self.rank)}" 66 | if self.exclusive: 67 | name += ".exclusive" 68 | if self.compensate: 69 | name += ".compensate" 70 | if self.early_stop and self.num_iters > 1: 71 | name += ".earlystop" 72 | names.append(name) 73 | if prefix: 74 | names = [f"{prefix}.{name}" for name in names] 75 | return names 76 | 77 | 78 | @configclass 79 | @dataclass 80 | class SkipBasedQuantLowRankCalibConfig(SkipBasedConfig, QuantLowRankCalibConfig): 81 | """Configuration for Quantization Low-Rank Branch calibration. 82 | 83 | Args: 84 | rank (`int`, *optional*, defaults to `32`): 85 | The rank of the low-rank branch. 86 | exclusive (`bool`, *optional*, defaults to `False`): 87 | Whether to use exclusive low-rank branch for each weight sharing the inputs. 88 | compensate (`bool`, *optional*, defaults to `False`): 89 | Whether the low-rank branch compensates the quantization error. 90 | degree (`int`, *optional*, default=`2`): 91 | The power degree for the quantization error. Defaults to `2`. 92 | objective (`SearchBasedCalibObjective`, *optional*, default=`SearchBasedCalibObjective.OutputsError`): 93 | The objective for quantization calibration. 94 | sample_batch_size (`int`, *optional*, default=`-1`): 95 | The samples batch size for calibration. 96 | sample_size (`int`, *optional*, default=`-1`): 97 | The calibration sample size. 98 | outputs_device (`str`, *optional*, default=`"cpu"`): 99 | The device to store the precomputed outputs of the module. 100 | num_iters (`int`, *optional*, default=`1`): 101 | The number of iterations. 102 | early_stop (`bool`, *optional*, default=`False`): 103 | Whether to stop the calibration early. 104 | skips (`list[str]`, *optional*, default=`[]`): 105 | The keys of the modules to skip. 106 | """ 107 | 108 | pass 109 | -------------------------------------------------------------------------------- /deepcompressor/calib/config/rotation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Quantization Rotation configuration.""" 3 | 4 | import os 5 | import typing as tp 6 | from dataclasses import dataclass, field 7 | 8 | import omniconfig 9 | from omniconfig import configclass 10 | 11 | __all__ = ["QuantRotationConfig"] 12 | 13 | 14 | @configclass 15 | @dataclass 16 | class QuantRotationConfig: 17 | """Configuration for rotation quantization. 18 | 19 | Args: 20 | name (`str`): 21 | The name of the rotation quantization configuration. If `path` is provided, this is required. 22 | Otherwise, it is set to "random" if `random` is `True`, and "hadamard" otherwise. 23 | path (`str`, *optional*, default=`""`): 24 | The path to the rotation matrix. If provided, `name` must be set. 25 | random (`bool`, *optional*, default=`False`): 26 | Whether to use random hadamard sample as rotation matrix. 27 | transforms (`list[str]`, *optional*, default=`[]`): 28 | The module keys using explicit hadamard transform. 29 | """ 30 | 31 | name: str = "" 32 | path: str = "" 33 | random: bool = False 34 | transforms: list[str] = field(default_factory=list) 35 | 36 | def __post_init__(self) -> None: 37 | self.transforms = sorted(set(self.transforms or [])) 38 | if self.path and os.path.exists(self.path): 39 | assert self.name, "The name of the rotation quantization configuration must be provided." 40 | self.random = False 41 | else: 42 | self.path = "" 43 | self.name = "random" if self.random else "hadamard" 44 | 45 | def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]: 46 | """Get the directory names of the rotation quantization configuration. 47 | 48 | Returns: 49 | list[str]: The directory names of the rotation quantization configuration. 50 | """ 51 | name = self.name 52 | if self.transforms: 53 | name += f".[{'+'.join(self.transforms)}]" 54 | return [f"{prefix}.{name}" if prefix else name] 55 | 56 | @classmethod 57 | def update_get_arguments( 58 | cls: type["QuantRotationConfig"], 59 | *, 60 | overwrites: dict[str, tp.Callable[[omniconfig.Arguments], None] | None] | None = None, 61 | defaults: dict[str, tp.Any] | None = None, 62 | ) -> tuple[dict[str, tp.Callable[[omniconfig.Arguments], None] | None], dict[str, tp.Any]]: 63 | """Get the arguments for the rotation quantization configuration.""" 64 | overwrites = overwrites or {} 65 | defaults = defaults or {} 66 | 67 | collect_fn = omniconfig.ADD_PREFIX_BOOL_FIELDS("transform", **defaults) 68 | 69 | def add_transforms_argument(parser): 70 | collect_fn(parser) 71 | parser.add_argument("--transforms", nargs="+", default=[], help="The keys of the modules to transform.") 72 | 73 | overwrites.setdefault("transforms", add_transforms_argument) 74 | return overwrites, defaults 75 | 76 | @classmethod 77 | def update_from_dict( 78 | cls: type["QuantRotationConfig"], *, parsed_args: dict[str, tp.Any], overwrites: dict[str, tp.Any] 79 | ) -> tuple[dict[str, tp.Any], dict[str, tp.Any]]: 80 | """Create a rotation quantization configuration from the parsed arguments.""" 81 | parsed_args.setdefault("transforms", []).extend(omniconfig.COLLECT_PREFIX_BOOL_FIELDS(parsed_args, "transform")) 82 | return parsed_args, overwrites 83 | -------------------------------------------------------------------------------- /deepcompressor/calib/config/search.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Quantization calibrator configurations.""" 3 | 4 | import enum 5 | from dataclasses import dataclass 6 | 7 | from omniconfig import configclass 8 | 9 | from ...utils.common import num2str 10 | 11 | __all__ = [ 12 | "SearchBasedCalibStrategy", 13 | "SearchBasedCalibGranularity", 14 | "SearchBasedCalibObjective", 15 | "SearchBasedCalibConfig", 16 | ] 17 | 18 | 19 | class SearchBasedCalibStrategy(enum.Enum): 20 | """The strategy for search-based quantization calibration.""" 21 | 22 | Manual = enum.auto() 23 | GridSearch = enum.auto() 24 | # RandomSearch = enum.auto() 25 | # Bayesian = enum.auto() 26 | # EvolutionaryAlgorithm = enum.auto() 27 | # EvolutionaryStrategy = enum.auto() 28 | 29 | 30 | class SearchBasedCalibGranularity(enum.Enum): 31 | """The granularity for search-based quantization calibration.""" 32 | 33 | Group = enum.auto() 34 | ChannelGroup = enum.auto() 35 | Layer = enum.auto() 36 | 37 | 38 | class SearchBasedCalibObjective(enum.Enum): 39 | """The objective for search-based quantization calibration.""" 40 | 41 | TensorError = enum.auto() 42 | """minimize the quantization error of the tensor.""" 43 | ProductsError = enum.auto() 44 | """minimize the error of the the multiplication products.""" 45 | OutputsError = enum.auto() 46 | """minimize the error of the outputs of the evaluation module.""" 47 | 48 | 49 | @configclass 50 | @dataclass 51 | class SearchBasedCalibConfig: 52 | """The base configuration for search-based quantization calibration. 53 | 54 | Args: 55 | degree (`int`, *optional*, default=`2`): 56 | The power degree for the quantization error. Defaults to `2`. 57 | objective (`SearchBasedCalibObjective`, *optional*, default=`SearchBasedCalibObjective.OutputsError`): 58 | The objective for quantization calibration. 59 | strategy (`SearchBasedCalibStrategy`, *optional*, default=`SearchBasedCalibStrategy.Manual`): 60 | The strategy for quantization calibration. 61 | granularity (`SearchBasedCalibGranularity`, *optional*, default=`SearchBasedCalibGranularity.Layer`): 62 | The granularity for quantization calibration. 63 | element_batch_size (`int`, *optional*, default=`-1`): 64 | The element batch size for calibration. 65 | sample_batch_size (`int`, *optional*, default=`-1`): 66 | The samples batch size for calibration. 67 | element_size (`int`, *optional*, default=`-1`): 68 | The calibration element size. 69 | sample_size (`int`, *optional*, default=`-1`): 70 | The calibration sample size. 71 | pre_reshape (`bool`, *optional*, default=`True`): 72 | Whether to enable reshaping the tensor before calibration. 73 | outputs_device (`str`, *optional*, default=`"cpu"`): 74 | The device to store the precomputed outputs of the module. 75 | """ 76 | 77 | degree: int = 2 78 | objective: SearchBasedCalibObjective = SearchBasedCalibObjective.OutputsError 79 | strategy: SearchBasedCalibStrategy = SearchBasedCalibStrategy.Manual 80 | granularity: SearchBasedCalibGranularity = SearchBasedCalibGranularity.Layer 81 | element_batch_size: int = -1 82 | sample_batch_size: int = -1 83 | element_size: int = -1 84 | sample_size: int = -1 85 | pre_reshape: bool = True 86 | outputs_device: str = "cpu" 87 | 88 | def __post_init__(self) -> None: 89 | if self.outputs_device != "cpu": 90 | self.outputs_device = None 91 | if self.element_size != 0 or self.sample_size != 0: 92 | assert self.element_batch_size != 0, "element_batch_size must not be zero" 93 | assert self.sample_batch_size != 0, "sample_batch_size must not be zero" 94 | assert self.element_size != 0, "element_size must not be zero" 95 | assert self.sample_size != 0, "sample_size must not be zero" 96 | else: 97 | assert self.objective == SearchBasedCalibObjective.TensorError 98 | if self.objective == SearchBasedCalibObjective.TensorError: 99 | pass 100 | elif self.granularity == SearchBasedCalibGranularity.Layer: 101 | self.objective = SearchBasedCalibObjective.OutputsError 102 | self.element_batch_size = -1 103 | self.element_size = -1 104 | 105 | @property 106 | def needs_search(self) -> bool: 107 | """Whether the search is enabled.""" 108 | return self.strategy != SearchBasedCalibStrategy.Manual 109 | 110 | def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]: 111 | """Generate the directory names of the configuration. 112 | 113 | Args: 114 | prefix (`str`, *optional*, default=`""`): 115 | The prefix of the directory. 116 | 117 | Returns: 118 | `list[str]`: 119 | The directory names. 120 | """ 121 | name = f"{self.objective.name}.{self.strategy.name}.{self.granularity.name}.d{num2str(self.degree)}" 122 | name += f".e{num2str(self.element_size)}.s{num2str(self.sample_size)}" 123 | if prefix: 124 | name = f"{prefix}.{name}" 125 | return [name] 126 | -------------------------------------------------------------------------------- /deepcompressor/csrc/load.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Deepcompressor Extension.""" 3 | 4 | import os 5 | 6 | from torch.utils.cpp_extension import load 7 | 8 | __all__ = ["_C"] 9 | 10 | dirpath = os.path.dirname(__file__) 11 | 12 | _C = load( 13 | name="deepcompressor_C", 14 | sources=[f"{dirpath}/pybind.cpp", f"{dirpath}/quantize/quantize.cu"], 15 | extra_cflags=["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++20"], 16 | extra_cuda_cflags=[ 17 | "-O3", 18 | "-std=c++20", 19 | "-U__CUDA_NO_HALF_OPERATORS__", 20 | "-U__CUDA_NO_HALF_CONVERSIONS__", 21 | "-U__CUDA_NO_HALF2_OPERATORS__", 22 | "-U__CUDA_NO_HALF2_CONVERSIONS__", 23 | "-U__CUDA_NO_BFLOAT16_OPERATORS__", 24 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 25 | "-U__CUDA_NO_BFLOAT162_OPERATORS__", 26 | "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", 27 | "--expt-relaxed-constexpr", 28 | "--expt-extended-lambda", 29 | "--use_fast_math", 30 | "--ptxas-options=--allow-expensive-optimizations=true", 31 | "--threads=8", 32 | ], 33 | ) 34 | -------------------------------------------------------------------------------- /deepcompressor/csrc/pybind.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "quantize/quantize.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("round_to_nearest_in_codebook_cuda", &round_to_nearest_in_codebook_cuda, 8 | py::arg("tensor"), py::arg("codebook"), py::arg("inplace") = false, 9 | py::arg("bnb") = false, "RTN with codebook (CUDA)"); 10 | } 11 | -------------------------------------------------------------------------------- /deepcompressor/csrc/quantize/quantize.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include "quantize.h" 12 | 13 | // The following code is adapted from the bitsandbytes library: 14 | // https://github.com/bitsandbytes-foundation/bitsandbytes/blob/main/csrc/kernels.cu#L232 15 | template 16 | __device__ __forceinline__ 17 | typename std::conditional::type 18 | bnb_nearest_neighbor(float_t x, float_t *codebook, const int C) 19 | { 20 | int mid = (C >> 1) - 1; 21 | int hi = C - 1; 22 | int lo = 0; 23 | 24 | float_t lval = codebook[lo]; 25 | float_t hval = codebook[hi]; 26 | float_t mval = codebook[mid]; 27 | 28 | for (int step = (C >> 2); step > 0; step >>= 1) 29 | { 30 | if (x > mval) 31 | { 32 | lo = mid; 33 | lval = mval; 34 | mid += step; 35 | } 36 | else 37 | { 38 | hi = mid; 39 | hval = mval; 40 | mid -= step; 41 | } 42 | mval = codebook[mid]; 43 | } 44 | 45 | if (x > mval) 46 | { 47 | if constexpr (ret_val) 48 | { 49 | return (x - mval > hval - x) ? hval : mval; 50 | } 51 | else 52 | { 53 | return (x - mval > hval - x) ? hi : mid; 54 | } 55 | } 56 | else 57 | { 58 | if constexpr (ret_val) 59 | { 60 | return (x - lval < mval - x) ? lval : mval; 61 | } 62 | else 63 | { 64 | return (x - lval < mval - x) ? lo : mid; 65 | } 66 | } 67 | } 68 | 69 | template 70 | __device__ __forceinline__ 71 | typename std::conditional::type 72 | nearest_neighbor(float_t x, const float_t *codebook, int C) 73 | { 74 | int lo = 0; 75 | int bit = 1 << (31 - __clz(C)); 76 | 77 | float_t lval = codebook[lo]; 78 | while (bit) 79 | { 80 | int next = lo | bit; 81 | float_t nval = codebook[next]; 82 | bool pred = next < C && nval <= x; 83 | lo = pred ? next : lo; 84 | lval = pred ? nval : lval; 85 | bit >>= 1; 86 | } 87 | 88 | int hi = lo + (lo < C - 1); 89 | float_t hval = codebook[hi]; 90 | 91 | if constexpr (ret_val) 92 | { 93 | return (x + x < lval + hval) ? lval : hval; 94 | } 95 | else 96 | { 97 | return (x + x < lval + hval) ? lo : hi; 98 | } 99 | } 100 | 101 | // CUDA kernel: Each thread processes one element from x and finds the nearest 102 | // codebook entry. The codebook (of size C < 256) is first loaded into shared 103 | // memory. 104 | template 105 | __global__ void round_to_nearest_in_codebook_kernel( 106 | const float_t *__restrict__ x, const float_t *__restrict__ codebook, 107 | float_t *__restrict__ y, const int N, const int C) 108 | { 109 | // Use a shared memory array for the codebook. 110 | __shared__ float_t s_codebook[256]; 111 | 112 | // Have the first few threads load the codebook into shared memory. 113 | for (int i = threadIdx.x; i < C; i += blockDim.x) 114 | { 115 | s_codebook[i] = codebook[i]; 116 | } 117 | __syncthreads(); 118 | 119 | // Global index for the element processed by this thread. 120 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 121 | if (idx < N) 122 | { 123 | if constexpr (bnb) 124 | { 125 | y[idx] = bnb_nearest_neighbor(x[idx], s_codebook, C); 126 | } 127 | else 128 | { 129 | y[idx] = nearest_neighbor(x[idx], s_codebook, C); 130 | } 131 | } 132 | } 133 | 134 | torch::Tensor round_to_nearest_in_codebook_cuda(torch::Tensor tensor, 135 | torch::Tensor codebook, 136 | bool inplace, bool bnb) 137 | { 138 | auto x = tensor.contiguous(); 139 | auto c = codebook.contiguous(); 140 | auto y = inplace ? x : torch::empty_like(tensor); 141 | const int N = x.numel(); 142 | const int C = c.numel(); 143 | const int threads = 256; 144 | const int blocks = (N + threads - 1) / threads; 145 | AT_DISPATCH_FLOATING_TYPES( 146 | tensor.scalar_type(), "round_to_nearest_in_codebook_cuda", [&] 147 | { 148 | if (bnb && (C & (C - 1)) == 0) { 149 | round_to_nearest_in_codebook_kernel 150 | <<>>(x.data_ptr(), 151 | c.data_ptr(), 152 | y.data_ptr(), N, C); 153 | } else { 154 | round_to_nearest_in_codebook_kernel 155 | <<>>(x.data_ptr(), 156 | c.data_ptr(), 157 | y.data_ptr(), N, C); 158 | } }); 159 | return y; 160 | } -------------------------------------------------------------------------------- /deepcompressor/csrc/quantize/quantize.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | torch::Tensor round_to_nearest_in_codebook_cuda(torch::Tensor tensor, 5 | torch::Tensor codebook, 6 | bool inplace = false, 7 | bool bnb = false); 8 | -------------------------------------------------------------------------------- /deepcompressor/data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .dtype import QDType, QuantDataType 4 | from .range import DynamicRange, LogQuantRange, QuantRange, RangeBound 5 | from .scale import QuantScale 6 | from .tensor import QuantTensor 7 | -------------------------------------------------------------------------------- /deepcompressor/data/common.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Common uantization data.""" 3 | 4 | import enum 5 | 6 | __all__ = ["TensorType"] 7 | 8 | 9 | class TensorType(enum.Enum): 10 | """The tensor type.""" 11 | 12 | Weights = enum.auto() 13 | Inputs = enum.auto() 14 | Outputs = enum.auto() 15 | -------------------------------------------------------------------------------- /deepcompressor/data/scale.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Quantization scale module.""" 3 | 4 | import typing as tp 5 | 6 | import torch 7 | 8 | __all__ = ["QuantScale"] 9 | 10 | 11 | class QuantScale: 12 | data: torch.Tensor 13 | _children: list["QuantScale"] 14 | _leaves: list[torch.Tensor] 15 | 16 | def __init__(self): 17 | self.data, self._children, self._leaves = None, [], [] # type: ignore 18 | 19 | @property 20 | def num_children(self) -> int: 21 | """Get the number of children.""" 22 | return len(self._children) 23 | 24 | @property 25 | def num_leaves(self) -> int: 26 | """Get the number of leaves.""" 27 | return len(self._leaves) 28 | 29 | def is_quantized(self) -> bool: 30 | """Check if the scale is quantized.""" 31 | return self.data is not None and bool(self._leaves or all(child.is_quantized() for child in self._children)) 32 | 33 | def get_child(self, index: int) -> "QuantScale": 34 | """Get a child scale.""" 35 | return self._children[index] 36 | 37 | def append(self, scale: tp.Union[torch.Tensor, "QuantScale"]) -> "QuantScale": 38 | """Append a scale.""" 39 | if isinstance(scale, torch.Tensor): 40 | assert not self._children, "Cannot append a tensor scale to a non-leaf QuantScale." 41 | self.data = _join_scale_tensor(self.data, scale) 42 | self._leaves.append(scale) 43 | elif isinstance(scale, QuantScale): 44 | assert not self._leaves, "Cannot append a non-leaf QuantScale to a leaf QuantScale." 45 | self.data = _join_scale_tensor(self.data, scale.data) 46 | self._children.append(scale) 47 | else: 48 | raise TypeError(f"Unsupported scale type: {type(scale)}") 49 | return self 50 | 51 | def extend(self, scale: "QuantScale") -> "QuantScale": 52 | """Extend with another QuantScale.""" 53 | self.data = _join_scale_tensor(self.data, scale.data) 54 | if scale._children: 55 | assert not self._leaves, "Cannot extend a leaf QuantScale with a non-leaf QuantScale." 56 | self._children.extend(scale._children) 57 | elif scale._leaves: 58 | assert not scale._children, "Cannot extend a non-leaf QuantScale with a leaf QuantScale." 59 | self._leaves.extend(scale._leaves) 60 | return self 61 | 62 | def join(self, scale: "QuantScale") -> "QuantScale": 63 | """Return a new QuantScale by joining with another QuantScale.""" 64 | return QuantScale().append(self).append(scale) 65 | 66 | def remove_zero(self) -> "QuantScale": 67 | """Remove zero scales.""" 68 | self.data[self.data == 0] = 1 69 | return self 70 | 71 | def state_dict( 72 | self, 73 | param_name: str, 74 | device: torch.device | str = "cpu", 75 | flatten: bool = True, 76 | level_base: int = 0, 77 | ) -> dict[str, torch.Tensor]: 78 | """Get the state dictionary.""" 79 | if self._children: 80 | state_dict = {} 81 | for i, child in enumerate(self._children): 82 | child_param_name = param_name if flatten else f"{param_name}.{i}" 83 | child_level_base = len(state_dict) if flatten else 0 84 | child_state_dict = child.state_dict(child_param_name, device, flatten, child_level_base) 85 | state_dict.update(child_state_dict) 86 | return state_dict 87 | else: 88 | return {f"{param_name}.{level_base + i}": leaf.to(device) for i, leaf in enumerate(self._leaves)} 89 | 90 | 91 | def _join_scale_tensor(global_scale: torch.Tensor | None, local_scale: torch.Tensor) -> torch.Tensor: 92 | """Multiply the local scale tensor by the global scale tensor. 93 | 94 | Args: 95 | global_scale (`torch.Tensor` or `None`): 96 | Global scale tensor. 97 | local_scale (`torch.Tensor`): 98 | Local scale tensor. 99 | 100 | Returns: 101 | `torch.Tensor`: 102 | The compounded scale tensor. 103 | """ 104 | # global_scale: (#gs_g0, 1, #gs_g1, 1, #gs_g2, 1, ...) 105 | # local_scale: (#ss_g0, 1, #ss_g1, 1, #ss_g2, 1, ...) -> (#gs_g0, rs0, #gs_g1, rs1, #gs_g2, rs2, ...) 106 | shape = local_scale.shape 107 | return ( 108 | local_scale 109 | if global_scale is None 110 | else local_scale.view( 111 | tuple( 112 | global_scale.shape[i] if j == 0 else local_scale.shape[i] // global_scale.shape[i] 113 | for i in range(0, len(global_scale.shape), 2) 114 | for j in range(2) 115 | ) 116 | ).mul(global_scale) 117 | ).view(shape) 118 | -------------------------------------------------------------------------------- /deepcompressor/data/tensor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Quantized tensor module.""" 3 | 4 | import torch 5 | 6 | from .scale import QuantScale 7 | 8 | __all__ = ["QuantTensor"] 9 | 10 | 11 | class QuantTensor: 12 | """Quantized tensor.""" 13 | 14 | _dequantized: torch.Tensor | None 15 | _quantized: torch.Tensor | None 16 | scale: QuantScale | None 17 | zero: torch.Tensor | float | None 18 | view_shape: torch.Size | None 19 | 20 | def __init__( 21 | self, 22 | dequantized: torch.Tensor | None = None, 23 | quantized: torch.Tensor | None = None, 24 | scale: QuantScale | None = None, 25 | zero: torch.Tensor | float | None = None, 26 | view_shape: torch.Size | None = None, 27 | ): 28 | """Initialize the quantized tensor.""" 29 | assert ( 30 | dequantized is not None or quantized is not None 31 | ), "Either the dequantized or quantized tensor must be provided." 32 | self.view_shape = view_shape 33 | self._dequantized = dequantized 34 | self._quantized = quantized 35 | self.scale = scale 36 | self.zero = zero 37 | 38 | @property 39 | def data(self) -> torch.Tensor | None: 40 | """Get the dequantized tensor.""" 41 | return self._dequantized 42 | 43 | @property 44 | def qdata(self) -> torch.Tensor | None: 45 | """Get the quantized tensor.""" 46 | return self._quantized 47 | -------------------------------------------------------------------------------- /deepcompressor/data/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from . import dtype as DtypeUtils 4 | from . import scale as ScaleUtils 5 | from . import shape as ShapeUtils 6 | -------------------------------------------------------------------------------- /deepcompressor/data/utils/dtype.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Utility functions for dtype in quantization.""" 3 | 4 | import torch 5 | 6 | from ..dtype import QuantDataType 7 | 8 | __all__ = ["infer_dtype_bits", "infer_dtype_name", "eval_dtype"] 9 | 10 | 11 | def infer_dtype_bits(dtype: torch.dtype | QuantDataType) -> int: 12 | """Get the number of bits of a torch.dtype or QuantDataType. 13 | 14 | Args: 15 | dtype (`torch.dtype` or `QuantDataType`): 16 | The dtype to get the number of bits of. 17 | 18 | Returns: 19 | `int`: 20 | The number of bits. 21 | """ 22 | if isinstance(dtype, QuantDataType): 23 | return dtype.total_bits 24 | else: 25 | if dtype == torch.float32: 26 | return 32 27 | elif dtype == torch.float16 or dtype == torch.bfloat16: 28 | return 16 29 | elif dtype == torch.float64: 30 | return 64 31 | elif dtype == torch.int32: 32 | return 32 33 | elif dtype == torch.int16: 34 | return 16 35 | elif dtype == torch.int8: 36 | return 8 37 | elif dtype == torch.uint8: 38 | return 8 39 | else: 40 | raise ValueError(f"Unknown dtype {dtype}") 41 | 42 | 43 | def infer_dtype_name(dtype: torch.dtype | QuantDataType) -> str: 44 | """Get the string representation of a torch.dtype or QuantDataType. 45 | 46 | Args: 47 | dtype (`torch.dtype` | `QuantDataType`): 48 | The dtype to get the string representation of. 49 | 50 | Returns: 51 | `str`: 52 | The string representation. 53 | """ 54 | if isinstance(dtype, QuantDataType): 55 | return str(dtype) 56 | elif isinstance(dtype, torch.dtype): 57 | if dtype == torch.float16: 58 | return "fp16" 59 | elif dtype == torch.float32: 60 | return "fp32" 61 | elif dtype == torch.float64: 62 | return "fp64" 63 | elif dtype == torch.bfloat16: 64 | return "bf16" 65 | else: 66 | return str(dtype).split(".")[-1] 67 | else: 68 | raise ValueError(f"Unknown dtype {dtype}") 69 | 70 | 71 | def eval_dtype( # noqa: C901 72 | s: str | torch.dtype | QuantDataType | None, with_quant_dtype: bool = True, with_none: bool = True 73 | ) -> torch.dtype | QuantDataType | None: 74 | if isinstance(s, torch.dtype): 75 | return s 76 | if isinstance(s, QuantDataType): 77 | if with_quant_dtype: 78 | return s 79 | else: 80 | raise ValueError(f"Unknown dtype {s}") 81 | if s is None: 82 | if with_none: 83 | return None 84 | else: 85 | raise ValueError(f"Unknown dtype {s}") 86 | assert isinstance(s, str), f"Unknown dtype {s}" 87 | s = s.lower() 88 | if s in ("torch.float64", "float64", "fp64", "f64", "double"): 89 | return torch.float64 90 | elif s in ("torch.float32", "float32", "fp32", "f32", "single", "float"): 91 | return torch.float32 92 | elif s in ("torch.float16", "float16", "fp16", "f16", "half"): 93 | return torch.float16 94 | elif s in ("torch.bfloat16", "bfloat16", "bf16", "b16", "brain"): 95 | return torch.bfloat16 96 | elif s in ("torch.int64", "int64", "i64", "long"): 97 | return torch.int64 98 | elif s in ("torch.int32", "int32", "i32", "int"): 99 | return torch.int32 100 | elif s in ("torch.int16", "int16", "i16", "short"): 101 | return torch.int16 102 | elif s in ("torch.int8", "int8", "i8", "byte"): 103 | return torch.int8 104 | elif s in ("torch.uint8", "uint8", "u8", "ubyte"): 105 | return torch.uint8 106 | else: 107 | if with_none and s in ("", "none", "null", "nil"): 108 | return None 109 | if with_quant_dtype: 110 | return QuantDataType.from_str(s) 111 | raise ValueError(f"Unknown dtype {s}") 112 | -------------------------------------------------------------------------------- /deepcompressor/data/utils/scale.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Utility functions for quantization scale.""" 3 | 4 | import typing as tp 5 | 6 | import torch 7 | 8 | from ..dtype import QuantDataType 9 | 10 | __all__ = ["infer_scale_dtypes", "infer_scale_quant_spans", "infer_exponent_scale_level"] 11 | 12 | 13 | def infer_scale_dtypes( 14 | scale_dtypes: tp.Sequence[torch.dtype | QuantDataType | None], default_dtype: torch.dtype | QuantDataType 15 | ) -> list[torch.dtype | QuantDataType]: 16 | """Get the scale dtypes for the given tensor dtype. 17 | 18 | Args: 19 | scale_dtypes (`Sequence[torch.dtype | QuantDataType | None]`): 20 | The scale dtypes. 21 | default_dtype (`torch.dtype`): 22 | The default scale dtype. 23 | 24 | Returns: 25 | `list[torch.dtype | QuantDataType]`: 26 | The scale dtypes. 27 | """ 28 | assert isinstance( 29 | default_dtype, (torch.dtype, QuantDataType) 30 | ), f"dtype must be torch.dtype or QuantDataType, got {default_dtype}" 31 | return [s_dtype or default_dtype for s_dtype in scale_dtypes] 32 | 33 | 34 | def infer_scale_quant_spans(scale_dtypes: tp.Sequence[QuantDataType], base: int = 1) -> list[float]: 35 | quant_spans: list[float] = [base] 36 | for s_dtype in reversed(scale_dtypes[1:]): 37 | assert isinstance(s_dtype, QuantDataType), f"s_dtype must be QuantDataType, got {s_dtype}" 38 | quant_spans.append(s_dtype.max_value * quant_spans[-1]) 39 | return list(reversed(quant_spans)) 40 | 41 | 42 | def infer_exponent_scale_level(scale_dtypes: tp.Sequence[torch.dtype | QuantDataType]) -> int: 43 | """Get the exponent scaling level. 44 | 45 | Args: 46 | scale_dtypes (`Sequence[torch.dtype | QuantDataType]`): 47 | The scale data types. 48 | 49 | Returns: 50 | `int`: The exponent scaling level. 51 | """ 52 | for level, scale_dtype in enumerate(scale_dtypes): 53 | if isinstance(scale_dtype, QuantDataType) and scale_dtype.is_exponent: 54 | return level 55 | return len(scale_dtypes) 56 | -------------------------------------------------------------------------------- /deepcompressor/data/zero.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Zero-point for quantization.""" 3 | 4 | import enum 5 | 6 | __all__ = ["ZeroPointDomain"] 7 | 8 | 9 | class ZeroPointDomain(enum.Enum): 10 | """Zero-point domain.""" 11 | 12 | PreScale = enum.auto() 13 | PostScale = enum.auto() 14 | -------------------------------------------------------------------------------- /deepcompressor/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .action import CacheAction, ConcatCacheAction 4 | from .cache import BaseCalibCacheLoader 5 | from .config import BaseDataLoaderConfig 6 | -------------------------------------------------------------------------------- /deepcompressor/dataset/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configuration for collecting calibration dataset for quantization.""" 3 | 4 | from abc import ABC, abstractmethod 5 | from dataclasses import dataclass 6 | 7 | from omniconfig import configclass 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | from .cache import BaseCalibCacheLoader 11 | 12 | __all__ = ["BaseDataLoaderConfig"] 13 | 14 | 15 | @configclass 16 | @dataclass(kw_only=True) 17 | class BaseDataLoaderConfig(ABC): 18 | """Configuration for dataset loader. 19 | 20 | Args: 21 | data (`str`): 22 | Dataset name. 23 | num_samples (`int`): 24 | Number of dataset samples. 25 | batch_size (`int`): 26 | Batch size when loading dataset. 27 | """ 28 | 29 | data: str 30 | num_samples: int 31 | batch_size: int 32 | 33 | def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]: 34 | """Get the names of the configuration fields. 35 | 36 | Args: 37 | prefix (`str`, *optional*): 38 | Prefix for the names. 39 | 40 | Returns: 41 | `list[str]`: 42 | Names of the configuration. 43 | """ 44 | name = f"{self.data}.{self.num_samples}" 45 | return [f"{prefix}.{name}" if prefix else name] 46 | 47 | @abstractmethod 48 | def build_dataset(self, *args, **kwargs) -> Dataset: 49 | """Build dataset.""" 50 | ... 51 | 52 | @abstractmethod 53 | def build_loader(self, *args, **kwargs) -> DataLoader | BaseCalibCacheLoader: 54 | """Build data loader.""" 55 | ... 56 | -------------------------------------------------------------------------------- /deepcompressor/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /deepcompressor/nn/patch/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .conv import * 4 | from .linear import * 5 | from .lowrank import * 6 | from .sdpa import * 7 | -------------------------------------------------------------------------------- /deepcompressor/nn/patch/linear.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Concat Linear Module.""" 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | __all__ = ["ConcatLinear", "ShiftedLinear"] 8 | 9 | 10 | class ConcatLinear(nn.Module): 11 | def __init__( 12 | self, 13 | in_features_list: list[int], 14 | out_features: int, 15 | bias: bool = True, 16 | device=None, 17 | dtype=None, 18 | ) -> None: 19 | super().__init__() 20 | assert len(in_features_list) > 1, "ConcatLinear requires at least 2 input features" 21 | self.in_features_list = in_features_list 22 | self.in_features = sum(in_features_list) 23 | self.out_features = out_features 24 | num_linears = len(in_features_list) 25 | self.linears = nn.ModuleList( 26 | [ 27 | nn.Linear( 28 | in_features, 29 | out_features, 30 | bias if idx == num_linears - 1 else False, 31 | device, 32 | dtype, 33 | ) 34 | for idx, in_features in enumerate(in_features_list) 35 | ] 36 | ) 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | # slice x based on in_features_list 40 | x_splits: list[torch.Tensor] = x.split(self.in_features_list, dim=-1) 41 | # apply each linear to each slice (we have to make contiguous input for quantization) 42 | out_splits = [linear(x_split.contiguous()) for linear, x_split in zip(self.linears, x_splits, strict=True)] 43 | # sum the results 44 | return sum(out_splits) 45 | 46 | @staticmethod 47 | def from_linear(linear: nn.Linear, splits: list[int]) -> "ConcatLinear": 48 | splits.append(linear.in_features - sum(splits)) 49 | splits = [s for s in splits if s > 0] 50 | assert len(splits) > 1, "ConcatLinear requires at least 2 input features" 51 | concat_linear = ConcatLinear( 52 | in_features_list=splits, 53 | out_features=linear.out_features, 54 | bias=linear.bias is not None, 55 | device=linear.weight.device, 56 | dtype=linear.weight.dtype, 57 | ) 58 | used_in_features = 0 59 | for sub_linear in concat_linear.linears: 60 | assert isinstance(sub_linear, nn.Linear) 61 | in_features = sub_linear.in_features 62 | sub_linear.weight.data.copy_(linear.weight[:, used_in_features : used_in_features + in_features]) 63 | used_in_features += in_features 64 | if linear.bias is not None: 65 | assert sub_linear.bias is not None 66 | sub_linear.bias.data.copy_(linear.bias) 67 | return concat_linear 68 | 69 | 70 | class ShiftedLinear(nn.Module): 71 | shift: torch.Tensor 72 | 73 | def __init__( 74 | self, 75 | in_features: int, 76 | out_features: int, 77 | shift: float | torch.Tensor, 78 | bias: bool = True, 79 | device=None, 80 | dtype=None, 81 | ) -> None: 82 | super().__init__() 83 | self.linear = nn.Linear(in_features, out_features, bias, device, dtype) 84 | self.linear.shifted = True 85 | device, dtype = self.linear.weight.device, self.linear.weight.dtype 86 | if not isinstance(shift, torch.Tensor): 87 | shift = torch.tensor(shift, device=device, dtype=dtype) 88 | shift = shift.flatten().to(device=device, dtype=dtype) 89 | shift_features = shift.numel() 90 | if shift_features > 1: 91 | assert in_features >= shift_features and in_features % shift_features == 0 92 | shift = shift.view(-1, 1).expand(-1, in_features // shift_features).flatten() 93 | self.register_buffer("shift", shift) 94 | 95 | @property 96 | def in_features(self) -> int: 97 | return self.linear.in_features 98 | 99 | @property 100 | def out_features(self) -> int: 101 | return self.linear.out_features 102 | 103 | def forward(self, input: torch.Tensor) -> torch.Tensor: 104 | return self.linear(input + self.shift.view([1] * (input.dim() - 1) + [-1])) 105 | 106 | @staticmethod 107 | def from_linear(linear: nn.Linear, shift: float | torch.Tensor) -> "ShiftedLinear": 108 | device, dtype = linear.weight.device, linear.weight.dtype 109 | shifted = ShiftedLinear( 110 | in_features=linear.in_features, 111 | out_features=linear.out_features, 112 | shift=shift, 113 | bias=True, 114 | device=device, 115 | dtype=dtype, 116 | ) 117 | shifted.linear.weight.data.copy_(linear.weight) 118 | shift = shifted.shift 119 | if shift.numel() == 1: 120 | shifted_bias = linear.weight.double().sum(dim=1) * shift.double() 121 | else: 122 | shifted_bias = torch.matmul(linear.weight.double(), shift.view(1, -1).double()) 123 | shifted_bias = shifted_bias.view(shifted.linear.bias.size()) 124 | if linear.bias is not None: 125 | shifted.linear.bias.data.copy_((linear.bias.data.double() - shifted_bias).to(dtype)) 126 | else: 127 | shifted.linear.bias.data.copy_(shifted_bias.to(dtype).neg_()) 128 | return shifted 129 | -------------------------------------------------------------------------------- /deepcompressor/nn/patch/lowrank.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.linalg 5 | import torch.nn as nn 6 | 7 | from ...utils.hooks import AccumBranchHook, BaseInputPackager, BaseOutputPackager 8 | 9 | __all__ = ["LowRankBranch"] 10 | 11 | 12 | class LowRankBranch(nn.Module): 13 | def __init__( 14 | self, in_features: int, out_features: int, rank: int, alpha: float = 1.0, weight: torch.Tensor | None = None 15 | ): 16 | super().__init__() 17 | self.in_features = in_features 18 | self.out_features = out_features 19 | self.rank = rank 20 | self.alpha = alpha 21 | if rank == 0: 22 | self.a, self.b = None, None 23 | elif rank < 0: 24 | self.a, self.b = nn.Linear(in_features, out_features, bias=False), nn.Identity() 25 | else: 26 | self.a, self.b = nn.Linear(in_features, rank, bias=False), nn.Linear(rank, out_features, bias=False) 27 | self.reset_parameters(weight) 28 | 29 | @torch.no_grad() 30 | def reset_parameters(self, weight: torch.Tensor | None = None) -> None: 31 | if weight is None: 32 | if self.rank < 0: 33 | nn.init.zeros_(self.a.weight) 34 | elif self.rank > 0: 35 | nn.init.kaiming_uniform_(self.a.weight) 36 | nn.init.zeros_(self.b.weight) 37 | return 38 | if weight.ndim >= 2: 39 | assert weight.shape[2:].numel() == 1, "LinearLoRAHook only supports 2D input tensor" 40 | weight = weight.view(weight.shape[0], -1) 41 | device, dtype = weight.device, weight.dtype 42 | self.to(device=device, dtype=dtype) 43 | out_features, in_features = weight.shape 44 | assert self.in_features == in_features, "Input features size mismatch" 45 | assert self.out_features == out_features, "Output features size mismatch" 46 | if self.rank < 0: 47 | self.a.weight.data.copy_(weight) 48 | elif self.rank > 0: 49 | u, s, vh = torch.linalg.svd(weight.double()) 50 | # tensor: [oc, ic], u: [oc, oc], s: [oc], vh: [ic, ic] 51 | # us: [oc, rank], vh: [rank, ic] 52 | us = u[:, : self.rank] * s[: self.rank] 53 | vh = vh[: self.rank] 54 | assert not us.isnan().any(), "NaN in U * S" 55 | assert not vh.isnan().any(), "NaN in V^T" 56 | assert not us.isinf().any(), "Inf in U * S" 57 | assert not vh.isinf().any(), "Inf in V^T" 58 | self.a.weight.data.copy_(vh.to(dtype)) 59 | self.b.weight.data.copy_(us.to(dtype)) 60 | 61 | def get_effective_weight(self) -> torch.Tensor | None: 62 | if self.rank == 0: 63 | return None 64 | elif self.rank < 0: 65 | return self.a.weight 66 | else: 67 | return self.b.weight @ self.a.weight 68 | 69 | def forward(self, input: torch.Tensor) -> torch.Tensor | None: 70 | if self.a is None: 71 | return None 72 | else: 73 | if input.ndim <= 3: 74 | return self.alpha * self.b(self.a(input)) 75 | else: 76 | assert input.ndim == 4 77 | assert input.shape[-1] != self.in_features 78 | assert input.shape[1] == self.in_features 79 | # [B, C, H, W] -> [B, H, W, C] -> [B, H * W, C] 80 | B, C, H, W = input.shape 81 | input = input.permute(0, 2, 3, 1).reshape(B, H * W, C) 82 | output = self.alpha * self.b(self.a(input)) 83 | # [B, H * W, C] -> [B, H, W, C] -> [B, C, H, W] 84 | output = output.reshape(B, H, W, -1).permute(0, 3, 1, 2) 85 | return output 86 | 87 | def as_hook( 88 | self, 89 | input_packager: BaseInputPackager | None = None, 90 | output_packager: BaseOutputPackager | None = None, 91 | ) -> AccumBranchHook: 92 | """Wrap the module as a branch hook. 93 | 94 | Args: 95 | input_packager (`BaseInputPackager` or `None`, *optional*, defaults to `None`): 96 | Input packager. 97 | output_packager (`BaseOutputPackager` or `None`, *optional*, defaults to `None`): 98 | Output packager. 99 | Returns: 100 | `AccumBranchHook`: 101 | The branch hook. 102 | """ 103 | return AccumBranchHook(self, input_packager=input_packager, output_packager=output_packager) 104 | -------------------------------------------------------------------------------- /deepcompressor/nn/patch/sdpa.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Sparse attention module.""" 3 | 4 | import typing as tp 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | __all__ = ["ScaleDotProductAttention"] 11 | 12 | 13 | class ScaleDotProductAttention(nn.Module): 14 | def forward( 15 | self, 16 | query: torch.Tensor, 17 | key: torch.Tensor, 18 | value: torch.Tensor, 19 | attn_mask: tp.Optional[torch.Tensor] = None, 20 | dropout_p: float = 0.0, 21 | is_causal: bool = False, 22 | scale: tp.Optional[float] = None, 23 | ) -> torch.Tensor: 24 | return F.scaled_dot_product_attention( 25 | query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale 26 | ) 27 | -------------------------------------------------------------------------------- /deepcompressor/nn/struct/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .attn import * 4 | from .base import * 5 | -------------------------------------------------------------------------------- /deepcompressor/quantizer/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .processor import Quantizer 4 | -------------------------------------------------------------------------------- /deepcompressor/quantizer/config/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .base import BaseQuantizerConfig, DecomposedQuantizerConfig, ProgressiveQuantizerConfig, QuantizerConfig 4 | from .kernel import BaseKeyEnableQuantKernelConfig, BaseQuantKernelConfig 5 | from .lowrank import QuantLowRankConfig 6 | -------------------------------------------------------------------------------- /deepcompressor/quantizer/config/lowrank.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from dataclasses import dataclass 4 | 5 | from omniconfig import configclass 6 | 7 | from ...utils.common import num2str 8 | from ...utils.config import EnableConfig 9 | 10 | __all__ = ["QuantLowRankConfig"] 11 | 12 | 13 | @configclass 14 | @dataclass 15 | class QuantLowRankConfig(EnableConfig): 16 | """Quantization low-rank branch configuration. 17 | 18 | Args: 19 | rank (`int`, *optional*, defaults to `32`): 20 | The rank of the low-rank branch. 21 | exclusive (`bool`, *optional*, defaults to `False`): 22 | Whether to use exclusive low-rank branch for each weight sharing the inputs. 23 | compensate (`bool`, *optional*, defaults to `False`): 24 | Whether the low-rank branch compensates the quantization error. 25 | """ 26 | 27 | rank: int = 32 28 | exclusive: bool = False 29 | compensate: bool = False 30 | 31 | def is_enabled(self) -> bool: 32 | return self.rank != 0 33 | 34 | def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]: 35 | """Generate the directory names of the configuration. 36 | 37 | Returns: 38 | list[str]: The directory names. 39 | """ 40 | if not self.is_enabled(): 41 | return [] 42 | name = f"r{num2str(self.rank)}" 43 | if self.exclusive: 44 | name += ".exclusive" 45 | if self.compensate: 46 | name += ".compensate" 47 | return [f"{prefix}.{name}" if prefix else name] 48 | -------------------------------------------------------------------------------- /deepcompressor/quantizer/impl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/deepcompressor/69f3473f5e1c1504bae35cc50c7858ef900a9b17/deepcompressor/quantizer/impl/__init__.py -------------------------------------------------------------------------------- /deepcompressor/quantizer/impl/simple.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Simple quantization functions.""" 3 | 4 | import torch 5 | 6 | from ...data.dtype import QuantDataType 7 | from ...data.range import LogQuantRange, QuantRange 8 | from .ste import ste 9 | 10 | __all__ = ["simple_quantize"] 11 | 12 | 13 | def simple_quantize( 14 | tensor: torch.Tensor, 15 | *, 16 | quant_dtype: torch.dtype | QuantDataType, 17 | has_zero_point: bool, 18 | quant_range: QuantRange | None = None, 19 | round_delta: torch.Tensor | None = None, 20 | ) -> torch.Tensor: 21 | """Simple quantization function.""" 22 | requires_grad = tensor.requires_grad 23 | if isinstance(quant_dtype, torch.dtype): 24 | dtype = tensor.dtype 25 | tensor = tensor.to(dtype=quant_dtype).to(dtype=dtype) 26 | if round_delta is not None: 27 | tensor = tensor.add_(round_delta) 28 | if quant_range is not None and quant_range.is_set(): 29 | tensor = torch.clamp(tensor, min=quant_range.min, max=quant_range.max) 30 | return tensor 31 | elif isinstance(quant_dtype, QuantDataType): 32 | if quant_dtype.is_exponent: 33 | assert round_delta is None, "round_delta is not supported for exponential quantization" 34 | quant_range = LogQuantRange.construct(quant_dtype, quant_range) 35 | tensor = ste(tensor.log2(), torch.floor) if requires_grad else tensor.log2_().floor_() 36 | return tensor.clamp_(min=quant_range.min, max=quant_range.max).exp2_() 37 | elif quant_dtype.is_float_point: 38 | assert round_delta is None, "round_delta is not supported for float quantization" 39 | tensor = torch.clamp(tensor, min=quant_dtype.min_value, max=quant_dtype.max_value) 40 | tensor = ste(tensor, quant_dtype.round) 41 | if quant_range is not None and quant_range.is_set(): 42 | tensor = tensor.clamp_(min=quant_range.min, max=quant_range.max) 43 | return tensor 44 | else: 45 | quant_range = QuantRange.construct(quant_dtype, has_zero_point=has_zero_point, quant_range=quant_range) 46 | if round_delta is None: 47 | tensor = ste(tensor, torch.round) if requires_grad else tensor.round_() 48 | else: 49 | tensor = ste(tensor, torch.floor) if requires_grad else tensor.floor_() 50 | tensor = tensor.add_(round_delta) 51 | return tensor.clamp_(min=quant_range.min, max=quant_range.max) 52 | else: 53 | raise TypeError( 54 | f"quant_dtype must be either torch.dtype or QuantDataType, got {quant_dtype} ({type(quant_dtype)})" 55 | ) 56 | -------------------------------------------------------------------------------- /deepcompressor/quantizer/impl/ste.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Simple quantization functions.""" 3 | 4 | import typing as tp 5 | 6 | import torch 7 | 8 | __all__ = ["ste"] 9 | 10 | 11 | class STEFunction(torch.autograd.Function): 12 | """STEFunction for quantization.""" 13 | 14 | @staticmethod 15 | def forward(ctx: tp.Any, tensor: torch.Tensor, fn: tp.Callable[[torch.Tensor], torch.Tensor]) -> torch.Tensor: 16 | """Forward pass for DtypeSTEFunction.""" 17 | return fn(tensor) 18 | 19 | @staticmethod 20 | def backward(ctx: tp.Any, grad_output: torch.Tensor) -> tp.Tuple[torch.Tensor, None]: 21 | """Backward pass for DtypeSTEFunction.""" 22 | return grad_output, None 23 | 24 | 25 | def ste(tensor: torch.Tensor, fn: tp.Callable[[torch.Tensor], torch.Tensor]) -> torch.Tensor: 26 | """STE function.""" 27 | return STEFunction.apply(tensor, fn) # type: ignore 28 | -------------------------------------------------------------------------------- /deepcompressor/quantizer/kernel/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .gptq import QuantGptqConfig, QuantGptqKernel, gptq_quantize 4 | from .rtn import QuantRtnKernel, rtn_quantize 5 | -------------------------------------------------------------------------------- /deepcompressor/quantizer/kernel/rtn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Round-to-nearest (RTN) quantization module.""" 3 | 4 | import torch 5 | 6 | from ...data.dtype import QuantDataType 7 | from ...data.range import QuantRange 8 | from ...data.zero import ZeroPointDomain 9 | from ..config.kernel import BaseQuantKernel 10 | from ..impl.simple import simple_quantize 11 | 12 | __all__ = ["QuantRtnKernel", "rtn_quantize"] 13 | 14 | 15 | class QuantRtnKernel(BaseQuantKernel): 16 | """Round-to-nearest (RTN) Quantization kernel.""" 17 | 18 | def quantize( 19 | self, 20 | tensor: torch.Tensor, 21 | *, 22 | view_shape: torch.Size, 23 | quant_dtype: QuantDataType, 24 | zero_domain: ZeroPointDomain | None, 25 | scale: torch.Tensor, 26 | zero: torch.Tensor, 27 | quant_range: QuantRange | None = None, 28 | round_delta: torch.Tensor | None = None, 29 | **kwargs, 30 | ) -> torch.Tensor: 31 | """Quantize the tensor. 32 | 33 | Args: 34 | tensor (`torch.Tensor`): 35 | The tensor to quantize. 36 | view_shape (`torch.Size`): 37 | The view shape when quantizing the tensor. 38 | quant_dtype (`QuantDataType`): 39 | The quantization data type. 40 | zero_domain (`ZeroPointDomain` or `None`): 41 | The zero point domain. 42 | scale (`torch.Tensor`): 43 | The scale tensor. 44 | zero (`torch.Tensor`): 45 | The zero point tensor. 46 | quant_range (`QuantRange` or `None`, *optional*, defaults to `None`): 47 | The quantization range. 48 | round_delta (`torch.Tensor` or `None`, *optional*, defaults to `None`): 49 | The rounding delta. 50 | **kwargs: Other keyword arguments. 51 | 52 | Returns: 53 | `torch.Tensor`: 54 | The quantized tensor in the shape of ``view_shape``. 55 | """ 56 | return rtn_quantize( 57 | tensor, 58 | view_shape=view_shape, 59 | quant_dtype=quant_dtype, 60 | zero_domain=zero_domain, 61 | scale=scale, 62 | zero=zero, 63 | quant_range=quant_range, 64 | round_delta=round_delta, 65 | ) 66 | 67 | 68 | def rtn_quantize( 69 | tensor: torch.Tensor, 70 | *, 71 | view_shape: torch.Size, 72 | quant_dtype: QuantDataType, 73 | zero_domain: ZeroPointDomain | None, 74 | scale: torch.Tensor, 75 | zero: torch.Tensor, 76 | quant_range: QuantRange | None = None, 77 | round_delta: torch.Tensor | None = None, 78 | ) -> torch.Tensor: 79 | """Quantize the tensor using the RTN quantization kernel. 80 | 81 | Args: 82 | tensor (`torch.Tensor`): 83 | The tensor to quantize. 84 | view_shape (`torch.Size`): 85 | The view shape when quantizing the tensor. 86 | quant_dtype (`QuantDataType`): 87 | The quantization data type. 88 | zero_domain (`ZeroPointDomain` or `None`): 89 | The zero point domain. 90 | scale (`torch.Tensor`): 91 | The scale tensor. 92 | zero (`torch.Tensor`): 93 | The zero point tensor. 94 | quant_range (`QuantRange` or `None`, *optional*, defaults to `None`): 95 | The quantization range. 96 | round_delta (`torch.Tensor` or `None`, *optional*, defaults to `None`): 97 | The rounding delta. 98 | 99 | Returns: 100 | `torch.Tensor`: 101 | The quantized tensor in the shape of ``view_shape``. 102 | """ 103 | qtensor = tensor.view(view_shape) 104 | round_delta = round_delta.view(view_shape) if round_delta is not None else None 105 | if zero_domain == ZeroPointDomain.PostScale: 106 | qtensor = qtensor.add_(zero) 107 | qtensor = qtensor.div(scale) 108 | if zero_domain == ZeroPointDomain.PreScale: 109 | qtensor = qtensor.add_(zero) 110 | qtensor = simple_quantize( 111 | qtensor, 112 | quant_dtype=quant_dtype, 113 | has_zero_point=zero_domain is not None, 114 | quant_range=quant_range, 115 | round_delta=round_delta, 116 | ) 117 | return qtensor 118 | -------------------------------------------------------------------------------- /deepcompressor/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .common import * 4 | from .patch import * 5 | -------------------------------------------------------------------------------- /deepcompressor/utils/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import EnableConfig, IncludeBasedConfig, KeyEnableConfig, SkipBasedConfig 2 | -------------------------------------------------------------------------------- /deepcompressor/utils/config/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Net configurations.""" 3 | 4 | import os 5 | import typing as tp 6 | from abc import ABC, abstractmethod 7 | from dataclasses import dataclass 8 | 9 | from omniconfig import configclass 10 | 11 | __all__ = ["BaseModelConfig"] 12 | 13 | 14 | @configclass 15 | @dataclass 16 | class BaseModelConfig(ABC): 17 | """Base class for all model configs. 18 | 19 | Args: 20 | name (`str`): 21 | Name of the model. 22 | family (`str`, *optional*, defaults to `""`): 23 | Family of the model. If not specified, it will be inferred from the name. 24 | path (`str`, *optional*, defaults to `""`): 25 | Path of the model. 26 | root (`str`, *optional*, defaults to `""`): 27 | Root directory path for models. 28 | local_path (`str`, *optional*, defaults to `""`): 29 | Local path of the model. 30 | local_root (`str`, *optional*, defaults to `""`): 31 | Local root directory path for models. 32 | """ 33 | 34 | name: str 35 | family: str = "" 36 | path: str = "" 37 | root: str = "" 38 | local_path: str = "" 39 | local_root: str = "" 40 | 41 | def __post_init__(self): 42 | if not self.family: 43 | self.family = self.name.split("-")[0] 44 | self.local_root = os.path.expanduser(self.local_root) 45 | if not self.local_path: 46 | self.local_path = os.path.join(self.local_root, self.family, self.name) 47 | if not self.path: 48 | self.path = os.path.join(self.root, self.family, self.name) 49 | if os.path.exists(self.local_path): 50 | self.path = self.local_path 51 | 52 | @abstractmethod 53 | def build(self, *args, **kwargs) -> tp.Any: 54 | """Build model from config.""" 55 | ... 56 | -------------------------------------------------------------------------------- /deepcompressor/utils/config/output.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Output configuration.""" 3 | 4 | import os 5 | from dataclasses import dataclass, field 6 | from datetime import datetime as DateTime 7 | 8 | from omniconfig import configclass 9 | 10 | __all__ = ["OutputConfig"] 11 | 12 | 13 | @configclass 14 | @dataclass 15 | class OutputConfig: 16 | """Output configuration. 17 | 18 | Args: 19 | root (`str`, *optional*, defaults to `"runs"`): 20 | The output root directory. 21 | dirname (`str`, *optional*, defaults to `"default"`): 22 | The output directory name. 23 | job (`str`, *optional*, defaults to `"run"`): 24 | The job name. 25 | 26 | Attributes: 27 | dirpath (`str`): 28 | The output directory path. 29 | timestamp (`str`): 30 | The timestamp. 31 | """ 32 | 33 | root: str = "runs" 34 | dirname: str = "default" 35 | job: str = "run" 36 | dirpath: str = field(init=False) 37 | timestamp: str = field(init=False) 38 | 39 | def __post_init__(self): 40 | self.timestamp = self.generate_timestamp() 41 | self.dirpath = os.path.join(self.root, self.dirname) 42 | 43 | @property 44 | def running_dirpath(self) -> str: 45 | """Get the running directory path.""" 46 | return f"{self.dirpath}.RUNNING" 47 | 48 | @property 49 | def error_dirpath(self) -> str: 50 | """Get the error directory path.""" 51 | return f"{self.dirpath}.ERROR" 52 | 53 | @property 54 | def job_dirname(self) -> str: 55 | """Get the job directory name.""" 56 | return f"{self.job}-{self.timestamp}" 57 | 58 | @property 59 | def job_dirpath(self) -> str: 60 | """Get the job directory path.""" 61 | return os.path.join(self.dirpath, self.job_dirname) 62 | 63 | @property 64 | def running_job_dirname(self) -> str: 65 | """Get the running job directory name.""" 66 | return f"{self.job_dirname}.RUNNING" 67 | 68 | @property 69 | def error_job_dirname(self) -> str: 70 | """Get the error job directory name.""" 71 | return f"{self.job_dirname}.ERROR" 72 | 73 | @property 74 | def running_job_dirpath(self) -> str: 75 | """Get the running job directory path.""" 76 | return os.path.join(self.running_dirpath, self.running_job_dirname) 77 | 78 | def lock(self) -> None: 79 | """Lock the running (job) directory.""" 80 | try: 81 | if os.path.exists(self.dirpath): 82 | os.rename(self.dirpath, self.running_dirpath) 83 | elif os.path.exists(self.error_dirpath): 84 | os.rename(self.error_dirpath, self.running_dirpath) 85 | except Exception: 86 | pass 87 | os.makedirs(self.running_job_dirpath, exist_ok=True) 88 | 89 | def unlock(self, error: bool = False) -> None: 90 | """Unlock the running (job) directory.""" 91 | job_dirpath = os.path.join(self.running_dirpath, self.error_job_dirname if error else self.job_dirname) 92 | os.rename(self.running_job_dirpath, job_dirpath) 93 | if not self.is_locked_by_others(): 94 | os.rename(self.running_dirpath, self.error_dirpath if error else self.dirpath) 95 | 96 | def is_locked_by_others(self) -> bool: 97 | """Check if the running directory is locked by others.""" 98 | running_job_dirname = self.running_job_dirname 99 | for dirname in os.listdir(self.running_dirpath): 100 | if dirname.endswith(".RUNNING") and dirname != running_job_dirname: 101 | return True 102 | return False 103 | 104 | def get_running_path(self, filename: str) -> str: 105 | """Get the file path in the running directory.""" 106 | name, ext = os.path.splitext(filename) 107 | return os.path.join(self.running_dirpath, f"{name}-{self.timestamp}{ext}") 108 | 109 | def get_running_job_path(self, filename: str) -> str: 110 | """Get the file path in the running job directory.""" 111 | name, ext = os.path.splitext(filename) 112 | return os.path.join(self.running_job_dirpath, f"{name}-{self.timestamp}{ext}") 113 | 114 | @staticmethod 115 | def generate_timestamp() -> str: 116 | """Generate a timestamp.""" 117 | return DateTime.now().strftime("%y%m%d.%H%M%S") 118 | -------------------------------------------------------------------------------- /deepcompressor/utils/config/path.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Path configuration.""" 3 | 4 | import os 5 | import typing as tp 6 | 7 | from ..dataclass import get_fields 8 | 9 | __all__ = ["BasePathConfig"] 10 | 11 | 12 | class BasePathConfig: 13 | """Base path configuration.""" 14 | 15 | def is_all_set(self) -> bool: 16 | """Check if the path configuration is all set. 17 | 18 | Returns: 19 | `bool`: 20 | Whether the path configuration is all set. 21 | """ 22 | fields = get_fields(self) 23 | for f in fields: 24 | if not getattr(self, f.name): 25 | return False 26 | return True 27 | 28 | def is_all_empty(self) -> bool: 29 | """Check if the path configuration is all empty. 30 | 31 | Returns: 32 | `bool`: 33 | Whether the path configuration is all empty. 34 | """ 35 | fields = get_fields(self) 36 | for f in fields: 37 | if getattr(self, f.name): 38 | return False 39 | return True 40 | 41 | def clone(self) -> tp.Self: 42 | """Clone the path configuration. 43 | 44 | Returns: 45 | `Self`: 46 | The cloned path configuration. 47 | """ 48 | fields = get_fields(self) 49 | return self.__class__(**{f.name: getattr(self, f.name) for f in fields}) 50 | 51 | def add_parent_dirs(self, *parent_dirs: str) -> tp.Self: 52 | """Add the parent directories to the paths. 53 | 54 | Args: 55 | parent_dirs (`str`): 56 | The parent directories. 57 | """ 58 | fields = get_fields(self) 59 | for f in fields: 60 | path = getattr(self, f.name) 61 | if path: 62 | setattr(self, f.name, os.path.join(*parent_dirs, path)) 63 | return self 64 | 65 | def add_children(self, *children: str) -> tp.Self: 66 | """Add the children to the paths. 67 | 68 | Args: 69 | children (`str`): 70 | The children paths. 71 | """ 72 | fields = get_fields(self) 73 | for f in fields: 74 | path = getattr(self, f.name) 75 | if path: 76 | setattr(self, f.name, os.path.join(path, *children)) 77 | return self 78 | 79 | def to_dirpath(self) -> tp.Self: 80 | """Convert the paths to directory paths.""" 81 | fields = get_fields(self) 82 | for f in fields: 83 | path = getattr(self, f.name) 84 | if path: 85 | setattr(self, f.name, os.path.dirname(path)) 86 | return self 87 | 88 | def apply(self, fn: tp.Callable) -> tp.Self: 89 | """Apply the function to the paths. 90 | 91 | Args: 92 | fn (`Callable`): 93 | The function to apply. 94 | """ 95 | fields = get_fields(self) 96 | for f in fields: 97 | path = getattr(self, f.name) 98 | if path: 99 | setattr(self, f.name, fn(path)) 100 | return self 101 | -------------------------------------------------------------------------------- /deepcompressor/utils/dataclass.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Dataclass utilities.""" 3 | 4 | from dataclasses import _FIELD, _FIELD_CLASSVAR, _FIELD_INITVAR, _FIELDS, Field 5 | 6 | __all__ = ["get_fields"] 7 | 8 | 9 | def get_fields(class_or_instance, *, init_vars: bool = False, class_vars: bool = False) -> tuple[Field, ...]: 10 | """Get the fields of the dataclass. 11 | 12 | Args: 13 | class_or_instance: 14 | The dataclass type or instance. 15 | init_vars (`bool`, *optional*, defaults to `False`): 16 | Whether to include the init vars. 17 | class_vars (`bool`, *optional*, defaults to `False`): 18 | Whether to include the class vars. 19 | 20 | Returns: 21 | tuple[Field, ...]: The fields. 22 | """ 23 | try: 24 | fields = getattr(class_or_instance, _FIELDS) 25 | except AttributeError: 26 | raise TypeError("must be called with a dataclass type or instance") from None 27 | return tuple( 28 | v 29 | for v in fields.values() 30 | if v._field_type is _FIELD 31 | or (init_vars and v._field_type is _FIELD_INITVAR) 32 | or (class_vars and v._field_type is _FIELD_CLASSVAR) 33 | ) 34 | -------------------------------------------------------------------------------- /deepcompressor/utils/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | from .branch import AccumBranchHook 2 | from .hook import EarlyStopException, EarlyStopHook, Hook, IOHook 3 | from .packager import ( 4 | BaseInputPackager, 5 | BaseOutputPackager, 6 | KeyedInputPackager, 7 | KeyedOutputPackager, 8 | SimpleInputPackager, 9 | SimpleOutputPackager, 10 | ) 11 | from .processor import BaseTensorProcessor, ProcessHook 12 | -------------------------------------------------------------------------------- /deepcompressor/utils/hooks/branch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Branch hook module.""" 3 | 4 | import typing as tp 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from .hook import IOHook 10 | from .packager import BaseInputPackager, BaseOutputPackager 11 | 12 | __all__ = ["AccumBranchHook"] 13 | 14 | 15 | class AccumBranchHook(IOHook): 16 | branch: nn.Module | None 17 | 18 | def __init__( 19 | self, 20 | branch: nn.Module | None, 21 | input_packager: BaseInputPackager | None = None, 22 | output_packager: BaseOutputPackager | None = None, 23 | ): 24 | super().__init__(pre=True, post=True, input_packager=input_packager, output_packager=output_packager) 25 | self.branch = branch 26 | self.tensor = None 27 | 28 | def pre_forward( 29 | self, module: nn.Module, input_args: tuple[torch.Tensor, ...], input_kwargs: dict[str, tp.Any] 30 | ) -> None: 31 | """Pre-forward function. 32 | 33 | Args: 34 | module (nn.Module): Module. 35 | input_args (tuple[torch.Tensor, ...]): Input arguments. 36 | input_kwargs (dict[str, tp.Any]): Input keyword arguments. 37 | """ 38 | tensors = self.input_packager.unpack(module, input_args, input_kwargs) 39 | assert len(tensors) == 1, "BranchHook only supports single input tensor" 40 | self.tensor = next(iter(tensors.values())) 41 | return None 42 | 43 | def post_forward( 44 | self, 45 | module: nn.Module, 46 | input_args: tuple[torch.Tensor, ...], 47 | input_kwargs: dict[str, tp.Any], 48 | output: tuple[torch.Tensor, ...], 49 | ) -> tp.Any: 50 | """Post-forward function. 51 | 52 | Args: 53 | module (nn.Module): Module. 54 | input_args (tuple[torch.Tensor, ...]): Input arguments. 55 | input_kwargs (dict[str, tp.Any]): Input keyword arguments. 56 | output (tuple[torch.Tensor, ...]): Output. 57 | """ 58 | output_tensors = self.output_packager.unpack(module, input_args, input_kwargs, output) 59 | assert len(output_tensors) == 1, "LoRAHook only supports single output tensor" 60 | output_key, output_tensor = next(iter(output_tensors.items())) 61 | if self.branch is not None: 62 | output_tensor = output_tensor + self.branch(self.tensor) 63 | self.tensor = None 64 | return self.output_packager.repack({output_key: output_tensor}, module, input_args, input_kwargs, output) 65 | -------------------------------------------------------------------------------- /deepcompressor/utils/hooks/processor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Tensor processor.""" 3 | 4 | import abc 5 | import typing as tp 6 | 7 | import torch 8 | import torch.ao.quantization 9 | import torch.nn as nn 10 | import torch.utils.hooks 11 | 12 | from .hook import IOHook 13 | from .packager import BaseInputPackager, BaseOutputPackager 14 | 15 | __all__ = ["BaseTensorProcessor", "ProcessHook"] 16 | 17 | 18 | class BaseTensorProcessor(abc.ABC): 19 | @abc.abstractmethod 20 | def is_enabled(self) -> bool: ... 21 | 22 | @abc.abstractmethod 23 | def get_input_packager(self) -> BaseInputPackager | None: ... 24 | 25 | @abc.abstractmethod 26 | def get_output_packager(self) -> BaseOutputPackager | None: ... 27 | 28 | @abc.abstractmethod 29 | def process(self, tensor: torch.Tensor) -> torch.Tensor: ... 30 | 31 | def as_hook( 32 | self, func: tp.Callable[[torch.Tensor], torch.Tensor] | None = None, *, is_output: bool = False 33 | ) -> "ProcessHook": 34 | """Convert the processor to a hook. 35 | 36 | Args: 37 | func (`Callable[[torch.Tensor], torch.Tensor]` or `None`, *optional*, defaults to `None`): 38 | Function to process the tensors. 39 | is_output (`bool`, *optional*, defaults to `False`): 40 | Whether to process the output tensors. 41 | 42 | Returns: 43 | `ProcessHook`: 44 | The hook for processing the tensor. 45 | """ 46 | return ProcessHook(self, func, is_output=is_output) 47 | 48 | 49 | class ProcessHook(IOHook): 50 | def __init__( 51 | self, 52 | processor: BaseTensorProcessor, 53 | func: tp.Callable[[torch.Tensor], torch.Tensor] | None = None, 54 | is_output: bool = False, 55 | ): 56 | super().__init__( 57 | pre=not is_output, 58 | post=is_output, 59 | input_packager=processor.get_input_packager(), 60 | output_packager=processor.get_output_packager(), 61 | ) 62 | self.processor = processor 63 | self.func = func 64 | 65 | def process(self, tensors: dict[int | str, torch.Tensor]) -> dict[int | str, torch.Tensor]: 66 | for k, x in tensors.items(): 67 | assert isinstance(x, torch.Tensor) 68 | if self.func is not None: 69 | tensors[k] = self.func(x) 70 | else: 71 | tensors[k] = self.processor.process(x) 72 | return tensors 73 | 74 | def pre_forward( 75 | self, module: nn.Module, input_args: tuple[torch.Tensor, ...], input_kwargs: dict[str, tp.Any] 76 | ) -> tuple[tuple[torch.Tensor, ...], dict[str, tp.Any]]: 77 | if not self.processor.is_enabled(): 78 | return input_args, input_kwargs 79 | return self.input_packager.repack( 80 | self.process(self.input_packager.unpack(module, input_args, input_kwargs)), module, input_args, input_kwargs 81 | ) 82 | 83 | def post_forward( 84 | self, 85 | module: nn.Module, 86 | input_args: tuple[torch.Tensor, ...], 87 | input_kwargs: dict[str, tp.Any], 88 | output: tuple[torch.Tensor, ...], 89 | ) -> tp.Any: 90 | if not self.processor.is_enabled(): 91 | return output 92 | return self.output_packager.repack( 93 | self.process(self.output_packager.unpack(module, input_args, input_kwargs, output)), 94 | module, 95 | input_args, 96 | input_kwargs, 97 | output, 98 | ) 99 | -------------------------------------------------------------------------------- /deepcompressor/utils/math/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .functional import * 4 | from .hadamard import * 5 | -------------------------------------------------------------------------------- /deepcompressor/utils/math/functional.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Math utility functions.""" 3 | 4 | import torch 5 | 6 | __all__ = ["is_pow2", "root_"] 7 | 8 | 9 | def is_pow2(n: int) -> bool: 10 | """Check if a number is a power of 2. 11 | 12 | Args: 13 | n (`int`): 14 | The number to check. 15 | 16 | Returns: 17 | `bool`: 18 | Whether the number is a power of 2. 19 | """ 20 | return (n & (n - 1) == 0) and (n > 0) 21 | 22 | 23 | def root_(y: torch.Tensor, index: float) -> torch.Tensor: 24 | """In-place compute the root of a tensor element-wise. 25 | 26 | Args: 27 | y (`torch.Tensor`): 28 | The input tensor. 29 | index (`float`): 30 | The root index. 31 | 32 | Returns: 33 | `torch.Tensor`: 34 | The output tensor. 35 | """ 36 | return y.pow_(1 / index) if index != 2 else y.sqrt_() 37 | -------------------------------------------------------------------------------- /deepcompressor/utils/patch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Monkey-patching utilities.""" 3 | 4 | import copy 5 | import functools 6 | import types 7 | import typing 8 | 9 | import torch.nn as nn 10 | 11 | __all__ = ["copy_func", "get_module_parents_map"] 12 | 13 | 14 | def copy_func(f: types.FunctionType, globals: dict[str, typing.Any] | None = None): 15 | """Copied from https://stackoverflow.com/a/13503277/2988730 (@unutbu) 16 | and https://github.com/spcl/QuaRot/blob/main/fake_quant/monkeypatch.py. 17 | 18 | Copy a function. 19 | 20 | Args: 21 | f (`types.FunctionType`): 22 | Function to be copied. 23 | globals (`dict[str, typing.Any]` or `None`, *optional*, defaults to `None`): 24 | Globals. 25 | 26 | Returns: 27 | `types.FunctionType`: 28 | Copied function. 29 | """ 30 | if globals is None: 31 | globals = f.__globals__ 32 | g = types.FunctionType(f.__code__, globals, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__) 33 | g = functools.update_wrapper(g, f) 34 | g.__module__ = f.__module__ 35 | g.__kwdefaults__ = copy.copy(f.__kwdefaults__) # type: ignore 36 | return g 37 | 38 | 39 | def get_module_parents_map( 40 | module: nn.Module, name: str = "", parents_map: dict[nn.Module, list[tuple[str, nn.Module, str]]] | None = None 41 | ) -> dict[nn.Module, list[tuple[str, nn.Module, str]]]: 42 | """Get module parents map. 43 | 44 | Args: 45 | module (`nn.Module`): 46 | Module. 47 | name (`str`, *optional*, defaults to `""`): 48 | Name. 49 | parents_map (`dict[nn.Module, list[tuple[str, nn.Module, str]]]`, *optional*, defaults to `None`): 50 | Parents map. 51 | 52 | Returns: 53 | `dict[nn.Module, list[tuple[str, nn.Module, str]]]`: 54 | Module parents map. The key is the child module and the value is a list of tuples. 55 | Each tuple contains the name of the parent module, the parent module, 56 | and the child module name in the parent module. 57 | """ 58 | if parents_map is None: 59 | parents_map = {} 60 | for child_name, child_module in module._modules.items(): 61 | if child_module is None: 62 | continue 63 | parents_map.setdefault(child_module, []).append((name, module, child_name)) 64 | get_module_parents_map(child_module, f"{name}.{child_name}" if name else child_name, parents_map) 65 | return parents_map 66 | -------------------------------------------------------------------------------- /deepcompressor/utils/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from . import logging, sys 4 | -------------------------------------------------------------------------------- /deepcompressor/utils/tools/sys.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """System tools.""" 3 | 4 | import psutil 5 | import torch 6 | 7 | __all__ = ["get_max_memory_map"] 8 | 9 | 10 | def _get_visible_gpu_capacity_list() -> list[int]: 11 | """Get visible GPU capacity list. 12 | 13 | Returns: 14 | `list[int]`: Visible GPU capacity list. 15 | """ 16 | return [torch.cuda.get_device_properties(i).total_memory // 1024**3 for i in range(torch.cuda.device_count())] 17 | 18 | 19 | def _get_ram_capacity() -> int: 20 | """Get RAM capacity. 21 | 22 | Returns: 23 | `int`: RAM capacity in GiB. 24 | """ 25 | return psutil.virtual_memory().total // 1024**3 # in GiB 26 | 27 | 28 | def get_max_memory_map(ratio: float = 0.9) -> dict[str, str]: 29 | """Get maximum memory map. 30 | 31 | Args: 32 | ratio (`float`, *optional*, defaults to `0.9`): The ratio of the maximum memory to use. 33 | 34 | Returns: 35 | `dict[str, str]`: Maximum memory map. 36 | """ 37 | gpu_capacity_list = _get_visible_gpu_capacity_list() 38 | ram_capacity = _get_ram_capacity() 39 | gpu_capacity_list = [str(int(c * ratio)) + "GiB" for c in gpu_capacity_list] 40 | ram_capacity = str(int(ram_capacity * ratio)) + "GiB" 41 | ret_dict = {str(idx): gpu_capacity_list[idx] for idx in range(len(gpu_capacity_list))} 42 | ret_dict["cpu"] = ram_capacity 43 | return ret_dict 44 | -------------------------------------------------------------------------------- /deepcompressor/version.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Version information.""" 3 | 4 | __version__ = "0.0.2" 5 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - defaults 3 | dependencies: 4 | - python=3.12 5 | - pip 6 | - pip: 7 | - poetry 8 | -------------------------------------------------------------------------------- /examples/diffusion/.gitignore: -------------------------------------------------------------------------------- 1 | .tmp 2 | .tmp/ 3 | baselines 4 | baselines/ 5 | benchmarks 6 | benchmarks/ 7 | caches 8 | caches/ 9 | datasets 10 | datasets/ 11 | visualize/runs 12 | visualize/runs/ 13 | *.pdf 14 | -------------------------------------------------------------------------------- /examples/diffusion/configs/__default__.yaml: -------------------------------------------------------------------------------- 1 | seed: 12345 2 | enable_cache: true 3 | cache: 4 | root: runs 5 | output: 6 | root: runs 7 | dirname: default 8 | pipeline: 9 | dtype: torch.float16 10 | device: cuda 11 | shift_activations: false 12 | eval: 13 | num_samples: 5000 14 | height: null 15 | width: null 16 | guidance_scale: null 17 | num_steps: null 18 | gt_metrics: ["clip_iqa", "clip_score", "image_reward", "fid"] 19 | ref_metrics: ["psnr", "lpips", "ssim", "fid"] 20 | gen_root: "{output}/{job}" 21 | ref_root: baselines/{dtype}/{model}/{protocol} 22 | gt_stats_root: benchmarks/stats 23 | num_gpus: 8 24 | batch_size_per_gpu: 1 25 | chunk_start: 0 26 | chunk_step: 1 27 | benchmarks: 28 | - "MJHQ" 29 | - "DCI" 30 | control_root: "benchmarks" 31 | quant: 32 | calib: 33 | data: qdiff 34 | path: datasets/{dtype}/{model}/{protocol}/{data}/s128 35 | num_samples: 128 36 | num_workers: 8 37 | wgts: 38 | dtype: null 39 | zero_point: null 40 | group_shapes: 41 | - - 1 42 | - -1 43 | scale_dtypes: 44 | - null 45 | skips: [] 46 | enable_calib_range: true 47 | calib_range: 48 | degree: 2 49 | objective: OutputsError 50 | strategy: Manual 51 | granularity: Layer 52 | element_batch_size: 64 53 | sample_batch_size: 64 54 | element_size: 512 55 | sample_size: -1 56 | ratio: 1.0 57 | max_shrink: 0.2 58 | max_expand: 1.0 59 | num_grids: 80 60 | skips: [] 61 | low_rank: 62 | rank: 32 63 | exclusive: false 64 | compensate: false 65 | early_stop: false 66 | degree: 2 67 | objective: OutputsError 68 | sample_batch_size: 64 69 | sample_size: -1 70 | num_iters: 1 71 | skips: [] 72 | ipts: 73 | static: false 74 | dtype: null 75 | zero_point: null 76 | group_shapes: 77 | - - 1 78 | - -1 79 | scale_dtypes: 80 | - null 81 | allow_unsigned: false 82 | skips: [] 83 | enable_calib_range: false 84 | calib_range: 85 | degree: 2 86 | objective: OutputsError 87 | strategy: Manual 88 | granularity: Layer 89 | element_batch_size: 64 90 | sample_batch_size: 64 91 | element_size: 512 92 | sample_size: -1 93 | ratio: 1.0 94 | max_shrink: 0.2 95 | max_expand: 1.0 96 | num_grids: 80 97 | skips: [] 98 | enable_smooth: false 99 | smooth: 100 | enable_proj: false 101 | proj: 102 | degree: 2 103 | objective: OutputsError 104 | strategy: Manual 105 | granularity: Layer 106 | element_batch_size: -1 107 | sample_batch_size: 64 108 | element_size: -1 109 | sample_size: -1 110 | pre_reshape: true 111 | outputs_device: cpu 112 | spans: 113 | - - AbsMax 114 | - AbsMax 115 | alpha: 0.5 116 | beta: -1 117 | num_grids: 20 118 | skips: [] 119 | develop_dtype: torch.float32 120 | -------------------------------------------------------------------------------- /examples/diffusion/configs/collect/qdiff.yaml: -------------------------------------------------------------------------------- 1 | collect: 2 | root: datasets 3 | dataset_name: qdiff 4 | data_path: prompts/qdiff.yaml 5 | num_samples: 128 -------------------------------------------------------------------------------- /examples/diffusion/configs/lora/__default__.yaml: -------------------------------------------------------------------------------- 1 | pipeline: 2 | enable_lora: true 3 | skip_eval: true 4 | -------------------------------------------------------------------------------- /examples/diffusion/configs/lora/flux.1-dev/anime.yaml: -------------------------------------------------------------------------------- 1 | # https://huggingface.co/alvdansen/sonny-anime-fixed 2 | # alvdansen/sonny-anime-fixed 3 | # separate, rank=16 4 | eval: 5 | benchmarks: 6 | - prompts/lora/anime.yaml 7 | num_steps: 28 8 | pipeline: 9 | lora: 10 | alpha: 1 11 | path: alvdansen/sonny-anime-fixed 12 | weight_name: araminta_k_sonnyanime_fluxd_fixed.safetensors 13 | output: 14 | job: anime-1.0 15 | -------------------------------------------------------------------------------- /examples/diffusion/configs/lora/flux.1-dev/ghibsky.yaml: -------------------------------------------------------------------------------- 1 | # https://huggingface.co/aleksa-codes/flux-ghibsky-illustration 2 | # aleksa-codes/flux-ghibsky-illustration 3 | # separate, rank=16 4 | eval: 5 | benchmarks: 6 | - prompts/lora/ghibsky.yaml 7 | num_steps: 28 8 | pipeline: 9 | lora: 10 | alpha: 1 11 | path: aleksa-codes/flux-ghibsky-illustration 12 | weight_name: lora.safetensors 13 | output: 14 | job: ghibsky-1.0 15 | -------------------------------------------------------------------------------- /examples/diffusion/configs/lora/flux.1-dev/realism.yaml: -------------------------------------------------------------------------------- 1 | # https://huggingface.co/XLabs-AI/flux-RealismLora 2 | # XLabs-AI/flux-RealismLora 3 | # qkv fused, rank=16, only joint blocks 4 | eval: 5 | benchmarks: 6 | - prompts/lora/realism.yaml 7 | num_steps: 25 8 | pipeline: 9 | lora: 10 | alpha: 0.9 11 | path: mit-han-lab/FLUX.1-dev-LoRA-Collections 12 | weight_name: realism.safetensors 13 | output: 14 | job: realism-0.9 -------------------------------------------------------------------------------- /examples/diffusion/configs/lora/flux.1-dev/sketch.yaml: -------------------------------------------------------------------------------- 1 | # https://huggingface.co/Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch/tree/main 2 | # Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch 3 | # pretrained/converted/drawing.safetensors 4 | # fused, rank=64 5 | eval: 6 | benchmarks: 7 | - prompts/lora/sketch.yaml 8 | num_steps: 24 9 | pipeline: 10 | lora: 11 | alpha: 1 12 | path: mit-han-lab/FLUX.1-dev-LoRA-Collections 13 | weight_name: sketch.safetensors 14 | output: 15 | job: sketch-1.0 -------------------------------------------------------------------------------- /examples/diffusion/configs/lora/flux.1-dev/yarn.yaml: -------------------------------------------------------------------------------- 1 | # https://huggingface.co/linoyts/yarn_art_Flux_LoRA 2 | # linoyts/yarn_art_Flux_LoRA 3 | # separate, rank=4, both joint and single blocks 4 | eval: 5 | benchmarks: 6 | - prompts/lora/yarn.yaml 7 | num_steps: 28 8 | pipeline: 9 | lora: 10 | alpha: 1 11 | path: linoyts/yarn_art_Flux_LoRA 12 | weight_name: pytorch_lora_weights.safetensors 13 | output: 14 | job: yarn-1.0 15 | -------------------------------------------------------------------------------- /examples/diffusion/configs/model/flux.1-dev.yaml: -------------------------------------------------------------------------------- 1 | pipeline: 2 | name: flux.1-dev 3 | dtype: torch.bfloat16 4 | eval: 5 | num_steps: 50 6 | guidance_scale: 3.5 7 | protocol: fmeuler{num_steps}-g{guidance_scale} 8 | quant: 9 | calib: 10 | batch_size: 16 11 | wgts: 12 | calib_range: 13 | element_batch_size: 64 14 | sample_batch_size: 16 15 | element_size: 512 16 | sample_size: -1 17 | low_rank: 18 | sample_batch_size: 16 19 | sample_size: -1 20 | skips: 21 | - embed 22 | - resblock_shortcut 23 | - resblock_time_proj 24 | - transformer_proj_in 25 | - transformer_proj_out 26 | - down_sample 27 | - up_sample 28 | ipts: 29 | calib_range: 30 | element_batch_size: 64 31 | sample_batch_size: 16 32 | element_size: 512 33 | sample_size: -1 34 | skips: 35 | - embed 36 | - resblock_shortcut 37 | - resblock_time_proj 38 | - transformer_proj_in 39 | - transformer_proj_out 40 | - transformer_norm 41 | - transformer_add_norm 42 | - down_sample 43 | - up_sample 44 | opts: 45 | calib_range: 46 | element_batch_size: 64 47 | sample_batch_size: 16 48 | element_size: 512 49 | sample_size: -1 50 | smooth: 51 | proj: 52 | element_batch_size: -1 53 | sample_batch_size: 16 54 | element_size: -1 55 | sample_size: -1 56 | attn: 57 | sample_batch_size: 16 58 | sample_size: -1 59 | -------------------------------------------------------------------------------- /examples/diffusion/configs/model/flux.1-schnell.yaml: -------------------------------------------------------------------------------- 1 | pipeline: 2 | name: flux.1-schnell 3 | dtype: torch.bfloat16 4 | eval: 5 | num_steps: 4 6 | guidance_scale: 0 7 | protocol: fmeuler{num_steps}-g{guidance_scale} 8 | quant: 9 | calib: 10 | batch_size: 16 11 | wgts: 12 | calib_range: 13 | element_batch_size: 64 14 | sample_batch_size: 32 15 | element_size: 512 16 | sample_size: -1 17 | low_rank: 18 | sample_batch_size: 32 19 | sample_size: -1 20 | skips: 21 | - embed 22 | - resblock_shortcut 23 | - resblock_time_proj 24 | - transformer_proj_in 25 | - transformer_proj_out 26 | - down_sample 27 | - up_sample 28 | ipts: 29 | calib_range: 30 | element_batch_size: 64 31 | sample_batch_size: 32 32 | element_size: 512 33 | sample_size: -1 34 | skips: 35 | - embed 36 | - resblock_shortcut 37 | - resblock_time_proj 38 | - transformer_proj_in 39 | - transformer_proj_out 40 | - transformer_norm 41 | - transformer_add_norm 42 | - down_sample 43 | - up_sample 44 | opts: 45 | calib_range: 46 | element_batch_size: 64 47 | sample_batch_size: 32 48 | element_size: 512 49 | sample_size: -1 50 | smooth: 51 | proj: 52 | element_batch_size: -1 53 | sample_batch_size: 32 54 | element_size: -1 55 | sample_size: -1 56 | attn: 57 | sample_batch_size: 32 58 | sample_size: -1 59 | -------------------------------------------------------------------------------- /examples/diffusion/configs/model/pixart-sigma.yaml: -------------------------------------------------------------------------------- 1 | pipeline: 2 | name: pixart-sigma 3 | eval: 4 | num_steps: 20 5 | guidance_scale: 4.5 6 | protocol: dpm{num_steps}-g{guidance_scale} 7 | quant: 8 | calib: 9 | batch_size: 256 10 | wgts: 11 | calib_range: 12 | sample_batch_size: -1 13 | low_rank: 14 | sample_batch_size: -1 15 | skips: 16 | - embed 17 | - resblock_shortcut 18 | - resblock_time_proj 19 | - transformer_proj_in 20 | - transformer_proj_out 21 | - transformer_norm 22 | - transformer_add_norm 23 | - attn_add 24 | - ffn_add 25 | - down_sample 26 | - up_sample 27 | ipts: 28 | calib_range: 29 | sample_batch_size: -1 30 | skips: 31 | - embed 32 | - resblock_shortcut 33 | - resblock_time_proj 34 | - transformer_proj_in 35 | - transformer_proj_out 36 | - transformer_norm 37 | - transformer_add_norm 38 | - attn_add 39 | - ffn_add 40 | - down_sample 41 | - up_sample 42 | opts: 43 | calib_range: 44 | sample_batch_size: -1 45 | smooth: 46 | proj: 47 | sample_batch_size: -1 48 | attn: 49 | sample_batch_size: -1 50 | -------------------------------------------------------------------------------- /examples/diffusion/configs/model/sana-1.6b.yaml: -------------------------------------------------------------------------------- 1 | pipeline: 2 | name: sana-1.6b-1024px-bf16-ch5632 3 | path: Lawrence-cj/Sana_1600M_1024px_BF16_diffusers_ch5632 4 | dtype: torch.bfloat16 5 | eval: 6 | num_steps: 20 7 | guidance_scale: 4.5 8 | protocol: flowdpm{num_steps}-g{guidance_scale} 9 | quant: 10 | calib: 11 | batch_size: 256 12 | wgts: 13 | calib_range: 14 | element_batch_size: 64 15 | sample_batch_size: 32 16 | element_size: 512 17 | sample_size: -1 18 | low_rank: 19 | sample_batch_size: 32 20 | sample_size: -1 21 | skips: 22 | - embed 23 | - resblock_shortcut 24 | - resblock_time_proj 25 | - transformer_proj_in 26 | - transformer_proj_out 27 | - transformer_norm 28 | - transformer_add_norm 29 | - attn_add 30 | - ffn_add 31 | - down_sample 32 | - up_sample 33 | ipts: 34 | calib_range: 35 | element_batch_size: 64 36 | sample_batch_size: 32 37 | element_size: 512 38 | sample_size: -1 39 | skips: 40 | - embed 41 | - resblock_shortcut 42 | - resblock_time_proj 43 | - transformer_proj_in 44 | - transformer_proj_out 45 | - transformer_norm 46 | - transformer_add_norm 47 | - attn_add 48 | - ffn_add 49 | - down_sample 50 | - up_sample 51 | opts: 52 | calib_range: 53 | element_batch_size: 64 54 | sample_batch_size: 32 55 | element_size: 512 56 | sample_size: -1 57 | smooth: 58 | proj: 59 | element_batch_size: -1 60 | sample_batch_size: 32 61 | element_size: -1 62 | sample_size: -1 63 | attn: 64 | sample_batch_size: 32 65 | sample_size: -1 66 | -------------------------------------------------------------------------------- /examples/diffusion/configs/svdquant/__default__.yaml: -------------------------------------------------------------------------------- 1 | quant: 2 | enable_smooth: true 3 | smooth: 4 | enable_proj: true 5 | proj: 6 | objective: OutputsError 7 | strategy: GridSearch 8 | granularity: Layer 9 | spans: 10 | - - AbsMax 11 | - AbsMax 12 | alpha: 0.5 13 | beta: -2 14 | num_grids: 20 15 | allow_low_rank: true 16 | fuse_when_possible: false 17 | skips: 18 | - embed 19 | - resblock 20 | - transformer_proj_in 21 | - transformer_proj_out 22 | - transformer_norm 23 | - transformer_add_norm 24 | - down_sample 25 | - up_sample 26 | wgts: 27 | enable_low_rank: true 28 | low_rank: 29 | rank: 32 30 | early_stop: true 31 | degree: 2 32 | objective: OutputsError 33 | num_iters: 100 34 | skips: 35 | - embed 36 | - resblock 37 | - transformer_proj_in 38 | - transformer_proj_out 39 | - transformer_norm 40 | - transformer_add_norm 41 | - down_sample 42 | - up_sample -------------------------------------------------------------------------------- /examples/diffusion/configs/svdquant/fast.yaml: -------------------------------------------------------------------------------- 1 | quant: 2 | smooth: 3 | proj: 4 | num_grids: 10 5 | calib: 6 | num_samples: 64 -------------------------------------------------------------------------------- /examples/diffusion/configs/svdquant/gptq.yaml: -------------------------------------------------------------------------------- 1 | quant: 2 | wgts: 3 | enable_kernel_gptq: true 4 | kernel_gptq: 5 | damp_percentage: 0.01 6 | block_size: 128 7 | num_inv_tries: 250 8 | hessian_block_size: 512 -------------------------------------------------------------------------------- /examples/diffusion/configs/svdquant/int4.yaml: -------------------------------------------------------------------------------- 1 | quant: 2 | wgts: 3 | dtype: sint4 4 | group_shapes: 5 | - - 1 6 | - 64 7 | - 1 8 | - 1 9 | - 1 10 | scale_dtypes: 11 | - null 12 | ipts: 13 | static: false 14 | dtype: sint4 15 | group_shapes: 16 | - - 1 17 | - 64 18 | - 1 19 | - 1 20 | - 1 21 | scale_dtypes: 22 | - null 23 | allow_unsigned: true 24 | pipeline: 25 | shift_activations: true -------------------------------------------------------------------------------- /examples/diffusion/configs/svdquant/nvfp4.yaml: -------------------------------------------------------------------------------- 1 | quant: 2 | wgts: 3 | dtype: sfp4_e2m1_all 4 | group_shapes: 5 | - - -1 6 | - -1 7 | - - 1 8 | - 16 9 | - 1 10 | - 1 11 | - 1 12 | scale_dtypes: 13 | - null 14 | - sfp8_e4m3_nan 15 | ipts: 16 | static: false 17 | dtype: sfp4_e2m1_all 18 | group_shapes: 19 | - - 1 20 | - 16 21 | - 1 22 | - 1 23 | - 1 24 | scale_dtypes: 25 | - sfp8_e4m3_nan 26 | enable_extra_wgts: true 27 | extra_wgts: 28 | dtype: sint4 29 | group_shapes: 30 | - - 1 31 | - 64 32 | - 1 33 | - 1 34 | - 1 35 | scale_dtypes: 36 | - null 37 | includes: 38 | - transformer_norm 39 | - transformer_add_norm 40 | -------------------------------------------------------------------------------- /examples/diffusion/configs/text/__default__.yaml: -------------------------------------------------------------------------------- 1 | # copied from projects/llm/configs/__default__.yaml 2 | enable_text: true 3 | text: 4 | calib: 5 | data: pileval 6 | path: mit-han-lab/pile-val-backup 7 | num_samples: 128 8 | seq_length: 1024 9 | min_seq_length: 0 10 | max_seq_length: 0 11 | develop_dtype: torch.float32 12 | wgts: 13 | dtype: null 14 | zero_point: null 15 | group_shapes: 16 | - - 1 17 | - -1 18 | scale_dtypes: 19 | - null 20 | intermediate_dtypes: [] 21 | intermediate_levels: [] 22 | needs_dequant_saturation: false 23 | enable_kernel_gptq: false 24 | kernel_gptq: 25 | damp_percentage: 0.01 26 | block_size: 128 27 | num_inv_tries: 250 28 | hessian_block_size: 512 29 | enable_calib_range: true 30 | calib_range: 31 | objective: OutputsError 32 | strategy: Manual 33 | granularity: Group 34 | degree: 2 35 | element_batch_size: 64 36 | sample_batch_size: -1 37 | element_size: 512 38 | sample_size: -1 39 | pre_reshape: true 40 | outputs_device: cpu 41 | ratio: 1.0 42 | max_shrink: 0.2 43 | max_expand: 1.0 44 | num_grids: 80 45 | skip_qkv_proj: false 46 | skip_out_proj: false 47 | skip_up_proj: false 48 | skip_down_proj: false 49 | skip_qkv_proj: false 50 | skip_out_proj: false 51 | skip_up_proj: false 52 | skip_down_proj: false 53 | ipts: 54 | static: false 55 | dtype: null 56 | zero_point: null 57 | group_shapes: 58 | - - 1 59 | - -1 60 | scale_dtypes: 61 | - null 62 | enable_calib_range: false 63 | calib_range: 64 | objective: OutputsError 65 | strategy: GridSearch 66 | granularity: ChannelGroup 67 | degree: 2 68 | element_batch_size: 64 69 | sample_batch_size: -1 70 | element_size: 512 71 | sample_size: -1 72 | pre_reshape: true 73 | outputs_device: cpu 74 | ratio: 1.0 75 | max_shrink: 0.2 76 | max_expand: 1.0 77 | num_grids: 80 78 | skip_qkv_proj: false 79 | skip_out_proj: false 80 | skip_up_proj: false 81 | skip_down_proj: false 82 | skip_qkv_proj: false 83 | skip_out_proj: false 84 | skip_up_proj: false 85 | skip_down_proj: false 86 | opts: 87 | static: false 88 | dtype: null 89 | zero_point: null 90 | group_shapes: 91 | - - 1 92 | - -1 93 | scale_dtypes: 94 | - null 95 | enable_calib_range: false 96 | calib_range: 97 | objective: OutputsError 98 | strategy: GridSearch 99 | granularity: ChannelGroup 100 | degree: 2 101 | element_batch_size: 64 102 | sample_batch_size: -1 103 | element_size: 512 104 | sample_size: -1 105 | pre_reshape: true 106 | outputs_device: cpu 107 | ratio: 1.0 108 | max_shrink: 0.2 109 | max_expand: 1.0 110 | num_grids: 80 111 | skip_attn_q: false 112 | skip_attn_k: false 113 | skip_attn_v: false 114 | skip_attn_q: false 115 | skip_attn_k: false 116 | skip_attn_v: false 117 | enable_rotation: false 118 | rotation: 119 | random: false 120 | transform_out_proj: false 121 | transform_down_proj: false 122 | enable_reorder: false 123 | reorder: 124 | strategy: Manual 125 | degree: 2 126 | element_batch_size: -1 127 | sample_batch_size: -1 128 | element_size: -1 129 | sample_size: -1 130 | pre_reshape: true 131 | outputs_device: cpu 132 | channel_metric: InputsAbsMax 133 | channel_index: Sequential 134 | dynamic: false 135 | skip_residual: true 136 | skip_out_proj: false 137 | skip_down_proj: false 138 | enable_smooth: false 139 | smooth: 140 | enable_proj: false 141 | proj: 142 | objective: OutputsError 143 | strategy: GridSearch 144 | granularity: Layer 145 | degree: 2 146 | element_batch_size: -1 147 | sample_batch_size: -1 148 | element_size: -1 149 | sample_size: -1 150 | pre_reshape: true 151 | outputs_device: cpu 152 | spans: 153 | - - AbsMax 154 | - AbsMax 155 | alpha: -3 156 | beta: -3 157 | num_grids: 20 158 | skip_qkv_proj: false 159 | skip_out_proj: false 160 | skip_up_proj: false 161 | skip_down_proj: false 162 | enable_attn: false 163 | attn: 164 | objective: OutputsError 165 | strategy: Manual 166 | granularity: Layer 167 | degree: 2 168 | element_batch_size: -1 169 | sample_batch_size: -1 170 | element_size: -1 171 | sample_size: -1 172 | pre_reshape: true 173 | outputs_device: cpu 174 | spans: 175 | - - AbsMax 176 | - AbsMax 177 | alpha: 0.5 178 | beta: 0 179 | num_grids: 20 180 | -------------------------------------------------------------------------------- /examples/diffusion/configs/text/awq.yaml: -------------------------------------------------------------------------------- 1 | # copied from projects/llm/configs/awq.yaml 2 | text: 3 | calib: 4 | num_samples: 128 5 | seq_length: 512 6 | min_seq_length: 0 7 | max_seq_length: 512 8 | wgts: 9 | dtype: uint4 10 | zero_point: PostScale 11 | group_shapes: 12 | - - 1 13 | - 128 14 | scale_dtypes: 15 | - torch.float16 16 | enable_calib_range: true 17 | calib_range: 18 | objective: ProductsError 19 | strategy: GridSearch 20 | granularity: Group 21 | degree: 2 22 | max_shrink: 0.8 23 | max_expand: 1.0 24 | num_grids: 20 25 | skip_qkv_proj: true 26 | ipts: 27 | static: false 28 | dtype: null 29 | group_shapes: 30 | - - 1 31 | - -1 32 | scale_dtypes: 33 | - null 34 | opts: 35 | static: false 36 | dtype: null 37 | group_shapes: 38 | - - 1 39 | - -1 40 | scale_dtypes: 41 | - null 42 | enable_smooth: true 43 | smooth: 44 | enable_proj: true 45 | proj: 46 | objective: OutputsError 47 | strategy: GridSearch 48 | granularity: Layer 49 | spans: 50 | - - AbsMax 51 | - AbsMax 52 | alpha: 0.5 53 | beta: 0 54 | num_grids: 20 55 | enable_attn: false -------------------------------------------------------------------------------- /examples/diffusion/scripts/svdquant.sh: -------------------------------------------------------------------------------- 1 | python -m deepcompressor.app.diffusion.ptq configs/model/flux.1-schnell.yaml configs/svdquant/int4.yaml -------------------------------------------------------------------------------- /examples/llm/.gitignore: -------------------------------------------------------------------------------- 1 | .tmp 2 | .tmp/ 3 | -------------------------------------------------------------------------------- /examples/llm/configs/__default__.yaml: -------------------------------------------------------------------------------- 1 | seed: 12345 2 | save_model: false 3 | cache: 4 | root: runs 5 | output: 6 | root: runs 7 | dirname: default 8 | model: 9 | name: llama-2-7b 10 | path: null 11 | root: '' 12 | local_path: null 13 | local_root: ~/models 14 | dtype: torch.float16 15 | eval: 16 | num_gpus: 8 17 | batch_size: 8 18 | tasks: 19 | - wikitext 20 | max_seq_length: -4096 21 | evaluators: 22 | - gptq 23 | quant: 24 | calib: 25 | data: pileval 26 | path: mit-han-lab/pile-val-backup 27 | num_samples: 128 28 | seq_length: 1024 29 | min_seq_length: 0 30 | max_seq_length: 0 31 | develop_dtype: torch.float32 32 | wgts: 33 | dtype: null 34 | zero_point: null 35 | group_shapes: 36 | - - 1 37 | - -1 38 | scale_dtypes: 39 | - null 40 | intermediate_dtypes: [] 41 | intermediate_levels: [] 42 | needs_dequant_saturation: false 43 | enable_kernel_gptq: false 44 | kernel_gptq: 45 | damp_percentage: 0.01 46 | block_size: 128 47 | num_inv_tries: 250 48 | hessian_block_size: 512 49 | enable_calib_range: true 50 | calib_range: 51 | objective: OutputsError 52 | strategy: Manual 53 | granularity: Group 54 | degree: 2 55 | element_batch_size: 64 56 | sample_batch_size: -1 57 | element_size: 512 58 | sample_size: -1 59 | pre_reshape: true 60 | outputs_device: cpu 61 | ratio: 1.0 62 | max_shrink: 0.2 63 | max_expand: 1.0 64 | num_grids: 80 65 | skips: [] 66 | skips: [] 67 | ipts: 68 | static: false 69 | dtype: null 70 | zero_point: null 71 | group_shapes: 72 | - - 1 73 | - -1 74 | scale_dtypes: 75 | - null 76 | enable_calib_range: false 77 | calib_range: 78 | objective: OutputsError 79 | strategy: GridSearch 80 | granularity: ChannelGroup 81 | degree: 2 82 | element_batch_size: 64 83 | sample_batch_size: -1 84 | element_size: 512 85 | sample_size: -1 86 | pre_reshape: true 87 | outputs_device: cpu 88 | ratio: 1.0 89 | max_shrink: 0.2 90 | max_expand: 1.0 91 | num_grids: 80 92 | skips: [] 93 | skips: [] 94 | opts: 95 | static: false 96 | dtype: null 97 | zero_point: null 98 | group_shapes: 99 | - - 1 100 | - -1 101 | scale_dtypes: 102 | - null 103 | enable_calib_range: false 104 | calib_range: 105 | objective: OutputsError 106 | strategy: GridSearch 107 | granularity: ChannelGroup 108 | degree: 2 109 | element_batch_size: 64 110 | sample_batch_size: -1 111 | element_size: 512 112 | sample_size: -1 113 | pre_reshape: true 114 | outputs_device: cpu 115 | ratio: 1.0 116 | max_shrink: 0.2 117 | max_expand: 1.0 118 | num_grids: 80 119 | skips: [] 120 | skips: [] 121 | enable_rotation: false 122 | rotation: 123 | random: false 124 | transforms: [] 125 | enable_reorder: false 126 | reorder: 127 | strategy: Manual 128 | degree: 2 129 | sample_batch_size: -1 130 | sample_size: -1 131 | outputs_device: cpu 132 | channel_metric: InputsAbsMax 133 | channel_index: Sequential 134 | dynamic: false 135 | skips: 136 | - residual 137 | enable_smooth: false 138 | smooth: 139 | enable_proj: false 140 | proj: 141 | objective: OutputsError 142 | strategy: GridSearch 143 | granularity: Layer 144 | degree: 2 145 | element_batch_size: -1 146 | sample_batch_size: -1 147 | element_size: -1 148 | sample_size: -1 149 | pre_reshape: true 150 | outputs_device: cpu 151 | spans: 152 | - - AbsMax 153 | - AbsMax 154 | alpha: -3 155 | beta: -3 156 | num_grids: 20 157 | skips: [] 158 | enable_attn: false 159 | attn: 160 | strategy: Manual 161 | degree: 2 162 | sample_batch_size: -1 163 | sample_size: -1 164 | outputs_device: cpu 165 | spans: 166 | - - AbsMax 167 | - AbsMax 168 | alpha: 0.5 169 | beta: 0 170 | num_grids: 20 171 | -------------------------------------------------------------------------------- /examples/llm/configs/awq.yaml: -------------------------------------------------------------------------------- 1 | quant: 2 | calib: 3 | num_samples: 128 4 | seq_length: 512 5 | min_seq_length: 0 6 | max_seq_length: 512 7 | wgts: 8 | dtype: uint4 9 | zero_point: PostScale 10 | group_shapes: 11 | - - 1 12 | - 128 13 | scale_dtypes: 14 | - torch.float16 15 | enable_calib_range: true 16 | calib_range: 17 | objective: ProductsError 18 | strategy: GridSearch 19 | granularity: Group 20 | degree: 2 21 | max_shrink: 0.8 22 | max_expand: 1.0 23 | num_grids: 20 24 | skips: 25 | - qkv_proj 26 | ipts: 27 | static: false 28 | dtype: null 29 | group_shapes: 30 | - - 1 31 | - -1 32 | scale_dtypes: 33 | - torch.float16 34 | opts: 35 | static: false 36 | dtype: null 37 | group_shapes: 38 | - - 1 39 | - -1 40 | scale_dtypes: 41 | - torch.float16 42 | enable_smooth: true 43 | smooth: 44 | enable_proj: true 45 | proj: 46 | objective: OutputsError 47 | strategy: GridSearch 48 | granularity: Layer 49 | spans: 50 | - - AbsMax 51 | - AbsMax 52 | alpha: 0.5 53 | beta: 0 54 | num_grids: 20 55 | enable_attn: false -------------------------------------------------------------------------------- /examples/llm/configs/gptq.yaml: -------------------------------------------------------------------------------- 1 | quant: 2 | calib: 3 | num_samples: 128 4 | seq_length: 2048 5 | min_seq_length: 2048 6 | max_seq_length: 0 7 | wgts: 8 | dtype: uint4 9 | zero_point: PostScale 10 | group_shapes: 11 | - - 1 12 | - 128 13 | scale_dtypes: 14 | - torch.float16 15 | enable_kernel_gptq: true 16 | kernel_gptq: 17 | damp_percentage: 0.01 18 | block_size: 128 19 | num_inv_tries: 250 20 | hessian_block_size: 512 21 | enable_calib_range: false 22 | calib_range: 23 | objective: TensorError 24 | strategy: GridSearch 25 | granularity: Group 26 | degree: 2.4 27 | element_batch_size: -1 28 | sample_batch_size: -1 29 | element_size: -1 30 | sample_size: -1 31 | pre_reshape: true 32 | outputs_device: cpu 33 | max_shrink: 0.2 34 | max_expand: 1.0 35 | num_grids: 100 36 | ipts: 37 | static: false 38 | dtype: null 39 | group_shapes: 40 | - - 1 41 | - -1 42 | scale_dtypes: 43 | - torch.float16 44 | opts: 45 | static: false 46 | dtype: null 47 | group_shapes: 48 | - - 1 49 | - -1 50 | scale_dtypes: 51 | - torch.float16 -------------------------------------------------------------------------------- /examples/llm/configs/ooo.yaml: -------------------------------------------------------------------------------- 1 | quant: 2 | calib: 3 | num_samples: 128 4 | seq_length: 1024 5 | min_seq_length: 0 6 | max_seq_length: 0 7 | wgts: 8 | dtype: sint8 9 | group_shapes: 10 | - - 1 11 | - -1 12 | scale_dtypes: 13 | - torch.float16 14 | enable_kernel_gptq: true 15 | kernel_gptq: 16 | damp_percentage: 0.01 17 | block_size: 128 18 | num_inv_tries: 250 19 | hessian_block_size: 512 20 | enable_calib_range: true 21 | calib_range: 22 | objective: OutputsError 23 | strategy: GridSearch 24 | granularity: Group 25 | max_shrink: 0.2 26 | max_expand: 1.0 27 | num_grids: 80 28 | ipts: 29 | static: false 30 | dtype: sint8 31 | group_shapes: 32 | - - 1 33 | - -1 34 | scale_dtypes: 35 | - torch.float16 36 | opts: 37 | static: true 38 | dtype: sint8 39 | group_shapes: 40 | - - -1 41 | - -1 42 | scale_dtypes: 43 | - torch.float16 44 | enable_calib_range: true 45 | calib_range: 46 | objective: OutputsError 47 | strategy: Manual 48 | granularity: Layer 49 | degree: 2 50 | element_batch_size: -1 51 | sample_batch_size: -1 52 | element_size: -1 53 | sample_size: -1 54 | pre_reshape: true 55 | outputs_device: cpu 56 | enable_rotation: true 57 | rotation: 58 | transforms: 59 | - out_proj 60 | enable_smooth: true 61 | smooth: 62 | enable_proj: true 63 | proj: 64 | objective: OutputsError 65 | strategy: Manual 66 | granularity: Layer 67 | degree: 2 68 | spans: 69 | - - AbsMax 70 | - AbsMax 71 | alpha: 0.1 72 | beta: 0.9 73 | num_grids: 20 74 | skips: 75 | - qkv_proj 76 | - up_proj 77 | - out_proj 78 | enable_attn: true 79 | attn: 80 | strategy: GridSearch 81 | degree: 2 82 | spans: 83 | - - AbsMax 84 | - AbsMax 85 | alpha: 0.5 86 | beta: -2 87 | num_grids: 20 -------------------------------------------------------------------------------- /examples/llm/configs/qoq-g128.yaml: -------------------------------------------------------------------------------- 1 | quant: 2 | calib: 3 | num_samples: 128 4 | seq_length: 1024 5 | min_seq_length: 0 6 | max_seq_length: 0 7 | wgts: 8 | dtype: uint4 9 | zero_point: PostScale 10 | group_shapes: 11 | - - 1 12 | - -1 13 | - - 1 14 | - 128 15 | scale_dtypes: 16 | - torch.float16 17 | - sint8 18 | intermediate_dtypes: 19 | - sint8 20 | intermediate_levels: 21 | - 0 22 | needs_dequant_saturation: false 23 | enable_kernel_gptq: true 24 | kernel_gptq: 25 | damp_percentage: 0.01 26 | block_size: 128 27 | num_inv_tries: 250 28 | hessian_block_size: 512 29 | ipts: 30 | static: false 31 | dtype: sint8 32 | group_shapes: 33 | - - 1 34 | - -1 35 | scale_dtypes: 36 | - torch.float16 37 | opts: 38 | static: false 39 | dtype: uint4 40 | zero_point: PostScale 41 | group_shapes: 42 | - - 1 43 | - 128 44 | scale_dtypes: 45 | - torch.float16 46 | skips: 47 | - attn_q 48 | enable_rotation: true 49 | enable_reorder: true 50 | reorder: 51 | strategy: Manual 52 | channel_metric: InputsAbsMax 53 | channel_index: Sequential 54 | skips: 55 | - residual 56 | enable_smooth: true 57 | rotation: 58 | transforms: 59 | - out_proj 60 | smooth: 61 | enable_proj: true 62 | proj: 63 | objective: OutputsError 64 | strategy: Manual 65 | granularity: Layer 66 | degree: 2 67 | spans: 68 | - - AbsMax 69 | - AbsMax 70 | alpha: 0.3 71 | beta: 0.7 72 | num_grids: 20 73 | skips: 74 | - qkv_proj 75 | - up_proj 76 | - out_proj 77 | enable_attn: true 78 | attn: 79 | strategy: Manual 80 | degree: 2 81 | spans: 82 | - - AbsMax 83 | - AbsMax 84 | alpha: 0.5 85 | beta: 0 86 | num_grids: 20 -------------------------------------------------------------------------------- /examples/llm/configs/qoq-gchn.yaml: -------------------------------------------------------------------------------- 1 | quant: 2 | calib: 3 | num_samples: 128 4 | seq_length: 1024 5 | min_seq_length: 0 6 | max_seq_length: 0 7 | wgts: 8 | dtype: uint4 9 | zero_point: PostScale 10 | group_shapes: 11 | - - 1 12 | - -1 13 | scale_dtypes: 14 | - torch.float16 15 | enable_kernel_gptq: true 16 | kernel_gptq: 17 | damp_percentage: 0.01 18 | block_size: 128 19 | num_inv_tries: 250 20 | hessian_block_size: 512 21 | enable_calib_range: true 22 | calib_range: 23 | objective: OutputsError 24 | strategy: GridSearch 25 | granularity: Group 26 | max_shrink: 0.2 27 | max_expand: 1.0 28 | num_grids: 80 29 | ipts: 30 | static: false 31 | dtype: sint8 32 | group_shapes: 33 | - - 1 34 | - -1 35 | scale_dtypes: 36 | - torch.float16 37 | opts: 38 | static: false 39 | dtype: uint4 40 | zero_point: PostScale 41 | group_shapes: 42 | - - 1 43 | - 128 44 | scale_dtypes: 45 | - torch.float16 46 | skips: 47 | - attn_q 48 | enable_rotation: true 49 | rotation: 50 | transforms: 51 | - out_proj 52 | enable_smooth: true 53 | smooth: 54 | enable_proj: true 55 | proj: 56 | objective: OutputsError 57 | strategy: Manual 58 | granularity: Layer 59 | degree: 2 60 | spans: 61 | - - AbsMax 62 | - AbsMax 63 | alpha: 0.1 64 | beta: 0.9 65 | num_grids: 20 66 | skips: 67 | - qkv_proj 68 | - up_proj 69 | - out_proj 70 | enable_attn: true 71 | attn: 72 | strategy: GridSearch 73 | degree: 2 74 | spans: 75 | - - AbsMax 76 | - AbsMax 77 | alpha: 0.5 78 | beta: -2 79 | num_grids: 20 -------------------------------------------------------------------------------- /examples/llm/configs/smoothquant-dynamic.yaml: -------------------------------------------------------------------------------- 1 | quant: 2 | calib: 3 | num_samples: 128 4 | seq_length: 512 5 | min_seq_length: 0 6 | max_seq_length: 0 7 | wgts: 8 | dtype: sint8 9 | group_shapes: 10 | - - 1 11 | - -1 12 | scale_dtypes: 13 | - null 14 | ipts: 15 | static: false 16 | dtype: sint8 17 | group_shapes: 18 | - - 1 19 | - -1 20 | scale_dtypes: 21 | - null 22 | opts: 23 | static: false 24 | dtype: sint8 25 | group_shapes: 26 | - - 1 27 | - -1 28 | scale_dtypes: 29 | - null 30 | enable_smooth: true 31 | smooth: 32 | enable_proj: true 33 | proj: 34 | objective: OutputsError 35 | strategy: Manual 36 | granularity: Layer 37 | spans: 38 | - - AbsMax 39 | - AbsMax 40 | alpha: 0.85 41 | beta: 0.15 42 | skips: 43 | - out_proj 44 | - down_proj 45 | enable_attn: false -------------------------------------------------------------------------------- /examples/llm/configs/smoothquant-static.yaml: -------------------------------------------------------------------------------- 1 | quant: 2 | calib: 3 | num_samples: 128 4 | seq_length: 512 5 | min_seq_length: 0 6 | max_seq_length: 0 7 | wgts: 8 | dtype: sint8 9 | group_shapes: 10 | - - 1 11 | - -1 12 | scale_dtypes: 13 | - null 14 | ipts: 15 | static: false 16 | dtype: sint8 17 | group_shapes: 18 | - - 1 19 | - -1 20 | scale_dtypes: 21 | - null 22 | opts: 23 | static: true 24 | dtype: sint8 25 | group_shapes: 26 | - - -1 27 | - -1 28 | scale_dtypes: 29 | - null 30 | enable_calib_range: true 31 | calib_range: 32 | objective: OutputsError 33 | strategy: Manual 34 | granularity: Layer 35 | degree: 2 36 | element_batch_size: -1 37 | sample_batch_size: -1 38 | element_size: -1 39 | sample_size: -1 40 | pre_reshape: true 41 | outputs_device: cpu 42 | enable_smooth: true 43 | smooth: 44 | enable_proj: true 45 | proj: 46 | objective: OutputsError 47 | strategy: Manual 48 | granularity: Layer 49 | spans: 50 | - - AbsMax 51 | - AbsMax 52 | alpha: 0.85 53 | beta: 0.15 54 | skips: 55 | - out_proj 56 | - down_proj 57 | enable_attn: false -------------------------------------------------------------------------------- /examples/llm/scripts/awq.sh: -------------------------------------------------------------------------------- 1 | # AWQ (W4A16) on Llama2-7B 2 | python -m deepcompressor.app.llm.ptq configs/awq.yaml --model-name llama-2-7b 3 | 4 | # AWQ (W4A16) on Llama2-13B 5 | python -m deepcompressor.app.llm.ptq configs/awq.yaml --model-name llama-2-13b 6 | 7 | # AWQ (W4A16) on Llama2-70B 8 | python -m deepcompressor.app.llm.ptq configs/awq.yaml --model-name llama-2-70b 9 | 10 | # AWQ (W4A16) on Llama3-8B 11 | python -m deepcompressor.app.llm.ptq configs/awq.yaml --model-name llama-3-8b 12 | 13 | # AWQ (W4A16) on Llama3-70B 14 | python -m deepcompressor.app.llm.ptq configs/awq.yaml --model-name llama-3-70b -------------------------------------------------------------------------------- /examples/llm/scripts/gptq.sh: -------------------------------------------------------------------------------- 1 | # GPTQ-R (W4A16) on Llama2-7B 2 | python -m deepcompressor.app.llm.ptq configs/gptq.yaml --model-name llama-2-7b 3 | 4 | # GPTQ-R (W4A16) on Llama2-13B 5 | python -m deepcompressor.app.llm.ptq configs/gptq.yaml --model-name llama-2-13b 6 | 7 | # GPTQ-R (W4A16) on Llama2-70B 8 | python -m deepcompressor.app.llm.ptq configs/gptq.yaml --model-name llama-2-70b 9 | 10 | # GPTQ-R (W4A16) on Llama3-8B 11 | python -m deepcompressor.app.llm.ptq configs/gptq.yaml --model-name llama-3-8b 12 | 13 | # GPTQ-R (W4A16) on Llama3-70B 14 | python -m deepcompressor.app.llm.ptq configs/gptq.yaml --model-name llama-3-70b -------------------------------------------------------------------------------- /examples/llm/scripts/smoothquant.sh: -------------------------------------------------------------------------------- 1 | # SmoothQuant (W8A8 with per-token dynamic KV quantization) on Llama2-7B 2 | python -m deepcompressor.app.llm.ptq configs/smoothquant-dynamic.yaml --model-name llama-2-7b --smooth-proj-alpha 0.85 --smooth-proj-beta 0.15 3 | 4 | # SmoothQuant (W8A8 with per-tensor static KV quantization) on Llama2-7B 5 | python -m deepcompressor.app.llm.ptq configs/smoothquant-static.yaml --model-name llama-2-7b --smooth-proj-alpha 0.85 --smooth-proj-beta 0.15 6 | 7 | # SmoothQuant (W8A8 with per-token dynamic KV quantization) on Llama2-13B 8 | python -m deepcompressor.app.llm.ptq configs/smoothquant-dynamic.yaml --model-name llama-2-13b --smooth-proj-alpha 0.85 --smooth-proj-beta 0.15 9 | 10 | # SmoothQuant (W8A8 with per-token dynamic KV quantization) on Llama2-70B 11 | python -m deepcompressor.app.llm.ptq configs/smoothquant-dynamic.yaml --model-name llama-2-13b --smooth-proj-alpha 0.9 --smooth-proj-beta 0.1 12 | 13 | # SmoothQuant (W8A8 with per-token dynamic KV quantization) on Llama3-8B 14 | python -m deepcompressor.app.llm.ptq configs/smoothquant-dynamic.yaml --model-name llama-3-8b --smooth-proj-alpha 0.85 --smooth-proj-beta 0.15 15 | 16 | # SmoothQuant (W8A8 with per-token dynamic KV quantization) on Llama3-70B 17 | python -m deepcompressor.app.llm.ptq configs/smoothquant-dynamic.yaml --model-name llama-3-8b --smooth-proj-alpha 0.85 --smooth-proj-beta 0.15 18 | 19 | # SmoothQuant (W8A8 with per-token dynamic KV quantization) on Mistral-7B 20 | python -m deepcompressor.app.llm.ptq configs/smoothquant-dynamic.yaml --model-name mistral-7b --smooth-proj-alpha 0.8 --smooth-proj-beta 0.2 21 | 22 | # SmoothQuant (W8A8 with per-token dynamic KV quantization) on Mixtral-8x7B 23 | python -m deepcompressor.app.llm.ptq configs/smoothquant-dynamic.yaml --model-name mixtral-8x7b --smooth-proj-alpha 0.8 --smooth-proj-beta 0.2 24 | 25 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "deepcompressor-toolkit" 3 | version = "0.0.2" 4 | description = "This package is model compression toolkit for large language models and diffusion models." 5 | authors = [ 6 | "Yujun Lin", 7 | "Muyang Li", 8 | "Shang Yang", 9 | "Zhekai Zhang", 10 | "Haotian Tang", 11 | "Song Han", 12 | ] 13 | packages = [ { include = "deepcompressor" } ] 14 | license = "Apache-2.0" 15 | readme = "README.md" 16 | 17 | [tool.poetry.dependencies] 18 | python = ">= 3.10 < 4.0" 19 | tqdm = ">= 4.66.0" 20 | torch = ">= 2.5.0" 21 | torchvision = ">= 0.18.1" 22 | torchmetrics = ">= 1.4.0" 23 | ninja = ">= 1.11.1" 24 | bitsandbytes = ">= 0.42.0" 25 | transformers = ">= 4.46.0" 26 | lm_eval = ">= 0.4.2" 27 | accelerate = ">= 0.26.0" 28 | datasets = ">= 2.16.0" 29 | sentencepiece = ">= 0.1.99" 30 | omniconfig = ">= 0.1.10" 31 | jieba = ">= 0.42.1" 32 | fuzzywuzzy = ">= 0.18.0" 33 | rouge = ">= 1.0.1" 34 | python-Levenshtein = ">=0.26.1" 35 | protobuf = ">= 5.26.0" 36 | diffusers = ">= 0.32.0" 37 | clean-fid = ">= 0.1.35" 38 | dominate = ">= 2.9.1" 39 | opencv-python = ">= 4.10.0" 40 | einops = ">= 0.8.0" 41 | timm = ">= 1.0.7" 42 | rotary-embedding-torch = ">= 0.6.4" 43 | bs4 = ">= 0.0.2" 44 | ftfy = ">= 6.2.0" 45 | cd-fvd = ">= 0.1.1" 46 | xformers = ">= 0.0.26" 47 | pyav = ">= 13.0.0" 48 | clip = ">= 0.2.0" 49 | image_reward = { git = "https://github.com/THUDM/ImageReward.git", branch = "main" } 50 | 51 | [tool.poetry.group.dev.dependencies] 52 | matplotlib = ">= 3.8.4" 53 | ipython = ">= 8.0.0" 54 | 55 | [build-system] 56 | requires = ["poetry-core"] 57 | build-backend = "poetry.core.masonry.api" 58 | 59 | [tool.ruff] 60 | line-length = 120 61 | indent-width = 4 62 | target-version = "py310" 63 | 64 | [tool.ruff.lint] 65 | select = ["B", "C", "E", "F", "I", "W"] 66 | ignore = [] 67 | 68 | [tool.ruff.lint.mccabe] 69 | max-complexity = 15 70 | 71 | [tool.ruff.lint.per-file-ignores] 72 | "__init__.py" = ["F401", "F403"] 73 | 74 | [tool.ruff.format] 75 | quote-style = "double" 76 | indent-style = "space" 77 | skip-magic-trailing-comma = false 78 | line-ending = "auto" 79 | --------------------------------------------------------------------------------