├── .gitignore ├── LICENSE ├── README.md ├── lpmm ├── __init__.py ├── config.py ├── configs │ ├── 1st_moment_group_1024.yml │ ├── 1st_moment_group_2048.yml │ ├── 1st_moment_group_256.yml │ ├── 1st_moment_group_4096.yml │ ├── 1st_moment_group_512.yml │ ├── 2nd_moment_group_128.yml │ ├── 2nd_moment_only_default.yml │ ├── 2nd_moment_only_group_128.yml │ ├── default.yml │ ├── default_5b.yml │ ├── default_6b.yml │ ├── default_7b.yml │ └── default_8b.yml ├── cpp_extension │ ├── common.h │ ├── fused_adamw.cc │ ├── fused_adamw_kernel.cu │ ├── quantization.cc │ └── quantization_kernel.cu ├── functional.py ├── optim │ ├── __init__.py │ ├── adamw.py │ ├── optimizer.py │ └── sgd.py └── utils.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Results 3 | results 4 | *.log 5 | *.png 6 | *.tsv 7 | *.json 8 | *.csv 9 | 10 | # Build 11 | build/ 12 | *.egg-info 13 | *.so 14 | 15 | # Python 16 | *.pyc 17 | __pycache__ 18 | 19 | # VIM 20 | *.swp 21 | 22 | # Others 23 | .DS_Store 24 | .vscode 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Low-bit Optimizers 2 | 3 | Offical implementation of the paper: *[Memory Efficient Optimizers with 4-bit States](https://arxiv.org/abs/2309.01507)*. 4 | 5 | Optimizer states are a major source of memory consumption for training neural networks, limiting the maximum trainable model within given memory budget. Compressing the optimizer states from 32-bit floating points to lower bitwidth is promising to reduce the training memory footprint, while the current lowest achievable bitwidth is 8-bit. In this work, we push optimizer states bitwidth down to 4-bit through a detailed empirical analysis of first and second order momentums. Specifically, we find that momentums have complicated outlier patterns, that current block-wise quantization cannot accurately approximate. We use a smaller block size and propose to utilize both row-wise and column-wise information for better quantization. We further identify a zero point problem of quantizing the second-order momentum, and solve this problem with a linear quantizer that excludes the zero point. Our 4-bit optimizer is evaluated on a wide variety of benchmarks including natural language understanding, machine translation, image classification, and instruction tuning. On all the tasks our optimizers can achieve comparable accuracy with their full-precision counterparts, while enjoying better memory efficiency. 6 | 7 | ## Installation 8 | 9 | **Requirements** 10 | Python >= 3.7 + CUDA >= 11.0 + torch >= 1.13.0. 11 | 12 | To install run: 13 | 14 | ```bash 15 | git clone https://github.com/thu-ml/low-bit-optimizers.git 16 | pip install -v -e . 17 | ``` 18 | 19 | ## Usage 20 | 21 | ### Using 4-bit Optimizers 22 | 23 | To get started with 4-bit optimizers, simply replace your existing optimizer with one of our 4-bit optimizers: 4-bit AdamW, 4-bit Factor, or 4-bit AdamW (fused). 24 | 25 | ```python 26 | import lpmm 27 | 28 | # Comment out or remove the old optimizer 29 | # optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.999)) 30 | 31 | # Use 4-bit AdamW 32 | optimizer = lpmm.optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.999)) 33 | 34 | # Or, use 4-bit Factor 35 | optimizer = lpmm.optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.999), factor_second_moment=True) 36 | 37 | # Or, use 4-bit AdamW (fused) 38 | optimizer = lpmm.optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.999), fused=True) 39 | ``` 40 | 41 | Currently, the supported optimizers are Adam (AdamW) and SGD. 42 | 43 | ### Modifying Quantization Hyperparameters 44 | 45 | To modify the quantization configuration (e.g., normalization function, quantization map, bits, etc.) of non-fused optimizers, create a new configuration file and pass its file path to the optimizer using the `qconfig` argument. Example configurations can be found in the [lpmm/configs](lpmm/configs) directory. 46 | By default, the quantization configuration for non-fused optimizers is specified in [lpmm/configs/default.yml](lpmm/configs/default.yml), while for fused optimizers, it is specified in [lpmm/configs/2nd_moment_group_128.yml](lpmm/configs/2nd_moment_group_128.yml). The configuration for fused optimizers is currently fixed and cannot be changed. 47 | 48 | To use a new configuration file, follow the example below: 49 | 50 | ```python 51 | config_path = f"configs/default.yml" # path to your configuration file 52 | optimizer = lpmm.optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.999), qconfig=config_path) 53 | ``` 54 | Commonly used hyperparameters and their possible values include: 55 | - SCALE_TYPE (normalization function): tensor, dim0, dim1, group, rank1 56 | - QUANT_TYPE (quantization map): nonlinear, power-1, power-2 57 | - BITS: 4, 5, 6, 7, 8 58 | - ENABLE (whether to quantize the state): True, False 59 | 60 | We recommend to use BITS = 4 or 8. 61 | 62 | ### Overriding Quantization Enablement for Specific Parameters 63 | 64 | To optimize certain parameters using 32-bit precision instead of quantizing them, use the `override_quantize_enable` method as shown below: 65 | 66 | ```python 67 | optimizer = lpmm.optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.999)) 68 | optimizer.override_quantize_enable(module, param_name, enable=False) 69 | ``` 70 | 71 | In this example, `module` is the module containing the parameter, and `param_name` is the name of the parameter you wish to optimize with 32-bit precision. Setting `enable=False` will prevent quantization of the specified parameter. -------------------------------------------------------------------------------- /lpmm/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from . import optim -------------------------------------------------------------------------------- /lpmm/config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import yaml 4 | from yacs.config import CfgNode as CN 5 | from .utils import get_rank, get_world_size 6 | 7 | _C = CN() 8 | 9 | # Base config files 10 | _C.BASE = [''] 11 | 12 | _C.QUANT = CN(new_allowed=True) 13 | _C.QUANT.INIT_STATES = ['param'] 14 | 15 | _C.QUANT.P = CN(new_allowed=True) 16 | _C.QUANT.P.ENABLE = False 17 | _C.QUANT.P.THRESHOLD = 4096 18 | _C.QUANT.P.EXCLUDE_SUFFIX = [''] # model-related 19 | _C.QUANT.P.EXCLUDE_REGEX = [''] 20 | _C.QUANT.P.BITS = 8 21 | _C.QUANT.P.SCALE_TYPE = CN(new_allowed=True) 22 | _C.QUANT.P.SCALE_TYPE.DEFAULT = 'group' 23 | _C.QUANT.P.SCALE_TYPE.DEFAULT_ONLY = True 24 | _C.QUANT.P.QUANT_TYPE = CN(new_allowed=True) 25 | _C.QUANT.P.QUANT_TYPE.DEFAULT = 'linear' 26 | _C.QUANT.P.QUANT_TYPE.DEFAULT_ONLY = True 27 | _C.QUANT.P.ROUND_TYPE = 'sr' 28 | _C.QUANT.P.GROUP_SIZE = 64 29 | _C.QUANT.P.SIGNED = True 30 | 31 | _C.QUANT.G = CN(new_allowed=True) 32 | _C.QUANT.G.ENABLE = False 33 | _C.QUANT.G.THRESHOLD = 4096 34 | 35 | _C.QUANT.M = CN(new_allowed=True) 36 | _C.QUANT.M.ENABLE = True 37 | _C.QUANT.M.THRESHOLD = 4096 38 | _C.QUANT.M.EXCLUDE_SUFFIX = [''] # model-related 39 | _C.QUANT.M.EXCLUDE_REGEX = [''] 40 | _C.QUANT.M.BITS = 4 41 | _C.QUANT.M.SCALE_TYPE = CN(new_allowed=True) 42 | _C.QUANT.M.SCALE_TYPE.DEFAULT = 'group' 43 | _C.QUANT.M.SCALE_TYPE.DEFAULT_ONLY = True 44 | _C.QUANT.M.QUANT_TYPE = CN(new_allowed=True) 45 | _C.QUANT.M.QUANT_TYPE.DEFAULT = 'nonlinear' 46 | _C.QUANT.M.QUANT_TYPE.DEFAULT_ONLY = True 47 | _C.QUANT.M.ROUND_TYPE = 'real-nearest' 48 | _C.QUANT.M.GROUP_SIZE = 128 49 | _C.QUANT.M.SIGNED = True 50 | 51 | _C.QUANT.SQM = CN(new_allowed=True) 52 | _C.QUANT.SQM.ENABLE = True 53 | _C.QUANT.SQM.THRESHOLD = 4096 54 | _C.QUANT.SQM.EXCLUDE_SUFFIX = [''] # model-related 55 | _C.QUANT.SQM.EXCLUDE_REGEX = [''] 56 | _C.QUANT.SQM.BITS = 4 57 | _C.QUANT.SQM.SCALE_TYPE = CN(new_allowed=True) 58 | _C.QUANT.SQM.SCALE_TYPE.DEFAULT = 'group' 59 | _C.QUANT.SQM.SCALE_TYPE.DEFAULT_ONLY = True 60 | _C.QUANT.SQM.QUANT_TYPE = CN(new_allowed=True) 61 | _C.QUANT.SQM.QUANT_TYPE.DEFAULT = 'power-1' 62 | _C.QUANT.SQM.QUANT_TYPE.DEFAULT_ONLY = True 63 | _C.QUANT.SQM.ROUND_TYPE = 'real-nearest' 64 | _C.QUANT.SQM.GROUP_SIZE = 128 65 | _C.QUANT.SQM.SIGNED = False 66 | 67 | _C.QUANT.DEBUG = CN(new_allowed=True) 68 | _C.QUANT.DEBUG.TRUNCATED_RATE_STAT_ITER = False 69 | _C.QUANT.DEBUG.ROW_ABSMAX_STAT_ITER = False 70 | _C.QUANT.DEBUG.ROW_ABSMAX_STAT_EPOCH = False 71 | 72 | _C.TRAIN = CN(new_allowed=True) 73 | 74 | # ----------------------------------------------------------------------------- 75 | # Misc 76 | # ----------------------------------------------------------------------------- 77 | _C.OUTPUT = '.' 78 | _C.TAG = '' # (optional) for index of repeat experiments 79 | _C.LOCAL_RANK = 0 80 | 81 | 82 | def _update_config_from_file(config, cfg_file): 83 | config.defrost() 84 | with open(cfg_file, 'r') as f: 85 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 86 | 87 | for cfg in yaml_cfg.setdefault('BASE', ['']): 88 | if cfg: 89 | _update_config_from_file( 90 | config, os.path.join(os.path.dirname(cfg_file), cfg) 91 | ) 92 | print('=> merge config from {}'.format(cfg_file)) 93 | config.merge_from_file(cfg_file) 94 | config.freeze() 95 | 96 | 97 | def update_config(config, args): 98 | def _check_args(name): 99 | if hasattr(args, name) and getattr(args, name) is not None: 100 | return True 101 | return False 102 | 103 | config.defrost() 104 | if _check_args('output'): 105 | config.OUTPUT = args.output 106 | elif _check_args('workspace'): 107 | config.OUTPUT = args.workspace 108 | elif _check_args('output_dir'): 109 | config.OUTPUT = args.output_dir 110 | elif _check_args('outdir'): 111 | config.OUTPUT = args.outdir 112 | elif _check_args('save_dir'): 113 | config.OUTPUT = args.save_dir 114 | elif _check_args('work_dir'): 115 | config.OUTPUT = args.work_dir 116 | if _check_args('tag'): 117 | config.TAG = args.tag 118 | # output folder, make sure that is consistent with the main output foler 119 | config.OUTPUT = os.path.join(config.OUTPUT, config.TAG) 120 | config.freeze() 121 | 122 | if _check_args('q_cfg'): 123 | if args.q_cfg is not None: 124 | _update_config_from_file(config, args.q_cfg) 125 | return 126 | 127 | config.defrost() 128 | if _check_args('lpmm_enable'): 129 | config.QUANT.P.ENABLE = bool(args.lpmm_enable & 1) 130 | config.QUANT.G.ENABLE = bool(args.lpmm_enable & 2) 131 | config.QUANT.M.ENABLE = bool(args.lpmm_enable & 4) 132 | config.QUANT.SQM.ENABLE = bool(args.lpmm_enable & 8) 133 | if _check_args('pb'): 134 | config.QUANT.P.BITS = args.pb 135 | if _check_args('gb'): 136 | config.QUANT.G.BITS = args.gb 137 | if _check_args('mb'): 138 | config.QUANT.M.BITS = args.mb 139 | if _check_args('sqmb'): 140 | config.QUANT.SQM.BITS = args.sqmb 141 | if _check_args('round_type'): 142 | if args.round_type in ['sr', 'up', 'down', 'nearest', 'sr1', 'real-nearest', 'real-sr']: 143 | config.QUANT.P.ROUND_TYPE = args.round_type 144 | config.QUANT.M.ROUND_TYPE = args.round_type 145 | config.QUANT.SQM.ROUND_TYPE = args.round_type 146 | if _check_args('scale_type'): 147 | if args.scale_type in ['tensor', 'dim0', 'dim1', 'dim01', 'dim10', 'group', 'rank1', 'rank1-group']: 148 | # config.QUANT.P.SCALE_TYPE.DEFAULT = args.scale_type 149 | # config.QUANT.M.SCALE_TYPE.DEFAULT = args.scale_type 150 | config.QUANT.SQM.SCALE_TYPE.DEFAULT = args.scale_type 151 | if args.scale_type[:5] == 'group' and len(args.scale_type) > 5: 152 | group_size = int(args.scale_type[5:]) # format 'group[xxx]' where 'xxx' is the exact group size 153 | # config.QUANT.P.SCALE_TYPE.DEFAULT = 'group' 154 | # config.QUANT.M.SCALE_TYPE.DEFAULT = 'group' 155 | config.QUANT.SQM.SCALE_TYPE.DEFAULT = 'group' 156 | config.QUANT.M.GROUP_SIZE = group_size 157 | config.QUANT.SQM.GROUP_SIZE = group_size 158 | if _check_args('q_oracle'): # NOTE: improvising 159 | if args.q_oracle in ['linear', 'nonlinear', 'nonlinear-nozero', 160 | 'power-1', 'power-2', 'power-3', 161 | 'float-point']: 162 | # config.QUANT.P.QUANT_TYPE.DEFAULT = args.q_oracle 163 | config.QUANT.M.QUANT_TYPE.DEFAULT = args.q_oracle 164 | config.QUANT.SQM.QUANT_TYPE.DEFAULT = args.q_oracle 165 | if _check_args('group_size'): 166 | if args.group_size > 0 and config.QUANT.M.SCALE_TYPE.DEFAULT == 'group': 167 | print(f"[Warn] Set M.GROUP_SIZE from {config.QUANT.M.GROUP_SIZE} to {args.group_size}.") 168 | config.QUANT.M.GROUP_SIZE = args.group_size 169 | 170 | # set local rank for distributed training 171 | if _check_args('local_rank'): 172 | config.LOCAL_RANK = args.local_rank 173 | 174 | config.freeze() 175 | 176 | # init output dir 177 | if get_rank() == 0: 178 | os.makedirs(config.OUTPUT, exist_ok=True) 179 | 180 | 181 | def get_config(args): 182 | """Get a yacs CfgNode object with default values.""" 183 | # Return a clone so that the defaults will not be altered 184 | # This is for the "local variable" use pattern 185 | config = _C.clone() 186 | if isinstance(args, str) : 187 | _update_config_from_file(config, args) 188 | elif args is not None: 189 | update_config(config, args) 190 | 191 | if get_rank() == 0: 192 | print(config) 193 | if config.OUTPUT is not None: 194 | config_file = os.path.join(config.OUTPUT, "lpmm_config.txt") 195 | with open(config_file, "w") as fout: 196 | fout.write(str(config)) 197 | 198 | return config -------------------------------------------------------------------------------- /lpmm/configs/1st_moment_group_1024.yml: -------------------------------------------------------------------------------- 1 | QUANT: 2 | M: 3 | ENABLE: True 4 | BITS: 4 5 | GROUP_SIZE: 1024 6 | SCALE_TYPE: 7 | DEFAULT: group 8 | QUANT_TYPE: 9 | DEFAULT: nonlinear 10 | ROUND_TYPE: real-nearest 11 | SQM: 12 | ENABLE: True 13 | BITS: 4 14 | GROUP_SIZE: 128 15 | SCALE_TYPE: 16 | DEFAULT: rank1 17 | QUANT_TYPE: 18 | DEFAULT: power-1 19 | ROUND_TYPE: real-nearest 20 | -------------------------------------------------------------------------------- /lpmm/configs/1st_moment_group_2048.yml: -------------------------------------------------------------------------------- 1 | QUANT: 2 | M: 3 | ENABLE: True 4 | BITS: 4 5 | GROUP_SIZE: 2048 6 | SCALE_TYPE: 7 | DEFAULT: group 8 | QUANT_TYPE: 9 | DEFAULT: nonlinear 10 | ROUND_TYPE: real-nearest 11 | SQM: 12 | ENABLE: True 13 | BITS: 4 14 | GROUP_SIZE: 128 15 | SCALE_TYPE: 16 | DEFAULT: rank1 17 | QUANT_TYPE: 18 | DEFAULT: power-1 19 | ROUND_TYPE: real-nearest 20 | -------------------------------------------------------------------------------- /lpmm/configs/1st_moment_group_256.yml: -------------------------------------------------------------------------------- 1 | QUANT: 2 | M: 3 | ENABLE: True 4 | BITS: 4 5 | GROUP_SIZE: 256 6 | SCALE_TYPE: 7 | DEFAULT: group 8 | QUANT_TYPE: 9 | DEFAULT: nonlinear 10 | ROUND_TYPE: real-nearest 11 | SQM: 12 | ENABLE: True 13 | BITS: 4 14 | GROUP_SIZE: 128 15 | SCALE_TYPE: 16 | DEFAULT: rank1 17 | QUANT_TYPE: 18 | DEFAULT: power-1 19 | ROUND_TYPE: real-nearest 20 | -------------------------------------------------------------------------------- /lpmm/configs/1st_moment_group_4096.yml: -------------------------------------------------------------------------------- 1 | QUANT: 2 | M: 3 | ENABLE: True 4 | BITS: 4 5 | GROUP_SIZE: 4096 6 | SCALE_TYPE: 7 | DEFAULT: group 8 | QUANT_TYPE: 9 | DEFAULT: nonlinear 10 | ROUND_TYPE: real-nearest 11 | SQM: 12 | ENABLE: True 13 | BITS: 4 14 | GROUP_SIZE: 128 15 | SCALE_TYPE: 16 | DEFAULT: rank1 17 | QUANT_TYPE: 18 | DEFAULT: power-1 19 | ROUND_TYPE: real-nearest 20 | -------------------------------------------------------------------------------- /lpmm/configs/1st_moment_group_512.yml: -------------------------------------------------------------------------------- 1 | QUANT: 2 | M: 3 | ENABLE: True 4 | BITS: 4 5 | GROUP_SIZE: 512 6 | SCALE_TYPE: 7 | DEFAULT: group 8 | QUANT_TYPE: 9 | DEFAULT: nonlinear 10 | ROUND_TYPE: real-nearest 11 | SQM: 12 | ENABLE: True 13 | BITS: 4 14 | GROUP_SIZE: 128 15 | SCALE_TYPE: 16 | DEFAULT: rank1 17 | QUANT_TYPE: 18 | DEFAULT: power-1 19 | ROUND_TYPE: real-nearest 20 | -------------------------------------------------------------------------------- /lpmm/configs/2nd_moment_group_128.yml: -------------------------------------------------------------------------------- 1 | QUANT: 2 | M: 3 | ENABLE: True 4 | BITS: 4 5 | GROUP_SIZE: 128 6 | SCALE_TYPE: 7 | DEFAULT: group 8 | QUANT_TYPE: 9 | DEFAULT: nonlinear 10 | ROUND_TYPE: real-nearest 11 | SQM: 12 | ENABLE: True 13 | BITS: 4 14 | GROUP_SIZE: 128 15 | SCALE_TYPE: 16 | DEFAULT: group 17 | QUANT_TYPE: 18 | DEFAULT: power-1 19 | ROUND_TYPE: real-nearest -------------------------------------------------------------------------------- /lpmm/configs/2nd_moment_only_default.yml: -------------------------------------------------------------------------------- 1 | QUANT: 2 | M: 3 | ENABLE: False 4 | BITS: 4 5 | GROUP_SIZE: 128 6 | SCALE_TYPE: 7 | DEFAULT: group 8 | QUANT_TYPE: 9 | DEFAULT: nonlinear 10 | ROUND_TYPE: real-nearest 11 | SQM: 12 | ENABLE: True 13 | BITS: 4 14 | GROUP_SIZE: 128 15 | SCALE_TYPE: 16 | DEFAULT: rank1 17 | QUANT_TYPE: 18 | DEFAULT: power-1 19 | ROUND_TYPE: real-nearest -------------------------------------------------------------------------------- /lpmm/configs/2nd_moment_only_group_128.yml: -------------------------------------------------------------------------------- 1 | QUANT: 2 | M: 3 | ENABLE: False 4 | BITS: 4 5 | GROUP_SIZE: 128 6 | SCALE_TYPE: 7 | DEFAULT: group 8 | QUANT_TYPE: 9 | DEFAULT: nonlinear 10 | ROUND_TYPE: real-nearest 11 | SQM: 12 | ENABLE: True 13 | BITS: 4 14 | GROUP_SIZE: 128 15 | SCALE_TYPE: 16 | DEFAULT: group 17 | QUANT_TYPE: 18 | DEFAULT: power-1 19 | ROUND_TYPE: real-nearest -------------------------------------------------------------------------------- /lpmm/configs/default.yml: -------------------------------------------------------------------------------- 1 | QUANT: 2 | M: 3 | ENABLE: True 4 | BITS: 4 5 | GROUP_SIZE: 128 6 | SCALE_TYPE: 7 | DEFAULT: group 8 | QUANT_TYPE: 9 | DEFAULT: nonlinear 10 | ROUND_TYPE: real-nearest 11 | SQM: 12 | ENABLE: True 13 | BITS: 4 14 | GROUP_SIZE: 128 15 | SCALE_TYPE: 16 | DEFAULT: rank1 17 | QUANT_TYPE: 18 | DEFAULT: power-1 19 | ROUND_TYPE: real-nearest -------------------------------------------------------------------------------- /lpmm/configs/default_5b.yml: -------------------------------------------------------------------------------- 1 | QUANT: 2 | M: 3 | ENABLE: True 4 | BITS: 5 5 | GROUP_SIZE: 128 6 | SCALE_TYPE: 7 | DEFAULT: group 8 | QUANT_TYPE: 9 | DEFAULT: nonlinear 10 | ROUND_TYPE: real-nearest 11 | SQM: 12 | ENABLE: True 13 | BITS: 5 14 | GROUP_SIZE: 128 15 | SCALE_TYPE: 16 | DEFAULT: rank1 17 | QUANT_TYPE: 18 | DEFAULT: power-1 19 | ROUND_TYPE: real-nearest 20 | -------------------------------------------------------------------------------- /lpmm/configs/default_6b.yml: -------------------------------------------------------------------------------- 1 | QUANT: 2 | M: 3 | ENABLE: True 4 | BITS: 6 5 | GROUP_SIZE: 128 6 | SCALE_TYPE: 7 | DEFAULT: group 8 | QUANT_TYPE: 9 | DEFAULT: nonlinear 10 | ROUND_TYPE: real-nearest 11 | SQM: 12 | ENABLE: True 13 | BITS: 6 14 | GROUP_SIZE: 128 15 | SCALE_TYPE: 16 | DEFAULT: rank1 17 | QUANT_TYPE: 18 | DEFAULT: power-1 19 | ROUND_TYPE: real-nearest 20 | -------------------------------------------------------------------------------- /lpmm/configs/default_7b.yml: -------------------------------------------------------------------------------- 1 | QUANT: 2 | M: 3 | ENABLE: True 4 | BITS: 7 5 | GROUP_SIZE: 128 6 | SCALE_TYPE: 7 | DEFAULT: group 8 | QUANT_TYPE: 9 | DEFAULT: nonlinear 10 | ROUND_TYPE: real-nearest 11 | SQM: 12 | ENABLE: True 13 | BITS: 7 14 | GROUP_SIZE: 128 15 | SCALE_TYPE: 16 | DEFAULT: rank1 17 | QUANT_TYPE: 18 | DEFAULT: power-1 19 | ROUND_TYPE: real-nearest 20 | -------------------------------------------------------------------------------- /lpmm/configs/default_8b.yml: -------------------------------------------------------------------------------- 1 | QUANT: 2 | M: 3 | ENABLE: True 4 | BITS: 8 5 | GROUP_SIZE: 128 6 | SCALE_TYPE: 7 | DEFAULT: group 8 | QUANT_TYPE: 9 | DEFAULT: nonlinear 10 | ROUND_TYPE: real-nearest 11 | SQM: 12 | ENABLE: True 13 | BITS: 8 14 | GROUP_SIZE: 128 15 | SCALE_TYPE: 16 | DEFAULT: rank1 17 | QUANT_TYPE: 18 | DEFAULT: power-1 19 | ROUND_TYPE: real-nearest 20 | -------------------------------------------------------------------------------- /lpmm/cpp_extension/common.h: -------------------------------------------------------------------------------- 1 | // Helper for type check 2 | #define CHECK_CUDA_TENSOR_DIM_TYPE(name, n_dim, type) \ 3 | TORCH_CHECK(name.device().is_cuda(), #name " must be a CUDA tensor!"); \ 4 | TORCH_CHECK(name.is_contiguous(), #name " must be contiguous!"); \ 5 | TORCH_CHECK(name.dim() == n_dim, "The dimension of " #name " is not correct!"); \ 6 | TORCH_CHECK(name.dtype() == type, "The type of " #name " is not correct!"); \ 7 | 8 | // Helper for type check 9 | #define CHECK_CUDA_TENSOR_TYPE(name, type) \ 10 | TORCH_CHECK(name.device().is_cuda(), #name " must be a CUDA tensor!"); \ 11 | TORCH_CHECK(name.is_contiguous(), #name " must be contiguous!"); \ 12 | TORCH_CHECK(name.dtype() == type, "The type of " #name " is not correct!"); \ 13 | 14 | // Helper for type check 15 | #define CHECK_CUDA_TENSOR_FLOAT(name) \ 16 | TORCH_CHECK(name.device().is_cuda(), #name " must be a CUDA tensor!"); \ 17 | TORCH_CHECK(name.is_contiguous(), #name " must be contiguous!"); \ 18 | TORCH_CHECK(name.dtype() == torch::kFloat32 || name.dtype() == torch::kFloat16, \ 19 | "The type of " #name " is not kFloat32 or kFloat16!"); \ 20 | 21 | // Helper for type check 22 | #define CHECK_CUDA_TENSOR_DIM_FLOAT(name, n_dim) \ 23 | TORCH_CHECK(name.device().is_cuda(), #name " must be a CUDA tensor!"); \ 24 | TORCH_CHECK(name.is_contiguous(), #name " must be contiguous!"); \ 25 | TORCH_CHECK(name.dim() == n_dim, "The dimension of " #name " is not correct!"); \ 26 | TORCH_CHECK(name.dtype() == torch::kFloat32 || name.dtype() == torch::kFloat16, \ 27 | "The type of " #name " is not kFloat32 or kFloat16!"); \ 28 | 29 | -------------------------------------------------------------------------------- /lpmm/cpp_extension/fused_adamw.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "common.h" 5 | 6 | using torch::Tensor; 7 | 8 | void fused_adamw_cuda(Tensor& p, Tensor& g, Tensor& exp_avg, 9 | Tensor& exp_avg_sq, float beta1, float beta2, 10 | float lr, float weight_decay, float eps, float step); 11 | void fused_adamw4bit_cuda(Tensor& p, Tensor& g, Tensor& exp_avg, Tensor& exp_avg_sq, 12 | Tensor& exp_avg_scale, Tensor& exp_avg_sq_scale, Tensor& exp_avg_qmap, Tensor& exp_avg_sq_qmap, 13 | float beta1, float beta2, float lr, float weight_decay, float eps, float step); 14 | 15 | void adamw_single_tensor( 16 | Tensor& p, 17 | Tensor& g, 18 | Tensor& exp_avg, 19 | Tensor& exp_avg_sq, 20 | float beta1, 21 | float beta2, 22 | float lr, 23 | float weight_decay, 24 | float eps, 25 | float step 26 | ) { 27 | CHECK_CUDA_TENSOR_FLOAT(p); 28 | CHECK_CUDA_TENSOR_FLOAT(exp_avg); 29 | CHECK_CUDA_TENSOR_FLOAT(exp_avg_sq); 30 | CHECK_CUDA_TENSOR_FLOAT(g); 31 | int64_t num_elem = p.numel(); 32 | AT_ASSERTM(exp_avg.numel() == num_elem, 33 | "number of elements in exp_avg and p tensors should be equal"); 34 | AT_ASSERTM(exp_avg_sq.numel() == num_elem, 35 | "number of elements in exp_avg_sq and p tensors should be equal"); 36 | AT_ASSERTM(g.numel() == num_elem, 37 | "number of elements in g and p tensors should be equal"); 38 | 39 | fused_adamw_cuda(p, g, exp_avg, exp_avg_sq, 40 | beta1, beta2, lr, weight_decay, eps, step); 41 | } 42 | 43 | 44 | void adamw4bit_single_tensor( 45 | Tensor& p, 46 | Tensor& g, 47 | Tensor& exp_avg, 48 | Tensor& exp_avg_sq, 49 | Tensor& exp_avg_scale, 50 | Tensor& exp_avg_sq_scale, 51 | Tensor& exp_avg_qmap, 52 | Tensor& exp_avg_sq_qmap, 53 | float beta1, 54 | float beta2, 55 | float lr, 56 | float weight_decay, 57 | float eps, 58 | float step 59 | ) { 60 | CHECK_CUDA_TENSOR_FLOAT(p); 61 | CHECK_CUDA_TENSOR_FLOAT(g); 62 | CHECK_CUDA_TENSOR_FLOAT(exp_avg_scale); 63 | CHECK_CUDA_TENSOR_FLOAT(exp_avg_sq_scale); 64 | CHECK_CUDA_TENSOR_FLOAT(exp_avg_qmap); 65 | CHECK_CUDA_TENSOR_FLOAT(exp_avg_sq_qmap); 66 | 67 | int64_t num_elem = p.numel(); 68 | AT_ASSERTM(exp_avg.numel() == num_elem / 2, 69 | "number of elements in exp_avg and p tensors should be equal"); 70 | AT_ASSERTM(exp_avg_sq.numel() == num_elem / 2, 71 | "number of elements in exp_avg_sq and p tensors should be equal"); 72 | AT_ASSERTM(g.numel() == num_elem, 73 | "number of elements in g and p tensors should be equal"); 74 | 75 | fused_adamw4bit_cuda(p, g, exp_avg, exp_avg_sq, 76 | exp_avg_scale, exp_avg_sq_scale, exp_avg_qmap, exp_avg_sq_qmap, 77 | beta1, beta2, lr, weight_decay, eps, step); 78 | } 79 | 80 | 81 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 82 | m.def("adamw_single_tensor", &adamw_single_tensor); 83 | m.def("adamw4bit_single_tensor", &adamw4bit_single_tensor); 84 | } 85 | -------------------------------------------------------------------------------- /lpmm/cpp_extension/fused_adamw_kernel.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Cuda kernels for fused adamw and adamw4bit 3 | */ 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | #include 16 | 17 | using torch::Tensor; 18 | 19 | __device__ __forceinline__ float atomicMaxNonneg (float * addr, float value) { 20 | float old; 21 | // assert value >= 0 22 | old = __int_as_float(atomicMax((int *)addr, __float_as_int(value))); 23 | return old; 24 | } 25 | 26 | 27 | __device__ __forceinline__ int q_mapping(const float* __restrict__ qmap, 28 | int bits, 29 | float x) { 30 | int lo = 0; 31 | int hi = 1 << bits; 32 | 33 | if (x <= qmap[lo]) 34 | return lo; 35 | if (qmap[hi - 1] <= x) 36 | return (hi - 1); 37 | 38 | while (lo < hi){ 39 | int mi = (lo + hi) >> 1; 40 | if (qmap[mi] <= x) lo = mi + 1; 41 | else hi = mi; 42 | } 43 | // return lo - 1; 44 | 45 | int rank = 0; 46 | float mid_val = (qmap[lo - 1] + qmap[lo]) * 0.5f; 47 | rank = (mid_val < x) ? lo : lo - 1; 48 | return rank; 49 | } 50 | 51 | 52 | template 53 | __global__ void adamw_cuda_kernel( 54 | T* __restrict__ p, 55 | T* __restrict__ exp_avg, 56 | T* __restrict__ exp_avg_sq, 57 | const T * __restrict__ g, 58 | const float beta1, 59 | const float beta2, 60 | const float lr, 61 | const float weight_decay, 62 | const float eps, 63 | const float step, 64 | const size_t total_size) 65 | { 66 | const int global_id = blockIdx.x * blockDim.x + threadIdx.x; 67 | if (global_id >= total_size) return; 68 | 69 | exp_avg[global_id] = beta1 * exp_avg[global_id] + (1 - beta1) * g[global_id]; 70 | exp_avg_sq[global_id] = beta2 * exp_avg_sq[global_id] + (1 - beta2) * g[global_id] * g[global_id]; 71 | 72 | const float correction1 = 1.0f - powf(beta1, step); 73 | const float correction2_sqrt = sqrtf(1.0f - powf(beta2, step)); 74 | 75 | float denom = (sqrtf(exp_avg_sq[global_id]) / correction2_sqrt + eps) * correction1; 76 | float update = (exp_avg[global_id]/denom) + (weight_decay * p[global_id]); 77 | p[global_id] = p[global_id] - (lr * update); 78 | } 79 | 80 | 81 | template 82 | __global__ void adamw4bit_cuda_kernel( 83 | T* __restrict__ p, 84 | const T * __restrict__ g, 85 | int8_t* __restrict__ exp_avg, 86 | int8_t* __restrict__ exp_avg_sq, 87 | T* __restrict__ exp_avg_scale, 88 | T* __restrict__ exp_avg_sq_scale, 89 | const float* __restrict__ exp_avg_qmap, 90 | const float* __restrict__ exp_avg_sq_qmap, 91 | const float beta1, 92 | const float beta2, 93 | const float lr, 94 | const float weight_decay, 95 | const float eps, 96 | const float step, 97 | const size_t total_size) 98 | { 99 | const int global_id = blockIdx.x * blockDim.x + threadIdx.x; 100 | const int scale_id = blockIdx.x; 101 | const int working_id0 = global_id << 1; 102 | const int working_id1 = (global_id << 1) + 1; 103 | const float correction1 = 1.0f - powf(beta1, step); 104 | const float correction2_sqrt = sqrtf(1.0f - powf(beta2, step)); 105 | 106 | __shared__ float absmax_exp_avg; 107 | __shared__ float absmax_exp_avg_sq; 108 | if (threadIdx.x == 0) { 109 | absmax_exp_avg = 0.0f; 110 | absmax_exp_avg_sq = 0.0f; 111 | } 112 | __syncthreads(); 113 | 114 | if (working_id0 >= total_size) return; 115 | 116 | const int8_t mask = (1 << 4) - 1; 117 | // dequantize optimizer state, and run optimizer 118 | // Note that we require the 'rank' of right 4-bits is smaller than that of left 4-bits in one byte 119 | const uint8_t exp_avg_idx0 = (exp_avg[global_id]) & mask; 120 | T exp_avg0 = (T)exp_avg_qmap[exp_avg_idx0] * exp_avg_scale[scale_id]; 121 | exp_avg0 = beta1 * exp_avg0 + (1 - beta1) * g[working_id0]; 122 | const uint8_t exp_avg_sq_idx0 = (exp_avg_sq[global_id]) & mask; 123 | T exp_avg_sq0 = (T)exp_avg_sq_qmap[exp_avg_sq_idx0] * exp_avg_sq_scale[scale_id]; 124 | exp_avg_sq0 = beta2 * exp_avg_sq0 + (1 - beta2) * g[working_id0] * g[working_id0]; 125 | 126 | float denom0 = (sqrtf(exp_avg_sq0) / correction2_sqrt + eps) * correction1; 127 | float update0 = (exp_avg0/denom0) + (weight_decay * p[working_id0]); 128 | p[working_id0] = p[working_id0] - (lr * update0); 129 | 130 | T exp_avg1 = 0; 131 | T exp_avg_sq1 = 0; 132 | if (working_id1 < total_size) { 133 | const uint8_t exp_avg_idx1 = (exp_avg[global_id] >> 4) & mask; 134 | exp_avg1 = (T)exp_avg_qmap[exp_avg_idx1] * exp_avg_scale[scale_id]; 135 | exp_avg1 = beta1 * exp_avg1 + (1 - beta1) * g[working_id1]; 136 | const uint8_t exp_avg_sq_idx1 = (exp_avg_sq[global_id] >> 4) & mask; 137 | exp_avg_sq1 = (T)exp_avg_sq_qmap[exp_avg_sq_idx1] * exp_avg_sq_scale[scale_id]; 138 | exp_avg_sq1 = beta2 * exp_avg_sq1 + (1 - beta2) * g[working_id1] * g[working_id1]; 139 | 140 | float denom1 = (sqrtf(exp_avg_sq1) / correction2_sqrt + eps) * correction1; 141 | float update1 = (exp_avg1/denom1) + (weight_decay * p[working_id1]); 142 | p[working_id1] = p[working_id1] - (lr * update1); 143 | } 144 | 145 | // compute new scale for quantization 146 | float local_absmax_exp_avg = fmaxf(fabsf((float)exp_avg0), fabsf((float)exp_avg1)); 147 | float local_absmax_exp_avg_sq = fmaxf((float)exp_avg_sq0, (float)exp_avg_sq1); 148 | atomicMaxNonneg(&absmax_exp_avg, local_absmax_exp_avg); 149 | atomicMaxNonneg(&absmax_exp_avg_sq, local_absmax_exp_avg_sq); 150 | __syncthreads(); 151 | 152 | // quantize optimizer state and write new scales 153 | int8_t local_packed_exp_avg = 0; 154 | int8_t local_packed_exp_avg_sq = 0; 155 | const int8_t q_exp_avg0 = (int8_t)q_mapping(exp_avg_qmap, 4, (float)exp_avg0 / absmax_exp_avg); 156 | const int8_t q_exp_avg_sq0 = (int8_t)q_mapping(exp_avg_sq_qmap, 4, (float)exp_avg_sq0 / absmax_exp_avg_sq); 157 | local_packed_exp_avg |= (q_exp_avg0 & mask); 158 | local_packed_exp_avg_sq |= (q_exp_avg_sq0 & mask); 159 | 160 | if (working_id1 < total_size) { 161 | const int8_t q_exp_avg1 = (int8_t)q_mapping(exp_avg_qmap, 4, (float)exp_avg1 / absmax_exp_avg); 162 | const int8_t q_exp_avg_sq1 = (int8_t)q_mapping(exp_avg_sq_qmap, 4, (float)exp_avg_sq1 / absmax_exp_avg_sq); 163 | local_packed_exp_avg |= ((q_exp_avg1 & mask) << 4); 164 | local_packed_exp_avg_sq |= ((q_exp_avg_sq1 & mask) << 4); 165 | } 166 | 167 | exp_avg[global_id] = local_packed_exp_avg; 168 | exp_avg_sq[global_id] = local_packed_exp_avg_sq; 169 | if (threadIdx.x == 0) { 170 | exp_avg_scale[scale_id] = (T)absmax_exp_avg; 171 | exp_avg_sq_scale[scale_id] = (T)absmax_exp_avg_sq; 172 | } 173 | __syncthreads(); 174 | } 175 | 176 | 177 | void fused_adamw_cuda(Tensor& p, Tensor& g, Tensor& exp_avg, Tensor& exp_avg_sq, 178 | float beta1, float beta2, float lr, float weight_decay, float eps, float step) { 179 | // Get tensor size 180 | int total_size = p.numel(); 181 | AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), 182 | "parameter tensor is too large to be indexed with int32"); 183 | 184 | const int block_dim = 128; 185 | int grid_dim = ((total_size + block_dim - 1) / block_dim); 186 | const dim3 blocks(grid_dim); 187 | 188 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(p.scalar_type(), "fused_adam_cuda", ([&] { 189 | adamw_cuda_kernel<<>>( 190 | p.data_ptr(), 191 | exp_avg.data_ptr(), 192 | exp_avg_sq.data_ptr(), 193 | g.data_ptr(), 194 | beta1, 195 | beta2, 196 | lr, 197 | weight_decay, 198 | eps, 199 | step, 200 | total_size 201 | ); 202 | })); 203 | 204 | AT_CUDA_CHECK(cudaGetLastError()); 205 | } 206 | 207 | 208 | void fused_adamw4bit_cuda(Tensor& p, Tensor& g, Tensor& exp_avg, Tensor& exp_avg_sq, 209 | Tensor& exp_avg_scale, Tensor& exp_avg_sq_scale, Tensor& exp_avg_qmap, Tensor& exp_avg_sq_qmap, 210 | float beta1, float beta2, float lr, float weight_decay, float eps, float step) { 211 | // Get tensor size 212 | int total_size = p.numel(); 213 | AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), 214 | "parameter tensor is too large to be indexed with int32"); 215 | 216 | const int block_dim = 128; 217 | int grid_dim = ((total_size + block_dim - 1) / block_dim); 218 | TORCH_CHECK(grid_dim == exp_avg_scale.numel()); 219 | TORCH_CHECK(grid_dim == exp_avg_sq_scale.numel()); 220 | const dim3 blocks(grid_dim); 221 | 222 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(p.scalar_type(), "fused_adam4bit_cuda", ([&] { 223 | adamw4bit_cuda_kernel<<>>( 224 | p.data_ptr(), 225 | g.data_ptr(), 226 | exp_avg.data_ptr(), 227 | exp_avg_sq.data_ptr(), 228 | exp_avg_scale.data_ptr(), 229 | exp_avg_sq_scale.data_ptr(), 230 | exp_avg_qmap.data_ptr(), 231 | exp_avg_sq_qmap.data_ptr(), 232 | beta1, 233 | beta2, 234 | lr, 235 | weight_decay, 236 | eps, 237 | step, 238 | total_size 239 | ); 240 | })); 241 | 242 | AT_CUDA_CHECK(cudaGetLastError()); 243 | } -------------------------------------------------------------------------------- /lpmm/cpp_extension/quantization.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Cuda operators for quantization and mixed-precision packing 3 | */ 4 | 5 | #include 6 | #include 7 | 8 | #include "common.h" 9 | 10 | using torch::autograd::Function; 11 | using torch::autograd::AutogradContext; 12 | using torch::autograd::tensor_list; 13 | using torch::Tensor; 14 | using torch::IntArrayRef; 15 | 16 | // Declarations for functions in ext_quantization_cuda_kernel.cu 17 | // Pack and unpack 18 | Tensor pack_absmax_linear_cuda( 19 | Tensor data, Tensor absmax, int bits, bool stochastic); 20 | Tensor unpack_absmax_linear_cuda( 21 | Tensor data, int bits, Tensor absmax, int64_t num_groups, int64_t group_size); 22 | Tensor pack_nonlinear_cuda( 23 | Tensor data, Tensor qmap, int bits, bool stochastic); 24 | Tensor unpack_nonlinear_cuda( 25 | Tensor data, Tensor qmap, int bits, int64_t num_groups, int64_t group_size); 26 | 27 | 28 | 29 | // Pack/Unpack with absmax linear quantization 30 | Tensor pack_absmax_linear(Tensor data, 31 | Tensor absmax, 32 | int bits, 33 | bool stochastic) { 34 | CHECK_CUDA_TENSOR_DIM_FLOAT(data, 2); 35 | CHECK_CUDA_TENSOR_DIM_FLOAT(absmax, 2); 36 | 37 | return pack_absmax_linear_cuda(data, absmax, bits, stochastic); 38 | } 39 | 40 | Tensor unpack_absmax_linear(Tensor data, 41 | int bits, 42 | Tensor absmax, 43 | int64_t num_groups, 44 | int64_t group_size) { 45 | CHECK_CUDA_TENSOR_DIM_TYPE(data, 1, torch::kInt8); 46 | CHECK_CUDA_TENSOR_DIM_FLOAT(absmax, 2); 47 | 48 | return unpack_absmax_linear_cuda(data, bits, absmax, 49 | num_groups, group_size); 50 | } 51 | 52 | 53 | // Pack/Unpack with nonlinear quantization 54 | Tensor pack_nonlinear(Tensor data, 55 | Tensor qmap, 56 | int bits, 57 | bool stochastic) { 58 | TORCH_CHECK(bits <= 8); 59 | CHECK_CUDA_TENSOR_DIM_FLOAT(data, 2); 60 | CHECK_CUDA_TENSOR_DIM_FLOAT(qmap, 1); 61 | 62 | return pack_nonlinear_cuda(data, qmap, bits, stochastic); 63 | } 64 | 65 | Tensor unpack_nonlinear(Tensor data, 66 | Tensor qmap, 67 | int bits, 68 | int64_t num_groups, 69 | int64_t group_size) { 70 | TORCH_CHECK(bits <= 8); 71 | CHECK_CUDA_TENSOR_DIM_TYPE(data, 1, torch::kInt8); 72 | CHECK_CUDA_TENSOR_DIM_FLOAT(qmap, 1); 73 | 74 | return unpack_nonlinear_cuda(data, qmap, bits, 75 | num_groups, group_size); 76 | } 77 | 78 | 79 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 80 | m.def("pack_absmax_linear", &pack_absmax_linear); 81 | m.def("unpack_absmax_linear", &unpack_absmax_linear); 82 | m.def("pack_nonlinear", &pack_nonlinear); 83 | m.def("unpack_nonlinear", &unpack_nonlinear); 84 | } 85 | -------------------------------------------------------------------------------- /lpmm/cpp_extension/quantization_kernel.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Cuda kernels for quantization and mixed-precision packing. 3 | */ 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #define BLOCK_Y_DIM_MAX ((((int64_t)(1)) << 16) - 1) 14 | #define fmax(a, b) ((a) > (b) ? (a): (b)) 15 | #define fmin(a, b) ((a) < (b) ? (a): (b)) 16 | 17 | using torch::IntArrayRef; 18 | using torch::Tensor; 19 | 20 | 21 | /**************************************************/ 22 | /***** Pack/Unpack Absmax Linear Quantization *****/ 23 | /**************************************************/ 24 | // Pack float16/32 data into int8 bit stream, for bits <= 8 25 | template 26 | __global__ void pack_absmax_linear_8bit_kernel(int32_t bits, 27 | const scalar_t* __restrict__ data, 28 | const scalar_t* __restrict__ absmax, 29 | int8_t* __restrict__ packed, 30 | std::pair seeds) { 31 | const int group_id = blockIdx.x; 32 | const int d = threadIdx.x; 33 | const int64_t global_thread_id = group_id * blockDim.x + d; 34 | const int work_per_int = 8 / bits; 35 | const int workint_per_thread = 4; 36 | const int work_per_thread = work_per_int << 2; 37 | const float B = (1 << (bits - 1)) - 1; 38 | const int32_t mask = (1 << bits) - 1; 39 | curandStatePhilox4_32_10_t state; 40 | curand_init(seeds.first, global_thread_id, seeds.second, &state); 41 | 42 | // debug 43 | // printf("work per int: %d, work per thread: %d\n", work_per_int, work_per_thread); 44 | // for (int k = 0; k < work_per_thread; k++) { 45 | // int data_id = (int)(global_thread_id * work_per_thread + k); 46 | // printf("pack: group id: %d, thread id: %d, global id: %d, data id: %d, data: %f, absmax: %f\n", group_id, d, (int)global_thread_id, data_id, data[data_id], absmax[group_id]); 47 | // // printf("pack: group id: %d, thread id: %d, global id: %d, data id: %d, data[int64]: %f\n", group_id, d, (int)global_thread_id, data_id, data[global_thread_id * work_per_thread + k]); 48 | // } 49 | // curandStatePhilox4_32_10_t state; 50 | // curand_init(seeds.first, global_thread_id, seeds.second, &state); 51 | 52 | for (int i = 0; i < workint_per_thread; i++) { 53 | uint8_t local_packed = 0; 54 | int64_t global_int_id = global_thread_id * workint_per_thread + i; 55 | for (int j = 0; j < work_per_int; j++) { 56 | const int64_t id = global_thread_id * work_per_thread + i * work_per_int + j; 57 | const float noise = curand_uniform(&state); 58 | const int32_t val = __float2int_rn(fmax(fmin((data[id] / absmax[group_id]) * B + noise - 0.5, B), -B)); 59 | local_packed |= ((val & mask) << (j * bits)); 60 | } 61 | 62 | packed[global_int_id] = local_packed; 63 | } 64 | 65 | // debug 66 | // for (int i = 0; i < work_per_int; i++) { 67 | // int int_id = global_thread_id * work_per_int + i; 68 | // printf("group id: %d, thread id: %d, global id: %d, int id: %d, int: %f\n", group_id, d, global_thread_id, int_id, packed[int_id]); 69 | // } 70 | 71 | } 72 | 73 | template 74 | __global__ void print_kernel(int32_t bits, 75 | const scalar_t* __restrict__ data, 76 | const scalar_t* __restrict__ absmax, 77 | int8_t* __restrict__ packed, 78 | std::pair seeds) { 79 | const int group_id = blockIdx.x; 80 | const int d = threadIdx.x; 81 | const int64_t global_thread_id = group_id * blockDim.x + d; 82 | const int work_per_int = 8 / bits; 83 | const int workint_per_thread = 4; 84 | const int work_per_thread = work_per_int << 2; 85 | const float B = (1 << (bits - 1)) - 1; 86 | const int32_t mask = (1 << bits) - 1; 87 | 88 | printf("group id: %d, thread id: %d, global id: %d\n", group_id, d, global_thread_id); 89 | printf("data: %lf\n", data[global_thread_id * work_per_thread + 1]); 90 | // for (int i = 0; i < workint_per_thread; i++) { 91 | // uint8_t local_packed = 1; 92 | // int64_t global_int_id = global_thread_id * workint_per_thread + i; 93 | // packed[global_int_id] = local_packed; 94 | // printf("group id: %d, thread id: %d, global id: %d, data: %f\n", group_id, d, global_thread_id, data[global_thread_id * work_per_thread + i]); 95 | // } 96 | } 97 | 98 | Tensor pack_absmax_linear_8bit_cuda(Tensor data, 99 | Tensor absmax, 100 | int bits, 101 | bool stochastic) { 102 | int64_t num_groups = data.size(0); 103 | int64_t group_size = data.size(1); 104 | 105 | // Compute total bits 106 | const int work_per_int = 8 / bits; 107 | const int workint_per_thread = 4; 108 | const int work_per_thread = work_per_int * workint_per_thread; 109 | TORCH_CHECK(8 % bits == 0); 110 | 111 | int64_t total_bits = (int64_t)bits * (num_groups * group_size); 112 | auto options = torch::TensorOptions().dtype(torch::kInt8).device(data.device()); 113 | Tensor packed = torch::empty({(total_bits + 8) / 8,}, options); 114 | 115 | // Random number generator 116 | int threads = group_size; 117 | auto gen = at::check_generator(at::cuda::detail::getDefaultCUDAGenerator()); 118 | std::pair rng_engine_inputs; 119 | { 120 | // See Note [Acquire lock when using random generators] 121 | std::lock_guard lock(gen->mutex_); 122 | rng_engine_inputs = gen->philox_engine_inputs(threads * work_per_thread); 123 | } 124 | TORCH_CHECK(stochastic); 125 | 126 | // debug 127 | // for (int i = 0; i < num_groups; i++) { 128 | // for (int j = 0; j < group_size; j++) { 129 | // printf("in pack_absmax_linear_cuda before kernel, data: %f\n", data[i][j].item()); 130 | // } 131 | // printf("in pack_absmax_linear_cuda before kernel, absmax: %f\n", absmax[i][0].item()); 132 | // } 133 | 134 | // printf("before entering kernel: %d, %d, %d, %f\n", bits, num_groups, group_size, data[0]); 135 | // Call pack kernels 136 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(data.scalar_type(), "pack_absmax_linear_8bit", ([&] { 137 | pack_absmax_linear_8bit_kernel<<>>( 138 | bits, 139 | data.data_ptr(), 140 | absmax.data_ptr(), 141 | packed.data_ptr(), 142 | rng_engine_inputs); 143 | })); 144 | 145 | // debug 146 | // for (int i = 0; i < num_groups; i++) { 147 | // for (int j = 0; j < group_size; j++) { 148 | // printf("in pack_absmax_linear_cuda after kernel, data: %f\n", data[i][j].item()); 149 | // } 150 | // printf("in pack_absmax_linear_cuda after kernel, absmax: %f\n", absmax[i][0].item()); 151 | // } 152 | // printf("in pack_absmax_linear_cuda after kernel, packed: %f\n", packed[0].item()); 153 | 154 | return packed; 155 | } 156 | 157 | // Pack float16/32 data into int8 bit stream, for 8 < bits <= 16 158 | template 159 | __global__ void pack_absmax_linear_16bit_kernel(int32_t bits, 160 | const scalar_t* __restrict__ data, 161 | const scalar_t* __restrict__ absmax, 162 | int8_t* __restrict__ packed, 163 | std::pair seeds, 164 | int64_t group_size) { 165 | 166 | const int64_t group_id = blockIdx.x; 167 | const int64_t d = threadIdx.x; 168 | const int64_t global_thread_id = group_id * blockDim.x + d; 169 | const int workbit_per_thread = 64; 170 | const int work_per_thread = workbit_per_thread / bits; 171 | const uint8_t packed8_mask = 0xff; 172 | const int B = (1 << (bits - 1)) - 1; 173 | const int64_t mask = (1 << bits) - 1; 174 | curandStatePhilox4_32_10_t state; 175 | curand_init(seeds.first, global_thread_id, seeds.second, &state); 176 | 177 | uint64_t local_packed = 0; 178 | for (int i = 0; i < work_per_thread; i++) { 179 | if (d * work_per_thread + i >= group_size) 180 | break; 181 | const int64_t data_id = group_id * group_size + d * work_per_thread + i; 182 | const float noise = curand_uniform(&state); 183 | const float x = data[data_id] / absmax[group_id]; 184 | // ensure positivity of 'val': [0, 2B], which was not introduced in 8-bit kernel 185 | const int64_t val = __float2int_rn(fmax(fmin(x * B + noise - 0.5, (float)B), -(float)B)) + B; 186 | local_packed |= ((val & mask) << (i * bits)); 187 | } 188 | 189 | for (int i = 0; i < 8; i++) { 190 | const int64_t global_int_id = global_thread_id * 8 + i; 191 | uint8_t local_packed8 = (local_packed >> (i << 3)) & packed8_mask; 192 | packed[global_int_id] = local_packed8; 193 | } 194 | } 195 | 196 | Tensor pack_absmax_linear_16bit_cuda(Tensor data, 197 | Tensor absmax, 198 | int bits, 199 | bool stochastic) { 200 | int64_t num_groups = data.size(0); 201 | int64_t group_size = data.size(1); 202 | 203 | // Compute total bits 204 | const int workbit_per_thread = 64; 205 | const int work_per_thread = workbit_per_thread / bits; 206 | int64_t threads_num = (group_size + work_per_thread - 1) / work_per_thread; 207 | TORCH_CHECK(bits > 8); 208 | TORCH_CHECK(bits <= 16); 209 | 210 | int64_t total_bits = num_groups * threads_num * workbit_per_thread; 211 | auto options = torch::TensorOptions().dtype(torch::kInt8).device(data.device()); 212 | Tensor packed = torch::empty({(total_bits) / 8,}, options); 213 | 214 | // Random number generator 215 | int threads = group_size; 216 | auto gen = at::check_generator(at::cuda::detail::getDefaultCUDAGenerator()); 217 | std::pair rng_engine_inputs; 218 | { 219 | // See Note [Acquire lock when using random generators] 220 | std::lock_guard lock(gen->mutex_); 221 | rng_engine_inputs = gen->philox_engine_inputs(threads * work_per_thread); 222 | } 223 | TORCH_CHECK(stochastic); 224 | 225 | // Call pack kernels 226 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(data.scalar_type(), "pack_absmax_linear_16bit", ([&] { 227 | pack_absmax_linear_16bit_kernel<<>>( 228 | bits, 229 | data.data_ptr(), 230 | absmax.data_ptr(), 231 | packed.data_ptr(), 232 | rng_engine_inputs, 233 | group_size); 234 | })); 235 | 236 | return packed; 237 | } 238 | 239 | Tensor pack_absmax_linear_cuda(Tensor data, 240 | Tensor absmax, 241 | int bits, 242 | bool stochastic) { 243 | if (bits <= 8) { 244 | return pack_absmax_linear_8bit_cuda(data, absmax, bits, stochastic); 245 | } else { 246 | return pack_absmax_linear_16bit_cuda(data, absmax, bits, stochastic); 247 | } 248 | } 249 | 250 | // Unpack int8 bit stream to float16/32 data, for bits <= 8 251 | template 252 | __global__ void unpack_absmax_linear_8bit_kernel(int32_t bits, 253 | const int8_t* __restrict__ data, 254 | const scalar_t* __restrict__ absmax, 255 | scalar_t* __restrict__ unpacked) { 256 | const int group_id = blockIdx.x; 257 | const int d = threadIdx.x; 258 | const int64_t global_thread_id = group_id * blockDim.x + d; 259 | const int work_per_int = 8 / bits; 260 | const int workint_per_thread = 4; 261 | const int work_per_thread = work_per_int << 2; 262 | const scalar_t B = (1 << (bits - 1)) - 1; 263 | const int8_t mask = (1 << bits) - 1; // 00001111 264 | 265 | for (int i = 0; i < workint_per_thread; i++) { 266 | int64_t global_int_id = global_thread_id * workint_per_thread + i; 267 | const int8_t local_packed = data[global_int_id]; 268 | for (int j = 0; j < work_per_int; j++) { 269 | const int64_t id = global_thread_id * work_per_thread + i * work_per_int + j; 270 | 271 | const int8_t unsigned_val = (local_packed >> (j * bits)) & mask; 272 | // const int8_t sign_mask = ~mask; 273 | // const int8_t sign = (0 - (unsigned_val >> (bits - 1))) << (8 - bits); 274 | 275 | const int8_t val = ((unsigned_val > (int)B) ? (unsigned_val | (~mask)) : unsigned_val) ; 276 | // const int8_t val = sign | unsigned_val; 277 | // const int8_t val = ((local_packed << (8 - (1 + j) * bits)) >> (8 - bits)) & mask; 278 | unpacked[id] = ((scalar_t)val) * (absmax[group_id] / B); 279 | // printf("unpack: group id: %d, thread id: %d, data id: %d, unsigned_val: %d, val: %d, absmax: %f, unpacked: %f\n", group_id, d, (int)id, unsigned_val, val, absmax[group_id], unpacked[id]); 280 | } 281 | } 282 | 283 | // for (int k = 0; k < work_per_thread; k++) { 284 | // int data_id = (int)(global_thread_id * work_per_thread + k); 285 | // printf("unpack: group id: %d, thread id: %d, global id: %d, data id: %d, data[int]: %f\n", group_id, d, (int)global_thread_id, data_id, unpacked[data_id]); 286 | // printf("unpack: group id: %d, thread id: %d, global id: %d, data id: %d, data[int64]: %f\n", group_id, d, (int)global_thread_id, data_id, unpacked[global_thread_id * work_per_thread + k]); 287 | // } 288 | } 289 | 290 | Tensor unpack_absmax_linear_8bit_cuda(Tensor data, 291 | int bits, 292 | Tensor absmax, 293 | int64_t num_groups, 294 | int64_t group_size) { 295 | auto options = torch::TensorOptions().dtype(absmax.dtype()).device(data.device()); 296 | Tensor unpacked = torch::empty({num_groups, group_size}, options); 297 | 298 | const int work_per_int = 8 / bits; 299 | const int workint_per_thread = 4; 300 | const int work_per_thread = work_per_int * workint_per_thread; 301 | TORCH_CHECK(8 % bits == 0); 302 | 303 | // debug 304 | // for (int i = 0; i < num_groups; i++) { 305 | // printf("in unpack_absmax_linear_cuda before kernel, absmax: %f\n", absmax[i][0].item()); 306 | // } 307 | 308 | // Call unpack kernels 309 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(absmax.scalar_type(), "unpack_absmax_linear_8bit", ([&] { 310 | unpack_absmax_linear_8bit_kernel<<>>( 311 | bits, 312 | data.data_ptr(), 313 | absmax.data_ptr(), 314 | unpacked.data_ptr()); 315 | })); 316 | 317 | return unpacked; 318 | } 319 | 320 | // Unpack int8 bit stream to float16/32 data, for 8 < bits <= 16 321 | template 322 | __global__ void unpack_absmax_linear_16bit_kernel(int32_t bits, 323 | const int8_t* __restrict__ data, 324 | const scalar_t* __restrict__ absmax, 325 | scalar_t* __restrict__ unpacked, 326 | int64_t group_size) { 327 | const int64_t group_id = blockIdx.x; 328 | const int64_t d = threadIdx.x; 329 | const int64_t global_thread_id = group_id * blockDim.x + d; 330 | const int workbit_per_thread = 64; 331 | const int work_per_thread = workbit_per_thread / bits; 332 | const int B = (1 << (bits - 1)) - 1; 333 | const uint8_t packed8_mask = 0xff; 334 | const int64_t val_mask = (1 << bits) - 1; 335 | 336 | uint64_t local_packed = 0; 337 | for (int i = 0; i < 8; i++) { 338 | const int64_t global_int_id = global_thread_id * 8 + i; 339 | uint64_t local_packed8 = (uint64_t)(data[global_int_id] & packed8_mask) << (i << 3); 340 | local_packed |= local_packed8; 341 | } 342 | 343 | for (int i = 0; i < work_per_thread; i++) { 344 | if (d * work_per_thread + i >= group_size) 345 | break; 346 | const int64_t data_id = group_id * group_size + d * work_per_thread + i; 347 | const int64_t q_val_nonneg = (local_packed >> (i * bits)) & val_mask; // [0, 2B] 348 | unpacked[data_id] = (scalar_t)((q_val_nonneg - B)) * (absmax[group_id] / B); 349 | } 350 | 351 | } 352 | 353 | Tensor unpack_absmax_linear_16bit_cuda(Tensor data, 354 | int bits, 355 | Tensor absmax, 356 | int64_t num_groups, 357 | int64_t group_size) { 358 | auto options = torch::TensorOptions().dtype(absmax.dtype()).device(data.device()); 359 | Tensor unpacked = torch::empty({num_groups, group_size}, options); 360 | 361 | const int workbit_per_thread = 64; 362 | const int work_per_thread = workbit_per_thread / bits; 363 | int64_t threads_num = (group_size + work_per_thread - 1) / work_per_thread; 364 | TORCH_CHECK(bits > 8); 365 | TORCH_CHECK(bits <= 16); 366 | 367 | // Call unpack kernels 368 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(absmax.scalar_type(), "unpack_absmax_linear_16bit", ([&] { 369 | unpack_absmax_linear_16bit_kernel<<>>( 370 | bits, 371 | data.data_ptr(), 372 | absmax.data_ptr(), 373 | unpacked.data_ptr(), 374 | group_size); 375 | })); 376 | 377 | return unpacked; 378 | } 379 | 380 | Tensor unpack_absmax_linear_cuda(Tensor data, 381 | int bits, 382 | Tensor absmax, 383 | int64_t num_groups, 384 | int64_t group_size) { 385 | if (bits <= 8) { 386 | return unpack_absmax_linear_8bit_cuda(data, bits, absmax, num_groups, group_size); 387 | } else { 388 | return unpack_absmax_linear_16bit_cuda(data, bits, absmax, num_groups, group_size); 389 | } 390 | } 391 | 392 | 393 | 394 | /******************************************************/ 395 | /***** Pack/Unpack Absmax Non-Linear Quantization *****/ 396 | /******************************************************/ 397 | template 398 | __device__ __forceinline__ int quantize_bsearch(const float* __restrict__ qmap, 399 | int bits, 400 | float x, 401 | float noise) { 402 | int lo = 0; 403 | int hi = 1 << bits; 404 | 405 | if (x <= qmap[lo]) 406 | return lo; 407 | if (qmap[hi - 1] <= x) 408 | return (hi - 1); 409 | 410 | while (lo < hi){ 411 | int mi = (lo + hi) >> 1; 412 | if (qmap[mi] <= x) lo = mi + 1; 413 | else hi = mi; 414 | } 415 | // return lo - 1; 416 | 417 | int rank = 0; 418 | if (STOCHASTIC) { 419 | float proba = (x - qmap[lo - 1]) / (qmap[lo] - qmap[lo - 1]); 420 | int flag = __float2int_rn(proba + noise - 0.5f); 421 | rank = (flag) ? lo : lo - 1; 422 | } else { 423 | float mid_val = (qmap[lo - 1] + qmap[lo]) * 0.5f; 424 | rank = (mid_val < x) ? lo : lo - 1; 425 | } 426 | return rank; 427 | } 428 | 429 | // Pack float16/32 data into int8 bit stream, for bits < 8 and 8 % bit == 0 430 | template 431 | __global__ void pack_nonlinear_4bit_kernel(int32_t bits, 432 | const scalar_t* __restrict__ data, 433 | const float* __restrict__ qmap, 434 | int8_t* __restrict__ packed, 435 | std::pair seeds) { 436 | const int group_id = blockIdx.x; 437 | const int id_in_group = threadIdx.x; 438 | const int64_t global_id = group_id * blockDim.x + id_in_group; 439 | const int work_per_int = 8 / bits; 440 | const int workint_per_thread = 4; 441 | const int work_per_thread = work_per_int << 2; 442 | const int8_t mask = (1 << bits) - 1; 443 | curandStatePhilox4_32_10_t state; 444 | curand_init(seeds.first, global_id, seeds.second, &state); 445 | 446 | for (int i = 0; i < workint_per_thread; i++) { 447 | uint8_t local_packed = 0; 448 | int64_t packed_id = global_id * workint_per_thread + i; 449 | for (int j = 0; j < work_per_int; j++) { 450 | const int64_t data_id = global_id * work_per_thread + i * work_per_int + j; 451 | const float noise = curand_uniform(&state); 452 | const float x = data[data_id]; 453 | const uint8_t qx = (uint8_t)quantize_bsearch(qmap, bits, x, noise); 454 | local_packed |= ((qx & mask) << (j * bits)); 455 | } 456 | 457 | packed[packed_id] = local_packed; 458 | } 459 | } 460 | 461 | Tensor pack_nonlinear_4bit_cuda(Tensor data, 462 | Tensor qmap, 463 | int bits, 464 | bool stochastic) { 465 | int64_t num_groups = data.size(0); 466 | int64_t group_size = data.size(1); 467 | 468 | // Compute total bits 469 | const int work_per_int = 8 / bits; 470 | const int workint_per_thread = 4; 471 | const int work_per_thread = work_per_int * workint_per_thread; 472 | TORCH_CHECK(8 % bits == 0); 473 | TORCH_CHECK(group_size % work_per_thread == 0); 474 | 475 | int64_t total_bits = (int64_t)bits * (num_groups * group_size); 476 | auto options = torch::TensorOptions().dtype(torch::kInt8).device(data.device()); 477 | Tensor packed = torch::empty({(total_bits + 8) / 8,}, options); 478 | 479 | // Random number generator 480 | int threads = group_size; 481 | auto gen = at::check_generator(at::cuda::detail::getDefaultCUDAGenerator()); 482 | std::pair rng_engine_inputs; 483 | { 484 | // See Note [Acquire lock when using random generators] 485 | std::lock_guard lock(gen->mutex_); 486 | rng_engine_inputs = gen->philox_engine_inputs(threads * work_per_thread); 487 | } 488 | // TORCH_CHECK(stochastic); 489 | 490 | // Call pack kernels 491 | if (stochastic) { 492 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(data.scalar_type(), "pack_nonlinear_4bit", ([&] { 493 | pack_nonlinear_4bit_kernel<<>>( 494 | bits, 495 | data.data_ptr(), 496 | qmap.data_ptr(), 497 | packed.data_ptr(), 498 | rng_engine_inputs); 499 | })); 500 | } else { 501 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(data.scalar_type(), "pack_nonlinear_4bit", ([&] { 502 | pack_nonlinear_4bit_kernel<<>>( 503 | bits, 504 | data.data_ptr(), 505 | qmap.data_ptr(), 506 | packed.data_ptr(), 507 | rng_engine_inputs); 508 | })); 509 | } 510 | 511 | 512 | return packed; 513 | } 514 | 515 | // Pack float16/32 data into int8 bit stream, for bits in [5, 6, 7, (8)] 516 | template 517 | __global__ void pack_nonlinear_8bit_kernel(int32_t bits, 518 | const scalar_t* __restrict__ data, 519 | const float* __restrict__ qmap, 520 | int8_t* __restrict__ packed, 521 | std::pair seeds) { 522 | const int group_id = blockIdx.x; 523 | const int id_in_group = threadIdx.x; 524 | const int64_t global_id = group_id * blockDim.x + id_in_group; 525 | const int work_per_thread = 4; 526 | curandStatePhilox4_32_10_t state; 527 | curand_init(seeds.first, global_id, seeds.second, &state); 528 | 529 | for (int i = 0; i < work_per_thread; i++) { 530 | const int64_t packed_id = global_id * work_per_thread + i; // which is same as data_id 531 | const float noise = curand_uniform(&state); 532 | const float x = data[packed_id]; 533 | const uint8_t qx = (uint8_t)quantize_bsearch(qmap, bits, x, noise); 534 | packed[packed_id] = qx; 535 | } 536 | } 537 | 538 | Tensor pack_nonlinear_8bit_cuda(Tensor data, 539 | Tensor qmap, 540 | int bits, 541 | bool stochastic) { 542 | int64_t num_groups = data.size(0); 543 | int64_t group_size = data.size(1); 544 | 545 | // Compute total bits 546 | const int storage_bits = 8; 547 | const int work_per_int = 1; 548 | const int workint_per_thread = 4; 549 | const int work_per_thread = work_per_int * workint_per_thread; 550 | TORCH_CHECK(group_size % work_per_thread == 0); 551 | 552 | int64_t total_bits = (int64_t)storage_bits * (num_groups * group_size); 553 | auto options = torch::TensorOptions().dtype(torch::kInt8).device(data.device()); 554 | Tensor packed = torch::empty({(total_bits + 8) / 8,}, options); 555 | 556 | // Random number generator 557 | int threads = group_size; 558 | auto gen = at::check_generator(at::cuda::detail::getDefaultCUDAGenerator()); 559 | std::pair rng_engine_inputs; 560 | { 561 | // See Note [Acquire lock when using random generators] 562 | std::lock_guard lock(gen->mutex_); 563 | rng_engine_inputs = gen->philox_engine_inputs(threads * work_per_thread); 564 | } 565 | // TORCH_CHECK(stochastic); 566 | 567 | // Call pack kernels 568 | if (stochastic) { 569 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(data.scalar_type(), "pack_nonlinear_8bit", ([&] { 570 | pack_nonlinear_8bit_kernel<<>>( 571 | bits, 572 | data.data_ptr(), 573 | qmap.data_ptr(), 574 | packed.data_ptr(), 575 | rng_engine_inputs); 576 | })); 577 | } else { 578 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(data.scalar_type(), "pack_nonlinear_8bit", ([&] { 579 | pack_nonlinear_8bit_kernel<<>>( 580 | bits, 581 | data.data_ptr(), 582 | qmap.data_ptr(), 583 | packed.data_ptr(), 584 | rng_engine_inputs); 585 | })); 586 | } 587 | 588 | 589 | return packed; 590 | } 591 | 592 | Tensor pack_nonlinear_cuda(Tensor data, 593 | Tensor qmap, 594 | int bits, 595 | bool stochastic) { 596 | if (8 % bits == 0 && bits < 8) { 597 | return pack_nonlinear_4bit_cuda(data, qmap, bits, stochastic); 598 | } else { // bits <= 8 599 | return pack_nonlinear_8bit_cuda(data, qmap, bits, stochastic); 600 | } 601 | } 602 | 603 | // Unpack int8 bit stream to float16/32 data, for bits < 8 and 8 % bit == 0 604 | template 605 | __global__ void unpack_nonlinear_4bit_kernel(int32_t bits, 606 | const int8_t* __restrict__ data, 607 | const float* __restrict__ qmap, 608 | scalar_t* __restrict__ unpacked) { 609 | const int group_id = blockIdx.x; 610 | const int d = threadIdx.x; 611 | const int64_t global_thread_id = group_id * blockDim.x + d; 612 | const int work_per_int = 8 / bits; 613 | const int workint_per_thread = 4; 614 | const int work_per_thread = work_per_int << 2; 615 | // const scalar_t B = (1 << (bits - 1)) - 1; 616 | const int8_t mask = (1 << bits) - 1; // 00001111 617 | 618 | for (int i = 0; i < workint_per_thread; i++) { 619 | int64_t global_int_id = global_thread_id * workint_per_thread + i; 620 | const uint8_t local_packed = data[global_int_id]; 621 | for (int j = 0; j < work_per_int; j++) { 622 | const int64_t id = global_thread_id * work_per_thread + i * work_per_int + j; 623 | const uint8_t unsigned_val = (local_packed >> (j * bits)) & mask; 624 | unpacked[id] = (scalar_t)qmap[unsigned_val]; 625 | } 626 | } 627 | 628 | } 629 | 630 | Tensor unpack_nonlinear_4bit_cuda(Tensor data, 631 | Tensor qmap, 632 | int bits, 633 | int64_t num_groups, 634 | int64_t group_size) { 635 | auto options = torch::TensorOptions().dtype(qmap.dtype()).device(data.device()); 636 | Tensor unpacked = torch::empty({num_groups, group_size}, options); 637 | 638 | const int work_per_int = 8 / bits; 639 | const int workint_per_thread = 4; 640 | const int work_per_thread = work_per_int * workint_per_thread; 641 | TORCH_CHECK(8 % bits == 0); 642 | 643 | // Call unpack kernels 644 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(qmap.scalar_type(), "unpack_nonlinear_4bit", ([&] { 645 | unpack_nonlinear_4bit_kernel<<>>( 646 | bits, 647 | data.data_ptr(), 648 | qmap.data_ptr(), 649 | unpacked.data_ptr()); 650 | })); 651 | 652 | return unpacked; 653 | } 654 | 655 | // Unpack int8 bit stream to float16/32 data, for bits in [5, 6, 7, (8)] 656 | template 657 | __global__ void unpack_nonlinear_8bit_kernel(int32_t bits, 658 | const int8_t* __restrict__ data, 659 | const float* __restrict__ qmap, 660 | scalar_t* __restrict__ unpacked) { 661 | const int group_id = blockIdx.x; 662 | const int d = threadIdx.x; 663 | const int64_t global_thread_id = group_id * blockDim.x + d; 664 | const int work_per_thread = 4; 665 | 666 | for (int i = 0; i < work_per_thread; i++) { 667 | const int64_t global_int_id = global_thread_id * work_per_thread + i; 668 | const uint8_t local_packed = data[global_int_id]; 669 | unpacked[global_int_id] = (scalar_t)qmap[local_packed]; 670 | } 671 | } 672 | 673 | Tensor unpack_nonlinear_8bit_cuda(Tensor data, 674 | Tensor qmap, 675 | int bits, 676 | int64_t num_groups, 677 | int64_t group_size) { 678 | auto options = torch::TensorOptions().dtype(qmap.dtype()).device(data.device()); 679 | Tensor unpacked = torch::empty({num_groups, group_size}, options); 680 | 681 | const int work_per_thread = 4; 682 | 683 | // Call unpack kernels 684 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(qmap.scalar_type(), "unpack_nonlinear_8bit", ([&] { 685 | unpack_nonlinear_8bit_kernel<<>>( 686 | bits, 687 | data.data_ptr(), 688 | qmap.data_ptr(), 689 | unpacked.data_ptr()); 690 | })); 691 | 692 | return unpacked; 693 | } 694 | 695 | Tensor unpack_nonlinear_cuda(Tensor data, 696 | Tensor qmap, 697 | int bits, 698 | int64_t num_groups, 699 | int64_t group_size) { 700 | if (8 % bits == 0 && bits < 8) { 701 | return unpack_nonlinear_4bit_cuda(data, qmap, bits, num_groups, group_size); 702 | } else { // bits <= 8 703 | return unpack_nonlinear_8bit_cuda(data, qmap, bits, num_groups, group_size); 704 | } 705 | } 706 | 707 | -------------------------------------------------------------------------------- /lpmm/functional.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | import itertools 4 | import lpmm.cpp_extension.quantization as ext_quantization 5 | 6 | lpmm_generator = None 7 | FP_EXPONENT_BIS_MAP = { 8 | 2: 1, 9 | 3: 2, 10 | 4: 2, 11 | 5: 3, 12 | 6: 3, 13 | 7: 4, 14 | 8: 4, 15 | } 16 | 17 | def init_lpmm_generator(gpu, seed): 18 | global lpmm_generator 19 | if lpmm_generator is None: 20 | lpmm_generator = torch.Generator(device=gpu) 21 | if seed is None: 22 | seed = 3407 23 | lpmm_generator.manual_seed(seed) 24 | 25 | 26 | def vectorwise_quant(x, **kwargs): 27 | '''interface quantization function 28 | ''' 29 | qx = x.detach() # keep the reference of original tensor 30 | 31 | # save kwargs 32 | generated_metadata = {} 33 | generated_metadata['dtype'] = x.dtype 34 | generated_metadata['stride'] = x.stride() 35 | 36 | # Given a ill-conditioned/quantization-unfriendly tensor, how to normalize and/or avoid outlier? 37 | # scale/noramlize the original tensor 38 | qx, md = quant_scaling(qx, **kwargs) 39 | generated_metadata.update(md) 40 | 41 | # Given a tensor distributed in [-1/0, 1], how to quantize with best error? 42 | # quantize the normalized tensor 43 | quant_type = kwargs['quant_type'] 44 | b, signed = kwargs['b'], kwargs['signed'] 45 | if quant_type == 'linear': 46 | MRQ, lo, hi = prepare_quant_boundary(b, signed) 47 | qx = atom_quant(qx, None, MRQ, lo, hi, round_type=kwargs['round_type']) 48 | elif quant_type in ['nonlinear', 'power-1', 'power-2', 'power-3', 'float-point', 'nonlinear-nozero']: 49 | if isinstance(kwargs['qmap'], torch.Tensor): 50 | qmap = kwargs['qmap'] 51 | else: 52 | qmap = kwargs['qmap'][(b, signed)][quant_type] 53 | qx = nonlinear_quant(qx, qmap, b, round_type=kwargs['round_type']) 54 | else: 55 | raise ValueError( 56 | f"Not support {quant_type} quant type." 57 | ) 58 | 59 | return qx, generated_metadata 60 | 61 | 62 | def vectorwise_dequant(qx, denormalized=True, **kwargs): 63 | '''dequantization function 64 | ''' 65 | x = qx.detach() 66 | 67 | # load kwargs 68 | dtype = kwargs['dtype'] 69 | stride = kwargs['stride'] 70 | 71 | # dequantize the quantized tensor to get a tensor in [-1/0, 1] 72 | quant_type = kwargs['quant_type'] 73 | b, signed = kwargs['b'], kwargs['signed'] 74 | if quant_type == 'linear': 75 | MRQ, lo, hi = prepare_quant_boundary(b, signed) 76 | x = atom_dequant(x, None, MRQ) 77 | elif quant_type in ['nonlinear', 'power-1', 'power-2', 'power-3', 'float-point', 'nonlinear-nozero']: 78 | if isinstance(kwargs['qmap'], torch.Tensor): 79 | qmap = kwargs['qmap'] 80 | else: 81 | qmap = kwargs['qmap'][(b, signed)][quant_type] 82 | x = nonlinear_dequant(x, qmap, b, shape=kwargs['scaled_shape'], round_type=kwargs['round_type']) 83 | else: 84 | raise ValueError( 85 | f"Not support {quant_type} quant type." 86 | ) 87 | 88 | # only for debug 89 | if not denormalized: 90 | return x 91 | 92 | # scale the dequantized tensor to get the original tensor 93 | scale_type = kwargs['scale_type'] 94 | max1 = kwargs['max1'] 95 | if scale_type in ['tensor', 'dim0', 'dim1']: 96 | x = x.mul(max1) 97 | elif scale_type in ['rank1']: 98 | dim = kwargs['dim'] 99 | if dim == 1: # group 100 | x = x.mul(max1) 101 | shape = kwargs['shape'] 102 | x = recon_grouped_tensor(x, shape) 103 | else: 104 | max_dims = kwargs['max_dims'] 105 | st = _compute_sm3_scale_tensor(max_dims) 106 | x = x.mul(st) 107 | elif scale_type == 'dim01': 108 | x = x.mul(max1) 109 | max_dim0 = kwargs['max_dim0'] 110 | x = x.mul(max_dim0) 111 | elif scale_type == 'dim10': 112 | x = x.mul(max1) 113 | max_dim1 = kwargs['max_dim1'] 114 | x = x.mul(max_dim1) 115 | elif scale_type == 'group': 116 | x = x.mul(max1) 117 | shape = kwargs['shape'] 118 | x = recon_grouped_tensor(x, shape) 119 | elif scale_type == 'rank1-group': 120 | dim = kwargs['dim'] 121 | if dim == 1: # group 122 | x = x.mul(max1) 123 | shape = kwargs['shape'] 124 | x = recon_grouped_tensor(x, shape) 125 | elif dim == 2: 126 | max0 = kwargs['max0'] 127 | gp0_shape = kwargs['gp0_shape'] 128 | st0 = recon_grouped2d_tensor(max0.expand(gp0_shape), kwargs['shape']) 129 | gp1_shape = kwargs['gp1_shape'] 130 | st1 = recon_grouped2d_tensor(max1.expand(gp1_shape), kwargs['Tshape']) 131 | st = torch.min(st0, st1.T) 132 | x = x.mul(st) 133 | else: # rank1 134 | max_dims = kwargs['max_dims'] 135 | st = _compute_sm3_scale_tensor(max_dims) 136 | x = x.mul(st) 137 | elif scale_type == 'id': 138 | pass 139 | else: 140 | raise NotImplementedError 141 | 142 | if x.stride() != stride: 143 | # print(f"[warn] in dequantization, approximator x has not same stride {x.stride()} as original stride {stride}." 144 | # "Renew a tensor with same memory format.") 145 | recon_x = torch.empty_strided(x.shape, stride, dtype=dtype, layout=torch.strided, device=x.device) 146 | recon_x.copy_(x) 147 | del x 148 | return recon_x 149 | else: 150 | x = x.to(dtype=dtype) 151 | return x 152 | 153 | 154 | def quant_scaling(qx, **kwargs): 155 | scale_type = kwargs['scale_type'] 156 | generated_metadata = {} 157 | # reshape and scaling 158 | if scale_type == 'tensor': 159 | max1 = torch.amax(torch.abs(qx), keepdim=True).to(torch.float32) # (1, 1) 160 | generated_metadata['max1'] = max1 161 | qx = qx.div(max1) 162 | elif scale_type == 'dim0': 163 | max1 = _max_reduce_except_dim(qx.abs(), 0) 164 | generated_metadata['max1'] = max1 165 | qx = qx.div(max1) 166 | elif scale_type == 'dim1': 167 | max1 = _max_reduce_except_dim(qx.abs(), 1) 168 | generated_metadata['max1'] = max1 169 | qx = qx.div(max1) 170 | elif scale_type == 'dim01': 171 | max_dim0 = _max_reduce_except_dim(qx.abs(), 0) 172 | qx = qx.div(max_dim0) 173 | max1 = _max_reduce_except_dim(qx.abs(), 1) 174 | generated_metadata['max_dim0'] = max_dim0 175 | generated_metadata['max1'] = max1 176 | qx = qx.div(max1) 177 | elif scale_type == 'dim10': 178 | max_dim1 = _max_reduce_except_dim(qx.abs(), 1) 179 | qx = qx.div(max_dim1) 180 | max1 = _max_reduce_except_dim(qx.abs(), 0) 181 | generated_metadata['max_dim1'] = max_dim1 182 | generated_metadata['max1'] = max1 183 | qx = qx.div(max1) 184 | elif scale_type == 'group': 185 | gp_sz = kwargs['gp_sz'] 186 | qx = group_tensor(qx, gp_sz) # (num_gp, gp_sz) 187 | max1 = _max_reduce_except_dim(qx.abs(), 0) 188 | qx = qx.div(max1) 189 | generated_metadata['max1'] = max1 190 | elif scale_type == 'rank1': 191 | generated_metadata['dim'] = qx.dim() 192 | if qx.dim() == 1: # group 193 | gp_sz = 128 194 | qx = group_tensor(qx, gp_sz) # (num_gp, gp_sz) 195 | max1 = _max_reduce_except_dim(qx.abs(), 0) 196 | qx = qx.div(max1) 197 | generated_metadata['max1'] = max1 198 | else: 199 | max_dims = get_sm3_statistics(qx.abs()) 200 | st = _compute_sm3_scale_tensor(max_dims) 201 | generated_metadata['max_dims'] = max_dims 202 | generated_metadata['max1'] = None 203 | qx = qx.div(st) 204 | elif scale_type == 'rank1-group': 205 | gp_sz = kwargs['gp_sz'] 206 | generated_metadata['dim'] = qx.dim() 207 | if qx.dim() == 1: # group 208 | gp_sz = 128 209 | qx = group_tensor(qx, gp_sz) # (num_gp, gp_sz) 210 | max1 = _max_reduce_except_dim(qx.abs(), 0) 211 | qx = qx.div(max1) 212 | generated_metadata['max1'] = max1 213 | elif qx.dim() == 2: 214 | generated_metadata['Tshape'] = qx.T.shape 215 | gp0_qx = group2d_tensor(qx, gp_sz) # (num_gp, gp_sz) 216 | max0 = _max_reduce_except_dim(gp0_qx.abs(), 0) 217 | generated_metadata['max0'] = max0 218 | st0 = recon_grouped2d_tensor(max0.expand_as(gp0_qx), qx.shape) 219 | generated_metadata['gp0_shape'] = gp0_qx.shape 220 | del gp0_qx 221 | gp1_qx = group2d_tensor(qx.T, gp_sz) # (num_gp, gp_sz) 222 | max1 = _max_reduce_except_dim(gp1_qx.abs(), 0) 223 | generated_metadata['max1'] = max1 224 | st1 = recon_grouped2d_tensor(max1.expand_as(gp1_qx), qx.T.shape) 225 | generated_metadata['gp1_shape'] = gp1_qx.shape 226 | del gp1_qx 227 | st = torch.min(st0, st1.T) 228 | del st0, st1 229 | qx = qx.div(st) 230 | else: # rank1 231 | max_dims = get_sm3_statistics(qx.abs()) 232 | st = _compute_sm3_scale_tensor(max_dims) 233 | generated_metadata['max_dims'] = max_dims 234 | generated_metadata['max1'] = None 235 | qx = qx.div(st) 236 | elif scale_type == 'id': 237 | generated_metadata['max1'] = None 238 | else: 239 | raise NotImplementedError 240 | generated_metadata['scaled_shape'] = qx.shape 241 | return qx, generated_metadata 242 | 243 | 244 | def create_general_qmap(quant_type, bit, signed): 245 | if bit == 1: 246 | return torch.Tensor([-1.0, 1.0]) if signed else torch.Tensor([0.0, 1.0]) 247 | 248 | if quant_type == 'linear': 249 | return None 250 | elif quant_type == 'nonlinear': 251 | return create_dynamic_map(signed, bit - 1, bit if signed else bit - 1) 252 | elif quant_type == 'nonlinear-nozero': 253 | mapping = create_dynamic_map(signed, bit - 1, bit if signed else bit - 1) 254 | if not signed: 255 | mapping[0] = mapping[1] 256 | return mapping 257 | elif quant_type == 'power-1': 258 | return create_pow_map(bit, signed, 1) 259 | elif quant_type == 'power-2': 260 | return create_pow_map(bit, signed, 2) 261 | elif quant_type == 'power-3': 262 | return create_pow_map(bit, signed, 3) 263 | elif quant_type == 'float-point': 264 | return create_fp8_map(signed, FP_EXPONENT_BIS_MAP[bit], bit) 265 | else: 266 | raise ValueError( 267 | f"Not support {quant_type} quant type." 268 | ) 269 | 270 | 271 | # nonlinear quantization utils 272 | def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): 273 | """ 274 | Creates the dynamic quantiztion map. 275 | 276 | The dynamic data type is made up of a dynamic exponent and 277 | fraction. As the exponent increase from 0 to -7 the number 278 | of bits available for the fraction shrinks. 279 | 280 | This is a generalization of the dynamic type where a certain 281 | number of the bits and be reserved for the linear quantization 282 | region (the fraction). n determines the maximum number of 283 | exponent bits. 284 | 285 | For more details see 286 | (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561] 287 | """ 288 | 289 | data = [] 290 | # these are additional items that come from the case 291 | # where all the exponent bits are zero and no 292 | # indicator bit is present 293 | non_sign_bits = total_bits - (1 if signed else 0) 294 | additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 295 | if not signed: 296 | additional_items = 2 * additional_items 297 | for i in range(max_exponent_bits): 298 | fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1)) 299 | boundaries = torch.linspace(0.1, 1, fraction_items) 300 | means = (boundaries[:-1] + boundaries[1:]) / 2.0 301 | data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() 302 | if signed: 303 | data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() 304 | 305 | if additional_items > 0: 306 | boundaries = torch.linspace(0.1, 1, additional_items + 1) 307 | means = (boundaries[:-1] + boundaries[1:]) / 2.0 308 | data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() 309 | if signed: 310 | data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() 311 | 312 | data.append(0) 313 | data.append(1.0) 314 | data.sort() 315 | return torch.Tensor(data) 316 | 317 | 318 | def create_fp8_map(signed=True, exponent_bits=5, total_bits=8): 319 | e = exponent_bits 320 | # p = precision_bits 321 | has_sign = 1 if signed else 0 322 | # assert e+p == total_bits-has_sign 323 | precision_bits = total_bits - has_sign - e 324 | # the exponent is biased to 2^(e-1) -1 == 0 325 | evalues = [] 326 | pvalues = [] 327 | for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)): 328 | evalues.append(2**val) 329 | 330 | 331 | values = [] 332 | lst = list(itertools.product([0, 1], repeat=precision_bits)) 333 | #for ev in evalues: 334 | bias = 2**(exponent_bits-1) 335 | for evalue in range(2**(exponent_bits)): 336 | for bit_pattern in lst: 337 | value = (1 if evalue != 0 else 0) 338 | for i, pval in enumerate(list(bit_pattern)): 339 | value += pval*(2**-(i+1)) 340 | if evalue == 0: 341 | # subnormals 342 | value = value*2**-(bias) 343 | else: 344 | # normals 345 | value = value*2**-(evalue-bias-1) 346 | values.append(value) 347 | if signed: 348 | values.append(-value) 349 | 350 | 351 | assert len(values) == 2**total_bits 352 | values.sort() 353 | code = torch.Tensor(values) 354 | code /= code.max() 355 | 356 | return code 357 | 358 | 359 | def nonlinear_quant(qx, qmap, b, round_type='sr'): 360 | 361 | def real_nonlinear_quant(qx, qmap, b, stochastic): 362 | grouped_qx = group_tensor(qx, 2048) 363 | return ext_quantization.pack_nonlinear(grouped_qx, qmap, b, stochastic) 364 | 365 | qmaplen = len(qmap) 366 | if round_type == 'real-sr': 367 | idx = real_nonlinear_quant(qx, qmap, b, True) 368 | elif round_type == 'real-nearest': 369 | idx = real_nonlinear_quant(qx, qmap, b, False) 370 | elif round_type.startswith('sr'): 371 | qx.clamp_(qmap[0], qmap[-1]) 372 | floor_idx = ((qx.unsqueeze(-1) >= qmap).sum(dim=-1) - 1).clamp_(0, qmaplen - 1) 373 | next_idx = (floor_idx + 1).clamp_max_(qmaplen - 1) 374 | Z = qmap[next_idx] - qmap[floor_idx] 375 | Z[Z <= 0] = 1. 376 | proba = (qx - qmap[floor_idx]) / Z 377 | proba = torch.bernoulli(proba, generator=lpmm_generator) 378 | idx = (floor_idx + proba).round_().to(torch.int) 379 | if round_type == 'sr1': 380 | idx = idx.clamp_min_(1) 381 | elif round_type == 'sr2': 382 | idx = idx.clamp_min_(2) 383 | elif round_type == 'down': 384 | idx = ((qx.unsqueeze(-1) >= qmap).sum(dim=-1) - 1).clamp_(0, qmaplen - 1).to(torch.int) 385 | elif round_type == 'up': 386 | idx = ((qx.unsqueeze(-1) > qmap).sum(dim=-1)).clamp_(0, qmaplen - 1).to(torch.int) 387 | elif round_type == 'nearest': 388 | diff_tensor = torch.abs(qx.unsqueeze(-1) - qmap) 389 | idx = torch.argmin(diff_tensor, dim=-1).to(torch.int) 390 | return idx 391 | 392 | 393 | def nonlinear_dequant(qx, qmap, b, shape, round_type='sr'): 394 | if round_type.startswith('real'): 395 | num_groups = (shape.numel() + 2047) // 2048 396 | grouped_x = ext_quantization.unpack_nonlinear(qx, qmap, b, num_groups, 2048) 397 | x = recon_grouped_tensor(grouped_x, shape) 398 | else: 399 | x = qmap[qx.to(torch.int64)] 400 | return x 401 | 402 | 403 | # group quantization utils 404 | def group_tensor(input: torch.Tensor, gp_sz: int): 405 | r"""Group tensor into subtensors of size 'gp_sz' 406 | """ 407 | if not gp_sz > 0: 408 | raise ValueError("group size need to be a positive integer, but found {}".format(gp_sz)) 409 | 410 | input_flatten = input.flatten() 411 | num_features = input_flatten.shape[0] 412 | 413 | # Reshape the tensor into group 414 | if num_features % gp_sz != 0: 415 | # Padding 416 | new_num_features = (num_features // gp_sz + 1) * gp_sz 417 | delta = new_num_features - num_features 418 | input_flatten = torch.cat([input_flatten, 419 | torch.zeros([delta], dtype=input.dtype, device=input.device)], dim=0) 420 | 421 | input_groups = input_flatten.view(-1, gp_sz) # num_groups, group_size 422 | return input_groups 423 | 424 | 425 | def recon_grouped_tensor(grouped_tensor: torch.Tensor, shape) -> torch.Tensor : 426 | r"""Reconstruction the tensor to original (or specific) shape 427 | """ 428 | numel = shape.numel() 429 | recon_flatten = grouped_tensor.flatten()[:numel] 430 | recon = recon_flatten.view(shape) 431 | return recon 432 | 433 | 434 | def group2d_tensor(input: torch.Tensor, gp_sz: int): 435 | r"""Group tensor into subtensors of size 'gp_sz' 436 | """ 437 | if not gp_sz > 0: 438 | raise ValueError("group size need to be a positive integer, but found {}".format(gp_sz)) 439 | if input.dim() != 2: 440 | raise ValueError("") 441 | C0, C1 = input.shape[0], input.shape[1] 442 | # Reshape the tensor into group 443 | if C1 % gp_sz != 0: 444 | # Padding 445 | new_num_features = (C1 // gp_sz + 1) * gp_sz 446 | delta = new_num_features - C1 447 | input = torch.cat([input, torch.zeros([C0, delta], dtype=input.dtype, device=input.device)], 448 | dim=1) 449 | input_groups = input.reshape(-1, gp_sz) # num_groups, group_size 450 | return input_groups 451 | 452 | 453 | def recon_grouped2d_tensor(grouped_tensor: torch.Tensor, shape) -> torch.Tensor : 454 | r"""Reconstruction the tensor to original (or specific) shape 455 | """ 456 | return grouped_tensor.reshape(shape[0], -1)[:, :shape[1]] 457 | 458 | 459 | # deprecated 460 | def sm3_quant(x, **kwargs): 461 | # save normal kwargs already finished 462 | qx = x.abs() 463 | max_dims = [] 464 | for i in range(x.dim()): 465 | nu_max = _max_reduce_except_dim(qx, i) 466 | if isinstance(kwargs['sm3_history'], list): 467 | torch.max(kwargs['sm3_history'][i], nu_max, out=nu_max) 468 | max_dims.append(nu_max) 469 | kwargs['gen'] = (max_dims,) # not changed afterwards 470 | # quantize 471 | signed = kwargs['signed'] 472 | b = kwargs['b'] 473 | if b == 0: 474 | # NOTE: exactly SM3 algorithm 475 | return torch.sign(x), kwargs 476 | else: 477 | st = _compute_sm3_scale_tensor(max_dims) 478 | MRQ, lo, hi = prepare_quant_boundary(b, signed) 479 | qx = atom_quant(x, st, MRQ, lo, hi, kwargs['round_type']) 480 | return qx, kwargs 481 | 482 | 483 | # deprecated 484 | def sm3_dequant(qx, **kwargs): 485 | # self-consistent 486 | max_dims = kwargs['gen'][0] 487 | st = _compute_sm3_scale_tensor(max_dims) 488 | signed = kwargs['signed'] 489 | dtype = kwargs['dtype'] 490 | memory_format = kwargs['memory_format'] 491 | b = kwargs['b'] 492 | if b == 0: 493 | # NOTE: exactly SM3 algorithm 494 | x = st * qx if signed else st 495 | else: 496 | MRQ, lo, hi = prepare_quant_boundary(b, signed) 497 | x = atom_dequant(qx, st, MRQ) 498 | x = x.to(dtype=dtype, memory_format=memory_format) 499 | assert x.shape == kwargs['shape'], f"The original shape is {kwargs['shape']} the dequantized shape is {x.shape}" 500 | return x 501 | 502 | 503 | def _compute_sm3_scale_tensor(max_dims): 504 | rank = len(max_dims) 505 | scale_tensor = max_dims[0].clone() 506 | for i in range(1, rank): 507 | # We rely on broadcasting to get the proper end shape. 508 | scale_tensor = torch.min(scale_tensor, max_dims[i]) 509 | return scale_tensor 510 | 511 | 512 | def get_sm3_statistics(x, **kwargs): 513 | qx = x.abs() 514 | max_dims = [] 515 | for i in range(x.dim()): 516 | nu_max = _max_reduce_except_dim(qx, i) 517 | max_dims.append(nu_max) 518 | return max_dims 519 | 520 | 521 | def _max_reduce_except_dim(tensor, dim): 522 | # Computes max along all dimensions except the given dim. 523 | # If tensor is a scalar, it returns tensor. 524 | rank = len(tensor.shape) 525 | result = tensor 526 | if rank > 0: 527 | assert dim < rank 528 | for d in range(rank): 529 | if d != dim: 530 | result = result.max(dim=d, keepdim=True).values 531 | return result 532 | 533 | 534 | # deprecated 535 | def adafactor_quant(x, **kwargs): 536 | # normal kwargs saving already finished 537 | assert x.dim() == 2 538 | qx = x.abs() 539 | sum_dims = [] 540 | for i in range(x.dim()): 541 | one_dim_sum = qx.sum(dim=i, keepdim=True) 542 | sum_dims.append(one_dim_sum) 543 | kwargs['gen'] = (sum_dims,) # not changed afterwards 544 | # quantize 545 | signed = kwargs['signed'] 546 | b = kwargs['b'] 547 | # NOTE: exactly adafactor algorithm 548 | return torch.sign(x), kwargs 549 | 550 | 551 | # deprecated 552 | def adafactor_dequant(qx, **kwargs): 553 | # self-consistent 554 | sum_dims = kwargs['gen'][0] 555 | st = sum_dims[0] * sum_dims[1] / sum_dims[0].sum().item() 556 | signed = kwargs['signed'] 557 | dtype = kwargs['dtype'] 558 | memory_format = kwargs['memory_format'] 559 | b = kwargs['b'] 560 | # NOTE: exactly adafactor algorithm 561 | x = st * qx if signed else st 562 | x = x.to(dtype=dtype, memory_format=memory_format) 563 | assert x.shape == kwargs['shape'], f"The original shape is {kwargs['shape']} the dequantized shape is {x.shape}" 564 | return x 565 | 566 | 567 | # basic quant utils 568 | def atom_quant(x, scale, maximal, lo, hi, round_type='sr'): 569 | if scale is None: 570 | qx = x * maximal 571 | else: 572 | qx = x / scale.expand_as(x) * maximal # scale x to integer unit 573 | if round_type in ['sr', 'real-sr']: 574 | eps = torch.rand(qx.size(), generator=lpmm_generator, device=qx.device) - 0.5 575 | qx = torch.clamp(qx + eps, lo, hi) 576 | qx = qx.round_().to(torch.int) 577 | elif round_type == 'up': 578 | qx = torch.clamp(qx, lo, hi) 579 | qx = qx.ceil_().to(torch.int) 580 | elif round_type == 'down': 581 | qx = torch.clamp(qx, lo, hi) 582 | qx = qx.floor_().to(torch.int) 583 | elif round_type == ['nearest', 'real-nearest']: 584 | qx = torch.clamp(qx, lo, hi) 585 | qx = qx.round_().to(torch.int) 586 | elif round_type == 'sr2': 587 | eps = torch.rand(qx.size(), generator=lpmm_generator, device=qx.device) - 0.5 588 | qx = torch.clamp(qx + eps, 2, hi) 589 | qx = qx.round_().to(torch.int) 590 | elif round_type == 'sr1': 591 | eps = torch.rand(qx.size(), generator=lpmm_generator, device=qx.device) - 0.5 592 | qx = torch.clamp(qx + eps, 1, hi) 593 | qx = qx.round_().to(torch.int) 594 | else: 595 | raise NotImplementedError 596 | return qx 597 | 598 | 599 | def atom_dequant(qx, scale, maximal): 600 | if scale is None: 601 | return qx / maximal 602 | else: 603 | return qx / maximal * scale.expand_as(qx) 604 | 605 | 606 | def prepare_quant_boundary(b, signed): 607 | B = (2 ** (b - 1) - 1) 608 | UB = 2 ** b - 1 609 | hi = MRQ = B if signed else UB # maximal representable quantized integer 610 | lo = -B if signed else 0 611 | return MRQ, lo, hi 612 | 613 | 614 | def symmetric_atom_quantize(x, bit_width, res, **kwargs): 615 | r''' 616 | symmetric quantization excluding zero 617 | only support for signed case, with single scale. 618 | ======== 619 | Parameters: 620 | x: zero mean tensor 621 | res: 622 | ''' 623 | num_points = 2 ** bit_width - 1 624 | translate_x = x + num_points / 2 * res 625 | qx = atom_quant(translate_x, res * num_points, num_points, 0, num_points, **kwargs) 626 | return qx 627 | 628 | 629 | def symmetric_atom_dequantize(qx, bit_width, res, **kwargs): 630 | r''' 631 | symmetric dequantization excluding zero 632 | only support for signed case, with single scale. 633 | ======== 634 | Parameters: 635 | x: zero mean tensor 636 | res: 637 | ''' 638 | num_points = 2 ** bit_width - 1 639 | translate_x = atom_dequant(qx, res * num_points, num_points) 640 | x = translate_x - num_points / 2 * res 641 | return x 642 | 643 | 644 | def create_pow_map(b, signed, p): 645 | if signed: 646 | qmap = torch.linspace(-1, 1, (2 ** b)) # no zero ver. 647 | # qmap = torch.linspace(-1, 1, (2 ** b) - 1) # less one ver. 648 | # qmap = torch.linspace(-1, 1, (2 ** b) + 1)[1:] # no minimal ver. 649 | if p != 1: 650 | qmap = qmap.sign() * (qmap.abs() ** p) 651 | else: 652 | # qmap = torch.linspace(0, 1, 2 ** b) # default ver. 653 | qmap = torch.linspace(0, 1, (2 ** b) + 1)[1:] # no zero ver. 654 | if p != 1: 655 | qmap = qmap ** p 656 | return qmap 657 | 658 | 659 | def create_exp_map(qx, b, signed): 660 | if signed: 661 | N = (2 ** (b - 1)) 662 | logqx = qx.abs().log() 663 | pos = torch.exp(torch.linspace(logqx.min(), logqx.max(), N)) 664 | neg = -pos.flip(0) 665 | # qmap = torch.cat([neg, torch.as_tensor([0]), pos]) # less one ver. 666 | qmap = torch.cat([neg, pos]) # no zero ver. 667 | else: 668 | N = (2 ** b) - 1 669 | logqx = qx.log() 670 | pos = torch.exp(torch.linspace(logqx.min(), logqx.max(), N)) 671 | qmap = torch.cat([torch.as_tensor([0]), pos]) 672 | return qmap 673 | 674 | 675 | def create_log_map(b, signed): 676 | Z = torch.exp(torch.as_tensor([1])) - 1 677 | if signed: 678 | # positive_quantization_points = (torch.exp(torch.linspace(0, 1, 2 ** (b - 1) + 1)) - 1) / Z # no minimal ver. 679 | positive_quantization_points = (torch.exp(torch.linspace(0, 1, 2 ** (b - 1) + 1)) - 1)[1:] / Z # no zero ver. 680 | # positive_quantization_points = (torch.exp(torch.linspace(0, 1, 2 ** (b - 1))) - 1) / Z # less one ver. 681 | negative_quantization_points = -positive_quantization_points.flip(0) 682 | # qmap = torch.cat((negative_quantization_points[1:-1], positive_quantization_points)) # no minimal ver. 683 | qmap = torch.cat((negative_quantization_points, positive_quantization_points)) # no zero ver. 684 | # qmap = torch.cat((negative_quantization_points[:-1], positive_quantization_points)) # less one ver. 685 | else: 686 | # qmap = (torch.exp(torch.linspace(0, 1, 2 ** b)) - 1) / Z # default ver. 687 | qmap = (torch.exp(torch.linspace(0, 1, (2 ** b) + 1)) - 1)[1:] / Z # no zero ver. 688 | return qmap 689 | 690 | 691 | # --- quantization enable --- 692 | def always_bool(val: bool): 693 | def f(*inputs, **kwargs): 694 | return val 695 | return f 696 | 697 | 698 | always_true = always_bool(val=True) 699 | always_false = always_bool(val=False) 700 | 701 | 702 | def numel_enable(th): 703 | def enable(p_name, state_name, x): 704 | return x.numel() > th if th is not None else True 705 | return enable 706 | 707 | 708 | def suffix_enable(suffix, has_suffix=False): 709 | def enable(p_name, state_name, x): 710 | return not (p_name.endswith(suffix) ^ has_suffix) 711 | return enable 712 | 713 | 714 | def intersect_enable(*func_list): 715 | def enable(p_name, state_name, x): 716 | for func in func_list: 717 | if not func(p_name, state_name, x): 718 | return False 719 | return True 720 | return enable 721 | 722 | 723 | def union_enable(*func_list): 724 | def enable(p_name, state_name, x): 725 | for func in func_list: 726 | if func(p_name, state_name, x): 727 | return True 728 | return False 729 | return enable 730 | 731 | 732 | def shape_enable(dim=2): 733 | def enable(p_name, state_name, x): 734 | return x.dim() >= dim 735 | return enable 736 | 737 | 738 | def get_enable_fn_from_subconfig(subconfig): 739 | if not subconfig.ENABLE: 740 | return always_false 741 | func_list = [] 742 | # if only_suffix is not None: 743 | # assert isinstance(only_suffix, List) 744 | # for suffix in only_suffix: 745 | # func_list.append(suffix_enable(suffix, has_suffix=True)) 746 | # union = union_enable(*func_list) 747 | # return intersect_enable(numel_enable(th), union) 748 | func_list.append(numel_enable(subconfig.THRESHOLD)) 749 | func_list.append(shape_enable()) 750 | for suffix in subconfig.EXCLUDE_SUFFIX: 751 | func_list.append(suffix_enable(suffix, has_suffix=False)) 752 | return intersect_enable(*func_list) 753 | -------------------------------------------------------------------------------- /lpmm/optim/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .adamw import AdamW 3 | from .sgd import SGD 4 | 5 | del adamw 6 | del sgd 7 | del optimizer -------------------------------------------------------------------------------- /lpmm/optim/adamw.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | from torch import Tensor 5 | from typing import List, Optional 6 | import time 7 | 8 | from .optimizer import LowBitOptimizer 9 | from ..functional import vectorwise_dequant, vectorwise_quant 10 | 11 | __all__ = ["AdamW"] 12 | 13 | 14 | class AdamW(LowBitOptimizer): 15 | def __init__( 16 | self, 17 | params, 18 | lr=1e-3, 19 | betas=(0.9, 0.999), 20 | eps=1e-8, 21 | weight_decay=1e-2, 22 | use_first_moment=True, 23 | factor_second_moment=False, 24 | qconfig=None, 25 | *, 26 | fused: Optional[bool] = False, 27 | ): 28 | if not 0.0 <= lr: 29 | raise ValueError("Invalid learning rate: {}".format(lr)) 30 | if not 0.0 <= eps: 31 | raise ValueError("Invalid epsilon value: {}".format(eps)) 32 | if not 0.0 <= betas[0] < 1.0: 33 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 34 | if not 0.0 <= betas[1] < 1.0: 35 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 36 | if not 0.0 <= weight_decay: 37 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 38 | defaults = dict( 39 | lr=lr, 40 | betas=betas, 41 | eps=eps, 42 | weight_decay=weight_decay, 43 | fused=fused, 44 | use_first_moment=use_first_moment, 45 | factor_second_moment=factor_second_moment, 46 | ) 47 | super().__init__(params, defaults, qconfig) 48 | 49 | def __setstate__(self, state): 50 | super().__setstate__(state) 51 | for group in self.param_groups: 52 | group.setdefault("fused", None) 53 | state_values = list(self.state.values()) 54 | step_is_tensor = (len(state_values) != 0) and torch.is_tensor( 55 | state_values[0]["step"] 56 | ) 57 | if not step_is_tensor: 58 | for s in state_values: 59 | s["step"] = torch.tensor(float(s["step"])) 60 | 61 | def get_subqconfig(self, optimizer_state_name): 62 | if optimizer_state_name == 'exp_avg': 63 | return self.qconfig.QUANT.M 64 | elif optimizer_state_name == 'exp_avg_sq': 65 | return self.qconfig.QUANT.SQM 66 | else: 67 | raise ValueError( 68 | f"" 69 | ) 70 | 71 | @staticmethod 72 | def _get_options(param_group, param_shape): 73 | factored = len(param_shape) >= 2 and param_group["factor_second_moment"] 74 | use_first_moment = param_group["use_first_moment"] 75 | return factored, use_first_moment 76 | 77 | @staticmethod 78 | def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): 79 | # copy from fairseq's adafactor implementation: 80 | # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 81 | r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).unsqueeze(-1) 82 | c_factor = exp_avg_sq_col.unsqueeze(-2) 83 | return torch.mul(r_factor, c_factor) 84 | 85 | def _init_group( 86 | self, 87 | group, 88 | params_with_grad, 89 | grads, 90 | exp_avgs, 91 | exp_avg_sqs, 92 | exp_avg_sqs_factored, 93 | exp_avg_sq_rows, 94 | exp_avg_sq_cols, 95 | state_steps, 96 | exp_avgs_q_enabled, 97 | exp_avg_sqs_q_enabled, 98 | exp_avgs_q_overhead, 99 | exp_avg_sqs_q_overhead, 100 | exp_avgs_qmap, 101 | exp_avg_sqs_qmap, 102 | ): 103 | for p in group["params"]: 104 | if p.grad is None: 105 | continue 106 | params_with_grad.append(p) 107 | if p.grad.is_sparse: 108 | raise RuntimeError("AdamW does not support sparse gradients") 109 | # if p.grad.dtype in {torch.float16, torch.bfloat16}: 110 | # p.grad = p.grad.float() 111 | grads.append(p.grad) 112 | state = self.state[p] 113 | 114 | factored, _ = self._get_options(group, p.shape) 115 | # State initialization 116 | if len(state) == 0: 117 | # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off. 118 | # This is because kernel launches are costly on CUDA and XLA. 119 | state["step"] = torch.tensor(0.0) 120 | # Exponential moving average of gradient values 121 | state["exp_avg"] = torch.zeros((), dtype=torch.float, device=p.device) 122 | self.init_qstate(p, "exp_avg") 123 | # Exponential moving average of squared gradient values 124 | if factored: 125 | state["exp_avg_sq_row"] = torch.zeros(p.shape[:-1], device=p.device) 126 | state["exp_avg_sq_col"] = torch.zeros(p.shape[:-2] + p.shape[-1:], device=p.device) 127 | else: 128 | state["exp_avg_sq"] = torch.zeros((), dtype=torch.float, device=p.device) 129 | self.init_qstate(p, "exp_avg_sq") 130 | 131 | state_steps.append(state["step"]) 132 | exp_avgs.append(state["exp_avg"]) 133 | exp_avg_sqs_factored.append(factored) 134 | if factored: 135 | exp_avg_sq_rows.append(state["exp_avg_sq_row"]) 136 | exp_avg_sq_cols.append(state["exp_avg_sq_col"]) 137 | exp_avg_sqs.append(None) 138 | else: 139 | exp_avg_sq_rows.append(None) 140 | exp_avg_sq_cols.append(None) 141 | exp_avg_sqs.append(state["exp_avg_sq"]) 142 | 143 | exp_avgs_q_enabled.append(self.override_q_enable[id(p)] if id(p) in self.override_q_enable else state["exp_avg_qstate"]["enable"]) 144 | exp_avg_sqs_q_enabled.append(self.override_q_enable[id(p)] if id(p) in self.override_q_enable else state["exp_avg_sq_qstate"]["enable"]) 145 | exp_avgs_q_overhead.append(state["exp_avg_qstate"]["overhead"]) 146 | exp_avg_sqs_q_overhead.append(state["exp_avg_sq_qstate"]["overhead"]) 147 | exp_avgs_qmap.append(state["exp_avg_qstate"]["qmap"]) 148 | exp_avg_sqs_qmap.append(state["exp_avg_sq_qstate"]["qmap"]) 149 | 150 | 151 | @torch.no_grad() 152 | def step(self, closure=None): 153 | """Performs a single optimization step. 154 | 155 | Args: 156 | closure (Callable, optional): A closure that reevaluates the model 157 | and returns the loss. 158 | """ 159 | 160 | loss = None 161 | if closure is not None: 162 | with torch.enable_grad(): 163 | loss = closure() 164 | 165 | for group in self.param_groups: 166 | params_with_grad = [] 167 | grads = [] 168 | exp_avg_sqs_factored = [] 169 | exp_avgs = [] 170 | exp_avg_sqs = [] 171 | exp_avg_sq_rows = [] 172 | exp_avg_sq_cols = [] 173 | state_steps = [] 174 | beta1, beta2 = group["betas"] 175 | exp_avgs_q_enabled = [] 176 | exp_avg_sqs_q_enabled = [] 177 | exp_avgs_q_overhead = [] 178 | exp_avg_sqs_q_overhead = [] 179 | exp_avgs_qmap = [] 180 | exp_avg_sqs_qmap = [] 181 | 182 | self._init_group( 183 | group, 184 | params_with_grad, 185 | grads, 186 | exp_avgs, 187 | exp_avg_sqs, 188 | exp_avg_sqs_factored, 189 | exp_avg_sq_rows, 190 | exp_avg_sq_cols, 191 | state_steps, 192 | exp_avgs_q_enabled, 193 | exp_avg_sqs_q_enabled, 194 | exp_avgs_q_overhead, 195 | exp_avg_sqs_q_overhead, 196 | exp_avgs_qmap, 197 | exp_avg_sqs_qmap, 198 | ) 199 | 200 | kwargs = dict( 201 | params_with_grad=params_with_grad, 202 | grads=grads, 203 | exp_avgs=exp_avgs, 204 | exp_avg_sqs=exp_avg_sqs, 205 | exp_avg_sqs_factored=exp_avg_sqs_factored, 206 | exp_avg_sq_rows=exp_avg_sq_rows, 207 | exp_avg_sq_cols=exp_avg_sq_cols, 208 | state_steps=state_steps, 209 | exp_avgs_q_enabled=exp_avgs_q_enabled, 210 | exp_avg_sqs_q_enabled=exp_avg_sqs_q_enabled, 211 | exp_avgs_q_overhead=exp_avgs_q_overhead, 212 | exp_avg_sqs_q_overhead=exp_avg_sqs_q_overhead, 213 | exp_avgs_qmap=exp_avgs_qmap, 214 | exp_avg_sqs_qmap=exp_avg_sqs_qmap, 215 | exp_avg_qmetadata=self.get_qmetadata_by_state_name("exp_avg"), 216 | exp_avg_sq_qmetadata=self.get_qmetadata_by_state_name("exp_avg_sq"), 217 | beta1=beta1, 218 | beta2=beta2, 219 | lr=group["lr"], 220 | weight_decay=group["weight_decay"], 221 | eps=group["eps"], 222 | ) 223 | 224 | if group["fused"] and torch.jit.is_scripting(): 225 | raise RuntimeError("torch.jit.script not supported with fused optimizers") 226 | 227 | if group["fused"] and not torch.jit.is_scripting(): 228 | _fused_adamw4bit(**kwargs) 229 | else: 230 | _single_tensor_adamw4bit(**kwargs) 231 | 232 | # beta1, beta2 = group["betas"] 233 | # lr = group["lr"] 234 | # weight_decay = group["weight_decay"] 235 | # eps = group["eps"] 236 | 237 | # for p in group["params"]: 238 | # if p.grad is None: 239 | # continue 240 | # grad = p.grad.data 241 | # if grad.dtype in {torch.float16, torch.bfloat16}: 242 | # grad = grad.float() 243 | # if p.grad.is_sparse: 244 | # raise RuntimeError("AdamW does not support sparse gradients") 245 | 246 | # state = self.state[p] 247 | # grad_shape = p.grad.shape 248 | 249 | # factored, use_first_moment = self._get_options(group, grad_shape) 250 | # # State initialization 251 | # if len(state) == 0: 252 | # # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off. 253 | # # This is because kernel launches are costly on CUDA and XLA. 254 | # state["step"] = 0 255 | # # Exponential moving average of gradient values 256 | # if use_first_moment: 257 | # state["exp_avg"] = torch.tensor(0.0) 258 | # # Exponential moving average of squared gradient values 259 | # if factored: 260 | # state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) 261 | # state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) 262 | # else: 263 | # state["exp_avg_sq"] = torch.tensor(0.0) 264 | # # quantization state 265 | # self.init_qstate(p) 266 | 267 | # # take out optimizer state 268 | # param = p 269 | # # dequantize 270 | # if use_first_moment: 271 | # exp_avg = state["exp_avg"] 272 | # if exp_avg.numel() <= 1: 273 | # exp_avg.data = torch.zeros_like(param, memory_format=torch.preserve_format) 274 | # else: 275 | # hat_exp_avg = self.dequantize(param, 'exp_avg', exp_avg) 276 | # if hat_exp_avg is not None: 277 | # exp_avg.data = hat_exp_avg 278 | # del hat_exp_avg 279 | # else: 280 | # exp_avg = grad 281 | # if factored: 282 | # exp_avg_sq_row = state["exp_avg_sq_row"] 283 | # exp_avg_sq_col = state["exp_avg_sq_col"] 284 | # else: 285 | # exp_avg_sq = state["exp_avg_sq"] 286 | # if exp_avg_sq.numel() <= 1: 287 | # exp_avg_sq.data = torch.zeros_like(param, memory_format=torch.preserve_format) 288 | # else: 289 | # hat_exp_avg_sq = self.dequantize(param, 'exp_avg_sq', exp_avg_sq) 290 | # if hat_exp_avg_sq is not None: 291 | # exp_avg_sq.data = hat_exp_avg_sq 292 | # del hat_exp_avg_sq 293 | 294 | # # update 295 | # state["step"] += 1 296 | # # Perform stepweight decay 297 | # param.mul_(1 - lr * weight_decay) 298 | 299 | # # Decay the first and second moment running average coefficient 300 | # if use_first_moment: 301 | # exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 302 | # if factored: 303 | # update = (grad ** 2) 304 | # exp_avg_sq_row.mul_(beta2).add_(update.mean(dim=-1), alpha=1 - beta2) 305 | # exp_avg_sq_col.mul_(beta2).add_(update.mean(dim=-2), alpha=1 - beta2) 306 | # exp_avg_sq = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) 307 | # else: 308 | # exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 309 | 310 | # step = state["step"] 311 | # bias_correction1 = 1 - beta1 ** step 312 | # bias_correction2 = 1 - beta2 ** step 313 | # step_size = lr / bias_correction1 314 | # bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) 315 | 316 | # denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) 317 | # param.addcdiv_(exp_avg, denom, value=-step_size) 318 | 319 | # # take in optimizer state 320 | # if use_first_moment: 321 | # q_exp_avg = self.quantize(param, 'exp_avg', exp_avg) 322 | # if q_exp_avg is not None: 323 | # exp_avg.data = q_exp_avg 324 | # if not factored: 325 | # q_exp_avg_sq = self.quantize(param, 'exp_avg_sq', exp_avg_sq) 326 | # if q_exp_avg_sq is not None: 327 | # exp_avg_sq.data = q_exp_avg_sq 328 | 329 | return loss 330 | 331 | 332 | def _single_tensor_adamw4bit( 333 | params_with_grad: List[Tensor], 334 | grads: List[Tensor], 335 | exp_avgs: List[Tensor], 336 | exp_avg_sqs: List[Tensor], 337 | exp_avg_sqs_factored: List[bool], 338 | exp_avg_sq_rows: List[Tensor], 339 | exp_avg_sq_cols: List[Tensor], 340 | state_steps: List[Tensor], 341 | exp_avgs_q_enabled: List[bool], 342 | exp_avg_sqs_q_enabled: List[bool], 343 | exp_avgs_q_overhead: List, 344 | exp_avg_sqs_q_overhead: List, 345 | exp_avgs_qmap: List, 346 | exp_avg_sqs_qmap: List, 347 | exp_avg_qmetadata, 348 | exp_avg_sq_qmetadata, 349 | *, 350 | beta1: float, 351 | beta2: float, 352 | lr: float, 353 | weight_decay: float, 354 | eps: float 355 | ): 356 | 357 | for i, param in enumerate(params_with_grad): 358 | grad = grads[i] 359 | q_exp_avg = exp_avgs[i] 360 | q_exp_avg_sq = exp_avg_sqs[i] 361 | exp_avg_sq_row = exp_avg_sq_rows[i] 362 | exp_avg_sq_col = exp_avg_sq_cols[i] 363 | factored = exp_avg_sqs_factored[i] 364 | step_t = state_steps[i] 365 | 366 | # update step 367 | step_t += 1 368 | # Perform stepweight decay 369 | param.mul_(1 - lr * weight_decay) 370 | 371 | if factored: 372 | _single_quantized_factored_update( 373 | param, 374 | grad, 375 | q_exp_avg, 376 | exp_avg_sq_row, 377 | exp_avg_sq_col, 378 | exp_avgs_q_enabled[i], 379 | exp_avgs_q_overhead[i], 380 | exp_avgs_qmap[i], 381 | exp_avg_qmetadata, 382 | lr, 383 | beta1, 384 | beta2, 385 | eps, 386 | step_t.item() 387 | ) 388 | 389 | else: 390 | exp_avg_q_overhead = exp_avgs_q_overhead[i] 391 | exp_avg_sq_q_overhead = exp_avg_sqs_q_overhead[i] 392 | 393 | # dequantize 394 | if q_exp_avg.numel() <= 1: 395 | q_exp_avg.data = exp_avg = torch.zeros_like(param, memory_format=torch.preserve_format) 396 | elif exp_avgs_q_enabled[i]: 397 | exp_avg_q_overhead.update(exp_avg_qmetadata) 398 | exp_avg = vectorwise_dequant(q_exp_avg, qmap=exp_avgs_qmap[i], shape=param.shape, **exp_avg_q_overhead) 399 | exp_avg_q_overhead.clear() 400 | else: 401 | exp_avg = q_exp_avg 402 | if q_exp_avg_sq.numel() <= 1: 403 | q_exp_avg_sq.data = exp_avg_sq = torch.zeros_like(param, memory_format=torch.preserve_format) 404 | elif exp_avg_sqs_q_enabled[i]: 405 | exp_avg_sq_q_overhead.update(exp_avg_sq_qmetadata) 406 | exp_avg_sq = vectorwise_dequant(q_exp_avg_sq, qmap=exp_avg_sqs_qmap[i], shape=param.shape, **exp_avg_sq_q_overhead) 407 | exp_avg_sq_q_overhead.clear() 408 | else: 409 | exp_avg_sq = q_exp_avg_sq 410 | 411 | # Decay the first and second moment running average coefficient 412 | exp_avg.lerp_(grad, 1 - beta1) 413 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 414 | 415 | step = step_t.item() 416 | bias_correction1 = 1 - beta1 ** step 417 | bias_correction2 = 1 - beta2 ** step 418 | step_size = lr / bias_correction1 419 | bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) 420 | 421 | denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) 422 | param.addcdiv_(exp_avg, denom, value=-step_size) 423 | 424 | # quantize 425 | if exp_avgs_q_enabled[i]: 426 | qx, gen = vectorwise_quant(exp_avg, qmap=exp_avgs_qmap[i], shape=param.shape, **exp_avg_qmetadata) 427 | q_exp_avg.data = qx 428 | exp_avg_q_overhead.update(gen) 429 | else: 430 | pass 431 | if exp_avg_sqs_q_enabled[i]: 432 | qx, gen = vectorwise_quant(exp_avg_sq, qmap=exp_avg_sqs_qmap[i], shape=param.shape, **exp_avg_sq_qmetadata) 433 | q_exp_avg_sq.data = qx 434 | exp_avg_sq_q_overhead.update(gen) 435 | else: 436 | pass 437 | 438 | 439 | def _fused_adamw4bit( 440 | params_with_grad: List[Tensor], 441 | grads: List[Tensor], 442 | exp_avgs: List[Tensor], 443 | exp_avg_sqs: List[Tensor], 444 | exp_avg_sqs_factored: List[bool], 445 | exp_avg_sq_rows: List[Tensor], 446 | exp_avg_sq_cols: List[Tensor], 447 | state_steps: List[Tensor], 448 | exp_avgs_q_enabled: List[bool], 449 | exp_avg_sqs_q_enabled: List[bool], 450 | exp_avgs_q_overhead: List, 451 | exp_avg_sqs_q_overhead: List, 452 | exp_avgs_qmap: List, 453 | exp_avg_sqs_qmap: List, 454 | exp_avg_qmetadata, 455 | exp_avg_sq_qmetadata, 456 | *, 457 | beta1: float, 458 | beta2: float, 459 | lr: float, 460 | weight_decay: float, 461 | eps: float 462 | ): 463 | for i, param in enumerate(params_with_grad): 464 | grad = grads[i] 465 | q_exp_avg = exp_avgs[i] 466 | q_exp_avg_sq = exp_avg_sqs[i] 467 | exp_avg_sq_row = exp_avg_sq_rows[i] 468 | exp_avg_sq_col = exp_avg_sq_cols[i] 469 | factored = exp_avg_sqs_factored[i] 470 | step_t = state_steps[i] 471 | 472 | if factored: 473 | # fused_adam4bit do not apply to factored case 474 | 475 | # update step 476 | step_t += 1 477 | # Perform stepweight decay 478 | param.mul_(1 - lr * weight_decay) 479 | 480 | _single_quantized_factored_update( 481 | param, 482 | grad, 483 | q_exp_avg, 484 | exp_avg_sq_row, 485 | exp_avg_sq_col, 486 | exp_avgs_q_enabled[i], 487 | exp_avgs_q_overhead[i], 488 | exp_avgs_qmap[i], 489 | exp_avg_qmetadata, 490 | lr, 491 | beta1, 492 | beta2, 493 | eps, 494 | step_t.item() 495 | ) 496 | else: 497 | # update step 498 | step_t += 1 499 | if exp_avgs_q_enabled[i] != exp_avg_sqs_q_enabled[i]: 500 | raise ValueError(f"For same tensor, exp_avg and exp_avg_sq should be both quantized or unquantized simultaneously," 501 | f" but get ({exp_avgs_q_enabled[i]} {exp_avg_sqs_q_enabled[i]})") 502 | if exp_avgs_q_enabled[i]: 503 | if exp_avg_qmetadata["scale_type"] != "group": 504 | print(f"Warning: fused_adamw4bit only support block-wise scaling, but get exp_avg scale_type {exp_avg_qmetadata['scale_type']}.") 505 | if exp_avg_sq_qmetadata["scale_type"] != "group": 506 | print(f"Warning: fused_adamw4bit only support block-wise scaling, but get exp_avg_sq scale_type {exp_avg_sq_qmetadata['scale_type']}.") 507 | 508 | bytelength = (param.numel() + 1) // 2 509 | if q_exp_avg.numel() <= 1: 510 | q_exp_avg.data = torch.zeros((bytelength,), dtype=torch.int8, device=param.device) 511 | if q_exp_avg_sq.numel() <= 1: 512 | q_exp_avg_sq.data = torch.zeros((bytelength,), dtype=torch.int8, device=param.device) 513 | blocks = (param.numel() + 127) // 128 514 | if "max1" in exp_avgs_q_overhead[i]: 515 | exp_avg_scale = exp_avgs_q_overhead[i]["max1"] 516 | else: 517 | exp_avg_scale = torch.zeros((blocks,), dtype=torch.float32, device=param.device) 518 | exp_avgs_q_overhead[i]["max1"] = exp_avg_scale 519 | if "max1" in exp_avg_sqs_q_overhead[i]: 520 | exp_avg_sq_scale = exp_avg_sqs_q_overhead[i]["max1"] 521 | else: 522 | exp_avg_sq_scale = torch.zeros((blocks,), dtype=torch.float32, device=param.device) 523 | exp_avg_sqs_q_overhead[i]["max1"] = exp_avg_sq_scale 524 | 525 | with torch.cuda.device(param.device): 526 | import lpmm.cpp_extension.fused_adamw as fused_adamw 527 | fused_adamw.adamw4bit_single_tensor( 528 | param, 529 | grad, 530 | q_exp_avg, 531 | q_exp_avg_sq, 532 | exp_avg_scale, 533 | exp_avg_sq_scale, 534 | exp_avgs_qmap[i], 535 | exp_avg_sqs_qmap[i], 536 | beta1, 537 | beta2, 538 | lr, 539 | weight_decay, 540 | eps, 541 | step_t.item(), 542 | ) 543 | else: 544 | if q_exp_avg.numel() <= 1: 545 | q_exp_avg.data = torch.zeros_like(param, memory_format=torch.preserve_format) 546 | if q_exp_avg_sq.numel() <= 1: 547 | q_exp_avg_sq.data = torch.zeros_like(param, memory_format=torch.preserve_format) 548 | with torch.cuda.device(param.device): 549 | import lpmm.cpp_extension.fused_adamw as fused_adamw 550 | fused_adamw.adamw_single_tensor( 551 | param, 552 | grad, 553 | q_exp_avg, 554 | q_exp_avg_sq, 555 | beta1, 556 | beta2, 557 | lr, 558 | weight_decay, 559 | eps, 560 | step_t.item(), 561 | ) 562 | 563 | 564 | def _dispatch_sqrt(x: float): # float annotation is needed because of torchscript type inference 565 | if not torch.jit.is_scripting() and isinstance(x, torch.Tensor): 566 | return x.sqrt() 567 | else: 568 | return math.sqrt(x) 569 | 570 | 571 | def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): 572 | # copy from fairseq's adafactor implementation: 573 | # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 574 | r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).unsqueeze(-1) 575 | c_factor = exp_avg_sq_col.unsqueeze(-2) 576 | return torch.mul(r_factor, c_factor) 577 | 578 | 579 | def _single_quantized_factored_update( 580 | param, 581 | grad, 582 | q_exp_avg, 583 | exp_avg_sq_row, 584 | exp_avg_sq_col, 585 | exp_avg_q_enabled, 586 | exp_avg_q_overhead, 587 | exp_avg_qmap, 588 | exp_avg_qmetadata, 589 | lr, 590 | beta1, 591 | beta2, 592 | eps, 593 | step, 594 | ): 595 | # dequantize 596 | if q_exp_avg.numel() <= 1: 597 | q_exp_avg.data = exp_avg = torch.zeros_like(param, memory_format=torch.preserve_format) 598 | elif exp_avg_q_enabled: 599 | exp_avg_q_overhead = exp_avg_q_overhead 600 | exp_avg_q_overhead.update(exp_avg_qmetadata) 601 | exp_avg = vectorwise_dequant(q_exp_avg, qmap=exp_avg_qmap, shape=param.shape, **exp_avg_q_overhead) 602 | exp_avg_q_overhead.clear() 603 | else: 604 | exp_avg = q_exp_avg 605 | 606 | # Decay the first and second moment running average coefficient 607 | exp_avg.lerp_(grad, 1 - beta1) 608 | update = (grad ** 2) 609 | exp_avg_sq_row.mul_(beta2).add_(update.mean(dim=-1), alpha=1 - beta2) 610 | exp_avg_sq_col.mul_(beta2).add_(update.mean(dim=-2), alpha=1 - beta2) 611 | exp_avg_sq = _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) 612 | 613 | bias_correction1 = 1 - beta1 ** step 614 | bias_correction2 = 1 - beta2 ** step 615 | step_size = lr / bias_correction1 616 | bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) 617 | 618 | denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) 619 | param.addcdiv_(exp_avg, denom, value=-step_size) 620 | 621 | # quantize 622 | if exp_avg_q_enabled: 623 | qx, gen = vectorwise_quant(exp_avg, qmap=exp_avg_qmap, shape=param.shape, **exp_avg_qmetadata) 624 | q_exp_avg.data = qx 625 | exp_avg_q_overhead.update(gen) 626 | else: 627 | pass 628 | -------------------------------------------------------------------------------- /lpmm/optim/optimizer.py: -------------------------------------------------------------------------------- 1 | from collections import abc as container_abcs 2 | from collections import defaultdict 3 | from copy import deepcopy 4 | from itertools import chain 5 | 6 | import torch 7 | 8 | from ..functional import create_general_qmap, init_lpmm_generator 9 | from ..utils import get_rank 10 | from ..config import get_config 11 | 12 | compression_time = 0 13 | 14 | class LowBitOptimizer(torch.optim.Optimizer): 15 | def __init__(self, params, defaults, config): 16 | super(LowBitOptimizer, self).__init__(params, defaults) 17 | 18 | # init lpmm generator 19 | if torch.distributed.is_initialized(): 20 | seed = torch.randint(1 << 31, size=[], device=torch.device('cuda')) 21 | torch.distributed.broadcast(seed, src=0) 22 | init_lpmm_generator(get_rank(), seed.item()) # no stochastic rounding 23 | 24 | self.qconfig = get_config(config) 25 | self.override_q_enable = {} 26 | self.qmaps = {} 27 | 28 | def override_quantize_enable(self, module, param_name, enable): 29 | p = getattr(module, param_name) 30 | assert p is not None 31 | assert isinstance(p, torch.Tensor) or isinstance(p, torch.Parameter) 32 | if len(self.state[p]) != 0: 33 | raise ValueError("overriding enabling of quantized parameters is prohibited after state initialization.") 34 | self.override_q_enable[id(p)] = enable 35 | 36 | def init_qstate(self, p, state_name): 37 | state = self.state[p] 38 | field = f"{state_name}_qstate" 39 | state[field] = { 40 | "enable": True, 41 | "overhead": dict(), 42 | "qmap": None, 43 | } 44 | subconfig = self.get_subqconfig(state_name) 45 | state[field][ 46 | "enable" 47 | ] = _get_qenable_fn(p, subconfig.ENABLE, subconfig.THRESHOLD) 48 | 49 | md = self.get_qmetadata_by_state_name(state_name) 50 | qmap_key = (md['quant_type'], md['b'], md['signed']) 51 | if qmap_key not in self.qmaps: 52 | self.qmaps[qmap_key] = create_general_qmap(*qmap_key) 53 | self.qmaps[qmap_key] = self.qmaps[qmap_key].to(p.device) 54 | state[field]["qmap"] = self.qmaps[qmap_key] 55 | 56 | def get_qmetadata_by_state_name(self, optimizer_state_name): 57 | subconfig = self.get_subqconfig(optimizer_state_name) 58 | md = dict( 59 | b=subconfig.BITS, 60 | scale_type=subconfig.SCALE_TYPE.DEFAULT, 61 | quant_type=subconfig.QUANT_TYPE.DEFAULT, 62 | round_type=subconfig.ROUND_TYPE, 63 | gp_sz=subconfig.GROUP_SIZE, 64 | signed=subconfig.SIGNED, 65 | ) 66 | return md 67 | 68 | def state_dict(self): 69 | state_dict = super().state_dict() 70 | state_dict['qconfig'] = self.qconfig 71 | return state_dict 72 | 73 | def load_state_dict(self, state_dict): 74 | r"""Loads the optimizer state. 75 | 76 | Args: 77 | state_dict (dict): optimizer state. Should be an object returned 78 | from a call to :meth:`state_dict`. 79 | """ 80 | self.qconfig = state_dict['qconfig'] 81 | 82 | # deepcopy, to be consistent with module API 83 | state_dict = deepcopy(state_dict) 84 | # Validate the state_dict 85 | groups = self.param_groups 86 | saved_groups = state_dict['param_groups'] 87 | 88 | if len(groups) != len(saved_groups): 89 | raise ValueError("loaded state dict has a different number of " 90 | "parameter groups") 91 | param_lens = (len(g['params']) for g in groups) 92 | saved_lens = (len(g['params']) for g in saved_groups) 93 | if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): 94 | raise ValueError("loaded state dict contains a parameter group " 95 | "that doesn't match the size of optimizer's group") 96 | 97 | # Update the state 98 | id_map = dict(zip(chain.from_iterable((g['params'] for g in saved_groups)), 99 | chain.from_iterable((g['params'] for g in groups)))) 100 | 101 | def cast(param, value, key=None): 102 | r"""Make a deep copy of value, casting all tensors to device of param.""" 103 | if isinstance(value, torch.Tensor): 104 | # Floating-point types are a bit special here. They are the only ones 105 | # that are assumed to always match the type of params. 106 | # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 107 | if (key != "step"): 108 | if param.is_floating_point() and value.dtype != torch.int8: 109 | value = value.to(param.dtype) 110 | value = value.to(param.device) 111 | return value 112 | elif isinstance(value, dict): 113 | return {k: cast(param, v, key=k) for k, v in value.items()} 114 | elif isinstance(value, container_abcs.Iterable): 115 | return type(value)(cast(param, v) for v in value) 116 | else: 117 | return value 118 | 119 | # Copy state assigned to params (and cast tensors to appropriate types). 120 | # State that is not assigned to params is copied as is (needed for 121 | # backward compatibility). 122 | state = defaultdict(dict) 123 | for k, v in state_dict['state'].items(): 124 | if k in id_map: 125 | param = id_map[k] 126 | state[param] = cast(param, v) 127 | else: 128 | state[k] = v 129 | 130 | # Update parameter groups, setting their 'params' value 131 | def update_group(group, new_group): 132 | new_group['params'] = group['params'] 133 | return new_group 134 | param_groups = [ 135 | update_group(g, ng) for g, ng in zip(groups, saved_groups)] 136 | self.__setstate__({'state': state, 'param_groups': param_groups}) 137 | 138 | 139 | @torch.no_grad() 140 | def step(self, closure=None): 141 | r"""Performs a single optimization step with quantization. 142 | Args: 143 | closure (callable, optional): A closure that reevaluates the model 144 | and returns the loss. 145 | """ 146 | raise NotImplementedError( 147 | 'The step method needs overriding' 148 | ) 149 | 150 | def get_subqconfig(self, optimizer_state_name): 151 | raise NotImplementedError( 152 | 'The get_subconfig method needs overriding' 153 | ) 154 | 155 | 156 | def _get_qenable_fn(p, prior_enable, th): 157 | if not prior_enable: 158 | return False 159 | # if p.dim() < 2: 160 | # return False 161 | if th is not None and p.numel() <= th: 162 | return False 163 | return True -------------------------------------------------------------------------------- /lpmm/optim/sgd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import List, Optional 4 | 5 | from .optimizer import LowBitOptimizer 6 | from ..functional import vectorwise_dequant, vectorwise_quant 7 | 8 | __all__ = ['SGD'] 9 | 10 | 11 | class SGD(LowBitOptimizer): 12 | def __init__( 13 | self, 14 | params, 15 | lr, 16 | momentum=0, 17 | dampening=0, 18 | weight_decay=0, 19 | nesterov=False, 20 | qconfig=None, 21 | *, 22 | maximize: bool = False, 23 | fused: Optional[bool] = False, 24 | ): 25 | if lr < 0.0: 26 | raise ValueError("Invalid learning rate: {}".format(lr)) 27 | if momentum < 0.0: 28 | raise ValueError("Invalid momentum value: {}".format(momentum)) 29 | if weight_decay < 0.0: 30 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 31 | if fused is True: 32 | raise ValueError("Not support fused SGD currently.") 33 | 34 | defaults = dict( 35 | lr=lr, 36 | momentum=momentum, 37 | weight_decay=weight_decay, 38 | dampening=dampening, 39 | nesterov=nesterov, 40 | maximize=maximize, 41 | ) 42 | super().__init__(params, defaults, qconfig) 43 | 44 | def __setstate__(self, state): 45 | super().__setstate__(state) 46 | 47 | def get_subqconfig(self, optimizer_state_name): 48 | if optimizer_state_name == 'momentum_buffer': 49 | return self.qconfig.QUANT.M 50 | else: 51 | raise ValueError( 52 | f"" 53 | ) 54 | 55 | def _init_group( 56 | self, 57 | group, 58 | params_with_grad, 59 | grads, 60 | momentum_buffer_list, 61 | momentum_buffer_q_enabled, 62 | momentum_buffer_q_overhead, 63 | momentum_buffer_qmap, 64 | ): 65 | for p in group['params']: 66 | if p.grad is None: 67 | continue 68 | params_with_grad.append(p) 69 | if p.grad.is_sparse: 70 | raise RuntimeError("SGD does not support sparse gradients") 71 | grads.append(p.grad) 72 | state = self.state[p] 73 | if len(state) == 0: 74 | state["momentum_buffer"] = torch.zeros((), dtype=torch.float, device=p.device) 75 | self.init_qstate(p, "momentum_buffer") 76 | 77 | momentum_buffer_list.append(state['momentum_buffer']) 78 | momentum_buffer_q_enabled.append(self.override_q_enable[id(p)] if id(p) in self.override_q_enable else state["momentum_buffer_qstate"]["enable"]) 79 | momentum_buffer_q_overhead.append(state["momentum_buffer_qstate"]["overhead"]) 80 | momentum_buffer_qmap.append(state["momentum_buffer_qstate"]["qmap"]) 81 | 82 | @torch.no_grad() 83 | def step(self, closure=None): 84 | """Performs a single optimization step. 85 | 86 | Args: 87 | closure (Callable, optional): A closure that reevaluates the model 88 | and returns the loss. 89 | """ 90 | loss = None 91 | if closure is not None: 92 | with torch.enable_grad(): 93 | loss = closure() 94 | 95 | for group in self.param_groups: 96 | params_with_grad = [] 97 | grads = [] 98 | momentum_buffer_list = [] 99 | momentum_buffer_q_enabled = [] 100 | momentum_buffer_q_overhead = [] 101 | momentum_buffer_qmap = [] 102 | 103 | self._init_group( 104 | group, 105 | params_with_grad, 106 | grads, 107 | momentum_buffer_list, 108 | momentum_buffer_q_enabled, 109 | momentum_buffer_q_overhead, 110 | momentum_buffer_qmap, 111 | ) 112 | 113 | kwargs = dict( 114 | params_with_grad=params_with_grad, 115 | grads=grads, 116 | momentum_buffer_list=momentum_buffer_list, 117 | momentum_buffer_q_enabled=momentum_buffer_q_enabled, 118 | momentum_buffer_q_overhead=momentum_buffer_q_overhead, 119 | momentum_buffer_qmap=momentum_buffer_qmap, 120 | momentum_buffer_qmetadata=self.get_qmetadata_by_state_name("momentum_buffer"), 121 | weight_decay=group['weight_decay'], 122 | momentum=group['momentum'], 123 | lr=group['lr'], 124 | dampening=group['dampening'], 125 | nesterov=group['nesterov'], 126 | maximize=group['maximize'], 127 | ) 128 | _single_tensor_sgd4bit(**kwargs) 129 | 130 | return loss 131 | 132 | 133 | def _single_tensor_sgd4bit( 134 | params_with_grad: List[Tensor], 135 | grads: List[Tensor], 136 | momentum_buffer_list: List[Tensor], 137 | momentum_buffer_q_enabled: List[bool], 138 | momentum_buffer_q_overhead: List, 139 | momentum_buffer_qmap: List, 140 | momentum_buffer_qmetadata, 141 | *, 142 | weight_decay: float, 143 | momentum: float, 144 | lr: float, 145 | dampening: float, 146 | nesterov: bool, 147 | maximize: bool, 148 | ): 149 | 150 | for i, param in enumerate(params_with_grad): 151 | d_p = grads[i] if not maximize else -grads[i] 152 | 153 | if weight_decay != 0: 154 | d_p = d_p.add(param, alpha=weight_decay) 155 | 156 | if momentum != 0: 157 | # dequantize and decay 158 | q_overhead = momentum_buffer_q_overhead[i] 159 | q_buf = momentum_buffer_list[i] 160 | if q_buf.numel() <= 1: 161 | # decay not needed when initializing first moment 162 | q_buf.data = buf = torch.clone(d_p).detach() 163 | else: 164 | if momentum_buffer_q_enabled[i]: 165 | q_overhead.update(momentum_buffer_qmetadata) 166 | buf = vectorwise_dequant(q_buf, qmap=momentum_buffer_qmap[i], shape=param.shape, **q_overhead) 167 | q_overhead.clear() 168 | else: 169 | buf = q_buf 170 | 171 | # decay 172 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 173 | 174 | # udpate 175 | if nesterov: 176 | d_p = d_p.add(buf, alpha=momentum) 177 | else: 178 | d_p = buf 179 | 180 | param.add_(d_p, alpha=-lr) 181 | 182 | # quantize 183 | if momentum != 0: 184 | if momentum_buffer_q_enabled[i]: 185 | qx, gen = vectorwise_quant(buf, qmap=momentum_buffer_qmap[i], shape=param.shape, **momentum_buffer_qmetadata) 186 | q_buf.data = qx 187 | q_overhead.update(gen) 188 | -------------------------------------------------------------------------------- /lpmm/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import numpy as np 5 | from typing import OrderedDict 6 | import json 7 | import os 8 | 9 | from lpmm.functional import vectorwise_dequant, vectorwise_quant, _max_reduce_except_dim 10 | 11 | 12 | def empty_cache(ratio): 13 | if ratio is None: 14 | return 15 | allocated = torch.cuda.memory_allocated(0) 16 | reserved = torch.cuda.memory_reserved(0) 17 | if reserved > 0 and allocated / reserved < ratio: 18 | torch.cuda.empty_cache() 19 | 20 | 21 | def get_memory_usage(print_info=False): 22 | """Get accurate gpu memory usage by querying torch runtime""" 23 | allocated = torch.cuda.memory_allocated(0) 24 | reserved = torch.cuda.memory_reserved(0) 25 | if print_info: 26 | print("allocated: %.2f MB" % (allocated / 1024 / 1024), flush=True) 27 | print("reserved: %.2f MB" % (reserved / 1024 / 1024), flush=True) 28 | return allocated 29 | 30 | 31 | def compute_tensor_bytes(tensors): 32 | """Compute the bytes used by a list of tensors""" 33 | if not isinstance(tensors, (list, tuple)): 34 | tensors = [tensors] 35 | 36 | ret = 0 37 | for x in tensors: 38 | if x.dtype in [torch.float32, torch.int]: 39 | ret += np.prod(x.size()) * 4 40 | elif x.dtype in [torch.bfloat16, torch.float16, torch.int16]: 41 | ret += np.prod(x.size()) * 2 42 | elif x.dtype in [torch.int8]: 43 | ret += np.prod(x.size()) 44 | 45 | return ret 46 | 47 | 48 | def get_rank(): 49 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 50 | 51 | 52 | def get_world_size(): 53 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 54 | 55 | 56 | def print0(*args, **kwargs): 57 | if get_rank() == 0: 58 | print(*args, **kwargs) 59 | 60 | 61 | def min_fn(a, b): 62 | return a < b 63 | 64 | 65 | def max_fn(a, b): 66 | return a > b 67 | 68 | 69 | def get_metric_fn(metric_op): 70 | if metric_op == 'min': 71 | return min_fn 72 | elif metric_op == 'max': 73 | return max_fn 74 | else: 75 | raise NotImplementedError 76 | 77 | 78 | def sqnr(x, qx): 79 | Ps = torch.norm(x) 80 | Pn = torch.norm(x-qx) 81 | return 20 * torch.log10(Ps/Pn) 82 | 83 | 84 | def relerr(x, qx): 85 | abs_error = torch.abs(x - qx) 86 | rel_error = abs_error.norm() / torch.abs(x).norm() 87 | return rel_error 88 | 89 | 90 | def jsd(x, qx): 91 | x = x.flatten() 92 | qx = qx.flatten() 93 | m = 0.5 * (x + qx) 94 | jsd = 0.5 * (F.kl_div(x, m) + F.kl_div(qx, m)) 95 | return jsd 96 | 97 | 98 | def abserr(x, qx): 99 | return torch.abs(x - qx).mean() 100 | 101 | 102 | def get_metric_from_q_and_dq(x, op, average, **kwargs): 103 | metric_fn_map = { 104 | 'snqr': sqnr, 105 | 'relerr': relerr, 106 | 'abserr': abserr, 107 | } 108 | metric_fn = metric_fn_map['relerr'] 109 | total_metric = 0. 110 | for _ in range(average): 111 | qx, md = vectorwise_quant(x, **kwargs) 112 | x_hat = vectorwise_dequant(qx, **md) 113 | total_metric += metric_fn(op(x), op(x_hat)) 114 | total_metric /= average 115 | return total_metric 116 | 117 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ninja 2 | pyyaml 3 | pynvml 4 | yacs 5 | numpy 6 | torch >= 1.13 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import glob 6 | import os 7 | 8 | from setuptools import find_packages, setup 9 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 10 | 11 | 12 | 13 | 14 | setup( 15 | name=f"lpmm", 16 | description="low-bit optimizers.", 17 | keywords="gpu optimizers optimization low-bit quantization compression", 18 | url="https://github.com/thu-ml/low-bit-optimizers", 19 | packages=find_packages(), 20 | cmdclass={'build_ext': BuildExtension}, 21 | ext_modules=[ 22 | CUDAExtension( 23 | 'lpmm.cpp_extension.quantization', 24 | ['lpmm/cpp_extension/quantization.cc', 'lpmm/cpp_extension/quantization_kernel.cu'] 25 | ), 26 | CUDAExtension( 27 | 'lpmm.cpp_extension.fused_adamw', 28 | ['lpmm/cpp_extension/fused_adamw.cc', 'lpmm/cpp_extension/fused_adamw_kernel.cu'] 29 | ), 30 | ], 31 | ) --------------------------------------------------------------------------------