├── requirements.txt ├── install.bash ├── .gitignore ├── LICENSE ├── .clang-format ├── setup.py ├── examples └── speed_test_mistral_7b.py ├── csrc ├── torch_fp4.cpp ├── dequant_fp4_optimized.cu └── gemv_fp4_optimized.cu ├── sanity_check.py ├── README.md └── torch_bnb_fp4 └── __init__.py /requirements.txt: -------------------------------------------------------------------------------- 1 | bitsandbytes<0.43 2 | prettytable 3 | accelerate -------------------------------------------------------------------------------- /install.bash: -------------------------------------------------------------------------------- 1 | MAX_JOBS=4 TORCH_CUDA_ARCH_LIST="8.9" python setup.py install -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | torch_bnb_fp4.egg-info 2 | dist 3 | build 4 | __pycache__ 5 | .misc 6 | .vscode -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Alex Redden 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | # https://clang.llvm.org/docs/ClangFormatStyleOptions.html 2 | 3 | ColumnLimit: 140 4 | ContinuationIndentWidth: 4 5 | IndentWidth: 4 6 | TabWidth: 4 7 | ConstructorInitializerIndentWidth: 4 8 | IndentCaseLabels: true 9 | UseTab: Never 10 | SortIncludes: true 11 | SortUsingDeclarations: false 12 | AlignConsecutiveMacros: false 13 | AlignEscapedNewlines: DontAlign 14 | AlignAfterOpenBracket: BlockIndent 15 | AlignOperands: false 16 | AlignTrailingComments: false 17 | BinPackArguments: false 18 | BinPackParameters: false 19 | SpacesInContainerLiterals: false 20 | Cpp11BracedListStyle: true 21 | AllowShortFunctionsOnASingleLine: Empty 22 | AllowShortIfStatementsOnASingleLine: Always 23 | FixNamespaceComments: false 24 | ReflowComments: false 25 | NamespaceIndentation: All 26 | IncludeBlocks: Merge 27 | BreakStringLiterals: false 28 | BreakConstructorInitializers: AfterColon 29 | IndentPPDirectives: BeforeHash 30 | BreakTemplateDeclarations: No 31 | Standard: c++17 32 | Language: Cpp 33 | 34 | IncludeCategories: 35 | - Regex: '"[[:alnum:]._-]+"' 36 | Priority: 1 37 | SortPriority: 1 38 | - Regex: '^((<|").*/)' 39 | Priority: 2 40 | SortPriority: 2 41 | - Regex: '<[[:alnum:]._-]+>' 42 | Priority: 3 43 | SortPriority: 3 44 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from setuptools import find_packages, setup 5 | from torch.utils.cpp_extension import ( 6 | BuildExtension, 7 | CUDAExtension, 8 | _get_cuda_arch_flags, 9 | ) 10 | 11 | 12 | def check_device_capability_minimum_allowed(): 13 | if os.environ.get("IGNORE_DEVICE_CHECK", "0") == "1": 14 | print("Ignoring device check (intended for docker builds / CI)") 15 | else: 16 | count = torch.cuda.device_count() 17 | for c in range(count): 18 | i = torch.cuda.get_device_capability(c) 19 | print(f"Device {c}: {i}") 20 | if i[0] < 8: 21 | raise ValueError( 22 | "Minimum compute capability is 80, if you are compiling this extension without a device, such as in a docker container or CI, set the IGNORE_DEVICE_CHECK environment variable to 1 to ignore this check" 23 | ) 24 | if os.getenv("TORCH_CUDA_ARCH_LIST", "") != "": 25 | archs = _get_cuda_arch_flags() 26 | archs = [int(x.rsplit("_", 1)[-1]) for x in archs if ("+" not in x and "-" not in x and "ptx" not in x.lower())] 27 | for arch in archs: 28 | if arch < 80: 29 | raise ValueError( 30 | "Minimum compute capability is 80, if you are compiling this extension without a device, such as in a docker container or CI, set the IGNORE_DEVICE_CHECK environment variable to 1 to ignore this check" 31 | ) 32 | 33 | 34 | flags = [ 35 | "-O3", 36 | "-std=c++17", 37 | "-U__CUDA_NO_HALF_OPERATORS__", 38 | "-U__CUDA_NO_HALF_CONVERSIONS__", 39 | "-U__CUDA_NO_HALF2_OPERATORS__", 40 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 41 | "--expt-relaxed-constexpr", 42 | "--expt-extended-lambda", 43 | "--use_fast_math", 44 | "--resource-usage", 45 | "--ptxas-options=-allow-expensive-optimizations=true" 46 | ] 47 | 48 | 49 | def append_nvcc_threads(nvcc_extra_args): 50 | return nvcc_extra_args + ["--threads", "4"] 51 | 52 | 53 | # Make sure the device is capable 54 | check_device_capability_minimum_allowed() 55 | 56 | setup( 57 | name="torch_bnb_fp4", 58 | version="0.0.9", 59 | packages=find_packages( 60 | exclude=[ 61 | "csrc", 62 | "csrc/*", 63 | ".misc", 64 | ".misc/*", 65 | ] 66 | ), 67 | requires=["bitsandbytes<0.43", "prettytable", "accelerate"], 68 | ext_modules=[ 69 | CUDAExtension( 70 | name="torch_bnb_fp4_ext", 71 | sources=[ 72 | "csrc/gemv_fp4_optimized.cu", 73 | "csrc/dequant_fp4_optimized.cu", 74 | "csrc/torch_fp4.cpp", 75 | ], 76 | extra_compile_args={ 77 | "cxx": ["-O3", "-std=c++17"], 78 | "nvcc": flags, 79 | }, 80 | ) 81 | ], 82 | cmdclass={"build_ext": BuildExtension}, 83 | ) -------------------------------------------------------------------------------- /examples/speed_test_mistral_7b.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | from transformers import ( 5 | BitsAndBytesConfig, 6 | GenerationConfig, 7 | LlamaTokenizer, 8 | MistralForCausalLM, 9 | TextStreamer, 10 | ) 11 | 12 | from torch_bnb_fp4 import recursively_replace_with_fp4_linear 13 | 14 | # Change this to your desired dtype 15 | DTYPE = torch.float16 16 | 17 | model_path = "mistralai/Mistral-7B-Instruct-v0.2" 18 | 19 | # Load weights as bnb fp4 20 | model: MistralForCausalLM = MistralForCausalLM.from_pretrained( 21 | model_path, 22 | torch_dtype=DTYPE, 23 | quantization_config=BitsAndBytesConfig( 24 | load_in_4bit=True, 25 | bnb_4bit_compute_dtype=DTYPE, 26 | # Must use "fp4" for this library 27 | bnb_4bit_quant_type="fp4", 28 | # double quant is also unsupported, set this to false 29 | bnb_4bit_use_double_quant=False, 30 | ), 31 | ) 32 | 33 | tokenizer: LlamaTokenizer = LlamaTokenizer.from_pretrained(model_path) 34 | tokenizer.pad_token_id = tokenizer.eos_token_id 35 | model.config.pad_token_id = tokenizer.eos_token_id 36 | 37 | silly_example = ( 38 | "I want to create sentences which sound correct and are grammatically correct, but don't make any sense such as the following:\n" 39 | + "Now basically the only new principle involved is that instead of power being generated by the relative motion of conductors and fluxes, " 40 | + "it's produced by the modial interaction of magneto reluctance and capacitive duractants. The original machine had a base plate of pre-famulated amulite, " 41 | + "surmounted by a malleable logarithmic casing in such a way that the two spurving bearings were in a direct line with a panometric fam. " 42 | + "The lineup consisted simply of six hydroscoptic marsal veins so fitted to the ambifacent lunar wane shaft that side fumbling was effectively prevented. " 43 | + "The main winding was of the normal lotus-o-deltoid type placed in panendermic semi-boloid slots of the stator. " 44 | + "Every seventh conductor being connected by a non-reversible tremi pipe to the differential girdle spring on the up end of the gram meters. " 45 | + "Moreover, whenever fluorescent score motion is required, it may also be employed in conjunction with a drawn reciprocation dingle arm to reduce sinusoidal depleneration.\n" 46 | + "Could you help make a sentence that is grammatically correct but doesn't make any sense using overly verbose language relating to some new technology?\n" 47 | ) 48 | 49 | ctx = tokenizer.apply_chat_template( 50 | [{"role": "user", "content": silly_example}], 51 | add_generation_prompt=True, 52 | tokenize=True, 53 | return_tensors="pt", 54 | ).to(model.device) 55 | 56 | gen_kwargs = GenerationConfig( 57 | **{ 58 | "max_new_tokens": 256, 59 | "min_new_tokens": 255, 60 | "temperature": 0.78, 61 | "do_sample": True, 62 | "top_k": 40, 63 | "top_p": 0.9, 64 | "num_return_sequences": 1, 65 | "use_cache": True, 66 | } 67 | ) 68 | 69 | streamer = TextStreamer(tokenizer) 70 | 71 | with torch.inference_mode(): 72 | st = time.perf_counter() 73 | out = model.generate( 74 | ctx, 75 | generation_config=gen_kwargs, 76 | pad_token_id=tokenizer.eos_token_id, 77 | ) 78 | nd = time.perf_counter() - st 79 | print("\nRun # 1 (warmup) BNB\n") 80 | print("Time to generate: ", nd) 81 | print("Total new tokens: ", len(out[0]) - ctx.shape[1]) 82 | print("Total ctx+tokens: ", len(out[0])) 83 | print("Generation tok/s: ", (len(out[0]) - ctx.shape[1]) / nd, "\n") 84 | st = time.perf_counter() 85 | out = model.generate( 86 | ctx, 87 | generation_config=gen_kwargs, 88 | pad_token_id=tokenizer.eos_token_id, 89 | ) 90 | nd = time.perf_counter() - st 91 | print("\nRun # 2 BNB\n") 92 | print("Time to generate: ", nd) 93 | print("Total new tokens: ", len(out[0]) - ctx.shape[1]) 94 | print("Total ctx+tokens: ", len(out[0])) 95 | print("Generation tok/s: ", (len(out[0]) - ctx.shape[1]) / nd, "\n") 96 | 97 | 98 | # # # # Replace layers with torch-bnb-fp4 layers in-place 99 | recursively_replace_with_fp4_linear( 100 | model, 101 | as_dtype=DTYPE, 102 | use_codebook_dequant=True, # or False for fp4 tree dequant, though is much slower. 103 | only_replace_bnb_layers=True, 104 | ) 105 | 106 | with torch.inference_mode(): 107 | st = time.perf_counter() 108 | out = model.generate( 109 | ctx, 110 | generation_config=gen_kwargs, 111 | pad_token_id=tokenizer.eos_token_id, 112 | ) 113 | nd = time.perf_counter() - st 114 | print("\nRun # 1 (warmup) TorchFP4\n") 115 | print("Time to generate: ", nd) 116 | print("Total new tokens: ", len(out[0]) - ctx.shape[1]) 117 | print("Total ctx+tokens: ", len(out[0])) 118 | print("Generation tok/s: ", (len(out[0]) - ctx.shape[1]) / nd, "\n") 119 | st = time.perf_counter() 120 | out = model.generate( 121 | ctx, 122 | generation_config=gen_kwargs, 123 | pad_token_id=tokenizer.eos_token_id, 124 | ) 125 | nd = time.perf_counter() - st 126 | print("\nRun # 2 TorchFP4\n") 127 | print("Time to generate: ", nd) 128 | print("Total new tokens: ", len(out[0]) - ctx.shape[1]) 129 | print("Total ctx+tokens: ", len(out[0])) 130 | print("Generation tok/s: ", (len(out[0]) - ctx.shape[1]) / nd, "\n") 131 | -------------------------------------------------------------------------------- /csrc/torch_fp4.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | void dequantize_blockwise_fp4(torch::Tensor A, torch::Tensor absmax, int M, int N, int blocksize, int n, torch::Tensor out); 6 | torch::Tensor dequantize_blockwise_codebook_fp4( 7 | torch::Tensor A, torch::Tensor absmax, torch::Tensor codebook, int M, int N, int blocksize, int n, torch::ScalarType dtype 8 | ); 9 | torch::Tensor gemv_4bit_inference( 10 | torch::Tensor A, 11 | torch::Tensor B, 12 | torch::Tensor absmax, 13 | torch::Tensor datatype, 14 | int blocksize, 15 | torch::ScalarType dtype, 16 | std::vector Bshape 17 | ); 18 | 19 | #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), " must be a CUDA tensor") 20 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), " must be contiguous") 21 | 22 | enum class ScalarTypeEnum { 23 | float16, 24 | float32, 25 | bfloat16, 26 | }; 27 | 28 | torch::ScalarType get_scalar_type(ScalarTypeEnum type_enum) { 29 | switch (type_enum) { 30 | case ScalarTypeEnum::float16: 31 | return torch::kFloat16; 32 | case ScalarTypeEnum::float32: 33 | return torch::kFloat32; 34 | case ScalarTypeEnum::bfloat16: 35 | return torch::kBFloat16; 36 | default: 37 | throw py::type_error("Unsupported scalar type"); 38 | } 39 | }; 40 | 41 | torch::Tensor dequantize_fp4(torch::Tensor A, torch::Tensor absmax, int blocksize, int M, int N, ScalarTypeEnum o_type) { 42 | CHECK_CUDA(A); 43 | CHECK_CUDA(absmax); 44 | CHECK_CONTIGUOUS(A); 45 | CHECK_CONTIGUOUS(absmax); 46 | 47 | torch::Tensor out = torch::empty({M, N}, torch::TensorOptions().dtype(get_scalar_type(o_type)).device(A.device())); 48 | dequantize_blockwise_fp4(A, absmax, M, N, blocksize, M * N, out); 49 | return out; 50 | } 51 | 52 | torch::Tensor dequantize_fp4_codebook( 53 | torch::Tensor A, torch::Tensor absmax, torch::Tensor codebook, int M, int N, int blocksize, int n, ScalarTypeEnum dtype 54 | ) { 55 | CHECK_CUDA(A); 56 | CHECK_CUDA(absmax); 57 | CHECK_CUDA(codebook); 58 | CHECK_CONTIGUOUS(A); 59 | CHECK_CONTIGUOUS(absmax); 60 | CHECK_CONTIGUOUS(codebook); 61 | return dequantize_blockwise_codebook_fp4(A, absmax, codebook, M, N, blocksize, n, get_scalar_type(dtype)); 62 | } 63 | 64 | torch::Tensor qlinear(torch::Tensor A_in, torch::Tensor A, torch::Tensor absmax, int M, int N, int blocksize) { 65 | CHECK_CUDA(A); 66 | CHECK_CUDA(absmax); 67 | CHECK_CONTIGUOUS(A); 68 | CHECK_CONTIGUOUS(absmax); 69 | torch::Tensor out = torch::empty({M, N}, A_in.options()); 70 | dequantize_blockwise_fp4(A, absmax, M, N, blocksize, M * N, out); 71 | return torch::nn::functional::linear(A_in, out); 72 | } 73 | 74 | torch::Tensor qlinear_bias(torch::Tensor A_in, torch::Tensor A, torch::Tensor absmax, int M, int N, int blocksize, torch::Tensor bias) { 75 | CHECK_CUDA(A); 76 | CHECK_CUDA(absmax); 77 | CHECK_CONTIGUOUS(A); 78 | CHECK_CONTIGUOUS(absmax); 79 | torch::Tensor out = torch::empty({M, N}, A_in.options()); 80 | dequantize_blockwise_fp4(A, absmax, M, N, blocksize, M * N, out); 81 | return torch::nn::functional::linear(A_in, out, bias); 82 | } 83 | 84 | torch::Tensor 85 | qlinear_codebook(torch::Tensor A_in, torch::Tensor A, torch::Tensor absmax, torch::Tensor codebook, int M, int N, int blocksize) { 86 | CHECK_CUDA(A); 87 | CHECK_CUDA(absmax); 88 | CHECK_CONTIGUOUS(A); 89 | CHECK_CONTIGUOUS(absmax); 90 | torch::Tensor weight = dequantize_blockwise_codebook_fp4(A, absmax, codebook, M, N, blocksize, A.numel(), A_in.scalar_type()); 91 | return torch::nn::functional::linear(A_in, weight); 92 | } 93 | 94 | torch::Tensor qlinear_codebook_bias( 95 | torch::Tensor A_in, torch::Tensor A, torch::Tensor absmax, torch::Tensor codebook, int M, int N, int blocksize, torch::Tensor bias 96 | ) { 97 | CHECK_CUDA(A); 98 | CHECK_CUDA(absmax); 99 | CHECK_CONTIGUOUS(A); 100 | CHECK_CONTIGUOUS(absmax); 101 | torch::Tensor weight = dequantize_blockwise_codebook_fp4(A, absmax, codebook, M, N, blocksize, A.numel(), A_in.scalar_type()); 102 | return torch::nn::functional::linear(A_in, weight, bias); 103 | } 104 | 105 | torch::Tensor gemv_fp4( 106 | torch::Tensor A, 107 | torch::Tensor B, 108 | torch::Tensor absmax, 109 | torch::Tensor datatype, 110 | int blocksize, 111 | ScalarTypeEnum dtype, 112 | std::vector Bshape 113 | ) { 114 | CHECK_CUDA(A); 115 | CHECK_CUDA(B); 116 | CHECK_CUDA(absmax); 117 | CHECK_CUDA(datatype); 118 | CHECK_CONTIGUOUS(A); 119 | CHECK_CONTIGUOUS(B); 120 | CHECK_CONTIGUOUS(absmax); 121 | CHECK_CONTIGUOUS(datatype); 122 | return gemv_4bit_inference(A, B, absmax, datatype, blocksize, get_scalar_type(dtype), Bshape); 123 | } 124 | 125 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 126 | pybind11::enum_(m, "ScalarType") 127 | .value("bfloat16", ScalarTypeEnum::bfloat16) 128 | .value("float16", ScalarTypeEnum::float16) 129 | .value("float32", ScalarTypeEnum::float32) 130 | .export_values(); 131 | 132 | m.def("dequantize_fp4", &dequantize_fp4, "A test function for dequantize_fp4"); 133 | m.def("dequantize_fp4_codebook", &dequantize_fp4_codebook, "A test function for dequantize_fp4_interface"); 134 | m.def("gemv_fp4", &gemv_fp4, "A test function for gemm_4bit_inference_impl"); 135 | m.def("qlinear", &qlinear, "A test function for qlinear"); 136 | m.def("qlinear_bias", &qlinear_bias, "A test function for qlinear with bias"); 137 | m.def("qlinear_codebook", &qlinear_codebook, "A test function for qlinear with codebook"); 138 | m.def("qlinear_codebook_bias", &qlinear_codebook_bias, "A test function for qlinear with codebook and bias"); 139 | } -------------------------------------------------------------------------------- /sanity_check.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from accelerate.utils.bnb import BnbQuantizationConfig, replace_with_bnb_layers 4 | from torch import nn 5 | from torch.utils.benchmark import Timer 6 | from prettytable import PrettyTable 7 | from torch_bnb_fp4 import recursively_replace_with_fp4_linear 8 | 9 | 10 | def replace_with_bnb(model, dtype=torch.float16): 11 | str_dtype = { 12 | torch.float16: "fp16", 13 | torch.float32: "fp32", 14 | torch.bfloat16: "bf16", 15 | }[dtype] 16 | 17 | qconfig = BnbQuantizationConfig( 18 | load_in_4bit=True, 19 | torch_dtype=dtype, 20 | bnb_4bit_compute_dtype=str_dtype, 21 | bnb_4bit_quant_type="fp4", 22 | ) 23 | 24 | replace_with_bnb_layers(model, qconfig) 25 | 26 | return model 27 | 28 | 29 | class TinyModel(nn.Module): 30 | def __init__(self, in_dim, out_dim) -> None: 31 | super().__init__() 32 | self.in_proj = nn.Linear(in_dim, out_dim) 33 | 34 | def forward(self, x): 35 | return self.in_proj(x) 36 | 37 | 38 | class TestModel(nn.Module): 39 | def __init__(self, in_dim, hidden, num_hidden, out_dim) -> None: 40 | super().__init__() 41 | self.in_proj = nn.Linear(in_dim, hidden) 42 | self.blocks = nn.Sequential( 43 | *([nn.GELU(), nn.Linear(hidden, hidden)] * num_hidden) 44 | ) 45 | self.out_proj = nn.Linear(hidden, out_dim) 46 | 47 | def forward(self, x): 48 | x = self.in_proj(x) 49 | x = self.blocks(x) 50 | return self.out_proj(x) 51 | 52 | 53 | def time_run(model, inputs, label): 54 | timer = Timer("model(inputs)", globals=locals(), label=label) 55 | measure = timer.adaptive_autorange() 56 | return measure 57 | 58 | 59 | def get_avg(measurements, attribute, mul=1000000, round=5): 60 | return ( 61 | np.mean([getattr(m, attribute) * mul for m in measurements]).round(round).item() 62 | ) 63 | 64 | 65 | def check_speed(dtype=torch.float16, gemm_type="gemm"): 66 | torch.set_printoptions(precision=3, sci_mode=False, linewidth=180) 67 | torch.cuda.manual_seed_all(10) 68 | torch.manual_seed(10) 69 | generator = torch.Generator("cuda").manual_seed(10) 70 | model = TestModel(768, 2048, 4, 64).cuda().type(dtype) 71 | if gemm_type == "gemv": 72 | input_gemm = torch.randn(1, 768, generator=generator, device="cuda").type(dtype) 73 | else: 74 | input_gemm = torch.randn(2, 768, generator=generator, device="cuda").type(dtype) 75 | table = PrettyTable( 76 | field_names=["type", "mean (us)", "median (us)", "iqr (us)"], 77 | title=f"GEMM Speed Benchmark for {dtype} and matmul type [{gemm_type.upper()}] W/ 6 Layer MLP", 78 | ) 79 | with torch.inference_mode(): 80 | _ = time_run(model, input_gemm, "NORMAL") 81 | result1 = time_run(model, input_gemm, "NORMAL") 82 | result2 = time_run(model, input_gemm, "NORMAL") 83 | result_original = result1.merge([result2]) 84 | replace_with_bnb(model, dtype=dtype) 85 | model.cuda() 86 | 87 | _ = time_run(model, input_gemm, "BNB") 88 | result1 = time_run(model, input_gemm, "BNB") 89 | result2 = time_run(model, input_gemm, "BNB") 90 | result_bnb = result1.merge([result2]) 91 | 92 | model = recursively_replace_with_fp4_linear( 93 | model, as_dtype=dtype, device=model.in_proj.weight.device 94 | ) 95 | 96 | _ = time_run(model, input_gemm, "ZIPPY") 97 | result1 = time_run(model, input_gemm, "ZIPPY") 98 | result2 = time_run(model, input_gemm, "ZIPPY") 99 | result_zippy = result1.merge([result2]) 100 | 101 | result_dicts = [ 102 | [ 103 | "pytorch", 104 | get_avg(result_original, "mean"), 105 | get_avg(result_original, "median"), 106 | get_avg(result_original, "iqr"), 107 | ], 108 | [ 109 | "bitsandbytes", 110 | get_avg(result_bnb, "mean"), 111 | get_avg(result_bnb, "median"), 112 | get_avg(result_bnb, "iqr"), 113 | ], 114 | [ 115 | "torch-bnb-fp4", 116 | get_avg(result_zippy, "mean"), 117 | get_avg(result_zippy, "median"), 118 | get_avg(result_zippy, "iqr"), 119 | ], 120 | ] 121 | table.add_rows(result_dicts) 122 | print(table.get_string()) 123 | 124 | 125 | def simple_fwd(model, input): 126 | weight = model.in_proj.quant_data.dequantize() 127 | return torch.nn.functional.linear(input, weight, model.in_proj.quant_data.bias) 128 | 129 | 130 | def check(dtype=torch.float16): 131 | torch.set_printoptions(precision=2, sci_mode=False, linewidth=180) 132 | torch.cuda.manual_seed_all(10) 133 | torch.manual_seed(10) 134 | generator = torch.Generator("cuda").manual_seed(10) 135 | model = TinyModel(256, 256).cuda().type(dtype) 136 | modelhijack = TinyModel(256, 256).cuda().type(dtype) 137 | modelhijack.in_proj.weight.data = model.in_proj.weight.data.clone() 138 | modelhijack.in_proj.bias.data = model.in_proj.bias.data.clone() 139 | 140 | hijack = recursively_replace_with_fp4_linear(modelhijack).to("cuda", dtype=dtype) 141 | input_gemv_3dim = torch.randn(1, 1, 256, generator=generator, device="cuda").type( 142 | dtype 143 | ) 144 | input_gemv = torch.randn(1, 256, generator=generator, device="cuda").type(dtype) 145 | input_gemm_3dim = torch.randn( 146 | 1, 2048, 256, generator=generator, device="cuda" 147 | ).type(dtype) 148 | with torch.inference_mode(): 149 | 150 | output_gemv_3dim = model(input_gemv_3dim) 151 | output_gemv_3dim_hijack = hijack(input_gemv_3dim) 152 | difference_avg = (output_gemv_3dim - output_gemv_3dim_hijack).abs().mean() 153 | print( 154 | "Elementwise Diff. Avg Between nn.Linear & Quant GEMV 3dim:", 155 | difference_avg.item(), 156 | ) 157 | output_gemv = model(input_gemv) 158 | output_gemv_hijack = hijack(input_gemv) 159 | difference_avg = (output_gemv - output_gemv_hijack).abs().mean() 160 | print( 161 | "Elementwise Diff. Avg Between nn.Linear & Quant GEMV 2dim:", 162 | difference_avg.item(), 163 | ) 164 | 165 | output_gemm_3dim = model(input_gemm_3dim) 166 | output_gemm_3dim_hijack = hijack(input_gemm_3dim) 167 | difference_avg = (output_gemm_3dim - output_gemm_3dim_hijack).abs().mean() 168 | print( 169 | "Elementwise Diff. Avg Between nn.Linear & Quant GEMM 3dim:", 170 | difference_avg.item(), 171 | ) 172 | 173 | 174 | if __name__ == "__main__": 175 | print("\n============ Running Sanity Checks ============\n") 176 | print() 177 | print( 178 | " NOTE: The acceptable range for the elementwise difference avg\n is around 0.045-0.065, which is the same as bitsandbytes.\n" 179 | ) 180 | print("== Running sanity check for torch-bnb-fp4 fp32 ==\n") 181 | dt = torch.float32 182 | check_speed(dt, gemm_type="gemv") 183 | check_speed(dt, gemm_type="gemm") 184 | check(dt) 185 | print("\n== Running sanity check for torch-bnb-fp4 fp16 ==\n") 186 | dt = torch.float16 187 | check_speed(dt, gemm_type="gemv") 188 | check_speed(dt, gemm_type="gemm") 189 | check(dt) 190 | print("\n== Running sanity check for torch-bnb-fp4 bf16 ==\n") 191 | dt = torch.bfloat16 192 | check_speed(dt, gemm_type="gemv") 193 | check_speed(dt, gemm_type="gemm") 194 | check(dt) 195 | print("\n============= Sanity Checks Compete =============\n") 196 | -------------------------------------------------------------------------------- /csrc/dequant_fp4_optimized.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #define CDIV(x, y) (((x) + (y)-1) / (y)) 20 | 21 | using namespace cooperative_groups; 22 | namespace cg = cooperative_groups; 23 | 24 | typedef struct { 25 | float param[16]; 26 | } param_large_t; 27 | 28 | static const param_large_t CODE_PARAM = { 29 | .param = 30 | {0.00000f, 31 | 5.208333e-03f, 32 | 0.6666667f, 33 | 1.000000f, 34 | 0.333333f, 35 | 0.500000f, 36 | 0.1666667f, 37 | 0.250000f, 38 | -0.000000f, 39 | -5.208333e-03f, 40 | -0.6666667f, 41 | -1.000000f, 42 | -0.333333f, 43 | -0.500000f, 44 | -0.1666667f, 45 | -0.250000f} 46 | }; 47 | 48 | void CUDA_CHECK_RETURN_(cudaError_t cudaStatus) { 49 | if (cudaStatus != cudaSuccess) { 50 | printf("CUDA Failure: %s\n", cudaGetErrorString(cudaStatus)); 51 | // exit(EXIT_FAILURE); // so many segfaults before being able to print out actual crap because of this stupidity 52 | } 53 | } 54 | 55 | __device__ float dequantize_fp4_tree(unsigned char val, float absmax) { 56 | float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; 57 | if ((val & 0b0100) == 4) // 0 58 | if ((val & 0b0010) == 2) // 01 59 | if ((val & 0b0001) == 1) // 111 60 | return 0.25000000f * absmax * sign; // 1111 61 | else 62 | return 0.16666667f * absmax * sign; // 1110 63 | else if ((val & 0b0001) == 1) // 110 64 | return 0.50000000f * absmax * sign; // 1101 65 | else 66 | return 0.33333333f * absmax * sign; // 1100 67 | else if ((val & 0b0010) == 2) // 10 68 | if ((val & 0b0001) == 1) // 101 69 | return 1.00000000f * absmax * sign; // 1011 70 | else 71 | return 0.66666667f * absmax * sign; // 1010 72 | else if ((val & 0b0001) == 1) // 100 73 | return 5.208333333e-03f * absmax * sign; // 1001 74 | else 75 | return 0.00000000f * absmax * sign; // 1000 76 | } 77 | 78 | template __device__ __forceinline__ T convert_to_ty(float val); 79 | template <> __device__ __forceinline__ nv_bfloat16 convert_to_ty(float val) { 80 | return __float2bfloat16_rn(val); 81 | } 82 | template <> __device__ __forceinline__ nv_half convert_to_ty(float val) { 83 | return __float2half_rn(val); 84 | } 85 | template <> __device__ __forceinline__ float convert_to_ty(float val) { 86 | return val; 87 | } 88 | 89 | template 90 | __global__ void dequantize_blockwise_kernel_fp4(unsigned char *A, float *absmax, T *out, const int blocksize, const int n) { 91 | const int n_load = (gridDim.x * TILE_SIZE); 92 | int valid_items_load = 0; 93 | int valid_items_store = 0; 94 | const int base_idx = (blockIdx.x * TILE_SIZE); 95 | T vals[NUM_PER_TH * 2]; 96 | unsigned char qvals[NUM_PER_TH]; 97 | float local_abs_max; 98 | 99 | valid_items_load = 0; 100 | valid_items_store = 0; 101 | local_abs_max = -FLT_MAX; 102 | typedef cub::BlockLoad LoadChar; 103 | typedef cub::BlockStore StoreT; 104 | 105 | __shared__ typename LoadChar::TempStorage loadchar; 106 | __shared__ typename StoreT::TempStorage storet; 107 | for (unsigned int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) { 108 | valid_items_load = (n + 1) / 2 - i > TILE_SIZE ? TILE_SIZE : (n + 1) / 2 - i; 109 | valid_items_store = n - i * 2 > TILE_SIZE * 2 ? TILE_SIZE * 2 : n - i * 2; 110 | local_abs_max = __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) / (blocksize)]); 111 | __syncthreads(); 112 | 113 | LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); 114 | #pragma unroll NUM_PER_TH 115 | for (int j = 0; j < NUM_PER_TH; j++) { 116 | 117 | vals[j * 2] = convert_to_ty(dequantize_fp4_tree(qvals[j] >> 4, local_abs_max)); 118 | vals[j * 2 + 1] = convert_to_ty(dequantize_fp4_tree(qvals[j] & 0x0F, local_abs_max)); 119 | } 120 | __syncthreads(); 121 | StoreT(storet).Store(&(out[i * 2]), vals, valid_items_store); 122 | } 123 | } 124 | 125 | template 126 | __global__ void dequantize_blockwise_codebook_kernel_fp4( 127 | const unsigned char *A, const float *absmax, T *out, __grid_constant__ const param_large_t code, const int blocksize, const int n 128 | ) { 129 | 130 | const int n_load = (gridDim.x * TILE_SIZE); 131 | int valid_items_load = 0; 132 | int valid_items_store = 0; 133 | const int base_idx = (blockIdx.x * TILE_SIZE); 134 | T vals[NUM_PER_TH * 2]; 135 | unsigned char qvals[NUM_PER_TH]; 136 | float local_abs_max; 137 | 138 | auto block = cg::this_thread_block(); 139 | const int warp_idx = threadIdx.x / 32; 140 | const int warp_lane = threadIdx.x % 32; 141 | __shared__ float local_code[16]; 142 | auto convert_ty = convert_to_ty; 143 | 144 | valid_items_load = 0; 145 | valid_items_store = 0; 146 | local_abs_max = -FLT_MAX; 147 | typedef cub::BlockLoad LoadChar; 148 | typedef cub::BlockStore StoreT; 149 | 150 | if (block.thread_rank() < 16) { 151 | local_code[block.thread_rank()] = code.param[block.thread_rank()]; 152 | } 153 | __syncthreads(); 154 | __shared__ typename LoadChar::TempStorage loadchar; 155 | __shared__ typename StoreT::TempStorage storet; 156 | for (unsigned int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) { 157 | valid_items_load = (n + 1) / 2 - i > TILE_SIZE ? TILE_SIZE : (n + 1) / 2 - i; 158 | valid_items_store = n - i * 2 > TILE_SIZE * 2 ? TILE_SIZE * 2 : n - i * 2; 159 | local_abs_max = __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) / (blocksize)]); 160 | __syncthreads(); 161 | 162 | LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); 163 | #pragma unroll NUM_PER_TH 164 | for (int j = 0; j < NUM_PER_TH; j++) { 165 | vals[j * 2] = convert_ty(local_code[int(qvals[j] >> 4)] * local_abs_max); 166 | vals[j * 2 + 1] = convert_ty(local_code[int(qvals[j] & 0x0F)] * local_abs_max); 167 | } 168 | __syncthreads(); 169 | StoreT(storet).Store(&(out[i * 2]), vals, valid_items_store); 170 | } 171 | } 172 | 173 | template 174 | void launch_dequantize_blockwise_kernel_fp4(torch::Tensor A, torch::Tensor absmax, torch::Tensor out, int blocksize, int n) { 175 | const int blocks = CDIV(n, 1024); 176 | dequantize_blockwise_kernel_fp4<<>>( 177 | (unsigned char *)A.data_ptr(), (float *)absmax.data_ptr(), (T *)out.mutable_data_ptr(), (const int)(blocksize / 2), (const int)n 178 | ); 179 | CUDA_CHECK_RETURN_(cudaGetLastError()); 180 | } 181 | 182 | void dequantize_blockwise_fp4(torch::Tensor A, torch::Tensor absmax, int M, int N, int blocksize, int n, torch::Tensor out) { 183 | TORCH_CHECK(A.dtype() == torch::kUInt8, "A must be uint8"); 184 | TORCH_CHECK(absmax.dtype() == torch::kFloat32, "absmax must be float32"); 185 | TORCH_CHECK(A.is_cuda(), "A must be cuda"); 186 | TORCH_CHECK(absmax.is_cuda(), "absmax must be cuda"); 187 | TORCH_CHECK(out.is_cuda(), "out must be cuda"); 188 | switch (out.scalar_type()) { 189 | case torch::kFloat16: { 190 | launch_dequantize_blockwise_kernel_fp4(A, absmax, out, blocksize, n); 191 | break; 192 | } 193 | case torch::kFloat32: { 194 | launch_dequantize_blockwise_kernel_fp4(A, absmax, out, blocksize, n); 195 | break; 196 | } 197 | case torch::kBFloat16: { 198 | launch_dequantize_blockwise_kernel_fp4(A, absmax, out, blocksize, n); 199 | break; 200 | } 201 | default: { 202 | std::cout << "NO APPLICABLE DEQUANT DTYPE!" << std::endl; 203 | } 204 | } 205 | } 206 | 207 | torch::Tensor dequantize_blockwise_codebook_fp4( 208 | torch::Tensor A, torch::Tensor absmax, torch::Tensor codebook, int M, int N, int blocksize, int n, torch::ScalarType dtype 209 | ) { 210 | TORCH_CHECK(A.dtype() == torch::kUInt8, "A must be uint8"); 211 | TORCH_CHECK(absmax.dtype() == torch::kFloat32, "absmax must be float32"); 212 | TORCH_CHECK(A.is_cuda(), "A must be cuda"); 213 | TORCH_CHECK(absmax.is_cuda(), "absmax must be cuda"); 214 | torch::Tensor out = torch::empty({M, N}, torch::dtype(dtype).device(A.device())); 215 | const int blocks = CDIV(n, 1024); 216 | switch (dtype) { 217 | case torch::kFloat32: { 218 | dequantize_blockwise_codebook_kernel_fp4<<>>( 219 | (unsigned char *)A.data_ptr(), 220 | (float *)absmax.data_ptr(), 221 | (float *)out.mutable_data_ptr(), 222 | CODE_PARAM, 223 | (const int)(blocksize / 2), 224 | (const int)n 225 | ); 226 | break; 227 | } 228 | case torch::kFloat16: { 229 | dequantize_blockwise_codebook_kernel_fp4<<>>( 230 | (unsigned char *)A.data_ptr(), 231 | (float *)absmax.data_ptr(), 232 | (nv_half *)out.mutable_data_ptr(), 233 | CODE_PARAM, 234 | (const int)(blocksize / 2), 235 | (const int)n 236 | ); 237 | break; 238 | } 239 | case torch::kBFloat16: { 240 | dequantize_blockwise_codebook_kernel_fp4<<>>( 241 | (unsigned char *)A.data_ptr(), 242 | (float *)absmax.data_ptr(), 243 | (nv_bfloat16 *)out.mutable_data_ptr(), 244 | CODE_PARAM, 245 | (const int)(blocksize / 2), 246 | (const int)n 247 | ); 248 | break; 249 | } 250 | default: { 251 | std::cout << "NO APPLICABLE DTYPE!" << std::endl; 252 | } 253 | } 254 | return out; 255 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TORCH BNB FP4 2 | 3 | torch_bnb_fp4 is a library that provides a Torch C++ extension for faster nn.Linear FP4 ops, via streamlining bitsandbytes [`kgemm_4bit_inference_naive`](https://github.com/TimDettmers/bitsandbytes/blob/e820409c095ea7cbb5ce156992307b84352cbf90/csrc/kernels.cu#L3533-L3649) and [`kDequantizeBlockwise`](https://github.com/TimDettmers/bitsandbytes/blob/e820409c095ea7cbb5ce156992307b84352cbf90/csrc/kernels.cu#L832-L896) kernels. 4 | 5 | ## Overview 6 | 7 | TORCH BNB FP4 is a high-performance library designed to accelerate quantized `nn.Linear` ops, by utilizing bitsandbytes fp4 quantized weights. This library is built as a Torch C++ extension instead of being linked via ctypes as with bitsandbytes. This library is designed to be used in conjunction with bitsandbytes, and is not a replacement for bitsandbytes. 8 | 9 | ## Requirements 10 | 11 | System: 12 | 13 | - CUDA capable device with compute >= 8.0, so only Ampere / Ada / Hopper and above. 14 | - [System cudatoolkit](https://developer.nvidia.com/cuda-downloads) with the same major version eg 11.x, 12.x as their installed pytorch's cuda. Minor version mismatches dont matter as much, as in, 12.1 pytorch will work fine with system cudatoolkit 12.3, etc. This is specifically for the libs & headers of NVIDIA CUB. 15 | 16 | Note: 17 | 18 | - _I am 100% unsure whether this works on (non-wsl) windows at all._ 19 | - I have only tested this on a 4090 on linux with cudatoolkit 12.3 w/ pytorch2.2+cuda=12.1, a 4080 with cudatoolkit 12.2 & pytorch2.2+cuda=12.1 on windows w/ wsl, and a 3090 on linux with cudatoolkit 12.2 & pytorch2.2+cuda=12.1. Other setups are not guaranteed to work, but only because I have not tested them. If you find issues, feel welcome to submit an issue with your cudatoolkit version, cuda device and the errors you had. 20 | 21 | Libraries: 22 | 23 | - Pytorch 24 | - [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) 25 | 26 | ### UPDATE 3.16.2024 27 | 28 | Noticed that with bitsandbytes 0.43.x precision dramatically drops, have to update requirements to ensure `bitsandbytes<0.43` 29 | 30 | 31 | ## Installation 32 | 33 | To install torch_bnb_fp4, follow these steps: 34 | 35 | 1. Clone the repository: 36 | 37 | ```bash 38 | git clone https://github.com/aredden/torch-bnb-fp4 39 | ``` 40 | 41 | 2. Navigate to the project directory: 42 | 43 | ```bash 44 | cd torch-bnb-fp4 45 | ``` 46 | 47 | 3. To reduce the chance of issues finding the correct libraries / headers, I recommend finding your cuda library directory and referencing them in the install command, since frequently your PATH env variable ends up overwriting your system cudatoolkit library / include dirs with older cudatoolkit installations. 48 | 49 | - You will need to specify the actual compute architecture in the TORCH_CUDA_ARCH_LIST environment variable, Ampere consumer gpus are 8.6, Ada rtx 40xx gpus and workstation cards are 8.9, and hopper datacenter gpus are 9.0. 50 | 51 | - For an ampere A100 I would use `TORCH_CUDA_ARCH_LIST="8.0"` 52 | - ampere datacenter cards are a special case for ampere, I am unsure whether all are 8.0 or just the A100, so be sure to check. 53 | - For an ampere RTX 3070 I would use `TORCH_CUDA_ARCH_LIST="8.6"` 54 | - For an ada RTX 4080 I would use `TORCH_CUDA_ARCH_LIST="8.9"` 55 | - For a hopper H100 I would use `TORCH_CUDA_ARCH_LIST="9.0"` 56 | - ... 57 | 58 | - On linux and wsl, the library directory it is usually `/usr/local/cuda-x.y/lib64`, and the nvcc nvidia compiler is usually `/usr/local/cuda-x.y/bin/nvcc`, where `x` is the cudatoolkit major version, and `y` is the minor version, eg for cudatoolkit 12.2, you would use `/usr/local/cuda-12.2/lib64` then you can use the install command: 59 | 60 | ```bash 61 | # assuming cudatoolkit 12.2 and cuda your device is a 3090 (aka compute 8.6) 62 | 63 | export LD_LIBRARY_PATH=/usr/local/cuda-12.2/lib64 64 | export CUDACXX=/usr/local/cuda-12.2/bin/nvcc 65 | export CUDA_HOME=/usr/local/cuda-12.2 66 | TORCH_CUDA_ARCH_LIST="8.6" python setup.py install 67 | ``` 68 | 69 | OR, if you're feeling lucky / know your system has all libs / headers properly set up: 70 | 71 | ```bash 72 | # assuming your device is a 4090 (aka compute 8.9) 73 | 74 | TORCH_CUDA_ARCH_LIST="8.9" python setup.py install 75 | ``` 76 | 77 | ## Usage 78 | 79 | Once the library is installed, you can use it in your Torch projects by importing the `torch_bnb_fp4` module, which provides access to the pytorch extension. 80 | 81 | To make sure things are working correctly, you can use the script `sanity_check.py` in the root of this repository, which tests the speed and accuracy of the library. For reference, the output from my gpu is as follows: 82 | 83 | ``` 84 | 85 | ❯ python sanity_check.py 86 | 87 | ============ Running Sanity Checks ============ 88 | 89 | 90 | NOTE: The acceptable range for the elementwise difference avg 91 | is around 0.045-0.065, which is the same as bitsandbytes. 92 | 93 | == Running sanity check for torch-bnb-fp4 fp32 == 94 | 95 | +------------------------------------------------------------------------------+ 96 | | GEMM Speed Benchmark for torch.float32 and matmul type [GEMV] W/ 6 Layer MLP | 97 | +-----------------------+----------------+-------------------+-----------------+ 98 | | type | mean (us) | median (us) | iqr (us) | 99 | +-----------------------+----------------+-------------------+-----------------+ 100 | | pytorch | 53.18113 | 53.09262 | 0.12039 | 101 | | bitsandbytes | 92.71299 | 92.70629 | 0.16016 | 102 | | torch-bnb-fp4 | 63.77637 | 63.78534 | 0.0904 | 103 | +-----------------------+----------------+-------------------+-----------------+ 104 | +------------------------------------------------------------------------------+ 105 | | GEMM Speed Benchmark for torch.float32 and matmul type [GEMM] W/ 6 Layer MLP | 106 | +-----------------------+----------------+-------------------+-----------------+ 107 | | type | mean (us) | median (us) | iqr (us) | 108 | +-----------------------+----------------+-------------------+-----------------+ 109 | | pytorch | 68.58508 | 68.58716 | 0.02236 | 110 | | bitsandbytes | 155.64296 | 155.13446 | 1.37504 | 111 | | torch-bnb-fp4 | 93.45283 | 93.4459 | 0.02174 | 112 | +-----------------------+----------------+-------------------+-----------------+ 113 | Elementwise Diff. Avg Between nn.Linear & Quant GEMV 3dim: 0.05073589086532593 114 | Elementwise Diff. Avg Between nn.Linear & Quant GEMV 2dim: 0.056356318295001984 115 | Elementwise Diff. Avg Between nn.Linear & Quant GEMM 3dim: 0.05096859857439995 116 | 117 | == Running sanity check for torch-bnb-fp4 fp16 == 118 | 119 | +------------------------------------------------------------------------------+ 120 | | GEMM Speed Benchmark for torch.float16 and matmul type [GEMV] W/ 6 Layer MLP | 121 | +-----------------------+----------------+-------------------+-----------------+ 122 | | type | mean (us) | median (us) | iqr (us) | 123 | +-----------------------+----------------+-------------------+-----------------+ 124 | | pytorch | 54.0681 | 53.92455 | 0.28024 | 125 | | bitsandbytes | 93.89957 | 93.93588 | 0.22058 | 126 | | torch-bnb-fp4 | 64.42346 | 64.4473 | 0.04361 | 127 | +-----------------------+----------------+-------------------+-----------------+ 128 | +------------------------------------------------------------------------------+ 129 | | GEMM Speed Benchmark for torch.float16 and matmul type [GEMM] W/ 6 Layer MLP | 130 | +-----------------------+----------------+-------------------+-----------------+ 131 | | type | mean (us) | median (us) | iqr (us) | 132 | +-----------------------+----------------+-------------------+-----------------+ 133 | | pytorch | 79.42544 | 79.41179 | 0.0154 | 134 | | bitsandbytes | 130.14084 | 130.1941 | 0.54197 | 135 | | torch-bnb-fp4 | 98.83817 | 98.83849 | 0.0185 | 136 | +-----------------------+----------------+-------------------+-----------------+ 137 | Elementwise Diff. Avg Between nn.Linear & Quant GEMV 3dim: 0.04998779296875 138 | Elementwise Diff. Avg Between nn.Linear & Quant GEMV 2dim: 0.05657958984375 139 | Elementwise Diff. Avg Between nn.Linear & Quant GEMM 3dim: 0.05096435546875 140 | 141 | == Running sanity check for torch-bnb-fp4 bf16 == 142 | 143 | +-------------------------------------------------------------------------------+ 144 | | GEMM Speed Benchmark for torch.bfloat16 and matmul type [GEMV] W/ 6 Layer MLP | 145 | +-----------------------+----------------+--------------------+-----------------+ 146 | | type | mean (us) | median (us) | iqr (us) | 147 | +-----------------------+----------------+--------------------+-----------------+ 148 | | pytorch | 54.3889 | 54.14199 | 0.39099 | 149 | | bitsandbytes | 94.2237 | 93.96561 | 0.60638 | 150 | | torch-bnb-fp4 | 64.3852 | 64.35706 | 0.21559 | 151 | +-----------------------+----------------+--------------------+-----------------+ 152 | +-------------------------------------------------------------------------------+ 153 | | GEMM Speed Benchmark for torch.bfloat16 and matmul type [GEMM] W/ 6 Layer MLP | 154 | +-----------------------+----------------+--------------------+-----------------+ 155 | | type | mean (us) | median (us) | iqr (us) | 156 | +-----------------------+----------------+--------------------+-----------------+ 157 | | pytorch | 81.96011 | 81.94626 | 0.01879 | 158 | | bitsandbytes | 152.93054 | 152.84844 | 0.50242 | 159 | | torch-bnb-fp4 | 101.29481 | 101.28148 | 0.02136 | 160 | +-----------------------+----------------+--------------------+-----------------+ 161 | Elementwise Diff. Avg Between nn.Linear & Quant GEMV 3dim: 0.049072265625 162 | Elementwise Diff. Avg Between nn.Linear & Quant GEMV 2dim: 0.05712890625 163 | Elementwise Diff. Avg Between nn.Linear & Quant GEMM 3dim: 0.051025390625 164 | 165 | ============= Sanity Checks Compete ============= 166 | 167 | ``` 168 | 169 | The library provides a `TorchFP4Linear` class that can be used to replace standard PyTorch nn.Linear layers via bitsandbytes FP4 quantized layers. 170 | 171 | ```py 172 | from torch import nn 173 | from torch_bnb_fp4 import TorchFP4Linear, swap_linear_with_bnb_linear 174 | 175 | # Define your original linear layer 176 | # NOTE: this lib supports float16, bfloat16 and float32 tensors. 177 | original_linear_layer = nn.Linear( 178 | in_features=512, 179 | out_features=1024, 180 | bias=True 181 | ).to(device='cuda', dtype=torch.float16) 182 | 183 | original_linear_layer = swap_linear_with_bnb_linear( 184 | original_linear_layer, 185 | dtype=torch.float16 186 | ).cuda() # cuda must be called to quantize the linear weights via bnb. 187 | 188 | # wrap the linear layer via passing to the constructor of the TorchFP4Linear layer. 189 | quantized_linear_layer = TorchFP4Linear( 190 | original_linear_layer, 191 | use_codebook_dequant=True # or False for fp4 tree dequant, though doesn't make much difference. 192 | ).to(device='cuda', dtype=torch.float16) 193 | 194 | # Use the quantized layer as you would with a standard nn.Linear layer 195 | input_tensor = torch.randn(10, 512).to(device='cuda', dtype=torch.float16) 196 | output = quantized_linear_layer(input_tensor) 197 | 198 | # output is now a torch.float16 tensor, and can be used as input to other torch-bnb-fp4'd layers or models. 199 | 200 | ``` 201 | 202 | For huggingface models, I recommend loading as bitsandbytes fp4 quantized model, and then recursively replacing the BNB layers with the TorchFP4Linear layers. 203 | 204 | ```py 205 | import torch 206 | from torch_bnb_fp4 import recursively_replace_with_fp4_linear 207 | from transformers import AutoModelForCausalLM, BitsAndBytesConfig 208 | 209 | # Change this to your desired dtype 210 | DTYPE = torch.float16 211 | 212 | # Load weights as bnb fp4 213 | model = AutoModelForCausalLM.from_pretrained( 214 | "meta-llama/Llama-2-7b-hf", 215 | device_map={"": 0}, 216 | torch_dtype=DTYPE, 217 | load_in_4bit=True, 218 | quantization_config=BitsAndBytesConfig( 219 | load_in_4bit=True, 220 | bnb_4bit_compute_dtype=DTYPE, 221 | # Must use "fp4" for this library 222 | bnb_4bit_quant_type="fp4", 223 | # double quant is also unsupported, set this to false 224 | bnb_4bit_use_double_quant=False, 225 | ) 226 | ) 227 | 228 | # Replace layers with torch-bnb-fp4 layers in-place 229 | recursively_replace_with_fp4_linear( 230 | model, 231 | as_dtype=DTYPE, 232 | # or False for fp4 tree dequant, though doesn't make much difference. 233 | use_codebook_dequant=True, 234 | # Flag to only replace the layers which are bnb layers 235 | only_replace_bnb_layers=True, 236 | # Optional list of model keys to ignore. 237 | # This is useful in the case where your model is not a huggingface model / you want 238 | # to ignore swapping of certain layers or modules. 239 | ignore_layer_names=["lm_head"] 240 | ) 241 | 242 | 243 | # Now your model is torch-bnb-fp4'd 244 | 245 | 246 | 247 | ``` 248 | 249 | 250 | ## Acknowledgements 251 | 252 | I would like to thank Tim Dettmers for creating bitsandbytes and providing 99.99% of the foundation for this library. For more detailed information on the underlying quantization techniques, refer to the [bitsandbytes GitHub repository](https://github.com/TimDettmers/bitsandbytes). 253 | -------------------------------------------------------------------------------- /csrc/gemv_fp4_optimized.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #define HLF_MAX 65504 22 | #define TH 1024 23 | #define NUM 4 24 | #define NUM_BLOCK 4096 25 | #define num_values_4bit 32 26 | #define num_values_4bit 32 27 | 28 | typedef struct { 29 | float param[16]; 30 | } param_large_t; 31 | 32 | static const param_large_t CODE_PARAM = { 33 | .param = 34 | {0.00000f, 35 | 5.208333e-03f, 36 | 0.6666667f, 37 | 1.000000f, 38 | 0.333333f, 39 | 0.500000f, 40 | 0.1666667f, 41 | 0.250000f, 42 | -0.000000f, 43 | -5.208333e-03f, 44 | -0.6666667f, 45 | -1.000000f, 46 | -0.333333f, 47 | -0.500000f, 48 | -0.1666667f, 49 | -0.250000f} 50 | }; 51 | 52 | #define CDIV(x, y) (((x) + (y)-1) / (y)) 53 | 54 | void CUDA_CHECK_RETURN(cudaError_t cudaStatus) { 55 | if (cudaStatus != cudaSuccess) { 56 | printf("CUDA Failure: %s\n", cudaGetErrorString(cudaStatus)); 57 | } 58 | } 59 | 60 | template 61 | __global__ void gemv_4bit_inference_kernel( 62 | int M, 63 | int N, 64 | int K, 65 | T *__restrict__ const A, 66 | unsigned char *B, 67 | T_REDUCE *absmax, 68 | __grid_constant__ const param_large_t datatype, 69 | T *out, 70 | int lda, 71 | int ldb, 72 | int ldc, 73 | int blocksize 74 | ) { 75 | // per threadblock: 76 | // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] 77 | // 4 warps -> 4 loads per iter 78 | // 1x32 * 32x4 -> 1x4 outputs per thread block 79 | 80 | typedef cub::WarpReduce WarpReduce; 81 | __shared__ typename WarpReduce::TempStorage temp_storage[THREADS / 32]; 82 | 83 | const int warp_idx = threadIdx.x / 32; 84 | const int warp_lane = threadIdx.x % 32; 85 | const int row_B = (THREADS / 32) * blockIdx.x + warp_idx; 86 | const int num_values_8bit = num_values_4bit / 2; 87 | T local_C = T(0.0f); 88 | 89 | unsigned char local_B_4bit[num_values_8bit]; 90 | T local_B[num_values_4bit / 4]; 91 | T local_A[num_values_4bit / 4]; 92 | __shared__ T quant_map[16]; 93 | T local_absmax = T(0.0f); 94 | 95 | if (warp_lane < 16 && warp_idx == 0) quant_map[warp_lane] = T(datatype.param[warp_lane]); 96 | __syncthreads(); 97 | // A: [1, K] 98 | // B: [N, K] 99 | for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) { 100 | int inner_idx_halved = inner_idx / 2; 101 | int offset_B = ldb * row_B; 102 | int absidx = ((2 * offset_B) + inner_idx) / blocksize; 103 | local_absmax = __ldg(&(absmax[absidx])); 104 | 105 | if (row_B < M) { 106 | if ((inner_idx_halved + num_values_8bit) < (K / 2)) { 107 | reinterpret_cast(local_B_4bit)[0] = 108 | reinterpret_cast(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)]; 109 | } else { 110 | #pragma unroll 111 | for (int j = 0; j < (num_values_8bit); j++) { 112 | if ((inner_idx_halved) + j < (K / 2)) { 113 | local_B_4bit[j] = B[offset_B + inner_idx_halved + j]; 114 | } else { 115 | local_B_4bit[j] = 0b01110111; 116 | } 117 | } 118 | } 119 | } else { 120 | #pragma unroll 121 | for (int j = 0; j < (num_values_8bit); j++) { 122 | local_B_4bit[j] = 0b01110111; 123 | } 124 | } 125 | for (int i = 0; i < 4; i++) { 126 | #pragma unroll 127 | for (int k = 0; k < num_values_8bit / 4; k++) { 128 | local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax; 129 | local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax; 130 | } 131 | if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { 132 | // this is also relatively important for performance 133 | reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx / (num_values_4bit / 4) + i]; 134 | } else { 135 | #pragma unroll 136 | for (int k = 0; k < num_values_4bit / 4; k++) { 137 | if (inner_idx + (i * num_values_4bit / 4) + k < K) { 138 | local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)]; 139 | } else { 140 | local_A[k] = T(0.0f); 141 | } 142 | } 143 | } 144 | // accumulate in float; small performance hit for Ampere, but lower error for outputs 145 | #pragma unroll 146 | for (int k = 0; k < num_values_4bit / 4; k++) { 147 | local_C += local_A[k] * local_B[k]; 148 | } 149 | } 150 | } 151 | 152 | local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); 153 | 154 | if (row_B < M && warp_lane == 0) { 155 | out[row_B] = T(local_C); 156 | }; 157 | } 158 | 159 | template 160 | __global__ void gemv_4bit_inference_kernel_float( 161 | int M, 162 | int N, 163 | int K, 164 | T *__restrict__ const A, 165 | unsigned char *B, 166 | T_REDUCE *absmax, 167 | __grid_constant__ const param_large_t datatype, 168 | T *out, 169 | int lda, 170 | int ldb, 171 | int ldc, 172 | int blocksize 173 | ) { 174 | // per threadblock: 175 | // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] 176 | // 4 warps -> 4 loads per iter 177 | // 1x32 * 32x4 -> 1x4 outputs per thread block 178 | 179 | typedef cub::WarpReduce WarpReduce; 180 | __shared__ typename WarpReduce::TempStorage temp_storage[THREADS / 32]; 181 | 182 | const int warp_idx = threadIdx.x / 32; 183 | const int warp_lane = threadIdx.x % 32; 184 | const int row_B = (THREADS / 32) * blockIdx.x + warp_idx; 185 | const int num_values_8bit = num_values_4bit / 2; 186 | T local_C = T(0.0f); 187 | 188 | unsigned char local_B_4bit[num_values_8bit]; 189 | T local_B[num_values_4bit / 4]; 190 | T local_A[num_values_4bit / 4]; 191 | __shared__ T quant_map[16]; 192 | T local_absmax = T(0.0f); 193 | 194 | if (warp_lane < 16 && warp_idx == 0) quant_map[warp_lane] = T(datatype.param[warp_lane]); 195 | __syncthreads(); 196 | // A: [1, K] 197 | // B: [N, K] 198 | for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) { 199 | int inner_idx_halved = inner_idx / 2; 200 | int offset_B = ldb * row_B; 201 | int absidx = ((2 * offset_B) + inner_idx) / blocksize; 202 | local_absmax = __ldg(&(absmax[absidx])); 203 | 204 | if (row_B < M) { 205 | if ((inner_idx_halved + num_values_8bit) < (K / 2)) { 206 | reinterpret_cast(local_B_4bit)[0] = 207 | reinterpret_cast(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)]; 208 | } else { 209 | #pragma unroll 210 | for (int j = 0; j < (num_values_8bit); j++) { 211 | if ((inner_idx_halved) + j < (K / 2)) { 212 | local_B_4bit[j] = B[offset_B + inner_idx_halved + j]; 213 | } else { 214 | local_B_4bit[j] = 0b01110111; 215 | } 216 | } 217 | } 218 | } else { 219 | #pragma unroll 220 | for (int j = 0; j < (num_values_8bit); j++) { 221 | local_B_4bit[j] = 0b01110111; 222 | } 223 | } 224 | for (int i = 0; i < 4; i++) { 225 | #pragma unroll 226 | for (int k = 0; k < num_values_8bit / 4; k++) { 227 | local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax; 228 | local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax; 229 | } 230 | if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { 231 | // this is also relatively important for performance 232 | reinterpret_cast(local_A)[0] = 233 | reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0]; 234 | reinterpret_cast(local_A)[1] = 235 | reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1]; 236 | } else { 237 | #pragma unroll 238 | for (int k = 0; k < num_values_4bit / 4; k++) { 239 | if (inner_idx + (i * num_values_4bit / 4) + k < K) { 240 | local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)]; 241 | } else { 242 | local_A[k] = T(0.0f); 243 | } 244 | } 245 | } 246 | // accumulate in float; small performance hit for Ampere, but lower error for outputs 247 | #pragma unroll 248 | for (int k = 0; k < num_values_4bit / 4; k++) { 249 | local_C += local_A[k] * local_B[k]; 250 | } 251 | } 252 | } 253 | 254 | local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); 255 | 256 | if (row_B < M && warp_lane == 0) { 257 | out[row_B] = T(local_C); 258 | }; 259 | } 260 | 261 | template 262 | void gemv_4bit_inference_launch( 263 | int m, int n, int k, T *A, unsigned char *B, T_REDUCE *absmax, T_REDUCE *datatype, T *out, int lda, int ldb, int ldc, int blocksize 264 | ) { 265 | int num_blocks = CDIV(m, 4); 266 | gemv_4bit_inference_kernel<<>>(m, n, k, A, B, absmax, CODE_PARAM, out, lda, ldb, ldc, blocksize); 267 | } 268 | 269 | void gemv_4bit_inference_launch_float( 270 | int m, int n, int k, float *A, unsigned char *B, float *absmax, float *datatype, float *out, int lda, int ldb, int ldc, int blocksize 271 | ) { 272 | int num_blocks = CDIV(m, 4); 273 | gemv_4bit_inference_kernel_float 274 | <<>>(m, n, k, A, B, absmax, CODE_PARAM, out, lda, ldb, ldc, blocksize); 275 | } 276 | 277 | torch::Tensor gemv_4bit_inference( 278 | torch::Tensor A, 279 | torch::Tensor B, 280 | torch::Tensor absmax, 281 | torch::Tensor datatype, 282 | int blocksize, 283 | torch::ScalarType dtype, 284 | std::vector Bshape 285 | ) { 286 | torch::TensorOptions topts = A.options(); 287 | auto bout = Bshape[0]; 288 | 289 | int n = 1; 290 | int m = Bshape[0]; 291 | int k = Bshape[1]; 292 | int lda = Bshape[0]; 293 | int ldc = Bshape[0]; 294 | int ldb = CDIV(A.sizes()[1], 2); 295 | 296 | torch::Tensor out_; 297 | if (A.sizes().size() == 3) out_ = torch::empty({A.sizes()[0], A.sizes()[1], bout}, topts); 298 | else 299 | out_ = torch::empty({A.sizes()[0], bout}, topts); 300 | 301 | switch (dtype) { 302 | case torch::kFloat16: { 303 | TORCH_CHECK(dtype == torch::kFloat16, "Only fp16 dtype is supported for not reduced precision fp16 accumulation") 304 | TORCH_CHECK(absmax.scalar_type() == torch::kFloat32, "Only fp32 absmax is supported for not reduced precision accumulation") 305 | TORCH_CHECK(datatype.scalar_type() == torch::kFloat32, "Only fp32 code is supported for not reduced precision accumulation") 306 | gemv_4bit_inference_launch( 307 | m, 308 | n, 309 | k, 310 | (nv_half *)A.data_ptr(), 311 | (unsigned char *)B.data_ptr(), 312 | (float *)absmax.data_ptr(), 313 | (float *)datatype.data_ptr(), 314 | (nv_half *)out_.mutable_data_ptr(), 315 | lda, 316 | ldb, 317 | ldc, 318 | blocksize 319 | ); 320 | break; 321 | } 322 | case torch::kBFloat16: { 323 | TORCH_CHECK(dtype == torch::kBFloat16, "Only bf16 dtype is supported for not reduced precision accumulation") 324 | TORCH_CHECK(absmax.scalar_type() == torch::kFloat32, "Only fp32 absmax is supported for not reduced precision accumulation") 325 | TORCH_CHECK(datatype.scalar_type() == torch::kFloat32, "Only fp32 code is supported for not reduced precision accumulation") 326 | gemv_4bit_inference_launch( 327 | m, 328 | n, 329 | k, 330 | (nv_bfloat16 *)A.data_ptr(), 331 | (unsigned char *)B.data_ptr(), 332 | (float *)absmax.data_ptr(), 333 | (float *)datatype.data_ptr(), 334 | (nv_bfloat16 *)out_.mutable_data_ptr(), 335 | lda, 336 | ldb, 337 | ldc, 338 | blocksize 339 | ); 340 | break; 341 | } 342 | case torch::kFloat32: { 343 | TORCH_CHECK(dtype == torch::kFloat32, "Only float32 dtype is supported for float32 gemv_4bit_inference") 344 | TORCH_CHECK(absmax.scalar_type() == torch::kFloat32, "Only float32 absmax is supported for float32 gemv_4bit_inference") 345 | TORCH_CHECK(datatype.scalar_type() == torch::kFloat32, "Only float32 code is supported for float32 gemv_4bit_inference") 346 | gemv_4bit_inference_launch_float( 347 | m, 348 | n, 349 | k, 350 | (float *)A.data_ptr(), 351 | (unsigned char *)B.data_ptr(), 352 | (float *)absmax.data_ptr(), 353 | (float *)datatype.data_ptr(), 354 | (float *)out_.mutable_data_ptr(), 355 | lda, 356 | ldb, 357 | ldc, 358 | blocksize 359 | ); 360 | break; 361 | } 362 | default: 363 | throw std::runtime_error("Unsupported datatype"); 364 | } 365 | CUDA_CHECK_RETURN(cudaGetLastError()); 366 | 367 | return out_; 368 | } -------------------------------------------------------------------------------- /torch_bnb_fp4/__init__.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import logging 3 | from math import prod 4 | from typing import List, Optional, Tuple, Union, TypeVar 5 | 6 | import torch 7 | import bitsandbytes as bnb 8 | from bitsandbytes import functional as BF 9 | from bitsandbytes.nn.modules import Linear4bit, Params4bit, LinearFP4 10 | from torch import nn 11 | from torch_bnb_fp4_ext import ScalarType as ScalarType_ # type: ignore 12 | from torch_bnb_fp4_ext import dequantize_fp4 as dequantize_fp4_ # type: ignore 13 | from torch_bnb_fp4_ext import gemv_fp4 as gemv_fp4_ # type: ignore 14 | from torch_bnb_fp4_ext import qlinear as qlinear_ # type: ignore 15 | from torch_bnb_fp4_ext import qlinear_bias as qlinear_bias_ # type: ignore 16 | from torch_bnb_fp4_ext import dequantize_fp4_codebook as dequantize_fp4_codebook_ # type: ignore 17 | from torch_bnb_fp4_ext import qlinear_codebook as qlinear_codebook_ # type: ignore 18 | from torch_bnb_fp4_ext import qlinear_codebook_bias as qlinear_codebook_bias_ # type: ignore 19 | 20 | T_Model = TypeVar("T_Model", bound=nn.Module) 21 | 22 | class ScalarType(Enum): 23 | """ 24 | Enum encapsulating c++ bound torch scalar types for fp32, fp16, and bf16. 25 | """ 26 | 27 | bfloat16 = ScalarType_.bfloat16 28 | float16 = ScalarType_.float16 29 | float32 = ScalarType_.float32 30 | 31 | @classmethod 32 | def from_torch_dtype( 33 | cls, dtype: torch.dtype 34 | ) -> Union["ScalarType.bfloat16", "ScalarType.float16", "ScalarType.float32"]: 35 | """ 36 | Convert a torch dtype to a ScalarType. 37 | 38 | Args: 39 | dtype (torch.dtype): The torch dtype to be converted. 40 | 41 | Returns: 42 | Union[ScalarType.bfloat16, ScalarType.float16, ScalarType.float32]: The corresponding ScalarType. 43 | """ 44 | if dtype == torch.bfloat16: 45 | return cls.bfloat16 46 | elif dtype == torch.float16: 47 | return cls.float16 48 | elif dtype == torch.float32: 49 | return cls.float32 50 | else: 51 | raise ValueError(f"Unsupported dtype {dtype}") 52 | 53 | @classmethod 54 | def from_str( 55 | cls, dtype: str 56 | ) -> Union["ScalarType.bfloat16", "ScalarType.float16", "ScalarType.float32"]: 57 | """ 58 | Convert a string to a ScalarType. 59 | 60 | Args: 61 | dtype (str): The string to be converted. 62 | 63 | Returns: 64 | Union[ScalarType.bfloat16, ScalarType.float16, ScalarType.float32]: The corresponding ScalarType. 65 | """ 66 | if dtype == "bfloat16": 67 | return cls.bfloat16 68 | elif dtype == "float16": 69 | return cls.float16 70 | elif dtype == "float32": 71 | return cls.float32 72 | else: 73 | raise ValueError(f"Unsupported dtype {dtype}") 74 | 75 | @property 76 | def torch_dtype(self) -> torch.dtype: 77 | if self == ScalarType.BFloat16: 78 | return torch.bfloat16 79 | elif self == ScalarType.Float16: 80 | return torch.float16 81 | elif self == ScalarType.Float32: 82 | return torch.float32 83 | else: 84 | raise ValueError(f"Unsupported dtype {self}") 85 | 86 | 87 | @torch.no_grad 88 | def dequantize_fp4( 89 | qweight: torch.ByteTensor, 90 | absmax: torch.Tensor, 91 | blocksize: int, 92 | M: int, 93 | N: int, 94 | dtype=torch.float16, 95 | ) -> torch.FloatTensor: 96 | """ 97 | Dequantizes 4-bit quantized weights to floating-point representation. 98 | 99 | This function is designed to convert the 4-bit quantized weights back into their original 100 | floating-point format. Allows for reduced model size and potentially faster computation on 101 | compatible hardware, while still being able to perform operations in the model's original 102 | precision. 103 | 104 | Parameters: 105 | - qweight (torch.ByteTensor): The quantized weights, stored in a byte tensor. 106 | - absmax (torch.Tensor): The maximum absolute value of the weights, used for scaling during dequantization. 107 | - blocksize (int): The size of the block used for quantization. This affects how the weights were originally quantized. 108 | - M (int): The first dimension of the weight matrix. 109 | - N (int): The second dimension of the weight matrix. 110 | - dtype (torch.dtype, optional): The target data type for the dequantized weights. Defaults to torch.float16. 111 | 112 | Returns: 113 | - torch.FloatTensor: The dequantized weights, converted back to floating-point representation. 114 | 115 | The function internally calls a CUDA implementation `dequantize_fp4_` with the appropriate scalar type 116 | derived from the given dtype to perform the dequantization. This operation is performed without 117 | gradient tracking to ensure it is purely computational and does not affect backpropagation. 118 | """ 119 | return dequantize_fp4_( 120 | qweight, absmax, blocksize, M, N, ScalarType.from_torch_dtype(dtype).value 121 | ) 122 | 123 | 124 | @torch.no_grad 125 | def dequantize_fp4_codebook_invoke_qtype( 126 | qweight: torch.ByteTensor, 127 | absmax: torch.FloatTensor, 128 | code: torch.FloatTensor, 129 | blocksize: int, 130 | M: int, 131 | N: int, 132 | numel: int, 133 | qtype: ScalarType, 134 | ) -> torch.FloatTensor: 135 | """ 136 | Dequantizes 4-bit quantized weights to floating-point representation using codebook. 137 | 138 | This function is designed to convert the 4-bit quantized weights back into their original 139 | floating-point format. Allows for reduced model size and potentially faster computation on 140 | compatible hardware, while still being able to perform operations in the model's original 141 | precision. 142 | 143 | Parameters: 144 | - qweight (torch.ByteTensor): The quantized weights, stored in a byte tensor. 145 | - absmax (torch.Tensor): The maximum absolute value of the weights, used for scaling during dequantization. 146 | - code (torch.FloatTensor): The 16 element codebook used for dequantization. 147 | - blocksize (int): The size of the block used for quantization. This affects how the weights were originally quantized. 148 | - M (int): The first dimension of the weight matrix. 149 | - N (int): The second dimension of the weight matrix. 150 | - numel (int): The number of elements in the weight matrix. 151 | - qtype (torch_bnb_fp4.ScalarType): The quantization type. 152 | 153 | Returns: 154 | - torch.FloatTensor: The dequantized weights, converted back to floating-point representation using codebook. 155 | 156 | The function internally calls a CUDA implementation `dequantize_fp4_codebook_` with the appropriate scalar type 157 | derived from the given qtype to perform the dequantization. 158 | """ 159 | return dequantize_fp4_codebook_( 160 | qweight, 161 | absmax, 162 | code, 163 | M, 164 | N, 165 | blocksize, 166 | numel, 167 | qtype, 168 | ) 169 | 170 | 171 | @torch.no_grad 172 | def dequantize_fp4_codebook_invoke( 173 | qweight: torch.ByteTensor, 174 | absmax: torch.FloatTensor, 175 | code: torch.FloatTensor, 176 | blocksize: int, 177 | M: int, 178 | N: int, 179 | numel: int, 180 | qtype: torch.dtype, 181 | ) -> torch.FloatTensor: 182 | """ 183 | Dequantizes 4-bit quantized weights to floating-point representation using codebook and invokes the CUDA implementation. 184 | 185 | This function is designed to convert the 4-bit quantized weights back into their original 186 | floating-point format. Allows for reduced model size and potentially faster computation on 187 | compatible hardware, while still being able to perform operations in the model's original 188 | precision. 189 | 190 | Parameters: 191 | - qweight (torch.ByteTensor): The quantized uint8 fp4 weights, stored as a byte tensor, 192 | each byte represents two four bit weight indices in the codebook (code). 193 | - absmax (torch.Tensor): The maximum absolute value of the weights, 1 absmax per blocksize weights, 194 | used for scaling during dequantization. 195 | - code (torch.FloatTensor): The 16 element codebook used for dequantization. 196 | - blocksize (int): The number of elements per absmax. 197 | - M (int): The first dimension of the weight matrix. 198 | - N (int): The second dimension of the weight matrix. 199 | - numel (int): The number of elements in the weight matrix. 200 | - qtype (ScalarType): The quantization type. 201 | 202 | Returns: 203 | - torch.FloatTensor: The dequantized weights, converted back to floating-point representation using codebook. 204 | 205 | The function internally calls a CUDA implementation `dequantize_fp4_codebook_` with the appropriate scalar type 206 | derived from the given qtype to perform the dequantization. 207 | """ 208 | return dequantize_fp4_codebook_( 209 | qweight, 210 | absmax, 211 | code, 212 | M, 213 | N, 214 | blocksize, 215 | numel, 216 | ScalarType.from_torch_dtype(qtype).value, 217 | ) 218 | 219 | 220 | @torch.no_grad 221 | def gemm_4bit_inference( 222 | A: torch.Tensor, 223 | B: torch.ByteTensor, 224 | absmax: torch.Tensor, 225 | code: torch.Tensor, 226 | blocksize: int, 227 | dtype=torch.float16, 228 | Bshape=None, 229 | ) -> torch.FloatTensor: 230 | """ 231 | Performs 4-bit quantized matrix multiplication using a GEMV algorithm. 232 | 233 | This function is designed to perform matrix multiplication on 4-bit quantized matrices using the GEMM algorithm. 234 | It takes two input matrices A and B, the maximum absolute value of the weights (absmax), the codebook used for 235 | dequantization (code), the size of the block used for quantization (blocksize), the data type for the output (dtype), 236 | and the shape of matrix B (Bshape). 237 | 238 | Parameters: 239 | - A (torch.Tensor): The first input matrix, of shape (1, hidden) or (1, 1, hidden), where the last dimension is always 240 | equal to the total number of it's elements. 241 | - B (torch.ByteTensor): The quantized uint8 fp4 weights, stored as a byte tensor, each byte represents two four bit weight 242 | indices in the codebook. 243 | - absmax (torch.Tensor): The maximum absolute value of the weights, used for scaling during dequantization. 244 | - code (torch.Tensor): The 16 element codebook used for dequantization. 245 | - blocksize (int): The size of the block used for quantization. This affects how the weights were originally quantized. 246 | - dtype (torch.dtype): The data type for the output. 247 | - Bshape (List[int]): The shape of matrix B. 248 | 249 | Returns: 250 | - torch.FloatTensor: The result of the matrix multiplication, in the specified data type. 251 | 252 | The function internally calls a CUDA implementation `gemv_fp4_` with the appropriate scalar type 253 | derived from the given dtype to perform the matrix multiplication. 254 | """ 255 | return gemv_fp4_( 256 | A, B, absmax, code, blocksize, ScalarType.from_torch_dtype(dtype).value, Bshape 257 | ) 258 | 259 | 260 | @torch.no_grad 261 | def gemm_4bit_inference_qtype( 262 | A: torch.Tensor, 263 | B: torch.ByteTensor, 264 | absmax: torch.FloatTensor, 265 | code: torch.FloatTensor, 266 | blocksize: int, 267 | dtype: ScalarType = ScalarType.bfloat16.value, 268 | Bshape: List[int] = None, 269 | ) -> torch.FloatTensor: 270 | """ 271 | Performs 4-bit quantized matrix multiplication using a GEMV algorithm. 272 | 273 | This function is designed to perform matrix multiplication on 4-bit quantized matrices using the GEMM algorithm. 274 | It takes two input matrices A and B, the maximum absolute value of the weights (absmax), the codebook used for 275 | dequantization (code), the size of the block used for quantization (blocksize), the data type for the output (dtype), 276 | and the shape of matrix B (Bshape). 277 | 278 | Parameters: 279 | - A (torch.Tensor): The first input matrix, of shape (1, hidden) or (1, 1, hidden), where the last dimension is always 280 | equal to the total number of it's elements. 281 | - B (torch.ByteTensor): The quantized uint8 fp4 weights, stored as a byte tensor, each byte represents two four bit weight 282 | indices in the codebook. 283 | - absmax (torch.Tensor): The maximum absolute value of the weights, used for scaling during dequantization. 284 | - code (torch.Tensor): The 16 element codebook used for dequantization. 285 | - blocksize (int): The size of the block used for quantization. This affects how the weights were originally quantized. 286 | - dtype (torch.dtype): The data type for the output. 287 | - Bshape (List[int]): The original shape of the unquantized matrix B. 288 | 289 | Returns: 290 | - torch.FloatTensor: The result of the matrix multiplication, in the specified data type. 291 | 292 | The function internally calls a CUDA implementation `gemv_fp4_` with the appropriate scalar type 293 | derived from the given dtype to perform the matrix multiplication. 294 | """ 295 | return gemv_fp4_(A, B, absmax, code, blocksize, dtype, Bshape) 296 | 297 | 298 | @torch.no_grad 299 | def dequantize_fp4_qtype( 300 | qweight: torch.ByteTensor, 301 | absmax: torch.Tensor, 302 | blocksize: int, 303 | M: int, 304 | N: int, 305 | dtype: ScalarType = ScalarType.bfloat16.value, 306 | ) -> torch.FloatTensor: 307 | """ 308 | Dequantizes the 4-bit quantized weights. 309 | 310 | This function is designed to dequantize the 4-bit quantized weights. 311 | It takes the quantized weights (qweight), the maximum absolute value of the weights (absmax), 312 | the size of the block used for quantization (blocksize), the number of rows in the matrix (M), 313 | the number of columns in the matrix (N), and the data type for the output (dtype). 314 | 315 | Parameters: 316 | - qweight (torch.ByteTensor): The quantized uint8 fp4 weights, stored as a byte tensor, each byte represents two four bit weight 317 | indices in the codebook. 318 | - absmax (torch.Tensor): The maximum absolute value of the weights, used for scaling during dequantization. 319 | - blocksize (int): The size of the block used for quantization. This affects how the weights were originally quantized. 320 | - M (int): The number of rows in the matrix. 321 | - N (int): The number of columns in the matrix. 322 | - dtype (torch.dtype): The data type for the output. 323 | 324 | Returns: 325 | - torch.FloatTensor: The dequantized weights, in the specified data type. 326 | 327 | The function internally calls a CUDA implementation `dequantize_fp4_` with the appropriate scalar type 328 | derived from the given dtype to perform the dequantization. 329 | """ 330 | return dequantize_fp4_( 331 | qweight, 332 | absmax, 333 | blocksize, 334 | M, 335 | N, 336 | dtype, 337 | ) 338 | 339 | 340 | class QuantData: 341 | """ 342 | This class is used to store quantized data and implements the forward pass of a quantized linear layer. 343 | """ 344 | 345 | def __init__( 346 | self, 347 | A: torch.ByteTensor, 348 | state: BF.QuantState, 349 | shape: Tuple[int, int], 350 | original_lin: Union[LinearFP4, Linear4bit], 351 | bias: Optional[torch.FloatTensor] = None, 352 | use_codebook_dequant: Optional[bool] = True, 353 | allow_reduced_precision_linear: Optional[bool] = False, 354 | ): 355 | """ 356 | Initializes the QuantData class. 357 | 358 | This function is used to initialize the QuantData class. 359 | It takes the quantized data (A), the quantization state (bitsandbytes.functional.QuantState), the shape of the data (shape), 360 | the bias (bias), the original bitsandbytes layer (original_lin), a flag to use codebook dequantization (use_codebook_dequant), 361 | a flag to allow reduced precision linear (allow_reduced_precision_linear), and the type of reduced precision linear dequantization (reduced_precision_linear_dequant_type). 362 | 363 | Parameters: 364 | - A `(torch.ByteTensor)` `REQUIRED` : The quantized data, stored as a byte tensor. 365 | - state `(bitsandbytes.functional.QuantState)` `REQUIRED` : The quantization state. 366 | - shape `(Tuple[int, int])` `REQUIRED` : The shape of the data. 367 | - original_lin `(nn.Linear)` `REQUIRED` : The original linear layer. 368 | - bias `(Optional[torch.FloatTensor])` `default: None` : The bias of the original linear layer, not necessary if original_lin has bias. 369 | - use_codebook_dequant `(Optional[bool])` `default: True` : A flag to use codebook dequantization vs fp4 tree dequantization, which is the bitsandbytes default. 370 | - allow_reduced_precision_linear `(Optional[bool])` `default: False` : A flag to allow reduced precision linear, will speed up full gemm (not gemv) forwards at the expense of loss of precision. 371 | * Typically ~0.35 elementwise error for matmul vs between ~0.04 to ~0.06 elementwise error when `False`. I do not recommend using this in general. 372 | * It is only applicable for input shapes where `(B, L, H), L > 1 or B > 1` or `(B, H), B > 1`, other types of gemms will remain with low elementwise error. 373 | 374 | Returns: 375 | - None 376 | """ 377 | self.use_codebook_dequant = use_codebook_dequant 378 | self.A = A 379 | self.absmax = state.absmax.float() 380 | self.blocksize = state.blocksize 381 | self.M = shape[0] 382 | self.N = shape[1] 383 | self.code = state.code.float() 384 | self.o_type = None 385 | self.qtype = None 386 | self.quant_state = state 387 | self.bias = original_lin.bias if hasattr(original_lin, "bias") else bias 388 | self.original_lin = original_lin 389 | self.compute_dtype_set = False 390 | self.numel = prod(shape) 391 | if allow_reduced_precision_linear: 392 | if self.use_codebook_dequant: 393 | self.qlinear = self._qlinear_low_precision_codebook 394 | else: 395 | self.qlinear = self._qlinear_low_precision_normal 396 | else: 397 | self.qlinear = self._dequant_linear 398 | if self.use_codebook_dequant: 399 | self.dequantize = self._dequantize_codebook 400 | else: 401 | self.dequantize = self._dequantize_normal 402 | 403 | def set_compute_type(self, x: torch.Tensor) -> None: 404 | """ 405 | Sets the compute type for the input tensor. 406 | 407 | This function is used to set the compute type for the input tensor. 408 | It takes the input tensor (x) and sets the output type (o_type) and quantization type (qtype) based on the input tensor's dtype. 409 | If the bias is not None, it also sets the bias to the output type. 410 | 411 | Parameters: 412 | - x `(torch.Tensor)` `REQUIRED` : The input tensor. 413 | 414 | Returns: 415 | - None 416 | """ 417 | self.o_type = x.dtype 418 | self.qtype = ScalarType.from_torch_dtype(x.dtype).value 419 | if self.bias is not None: 420 | self.bias = self.bias.to(dtype=self.o_type) 421 | self.compute_dtype_set = True 422 | 423 | def _dequant_linear(self, A: torch.Tensor) -> torch.FloatTensor: 424 | """ 425 | Dequantizes the input tensor and performs a linear transformation. 426 | 427 | This function is used to dequantize the input tensor (A) and perform a matrix multiply + add bias. 428 | It takes the input tensor (A) and dequantizes it using the dequantize function. 429 | It then performs an nn.Linear matmul+bias using the dequantized tensor and the original linear layer's bias. 430 | 431 | Parameters: 432 | - A `(torch.Tensor)` `REQUIRED` : The input tensor. 433 | 434 | - torch.Tensor : The output tensor after the linear transformation. 435 | """ 436 | return torch.nn.functional.linear(A, self.dequantize(), self.bias) 437 | 438 | def _dequantize_codebook(self) -> torch.FloatTensor: 439 | """ 440 | Dequantizes this QuantData's weights using the codebook. 441 | 442 | Used as a wrapped simplification of the dequantize_fp4_codebook_invoke_qtype function, pre-configured with the correct arguments. 443 | """ 444 | 445 | return dequantize_fp4_codebook_invoke_qtype( 446 | self.A, 447 | self.absmax, 448 | self.code, 449 | self.blocksize, 450 | self.M, 451 | self.N, 452 | self.numel, 453 | self.qtype, 454 | ) 455 | 456 | def _dequantize_normal(self) -> torch.FloatTensor: 457 | """ 458 | Dequantizes this QuantData's weights using the normal method. 459 | 460 | Used as a wrapped simplification of the dequantize_fp4_qtype function, pre-configured with the correct arguments. 461 | """ 462 | return dequantize_fp4_qtype( 463 | self.A, 464 | self.absmax, 465 | self.blocksize, 466 | self.M, 467 | self.N, 468 | self.qtype, 469 | ) 470 | 471 | def _qgemv(self, A: torch.Tensor) -> torch.FloatTensor: 472 | """ 473 | This function performs a Quantized GEMV operation. 474 | 475 | It takes the input tensor (A) and performs a matrix multiply with the transposed weight tensor (self.A). 476 | It then quantizes the result using the absmax, code, and blocksize parameters. 477 | 478 | Parameters: 479 | - A `(torch.Tensor)` `REQUIRED` : The input tensor. 480 | 481 | Returns: 482 | - torch.Tensor : The output tensor after the quantized GEMV operation. 483 | """ 484 | return gemm_4bit_inference_qtype( 485 | A=A, 486 | B=self.A.t(), 487 | absmax=self.absmax, 488 | code=self.code, 489 | blocksize=self.blocksize, 490 | dtype=self.qtype, 491 | Bshape=self.quant_state.shape, 492 | ) 493 | 494 | def _qlinear_low_precision_normal(self, A: torch.Tensor) -> torch.FloatTensor: 495 | """ 496 | Quantized nn.Linear operation using the low precision fp4 tree dequant method. 497 | This method is faster than the QuantData._dequant_linear method, but has a higher elementwise error. 498 | It is only applicable for input shapes where `(B, L, H), L > 1 or B > 1` or `(B, H), B > 1`, other types of gemms will remain with low elementwise error. 499 | 500 | Parameters: 501 | - A `(torch.Tensor)` `REQUIRED` : The input tensor. 502 | 503 | Returns: 504 | - torch.Tensor : The output tensor after the quantized GEMV operation. 505 | """ 506 | if self.bias is None: 507 | return qlinear_( 508 | A, 509 | self.A, 510 | self.absmax, 511 | self.M, 512 | self.N, 513 | self.blocksize, 514 | ) 515 | else: 516 | return qlinear_bias_( 517 | A, 518 | self.A, 519 | self.absmax, 520 | self.M, 521 | self.N, 522 | self.blocksize, 523 | self.bias, 524 | ) 525 | 526 | def _qlinear_low_precision_codebook(self, A: torch.Tensor) -> torch.FloatTensor: 527 | """ 528 | Quantized nn.Linear operation using the low precision codebook dequant method. 529 | This method is faster than the QuantData._dequant_linear method, but has a higher elementwise error. 530 | It is only applicable for input shapes where `(B, L, H), L > 1 or B > 1` or `(B, H), B > 1`, other types of gemms will remain with low elementwise error. 531 | 532 | Parameters: 533 | - A `(torch.Tensor)` `REQUIRED` : The input tensor. 534 | 535 | Returns: 536 | - torch.Tensor : The output tensor after the quantized GEMV operation. 537 | """ 538 | if self.bias is None: 539 | return qlinear_codebook_( 540 | A, 541 | self.A, 542 | self.absmax, 543 | self.code, 544 | self.M, 545 | self.N, 546 | self.blocksize, 547 | ) 548 | else: 549 | return qlinear_codebook_bias_( 550 | A, 551 | self.A, 552 | self.absmax, 553 | self.code, 554 | self.M, 555 | self.N, 556 | self.blocksize, 557 | self.bias, 558 | ) 559 | 560 | def forward(self, A: torch.FloatTensor) -> torch.FloatTensor: 561 | """ 562 | Faux nn.Linear forward pass. 563 | This function is used to perform a forward pass of a quantized linear layer. 564 | It takes the input tensor (A) and performs a quantized matmul+bias operation using the quantized weights and bias. 565 | If the input tensor is not contiguous, it will be made contiguous before the operation. 566 | 567 | Special Cases Handled: 568 | - If the input tensor is not contiguous, it will be made contiguous before the operation. 569 | - If the input tensor shape contains a zero, the output will be a tensor of zeros with the correct (0 element) shape. 570 | - If the input tensor's number of elements is equal to the last dimension of itself, and is divisible by the quantized weight's block size 571 | 572 | Parameters: 573 | - A `(torch.Tensor)` `REQUIRED` : The input tensor. 574 | 575 | Returns: 576 | - torch.Tensor : The output tensor after the quantized matmul+bias operation. 577 | """ 578 | prodshape = prod(A.shape) 579 | is_contig = A.is_contiguous() 580 | if prodshape == 0: 581 | B_shape = self.quant_state.shape 582 | if A.shape[-1] == B_shape[0]: 583 | return torch.empty( 584 | A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device 585 | ) 586 | else: 587 | return torch.empty( 588 | A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device 589 | ) 590 | if not self.compute_dtype_set: 591 | self.set_compute_type(A) 592 | if prodshape == A.shape[-1]: 593 | if A.shape[-1] % self.blocksize != 0: 594 | out = self.qlinear(A) 595 | else: 596 | if not is_contig: 597 | A = A.contiguous() 598 | # gemm 4bit only works when the input is a single batch, 599 | # with 1 token and batch size 1 600 | # aka- (1, 1, hidden_dim) 601 | # or (1, hidden_dim) 602 | 603 | if A.ndim == 3: 604 | N_batch = A.shape[0] 605 | A = A.view(-1, A.shape[-1]) 606 | out = self._qgemv(A) 607 | out = out.view(N_batch, 1, -1) 608 | if self.bias is not None: 609 | out += self.bias 610 | elif A.ndim == 2: 611 | out = self._qgemv(A) 612 | if self.bias is not None: 613 | out += self.bias 614 | else: 615 | out = self.qlinear(A) 616 | else: 617 | out = self.qlinear(A) 618 | return out 619 | 620 | 621 | class TorchFP4Linear(nn.Module): 622 | """ 623 | A wrapper for bitsandbytes.nn.LinearFP4 and bitsandbytes.nn.Linear4bit layers. 624 | """ 625 | 626 | def __init__( 627 | self, 628 | lin: Union[Linear4bit, LinearFP4], 629 | use_codebook_dequant: bool = True, 630 | name="", 631 | ): 632 | """ 633 | Initializes the TorchFP4Linear class. 634 | This class is used to wrap a bitsandbytes.nn.LinearFP4 or bitsandbytes.nn.Linear4bit layer and replace it with a torch-bnb-fp4 version. 635 | It takes the original linear layer (lin) and a flag for whether to use codebook dequantization (use_codebook_dequant). 636 | 637 | Parameters: 638 | - lin (Union[LinearFP4, Linear4bit]) `REQUIRED` : The original linear layer to wrap. 639 | - use_codebook_dequant (bool) `OPTIONAL` : Whether to use codebook dequantization in the TorchFP4Linear layer. 640 | Default is False. 641 | 642 | """ 643 | super().__init__() 644 | self.lin = [lin] 645 | self.in_features = lin.in_features 646 | self.out_features = lin.out_features 647 | self.use_codebook_dequant = use_codebook_dequant 648 | self.name = name 649 | if isinstance(lin.weight, Params4bit): 650 | if ( 651 | lin.weight.quant_state is not None 652 | and lin.weight.device.type == "cuda" 653 | and lin.weight.data.dtype == torch.uint8 654 | ): 655 | self.quant_data = QuantData( 656 | lin.weight.data, 657 | lin.weight.quant_state, 658 | lin.weight.quant_state.shape, 659 | bias=lin.bias, 660 | original_lin=lin, 661 | use_codebook_dequant=self.use_codebook_dequant, 662 | ) 663 | else: 664 | raise ValueError( 665 | f"Linear weights are not quantized, and I have no idea what to do with that rn. Weights are {lin.weight.data.dtype}" 666 | ) 667 | 668 | else: 669 | raise ValueError( 670 | f"Linear is not a bnb linear and is not quantized, and I have no idea what to do with that rn." 671 | ) 672 | 673 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: 674 | """ 675 | Calls this TorchFP4Linear's quantized linear layer's forward method. 676 | If the input tensor is not contiguous, it will be made contiguous before the operation. 677 | If the input tensor shape contains a zero, the output will be a tensor of zeros with the correct (0 element) shape. 678 | If the input tensor's number of elements is equal to the last dimension of itself, and is divisible by the quantized weight's block size, an optimized GEMV operation will be used. 679 | If the input tensor's number of elements is not equal to the last dimension of itself, a full dequantize + matmul + bias operation will be used. 680 | 681 | Parameters: 682 | - x `(torch.Tensor)` `REQUIRED` : The input tensor. 683 | 684 | Returns: 685 | - torch.Tensor : The output tensor after the quantized matmul+bias operation. 686 | """ 687 | return self.quant_data.forward(x) 688 | 689 | def __repr__(self) -> str: 690 | if hasattr(self, "quant_data"): 691 | return f"TorchFP4Linear(in_features={self.lin[0].in_features}, out_features={self.lin[0].out_features}, bias={self.lin[0].bias is not None}, dtype={self.quant_data.o_type})" 692 | else: 693 | return f"TorchFP4Linear(in_features={self.lin[0].in_features}, out_features={self.lin[0].out_features}, bias={self.lin[0].bias is not None})" 694 | 695 | @classmethod 696 | def from_linear( 697 | cls, 698 | linear: Union[LinearFP4, Linear4bit], 699 | use_codebook_dequant: bool = False, 700 | name="", 701 | ) -> "TorchFP4Linear": 702 | """ 703 | Initializes a TorchFP4Linear layer from a bitsandbytes.nn.LinearFP4, or bitsandbytes.nn.Linear4bit layer. 704 | If the input layer must be quantized prior to initialization! 705 | 706 | Parameters: 707 | - linear (Union[LinearFP4, Linear4bit]): The linear layer to initialize the TorchFP4Linear layer from. 708 | - use_codebook_dequant (bool): Whether to use codebook dequantization in the TorchFP4Linear layer. 709 | Default is False. 710 | 711 | Returns: 712 | - TorchFP4Linear: The TorchFP4Linear layer initialized from the linear layer. 713 | """ 714 | return cls(linear, use_codebook_dequant=use_codebook_dequant, name=name) 715 | 716 | 717 | @torch.no_grad 718 | def swap_linear_with_bnb_linear( 719 | linear: nn.Linear, 720 | dtype=torch.float16, 721 | ) -> LinearFP4: 722 | """ 723 | Swaps a torch.nn.Linear layer with a bitsandbytes.nn.LinearFP4 layer. 724 | 725 | Swaps and initializes a `bitsandbytes.nn.LinearFP4` layer with the weights 726 | and biases of a `torch.nn.Linear` layer. 727 | 728 | Parameters: 729 | - linear (nn.Linear): The linear layer to swap. 730 | - dtype (torch.dtype): The data type to use for the weights of the LinearFP4 layer. 731 | Default is torch.float16. 732 | 733 | Returns: 734 | - LinearFP4: The LinearFP4 layer with the weights and biases of the Linear layer. 735 | """ 736 | bnb_module = bnb.nn.LinearFP4( 737 | input_features=linear.in_features, 738 | output_features=linear.out_features, 739 | bias=linear.bias is not None, 740 | compute_dtype=dtype, 741 | ) 742 | 743 | bnb_module.weight.data = linear.weight.data.clone().detach() 744 | if linear.bias is not None: 745 | bnb_module.bias.data = linear.bias.data.clone().detach() 746 | bnb_module.requires_grad_(False) 747 | return bnb_module 748 | 749 | 750 | def check_if_name_contained_in_list(name, names_list): 751 | is_contained = False 752 | for name_i in names_list: 753 | if name_i in name: 754 | is_contained = True 755 | break 756 | return is_contained 757 | 758 | 759 | def todevice_if_necessary(module, device): 760 | if module.weight.data.dtype != torch.uint8 and isinstance( 761 | module, (Linear4bit, LinearFP4) 762 | ): 763 | module.weight = module.weight.to(device) 764 | try: 765 | assert ( 766 | module.weight.data.device == device 767 | and module.weight.data.dtype == torch.uint8 768 | ), f"AAAAAAAAAAAAAHHH {module.weight.data.device} {module.weight.data.dtype} {device}" 769 | except Exception as e: 770 | logging.debug( 771 | "bnb/accelerate done messed up, idk how they did it, but they did it. " 772 | + "They literally forgot to quantize a layer despite not having any restrictions on which layers to quantize. " 773 | + "How?" 774 | ) 775 | qweight, qstate = BF.quantize_fp4(module.weight.data) 776 | module.weight.data = qweight 777 | module.weight.quant_state = qstate 778 | return module 779 | 780 | 781 | def recursively_replace_with_fp4_linear( 782 | module: T_Model, 783 | as_dtype=torch.float16, 784 | use_codebook_dequant=True, 785 | device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), 786 | return_final_module: bool = True, 787 | only_replace_bnb_layers: bool = False, 788 | ignore_layer_names: List[str] = ["lm_head"], 789 | parent="", 790 | debug: bool = False, 791 | ) -> Optional[T_Model]: 792 | """Function to replace all bnb linear layers with torch-bnb-fp4 linear layers. 793 | 794 | Recursively replaces all nn.Linear, LinearFP4, and Linear4bit layers 795 | within a given PyTorch module with TorchFP4Linear layers. 796 | 797 | This function traverses the module hierarchy of the given PyTorch module and replaces each 798 | nn.Linear, LinearFP4, and Linear4bit layer it finds with an equivalent TorchFP4Linear layer 799 | that uses FP4 quantization. This can be useful for reducing the memory footprint of a model 800 | or for accelerating inference on hardware that supports FP4 operations. 801 | 802 | Parameters: 803 | - module (nn.Module): The root module to traverse and modify. 804 | - as_dtype (torch.dtype): The default data type to use for the forward pass of the TorchFP4Linear layers. 805 | Default is torch.float16. 806 | - use_codebook_dequant (bool): Whether to use codebook dequantization in the TorchFP4Linear layers. 807 | Default is True. 808 | - device (torch.device): The device to move the TorchFP4Linear layers to. Default is the CUDA 809 | device if available, otherwise CPU, though it will error on CPU. 810 | - return_final_module (bool): Whether to return the modified module. Default is True. 811 | - only_replace_bnb_layers (bool): Whether to only replace bnb layers with TorchFP4Linear layers. Default is False. 812 | - ignore_layer_names: (List[str]): List of keys to ignore when replacing layers. Default is ["lm_head"], 813 | which is the default for transformers when swapping layers with LLMs. You can also pass in a list of 814 | strings to ignore, such as ["lm_head", "pooler", "classifier", "model.final_mlp.to_out"], etc. 815 | - debug (bool): Print debugging output. Default is False. 816 | """ 817 | assert ( 818 | (device.type == "cuda") 819 | if hasattr(device, "type") 820 | else (device.split(":")[0] == "cuda") 821 | ), "Device type must be cuda!" 822 | 823 | if parent != "": 824 | parent = parent + "." 825 | 826 | # Only need to clean cache when we swap an nn.Linear (not Linear4bit or LinearFP4) layer with a TorchFP4Linear, 827 | # otherwise we don't need to clean cache. 828 | should_clean_cache = False 829 | for name, child in module.named_children(): 830 | child_name = parent + name 831 | if check_if_name_contained_in_list( 832 | name, ignore_layer_names 833 | ) or check_if_name_contained_in_list(name, ignore_layer_names): 834 | if debug: 835 | print(f"Ignoring name: {child_name}, as it is in the ignore list") 836 | continue 837 | if isinstance(child, (nn.Linear, LinearFP4, Linear4bit)): 838 | if isinstance(child, (LinearFP4, Linear4bit)): 839 | if debug: 840 | print( 841 | f"Replacing BNB layer {child_name} swapping with TorchFP4Linear." 842 | ) 843 | module._modules[name] = module._modules[name].to(device=device) 844 | module._modules[name] = TorchFP4Linear( 845 | lin=module._modules[name], 846 | use_codebook_dequant=use_codebook_dequant, 847 | name=child_name, 848 | ) 849 | elif isinstance(child, nn.Linear): 850 | if only_replace_bnb_layers: 851 | if debug: 852 | print(f"Ignoring {child_name}, as only_replace_bnb_layers=True") 853 | else: 854 | if debug: 855 | print( 856 | f"Replacing {child_name} with BNB linear, and then swapping with TorchFP4Linear." 857 | ) 858 | # Must call cuda(device) to initialize the bnb linear's quant state 859 | module._modules[name] = swap_linear_with_bnb_linear( 860 | module._modules[name], dtype=as_dtype 861 | ).to(device) 862 | if ( 863 | not hasattr(module._modules[name].weight, "quant_state") 864 | or module._modules[name].weight.quant_state is None 865 | ): 866 | module._modules[name] = module._modules[name].to(device) 867 | if module._modules[name].weight.quant_state is None: 868 | module._modules[name] = todevice_if_necessary( 869 | module._modules[name], device 870 | ) 871 | 872 | module._modules[name] = TorchFP4Linear( 873 | lin=module._modules[name], 874 | use_codebook_dequant=use_codebook_dequant, 875 | name=child_name, 876 | ) 877 | should_clean_cache = True 878 | 879 | elif isinstance(child, nn.Module): 880 | recursively_replace_with_fp4_linear( 881 | child, 882 | as_dtype=as_dtype, 883 | use_codebook_dequant=use_codebook_dequant, 884 | device=device, 885 | return_final_module=False, 886 | only_replace_bnb_layers=only_replace_bnb_layers, 887 | ignore_layer_names=ignore_layer_names, 888 | parent=child_name, 889 | debug=debug, 890 | ) 891 | if isinstance(module, (nn.Linear, LinearFP4, Linear4bit)): 892 | if isinstance(module, (LinearFP4, Linear4bit)): 893 | if debug: 894 | print(f"Replacing {parent} with TorchFP4Linear.") 895 | module = TorchFP4Linear.from_linear( 896 | linear=module, use_codebook_dequant=use_codebook_dequant, name=name 897 | ).to(device=device) 898 | elif isinstance(module, nn.Linear): 899 | if only_replace_bnb_layers: 900 | if debug: 901 | print(f"Ignoring {parent}, as only_replace_bnb_layers=True") 902 | else: 903 | # Must call cuda(device) to initialize the bnb linear's quant state 904 | if debug: 905 | print( 906 | f"Replacing {parent} with bnb linear, and then swapping with TorchFP4Linear." 907 | ) 908 | module = TorchFP4Linear.from_linear( 909 | linear=todevice_if_necessary( 910 | swap_linear_with_bnb_linear(module, dtype=as_dtype).to( 911 | device=device 912 | ), 913 | device, 914 | ), 915 | use_codebook_dequant=use_codebook_dequant, 916 | name=parent, 917 | ) 918 | should_clean_cache = True 919 | if should_clean_cache: 920 | torch.cuda.empty_cache() 921 | if return_final_module: 922 | return module 923 | --------------------------------------------------------------------------------