├── third_party └── CMakeLists.txt ├── media ├── logo.png ├── bench1.png ├── bench2.png └── bench3.png ├── python ├── quant_benchmark.png ├── setup.cfg ├── pyproject.toml ├── example │ ├── example_torch.py │ └── plot_stochastic_rounding_acc.py ├── benchmark │ ├── throughput_avg.py │ └── benchmark.py ├── tests │ └── test_torch.py ├── setup.py ├── src │ └── piquant │ │ ├── _bootstrap.py │ │ ├── torch.py │ │ └── __init__.py ├── README.md └── .gitignore ├── benchmark ├── CMakeLists.txt └── bench.cpp ├── src ├── kernel_generic.cpp ├── amd64 │ ├── kernel_amd64_avx2.cpp │ ├── kernel_amd64_sse42.cpp │ ├── kernel_amd64_avx512f.cpp │ └── kernel_amd64_avx512f_bf16.cpp ├── piquant_internal.hpp ├── capi.cpp ├── kernels │ ├── dequantize.inl │ ├── quantize.inl │ └── kernels.inl └── piquant.cpp ├── .gitmodules ├── .gitignore ├── test ├── CMakeLists.txt ├── quant_config.cpp ├── requant.cpp ├── dequant.cpp ├── naive.hpp └── quant.cpp ├── .github └── workflows │ ├── build-wheels.yml │ └── cmake-multi-platform.yml ├── include ├── piquant.h └── piquant.hpp ├── CMakeLists.txt └── README.md /third_party/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(threadpool) -------------------------------------------------------------------------------- /media/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrimeIntellect-ai/pi-quant/HEAD/media/logo.png -------------------------------------------------------------------------------- /media/bench1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrimeIntellect-ai/pi-quant/HEAD/media/bench1.png -------------------------------------------------------------------------------- /media/bench2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrimeIntellect-ai/pi-quant/HEAD/media/bench2.png -------------------------------------------------------------------------------- /media/bench3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrimeIntellect-ai/pi-quant/HEAD/media/bench3.png -------------------------------------------------------------------------------- /python/quant_benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrimeIntellect-ai/pi-quant/HEAD/python/quant_benchmark.png -------------------------------------------------------------------------------- /benchmark/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(quant_benchmark bench.cpp) 2 | 3 | target_link_libraries(quant_benchmark PRIVATE piquant) 4 | -------------------------------------------------------------------------------- /src/kernel_generic.cpp: -------------------------------------------------------------------------------- 1 | #define QUANT_KERNEL_IMPL install_quant_generic 2 | #include "kernels/kernels.inl" 3 | #undef QUANT_KERNEL_IMPL 4 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "test/googletest"] 2 | path = test/googletest 3 | url = https://github.com/google/googletest 4 | 5 | [submodule "third_party/threadpool"] 6 | path = third_party/threadpool 7 | url = https://github.com/PrimeIntellect-ai/threadpool 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | CMakeLists.txt.user 2 | CMakeCache.txt 3 | CMakeFiles 4 | CMakeScripts 5 | Testing 6 | Makefile 7 | cmake_install.cmake 8 | install_manifest.txt 9 | compile_commands.json 10 | CTestTestfile.cmake 11 | _deps 12 | CMakeUserPresets.json 13 | .idea/ 14 | bin/ 15 | build/ 16 | .vs/ 17 | -------------------------------------------------------------------------------- /src/amd64/kernel_amd64_avx2.cpp: -------------------------------------------------------------------------------- 1 | #ifndef __AVX2__ 2 | #error "Spec flag not enabled" 3 | #endif 4 | #ifdef __AVX512F__ 5 | #error "Spec level too high" 6 | #endif 7 | 8 | #define QUANT_KERNEL_IMPL install_quant_amd64_avx2 9 | #include "../kernels/kernels.inl" 10 | #undef QUANT_KERNEL_IMPL 11 | -------------------------------------------------------------------------------- /src/amd64/kernel_amd64_sse42.cpp: -------------------------------------------------------------------------------- 1 | #ifndef _MSC_VER 2 | #ifndef __SSE4_2__ 3 | #error "Spec flag not enabled" 4 | #endif 5 | #endif 6 | #ifdef __AVX__ 7 | #error "Spec level too high" 8 | #endif 9 | 10 | #define QUANT_KERNEL_IMPL install_quant_amd64_sse42 11 | #include "../kernels/kernels.inl" 12 | #undef QUANT_KERNEL_IMPL 13 | -------------------------------------------------------------------------------- /src/amd64/kernel_amd64_avx512f.cpp: -------------------------------------------------------------------------------- 1 | #if !defined(__AVX512F__) || !defined(__AVX512BW__) 2 | #error "Spec flag not enabled" 3 | #endif 4 | #ifdef __AVX10__ 5 | #error "Spec level too high" 6 | #endif 7 | 8 | #define QUANT_KERNEL_IMPL install_quant_amd64_avx512f 9 | #include "../kernels/kernels.inl" 10 | #undef QUANT_KERNEL_IMPL 11 | -------------------------------------------------------------------------------- /test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | enable_testing() 2 | add_subdirectory(googletest) 3 | file(GLOB_RECURSE TEST_SOURCES *.cpp) 4 | add_executable(piquant_tests ${TEST_SOURCES}) 5 | target_include_directories(piquant_tests PRIVATE ${CMAKE_SOURCE_DIR}/include) 6 | target_link_libraries(piquant_tests PRIVATE piquant GTest::gtest_main) 7 | add_test(NAME piquant_tests COMMAND piquant_tests) 8 | -------------------------------------------------------------------------------- /src/amd64/kernel_amd64_avx512f_bf16.cpp: -------------------------------------------------------------------------------- 1 | #if !defined(__AVX512F__) || !defined(__AVX512BW__) || !defined(__AVX512BF16__) 2 | #error "Spec flag not enabled" 3 | #endif 4 | #ifdef __AVX10__ 5 | #error "Spec level too high" 6 | #endif 7 | 8 | #define QUANT_KERNEL_IMPL install_quant_amd64_avx512f_bf16 9 | #include "../kernels/kernels.inl" 10 | #undef QUANT_KERNEL_IMPL 11 | -------------------------------------------------------------------------------- /python/setup.cfg: -------------------------------------------------------------------------------- 1 | [project] 2 | license = "MIT" 3 | 4 | [metadata] 5 | name = pypiquant 6 | version = 1.0.0 7 | author = Mario Sieg, Michael Keiblinger 8 | author_email = mario@primeintellect.ai, mike@primeintellect.ai 9 | description = Fast, multithreaded CPU quantization kernels 10 | long_description = file: README.md 11 | long_description_content_type = text/markdown 12 | url = https://github.com/PrimeIntellect-ai/pi-quant 13 | project_urls = 14 | Bug Tracker = https://github.com/PrimeIntellect-ai/pi-quant/issues 15 | python_requires = >=3.12 16 | include_package_data = True 17 | zip_safe = False -------------------------------------------------------------------------------- /python/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "pypiquant" 7 | dynamic = ["version"] 8 | authors = [ 9 | {name = "Mario Sieg", email = "mario.sieg.64@gmail.com"}, 10 | ] 11 | description = "Multithreaded SIMD int8 and int4 quantization kernels." 12 | dependencies = ["cffi", "torch", "numpy"] 13 | readme = "README.md" 14 | 15 | [project.urls] 16 | Documentation = "https://github.com/PrimeIntellect-ai/pi-quant" 17 | Repository = "https://github.com/PrimeIntellect-ai/pi-quant" 18 | Issues = "https://github.com/PrimeIntellect-ai/pi-quant/issues" 19 | 20 | [project.optional-dependencies] 21 | dev = ["pytest","torch","numpy","pre-commit","ruff", "matplotlib", "twine"] 22 | 23 | [tool.ruff] 24 | line-length = 120 25 | target-version = "py38" 26 | 27 | [tool.setuptools.dynamic] 28 | version = {attr = "piquant.__version__"} 29 | 30 | [tool.ruff.lint] 31 | ignore = ["F403"] 32 | select = ["ANN"] 33 | 34 | [tool.ruff.format] 35 | quote-style = "single" 36 | 37 | [tool.ruff.lint.per-file-ignores] 38 | "setup.py" = ["ANN"] 39 | -------------------------------------------------------------------------------- /python/example/example_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import piquant 3 | 4 | # Quantize and back with: bfloat16 -> uint4 -> bfloat16 5 | # In torch, quint4x2 means two 4-bit quantized integers per byte. 6 | tensor = torch.rand(1000, dtype=torch.bfloat16, device='cpu') 7 | 8 | # Compute quantization parameters for uint4 (needed for quantization and dequantization) 9 | scale, zero_point = piquant.torch.compute_quant_params(tensor, dtype=torch.quint4x2) 10 | 11 | # Quantize the tensor to uint4 12 | quantized = piquant.torch.quantize(tensor, scale=scale, zero_point=zero_point, dtype=torch.quint4x2) 13 | 14 | # Dequantize back to bfloat16 15 | dequantized = piquant.torch.dequantize(quantized, scale=scale, zero_point=zero_point, dtype=torch.bfloat16) 16 | 17 | # Check if the dequantized tensor is close to the original tensor 18 | assert torch.allclose(dequantized, tensor, atol=scale/2 + 1e-3), "Dequantization did not match original tensor" 19 | 20 | # Print parts of original and dequantized tensors for verification 21 | print("Original tensor (first 10 elements):", tensor[:10].tolist()) 22 | print("Dequant tensor (first 10 elements):", dequantized[:10].tolist()) 23 | 24 | -------------------------------------------------------------------------------- /benchmark/bench.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include "../test/naive.hpp" 11 | 12 | #define ANKERL_NANOBENCH_IMPLEMENT 13 | #include "nanobench.hpp" 14 | 15 | auto main() -> int { 16 | const std::size_t nt {std::max(1u, std::thread::hardware_concurrency())}; 17 | volatile std::size_t numel {(1ull<<30)}; 18 | std::vector data_in {}; 19 | std::vector data_out {}; 20 | data_in.resize(numel); 21 | data_out.resize((numel+1)/2); 22 | std::random_device rd {}; 23 | std::mt19937 gen {rd()}; 24 | std::uniform_real_distribution dist {-1.0f, 1.0f}; 25 | std::ranges::generate(data_in, [&] { return dist(gen); }); 26 | ankerl::nanobench::Bench bench {}; 27 | piquant::context ctx {nt}; 28 | bench.run("quantize bf16 -> uint4", [&] { 29 | ctx.quantize_generic(data_in, data_out, 0.2f, 127, piquant::round_mode::nearest); 30 | }); 31 | bench.run("dequantize uint4 -> bf16", [&] { 32 | ctx.dequantize_generic(data_out, data_in, 0.2f, 127, piquant::reduce_op::set); 33 | }); 34 | return 0; 35 | } 36 | -------------------------------------------------------------------------------- /src/piquant_internal.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "piquant.hpp" 4 | 5 | namespace piquant { 6 | 7 | /* Aborts with a formatted message. Because not all tested C++ compilers support std::format, C-style formatting is used for now. Should be replaced later. Pulling in fmt::format just for abort seems a bit too much... */ 8 | [[noreturn]] void panic(const char *msg, ...); 9 | 10 | #define QUANT_STRINGIZE2(x) #x 11 | #define QUANT_STRINGIZE(x) QUANT_STRINGIZE2(x) 12 | #define QUANT_SRC_NAME __FILE__ ":" QUANT_STRINGIZE(__LINE__) 13 | 14 | #define piquant_assert(expr, msg, ...) \ 15 | if ((!(expr))) [[unlikely]] { \ 16 | ::piquant::panic("%s:%d Assertion failed: " #expr " <- " msg, __FILE__, __LINE__, ## __VA_ARGS__);\ 17 | } 18 | #define piquant_assert2(expr) piquant_assert(expr, "") 19 | 20 | #ifdef _MSC_VER 21 | #define PIQUANT_HOT 22 | #define PIQUANT_AINLINE __forceinline 23 | #define PIQUANT_RESTRICT __restrict 24 | #else 25 | #define PIQUANT_HOT __attribute__((hot)) 26 | #define PIQUANT_AINLINE __attribute__((always_inline)) inline 27 | #define PIQUANT_RESTRICT __restrict__ 28 | #endif 29 | 30 | struct kernel_registry final { 31 | auto (*quant_kernel)( 32 | const void* x, 33 | void* o, 34 | std::int64_t numel, 35 | const context::quant_descriptor& desc 36 | ) noexcept -> void; 37 | auto (*find_min_max_float32)(std::span x) noexcept -> std::array; 38 | auto (*find_min_max_bfloat16)(std::span x) noexcept -> std::array; 39 | }; 40 | 41 | [[nodiscard]] constexpr auto packed_numel(std::size_t ne, const dtype_info& dto) noexcept -> std::size_t { 42 | std::size_t per_byte {8u / dto.bit_size}; 43 | return (ne + per_byte-1)/per_byte; 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /.github/workflows/build-wheels.yml: -------------------------------------------------------------------------------- 1 | name: Build and Publish Wheels 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*.*.*' 7 | 8 | jobs: 9 | build-and-publish: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | with: 14 | submodules: recursive 15 | 16 | - name: Set up Python for cibuildwheel 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: '3.11' 20 | 21 | - name: Get version from tag 22 | id: get_version 23 | run: | 24 | VERSION=${GITHUB_REF#refs/tags/v} 25 | echo "VERSION=$VERSION" 26 | echo "VERSION=$VERSION" >> $GITHUB_ENV 27 | shell: bash 28 | 29 | - name: Update version in __init__.py 30 | run: | 31 | sed '/__version__ =/d' python/src/piquant/__init__.py > python/src/piquant/__init__.py.tmp 32 | mv python/src/piquant/__init__.py.tmp python/src/piquant/__init__.py 33 | echo "__version__ = \"$VERSION\"" >> python/src/piquant/__init__.py 34 | shell: bash 35 | 36 | - name: Install build tools 37 | run: | 38 | python -m pip install --upgrade pip 39 | python -m pip install cibuildwheel==3.0.0b1 auditwheel 40 | - name: Build wheels 41 | run: | 42 | python -m cibuildwheel --output-dir wheelhouse python 43 | env: 44 | CIBW_ARCHS_LINUX: "x86_64" 45 | 46 | - name: Publish Artifact 47 | uses: actions/upload-artifact@v4 48 | with: 49 | name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }} 50 | path: ./wheelhouse/*.whl 51 | 52 | - name: Publish to PyPI 53 | if: startsWith(github.ref, 'refs/tags/v') 54 | uses: pypa/gh-action-pypi-publish@v1.4.2 55 | with: 56 | password: ${{ secrets.PYPI_API_TOKEN }} 57 | packages_dir: wheelhouse 58 | skip_existing: true 59 | -------------------------------------------------------------------------------- /python/benchmark/throughput_avg.py: -------------------------------------------------------------------------------- 1 | import os, time, multiprocessing as mp 2 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 3 | os.environ['OMP_NUM_THREADS'] = str(mp.cpu_count()) 4 | os.environ['MKL_NUM_THREADS'] = str(mp.cpu_count()) 5 | 6 | import torch, piquant 7 | torch.set_num_threads(mp.cpu_count()) 8 | 9 | TOTAL_GIB = 32 10 | ITERATIONS = 10 11 | 12 | def measure_throughput(dq_type: torch.dtype, q_type: torch.dtype): 13 | B_PER_ELEM = torch.tensor([], dtype=dq_type).element_size() 14 | TOTAL_BYTES = TOTAL_GIB * (1<<30) 15 | N_ELEMS = TOTAL_BYTES // B_PER_ELEM 16 | with torch.inference_mode(): 17 | x = torch.rand(N_ELEMS, dtype=dq_type) 18 | SCALE, ZP = piquant.torch.compute_quant_params(x, dtype=q_type) 19 | with torch.inference_mode(): 20 | q = piquant.torch.quantize(x, scale=SCALE, zero_point=int(ZP), dtype=q_type) 21 | _ = piquant.torch.dequantize(q, scale=SCALE, zero_point=int(ZP), dtype=dq_type) 22 | quant_results, dequant_results = [], [] 23 | uint4_source_gib = (N_ELEMS // 2) / (1<<30) 24 | for i in range(ITERATIONS): 25 | t0 = time.perf_counter() 26 | with torch.inference_mode(): 27 | _ = piquant.torch.quantize(x, scale=SCALE, zero_point=int(ZP), dtype=q_type) 28 | t1 = time.perf_counter() 29 | quant_results.append(TOTAL_GIB / (t1 - t0)) 30 | with torch.inference_mode(): 31 | q = piquant.torch.quantize(x, scale=SCALE, zero_point=int(ZP), dtype=q_type) 32 | t0 = time.perf_counter() 33 | with torch.inference_mode(): 34 | _ = piquant.torch.dequantize(q, scale=SCALE, zero_point=int(ZP), dtype=dq_type) 35 | t1 = time.perf_counter() 36 | dequant_results.append(uint4_source_gib / (t1 - t0)) 37 | 38 | print(f'Quant {dq_type} -> {q_type} avg throughput: {sum(quant_results)/ITERATIONS:.2f} GiB/s') 39 | print(f'Dequant {q_type} -> {dq_type} avg throughput: {sum(dequant_results)/ITERATIONS:.2f} GiB/s') 40 | 41 | measure_throughput(torch.bfloat16, torch.quint4x2) 42 | measure_throughput(torch.bfloat16, torch.quint2x4) 43 | -------------------------------------------------------------------------------- /test/quant_config.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | constexpr std::size_t iters {100}; 14 | constexpr int stochastic_epsilon {3}; 15 | 16 | using namespace piquant; 17 | 18 | #define test_quant_range(ti, to, rnd) \ 19 | TEST(quantize_range, quantize_range_##ti##_to_##to##_##rnd) { \ 20 | std::random_device rd {}; \ 21 | std::mt19937 gen {rd()}; \ 22 | std::uniform_real_distribution dist {-1.0, 1.0}; \ 23 | \ 24 | for (std::size_t n {}; n < iters; ++n) { \ 25 | std::size_t numel {std::uniform_int_distribution{5000, 1'5000}(gen)}; \ 26 | std::size_t numel_out {std::is_same_v ? (numel+3)>>2 : std::is_same_v ? (numel+1)>>1 : numel}; \ 27 | \ 28 | std::vector data_out {}; \ 29 | data_out.resize(numel_out); \ 30 | std::vector data_in {}; \ 31 | data_in.resize(numel); \ 32 | std::ranges::generate(data_in, [&] { return dist(gen); }); \ 33 | piquant::context ctx {std::max(1u, 4u)}; \ 34 | auto [scale, zero_point] {ctx.compute_quant_config_from_data(data_in, dtype_traits::type_code)}; \ 35 | ASSERT_GT(scale, 0.0f); \ 36 | ASSERT_TRUE(std::isfinite(scale)); \ 37 | ctx.quantize_generic(data_in, data_out, scale, zero_point, piquant::round_mode::rnd); \ 38 | } \ 39 | } 40 | 41 | test_quant_range(fp32_t, uint2_t, nearest) 42 | test_quant_range(fp32_t, uint2_t, stochastic) 43 | test_quant_range(fp32_t, uint4_t, nearest) 44 | test_quant_range(fp32_t, uint4_t, stochastic) 45 | test_quant_range(fp32_t, uint8_t, nearest) 46 | test_quant_range(fp32_t, uint8_t, stochastic) 47 | test_quant_range(bfp16_t, uint2_t, nearest) 48 | test_quant_range(bfp16_t, uint2_t, stochastic) 49 | test_quant_range(bfp16_t, uint4_t, nearest) 50 | test_quant_range(bfp16_t, uint4_t, stochastic) 51 | test_quant_range(bfp16_t, uint8_t, nearest) 52 | test_quant_range(bfp16_t, uint8_t, stochastic) 53 | 54 | -------------------------------------------------------------------------------- /python/tests/test_torch.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import pytest 4 | import math 5 | import torch 6 | import piquant 7 | 8 | gen = torch.manual_seed(128) 9 | random.seed(128) 10 | 11 | TORCH_FLOAT_TYPES: set[torch.dtype] = {torch.bfloat16, torch.float32} 12 | 13 | TORCH_QUANT_TYPES: set[torch.dtype] = { 14 | torch.quint8, 15 | torch.quint4x2, 16 | torch.quint2x4, 17 | } 18 | 19 | def numel() -> int: 20 | return random.randint(1, 128) 21 | 22 | 23 | @pytest.mark.parametrize('dtype_in', TORCH_FLOAT_TYPES) 24 | @pytest.mark.parametrize('dtype_quantized', TORCH_QUANT_TYPES) 25 | def test_compute_quant_config(dtype_in: torch.dtype, dtype_quantized: torch.dtype) -> None: 26 | tensor = torch.empty(numel(), numel(), numel(), numel(), dtype=dtype_in) 27 | tensor.uniform_(-1.0, 1.0, generator=gen) 28 | scale, zero_point = piquant.torch.compute_quant_params(tensor, dtype=dtype_quantized) 29 | assert scale > 0 30 | zero_point != 0 31 | assert not math.isnan(scale) 32 | assert not math.isinf(scale) 33 | 34 | 35 | @pytest.mark.parametrize('dtype_in', TORCH_FLOAT_TYPES) 36 | @pytest.mark.parametrize('dtype_quantized', TORCH_QUANT_TYPES) 37 | def test_quantize_roundtrip(dtype_in: torch.dtype, dtype_quantized: torch.dtype) -> None: 38 | input = torch.empty(numel(), numel(), numel(), numel(), dtype=dtype_in) 39 | input.uniform_(-1.0, 1.0, generator=gen) 40 | scale, zero_point = piquant.torch.compute_quant_params(input, dtype=dtype_quantized) 41 | quantized_torch = torch.quantize_per_tensor( 42 | input.float(), scale=scale, zero_point=zero_point, dtype=dtype_quantized 43 | ) 44 | quantized_pi = piquant.torch.quantize(input, zero_point=zero_point, scale=scale, dtype=dtype_quantized) 45 | 46 | # now dequantize both 47 | dequantized_torch = quantized_torch.dequantize().to(dtype_in) 48 | dequantized_pi = piquant.torch.dequantize(quantized_pi, scale=scale, zero_point=zero_point, dtype=dtype_in) 49 | assert dequantized_torch.dtype == dequantized_pi.dtype 50 | assert dequantized_pi.dtype == input.dtype 51 | assert torch.allclose(dequantized_torch, dequantized_pi, atol=1e-3) 52 | assert torch.allclose(dequantized_torch, input, atol=scale*0.5 + 1e-3) 53 | assert torch.allclose(dequantized_pi, input, atol=scale*0.5 + 1e-3) 54 | -------------------------------------------------------------------------------- /include/piquant.h: -------------------------------------------------------------------------------- 1 | /* Minimal C99 API used from the Python CFFI bindings but also useable from normal C. 2 | * For docs / a more complete C++ API, see piquant.hpp. 3 | */ 4 | 5 | #ifndef PIQUANT_H 6 | #define PIQUANT_H 7 | 8 | #include 9 | #include 10 | 11 | #ifdef __cplusplus 12 | extern "C" { 13 | #endif 14 | 15 | #ifdef _MSC_VER 16 | #define PIQUANT_EXPORT __declspec(dllexport) 17 | #else 18 | #define PIQUANT_EXPORT __attribute__((visibility("default"))) 19 | #endif 20 | 21 | typedef struct piquant_context_t piquant_context_t; 22 | 23 | typedef enum piquant_round_mode_t { 24 | PIQUANT_NEAREST, 25 | PIQUANT_STOCHASTIC 26 | } piquant_round_mode_t; 27 | 28 | typedef enum piquant_reduce_op_t { 29 | PIQUANT_REDUCE_OP_SET, 30 | PIQUANT_REDUCE_OP_ADD, 31 | } piquant_reduce_op_t; 32 | 33 | typedef enum piquant_dtype_t { 34 | PIQUANT_DTYPE_F32 = 0, 35 | PIQUANT_DTYPE_BF16, 36 | 37 | PIQUANT_DTYPE_UINT2, 38 | PIQUANT_DTYPE_UINT4, 39 | PIQUANT_DTYPE_UINT8 40 | } piquant_dtype_t; 41 | 42 | extern PIQUANT_EXPORT piquant_context_t* piquant_context_create(size_t num_threads); 43 | extern PIQUANT_EXPORT void piquant_context_destroy(piquant_context_t* ctx); 44 | 45 | extern PIQUANT_EXPORT void piquant_quantize( 46 | piquant_context_t* ctx, 47 | const void* in, 48 | piquant_dtype_t dtype_in, 49 | void* out, 50 | piquant_dtype_t dtype_out, 51 | size_t numel, 52 | float scale, 53 | int64_t zero_point, 54 | piquant_round_mode_t mode 55 | ); 56 | 57 | extern PIQUANT_EXPORT void piquant_dequantize( 58 | piquant_context_t* ctx, 59 | const void* in, 60 | piquant_dtype_t dtype_in, 61 | void* out, 62 | piquant_dtype_t dtype_out, 63 | size_t numel, 64 | float scale, 65 | int64_t zero_point, 66 | piquant_reduce_op_t op 67 | ); 68 | 69 | extern PIQUANT_EXPORT void piquant_compute_quant_params_float32( 70 | piquant_context_t* ctx, 71 | const float* x, 72 | size_t n, 73 | piquant_dtype_t target_quant_dtype, 74 | float* out_scale, 75 | int64_t* out_zero_point 76 | ); 77 | 78 | extern PIQUANT_EXPORT void piquant_compute_quant_params_bfloat16( 79 | piquant_context_t* ctx, 80 | const uint16_t* x, 81 | size_t n, 82 | piquant_dtype_t target_quant_dtype, 83 | float* out_scale, 84 | int64_t* out_zero_point 85 | ); 86 | 87 | #ifdef __cplusplus 88 | } 89 | #endif 90 | #endif 91 | -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import multiprocessing 4 | 5 | from setuptools import setup, Extension 6 | from setuptools.command.build_ext import build_ext 7 | 8 | CMAKE_ROOT = os.path.abspath( 9 | os.path.join(os.path.dirname(__file__), '..') # Go up one directory to find CMakeLists.txt 10 | ) 11 | NUM_JOBS = max(multiprocessing.cpu_count() - 1, 1) # Use all but one core 12 | 13 | 14 | class BuildException(Exception): 15 | def __init__(self, message): 16 | self.message = message 17 | super().__init__(self.message) 18 | 19 | 20 | class CMakeBuildExtension(Extension): 21 | def __init__(self, name, root_dir=''): 22 | super().__init__(name, sources=[]) 23 | self.root_dir = os.path.abspath(root_dir) 24 | 25 | 26 | class CMakeBuildExecutor(build_ext): 27 | def initialize_options(self): 28 | super().initialize_options() 29 | 30 | def run(self): 31 | try: 32 | print(subprocess.check_output(['cmake', '--version'])) 33 | except OSError: 34 | raise BuildException( 35 | 'CMake must be installed to build the piquant binaries from source. Please install CMake and try again.' 36 | ) 37 | super().run() 38 | for ext in self.extensions: 39 | self.build_extension(ext) 40 | 41 | def build_extension(self, ext): 42 | if os.path.exists(self.build_temp): 43 | import shutil 44 | 45 | shutil.rmtree(self.build_temp) 46 | os.makedirs(self.build_temp) 47 | 48 | cmake_args = [ 49 | '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + os.path.abspath(os.path.join(self.build_lib, 'piquant')), 50 | '-DCMAKE_BUILD_TYPE=Release', 51 | ] 52 | build_args = [ 53 | '--target piquant', 54 | '-j' + str(NUM_JOBS), 55 | '-v', 56 | ] 57 | print(subprocess.check_call(['cmake', ext.root_dir] + cmake_args, cwd=self.build_temp)) 58 | print(subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp)) 59 | 60 | 61 | setup( 62 | name='pypiquant', 63 | author='Mario Sieg', 64 | author_email='mario@primeintellect.ai', 65 | packages=['piquant'], 66 | package_dir={'': 'src'}, # tell setuptools packages are under src/ 67 | package_data={ 68 | 'piquant': ['libquant.so', 'libquant.dylib', 'libquant.dll'], 69 | }, 70 | include_package_data=True, 71 | ext_modules=[CMakeBuildExtension('piquant', root_dir=CMAKE_ROOT)], 72 | cmdclass={ 73 | 'build_ext': CMakeBuildExecutor, 74 | }, 75 | zip_safe=False, 76 | ) 77 | -------------------------------------------------------------------------------- /test/requant.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | #include "naive.hpp" 14 | 15 | constexpr std::size_t iters {10}; 16 | 17 | using namespace piquant; 18 | 19 | #define test_requant(ti, to, rnd, reduce) \ 20 | TEST(requantize, requantize_##ti##_to_##to##_##rnd##_##reduce) { \ 21 | std::random_device rd {}; \ 22 | std::mt19937 gen {rd()}; \ 23 | std::uniform_real_distribution dist {-1.0, 1.0}; \ 24 | const auto adjusted_epsilon {std::is_same_v ? 0.7f : std::is_same_v ? 0.2f : 1e-1f}; \ 25 | for (std::size_t n {}; n < iters; ++n) { \ 26 | std::size_t numel {std::uniform_int_distribution{5000, 1'5000}(gen)}; \ 27 | \ 28 | std::vector data_in {}; \ 29 | data_in.resize(numel); \ 30 | std::ranges::generate(data_in, [&] { return dist(gen); }); \ 31 | piquant::context ctx {std::max(1u, 4u)}; \ 32 | auto [scale, zero_point] {ctx.compute_quant_config_from_data(data_in, dtype_traits::type_code)}; \ 33 | std::vector requantized {}; \ 34 | requantized.resize(numel); \ 35 | ti prev {piquant::reduce_op::reduce == piquant::reduce_op::add ? dist(gen) : 0.0f}; \ 36 | std::ranges::fill(requantized, prev); \ 37 | ctx.quantize_dequantize_fused_generic(data_in, requantized, scale, zero_point, piquant::round_mode::rnd, piquant::reduce_op::reduce); \ 38 | for (std::size_t i {}; i < numel; ++i) { \ 39 | ASSERT_NEAR(static_cast(data_in[i]), static_cast(requantized[i]-prev), adjusted_epsilon); \ 40 | } \ 41 | } \ 42 | } 43 | 44 | test_requant(fp32_t, uint2_t, nearest, set) 45 | test_requant(fp32_t, uint2_t, stochastic, set) 46 | test_requant(fp32_t, uint2_t, nearest, add) 47 | test_requant(fp32_t, uint2_t, stochastic, add) 48 | test_requant(fp32_t, uint4_t, nearest, set) 49 | test_requant(fp32_t, uint4_t, stochastic, set) 50 | test_requant(fp32_t, uint4_t, nearest, add) 51 | test_requant(fp32_t, uint4_t, stochastic, add) 52 | test_requant(fp32_t, uint8_t, nearest, set) 53 | test_requant(fp32_t, uint8_t, stochastic, set) 54 | test_requant(fp32_t, uint8_t, nearest, add) 55 | test_requant(fp32_t, uint8_t, stochastic, add) 56 | test_requant(bfp16_t, uint4_t, nearest, set) 57 | test_requant(bfp16_t, uint4_t, stochastic, set) 58 | test_requant(bfp16_t, uint4_t, nearest, add) 59 | test_requant(bfp16_t, uint4_t, stochastic, add) 60 | test_requant(bfp16_t, uint8_t, nearest, set) 61 | test_requant(bfp16_t, uint8_t, stochastic, set) 62 | test_requant(bfp16_t, uint8_t, nearest, add) 63 | test_requant(bfp16_t, uint8_t, stochastic, add) 64 | -------------------------------------------------------------------------------- /python/example/plot_stochastic_rounding_acc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import piquant 5 | 6 | torch.manual_seed(42) 7 | 8 | def compute_uint4_params(x: torch.Tensor): 9 | scale, zero_point = piquant.torch.compute_quant_params(x, dtype=torch.quint4x2) 10 | return scale, zero_point 11 | 12 | def quantize_with_mode(x: torch.Tensor, scale: float, zp: int, mode: str): 13 | qt = piquant.torch 14 | kwargs_common = dict(scale=scale, zero_point=zp, dtype=torch.quint4x2) 15 | return qt.quantize(x, **kwargs_common, round_mode=mode) 16 | 17 | def dequantize_to_bf16(q: torch.Tensor, scale: float, zp: int) -> torch.Tensor: 18 | return piquant.torch.dequantize(q, scale=scale, zero_point=zp, dtype=torch.bfloat16) 19 | 20 | def cdf_values(x: torch.Tensor): 21 | s = torch.sort(x.flatten()).values.cpu().numpy() 22 | y = np.linspace(0.0, 1.0, num=s.size, endpoint=False) 23 | return s, y 24 | 25 | 26 | tensor = torch.rand(1000, dtype=torch.bfloat16, device="cpu") 27 | scale, zero_point = compute_uint4_params(tensor) 28 | 29 | quant_near = quantize_with_mode(tensor, scale, zero_point, mode="nearest") 30 | dq_near = dequantize_to_bf16(quant_near, scale, zero_point) 31 | 32 | quant_sto = quantize_with_mode(tensor, scale, zero_point, mode="stochastic") 33 | dq_sto = dequantize_to_bf16(quant_sto, scale, zero_point) 34 | 35 | t32 = tensor.to(torch.float32) 36 | err_near = (dq_near.to(torch.float32) - t32).abs() 37 | err_sto = (dq_sto.to(torch.float32) - t32).abs() 38 | 39 | mae_near = err_near.mean().item() 40 | mse_near = (err_near ** 2).mean().item() 41 | 42 | mae_sto = err_sto.mean().item() 43 | mse_sto = (err_sto ** 2).mean().item() 44 | 45 | print(f"scale={scale:.8g} zero_point={zero_point}") 46 | print(f"Nearest : MAE={mae_near:.6e} MSE={mse_near:.6e}") 47 | print(f"Stochastic : MAE={mae_sto:.6e} MSE={mse_sto:.6e}") 48 | 49 | step = float(scale) 50 | tol = step/2 + 1e-3 51 | print(f"Sanity tol: {tol:.6g}") 52 | print("Allclose-nearest?", torch.allclose(dq_near, tensor, atol=tol)) 53 | print("Allclose-stochastic?", torch.allclose(dq_sto, tensor, atol=tol)) 54 | 55 | s_near, y_near = cdf_values(err_near) 56 | s_sto, y_sto = cdf_values(err_sto) 57 | 58 | plt.figure() 59 | plt.plot(s_near, y_near, label=f"Nearest (MAE={mae_near:.3e}, MSE={mse_near:.3e})") 60 | plt.plot(s_sto, y_sto, label=f"Stochastic (MAE={mae_sto:.3e}, MSE={mse_sto:.3e})") 61 | plt.xlabel("Absolute error") 62 | plt.ylabel("CDF") 63 | plt.title("uint4 Quantization: Nearest vs Stochastic (Dequant error CDF)") 64 | plt.legend() 65 | plt.grid(True, linestyle="--", alpha=0.4) 66 | plt.tight_layout() 67 | plt.savefig("quant_error_cdf.png", dpi=160) 68 | plt.show() 69 | 70 | print("Original (10): ", tensor[:10].tolist()) 71 | print("Nearest (10): ", dq_near[:10].tolist()) 72 | print("Stochastic (10):", dq_sto[:10].tolist()) -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.20) 2 | 3 | project(piquant LANGUAGES CXX) 4 | 5 | set(CMAKE_CXX_STANDARD 20) 6 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 7 | 8 | option(QUANT_BUILD_TESTS "Build tests" ON) 9 | option(QUANT_BUILD_BENCHMARKS "Build benchmarks" ON) 10 | 11 | if(CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)") 12 | set(IS_AMD64 TRUE) 13 | else() 14 | set(IS_AMD64 FALSE) 15 | endif() 16 | 17 | if(CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64)|(arm64)") 18 | set(IS_ARM64 TRUE) 19 | else() 20 | set(IS_ARM64 FALSE) 21 | endif() 22 | 23 | message(STATUS "Building for ${CMAKE_SYSTEM_PROCESSOR}") 24 | file(GLOB QUANT_SOURCES include/*.hpp src/*.cpp src/*.hpp src/*.inl) 25 | if (${IS_AMD64}) 26 | function(set_file_opts filename posix_arch msvc_arch) 27 | message(STATUS "BLAS CPU permutation ${filename} ${posix_arch} / ${msvc_arch}") 28 | if (WIN32) 29 | set_property(SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/src/${filename}" APPEND PROPERTY COMPILE_FLAGS "${msvc_arch}") 30 | else() 31 | set_property(SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/src/${filename}" APPEND PROPERTY COMPILE_FLAGS "${posix_arch}") 32 | endif() 33 | endfunction() 34 | 35 | set(QUANT_SOURCES_AMD64 36 | src/amd64/kernel_amd64_sse42.cpp 37 | src/amd64/kernel_amd64_avx2.cpp 38 | src/amd64/kernel_amd64_avx512f.cpp 39 | src/amd64/kernel_amd64_avx512f_bf16.cpp 40 | ) 41 | set(QUANT_SOURCES ${QUANT_SOURCES} ${QUANT_SOURCES_AMD64}) 42 | 43 | set_file_opts("amd64/kernel_amd64_sse42.cpp" "-mtune=nehalem -msse4.2" "/arch:SSE4.2") 44 | set_file_opts("amd64/kernel_amd64_avx2.cpp" "-mtune=skylake -mavx -mavx2 -mfma -mf16c" "/arch:AVX2") 45 | set_file_opts("amd64/kernel_amd64_avx512f.cpp" "-mtune=cannonlake -mavx -mavx2 -mfma -mf16c -mavx512f -mavx512bw" "/arch:AVX512") 46 | set_file_opts("amd64/kernel_amd64_avx512f_bf16.cpp" "-mtune=cannonlake -mavx -mavx2 -mfma -mf16c -mavx512f -mavx512bf16" "/arch:AVX512") 47 | 48 | endif() 49 | 50 | if (QUANT_COMPILE_STATIC) 51 | add_library(piquant STATIC ${QUANT_SOURCES}) 52 | else() 53 | add_library(piquant SHARED ${QUANT_SOURCES}) 54 | target_compile_definitions(piquant PRIVATE QUANT_BUILD_SHARED) 55 | add_compile_options(-fPIC) 56 | endif () 57 | 58 | target_compile_options(piquant PRIVATE -fomit-frame-pointer -fno-rtti) 59 | target_include_directories(piquant PUBLIC include) 60 | 61 | # add threadpool library 62 | if (NOT TARGET threadpool) 63 | add_subdirectory(third_party/threadpool) 64 | endif() 65 | 66 | target_link_libraries(piquant PUBLIC threadpool) 67 | 68 | # release mode opt 69 | if (CMAKE_BUILD_TYPE STREQUAL "Release") 70 | target_compile_options(piquant PRIVATE -O3) 71 | endif() 72 | 73 | if (${QUANT_BUILD_TESTS}) 74 | enable_testing() 75 | add_subdirectory(test) 76 | endif() 77 | 78 | if (${QUANT_BUILD_BENCHMARKS}) 79 | add_subdirectory(benchmark) 80 | endif() 81 | -------------------------------------------------------------------------------- /python/src/piquant/_bootstrap.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from typing import List, Tuple 5 | from cffi import FFI 6 | 7 | import sys 8 | 9 | _NATIVE_MODULES: List[Tuple[str, str]] = [ 10 | ('win32', 'piquant.dll'), 11 | ('linux', 'libpiquant.so'), 12 | ('darwin', 'libpiquant.dylib'), 13 | ] 14 | 15 | _CDECLS: str = """ 16 | 17 | typedef struct piquant_context_t piquant_context_t; 18 | 19 | typedef enum piquant_round_mode_t { 20 | PIQUANT_NEAREST, 21 | PIQUANT_STOCHASTIC 22 | } piquant_round_mode_t; 23 | 24 | typedef enum piquant_reduce_op_t { 25 | PIQUANT_REDUCE_OP_SET, 26 | PIQUANT_REDUCE_OP_ADD, 27 | } piquant_reduce_op_t; 28 | 29 | typedef enum piquant_dtype_t { 30 | PIQUANT_DTYPE_F32 = 0, 31 | PIQUANT_DTYPE_BF16, 32 | 33 | PIQUANT_DTYPE_UINT2, 34 | PIQUANT_DTYPE_UINT4, 35 | PIQUANT_DTYPE_UINT8 36 | } piquant_dtype_t; 37 | 38 | extern piquant_context_t* piquant_context_create(size_t num_threads); 39 | extern void piquant_context_destroy(piquant_context_t* ctx); 40 | 41 | extern void piquant_quantize( 42 | piquant_context_t* ctx, 43 | const void* in, 44 | piquant_dtype_t dtype_in, 45 | void* out, 46 | piquant_dtype_t dtype_out, 47 | size_t numel, 48 | float scale, 49 | int64_t zero_point, 50 | piquant_round_mode_t mode 51 | ); 52 | 53 | extern void piquant_dequantize( 54 | piquant_context_t* ctx, 55 | const void* in, 56 | piquant_dtype_t dtype_in, 57 | void* out, 58 | piquant_dtype_t dtype_out, 59 | size_t numel, 60 | float scale, 61 | int64_t zero_point, 62 | piquant_reduce_op_t op 63 | ); 64 | 65 | extern void piquant_compute_quant_params_float32( 66 | piquant_context_t* ctx, 67 | const float* x, 68 | size_t n, 69 | piquant_dtype_t target_quant_dtype, 70 | float* out_scale, 71 | int64_t* out_zero_point 72 | ); 73 | 74 | extern void piquant_compute_quant_params_bfloat16( 75 | piquant_context_t* ctx, 76 | const uint16_t* x, 77 | size_t n, 78 | piquant_dtype_t target_quant_dtype, 79 | float* out_scale, 80 | int64_t* out_zero_point 81 | ); 82 | """ 83 | 84 | 85 | def _load_native_module() -> Tuple[FFI, object]: 86 | platform = sys.platform 87 | lib_name = next((lib for os, lib in _NATIVE_MODULES if platform.startswith(os)), None) 88 | assert lib_name, f'Unsupported platform: {platform}' 89 | 90 | # Locate the library in the package directory 91 | pkg_path = Path(__file__).parent 92 | lib_path = pkg_path / lib_name 93 | assert lib_path.exists(), f'piquant shared library not found: {lib_path}' 94 | 95 | ffi = FFI() 96 | ffi.cdef(_CDECLS) # Define the _C declarations 97 | lib = ffi.dlopen(str(lib_path)) # Load the shared library 98 | return ffi, lib 99 | 100 | 101 | ffi, C = _load_native_module() 102 | -------------------------------------------------------------------------------- /python/benchmark/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import timeit 3 | import multiprocessing 4 | 5 | os.environ['CUDA_VISIBLE_DEVICES'] = '' # Disable CUDA 6 | os.environ['OMP_NUM_THREADS'] = str(multiprocessing.cpu_count()) # OpenMP 7 | os.environ['MKL_NUM_THREADS'] = str(multiprocessing.cpu_count()) # MKL 8 | 9 | import piquant 10 | import torch 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | 14 | torch.set_num_threads(multiprocessing.cpu_count()) 15 | 16 | NUM_RUNS: int = 1_000 17 | NUMEL: int = 1000000 18 | 19 | QUANT_DTYPES_TO_BENCH: list[torch.dtype] = [ 20 | torch.quint8, 21 | torch.quint4x2, 22 | torch.quint2x4 23 | ] 24 | 25 | def quantize_torch(t: torch.Tensor, scale: float, zp: int, dtype: torch.dtype) -> torch.tensor: 26 | return torch.quantize_per_tensor(t, scale=scale, zero_point=zp, dtype=dtype) 27 | 28 | 29 | def quantize_piquant(t: torch.Tensor, scale: float, zp: int, dtype: torch.dtype) -> torch.tensor: 30 | return piquant.torch.quantize(t, scale=scale, zero_point=zp, dtype=dtype) 31 | 32 | 33 | dtype_labels: list[str] = [] 34 | torch_times: list[float] = [] 35 | piquant_times: list[float] = [] 36 | 37 | for torch_d in QUANT_DTYPES_TO_BENCH: 38 | tensor = torch.rand(NUMEL, dtype=torch.float32, device='cpu') 39 | torch_results = [] 40 | results_piquant = [] 41 | 42 | scale, zp = piquant.torch.compute_quant_params(tensor, dtype=torch_d) 43 | zp = int(zp) 44 | 45 | def _bench_torch() -> None: 46 | torch_results.append(quantize_torch(tensor, scale, zp, torch_d)) 47 | 48 | def _bench_piquant() -> None: 49 | results_piquant.append(quantize_piquant(tensor, scale, zp, torch_d)) 50 | 51 | # Warmup runs 52 | _bench_torch() 53 | _bench_piquant() 54 | 55 | torch_time = timeit.timeit(_bench_torch, number=NUM_RUNS) 56 | piquant_time = timeit.timeit(_bench_piquant, number=NUM_RUNS) 57 | dtype_labels.append(str(torch_d).replace('torch.', '')) 58 | torch_times.append(torch_time) 59 | piquant_times.append(piquant_time) 60 | 61 | # Verify that the results are the same 62 | for i in range(NUM_RUNS): # We compare dequantized results, because .int_repr() is implemented for packed types in torch 63 | dq_torch = torch_results[i].dequantize() 64 | dq_piquant = piquant.torch.dequantize(results_piquant[i], scale=scale, zero_point=zp, dtype=torch.float32) 65 | assert dq_torch.numel() == dq_piquant.numel() 66 | assert dq_torch.dtype == dq_piquant.dtype 67 | if not torch.allclose(dq_torch, dq_piquant, atol=1e-1): 68 | print(f"Results differ for dtype {torch_d} at run {i}") 69 | for j in range(dq_torch.numel()): 70 | if not torch.isclose(dq_torch[j], dq_piquant[j], atol=1e-1): 71 | print(f" Index {j}: torch={dq_torch[j]}, piquant={dq_piquant[j]}") 72 | print(f'{dtype_labels[-1]:<10} | torch: {torch_time:.6f}s | piquant: {piquant_time:.6f}s') 73 | 74 | 75 | x = np.arange(len(dtype_labels)) 76 | width = 0.35 77 | plt.figure(figsize=(8, 5)) 78 | plt.bar(x - width / 2, torch_times, width, label='torch') 79 | plt.bar(x + width / 2, piquant_times, width, label='piquant') 80 | plt.ylabel(f'Total time for {NUM_RUNS} runs (s)') 81 | plt.xticks(x, dtype_labels) 82 | plt.title('Quantization Benchmark: PyTorch vs. piquant') 83 | plt.legend() 84 | plt.tight_layout() 85 | plt.savefig('quant_benchmark.png', dpi=300) 86 | plt.show() 87 | -------------------------------------------------------------------------------- /python/README.md: -------------------------------------------------------------------------------- 1 | # pi-quant: Prime Intellect Fast Quantization Library 2 | ![logo.png](../media/logo.png) 3 | ## Overview 4 | 5 | **Fast, multithreaded CPU quantization kernels** with various rounding modes, outperforming PyTorch’s built-in quantization routines by **more than 2 times** on all tested hardware. 6 | The kernels are optimized with SIMD intrinsics for different CPU architectures, including **AMD64** (SSE4.2, AVX2, AVX512F) and **ARM64** (Neon). The most optimal kernel is selected at runtime using runtime CPU detection. 7 | 8 | ## What is Quantization? 9 | 10 | Quantization is the process of mapping continuous values into a finite, discrete set of values. In machine learning and signal processing, it is commonly used to **reduce the precision of numerical data**, lowering memory usage and improving computational efficiency while maintaining acceptable accuracy. 11 | 12 | ## Features 13 | 14 | ✅ **Parallel De/Quantization**: Efficiently quantizes and de-quantizes data using multiple threads. 15 | 16 | ✅ **Rich Datatype Support:** Provides f32, f64 ↔ (u)int8/16/32/64. 17 | 18 | ✅ **Modern Python API:** Use the library from Python with PyTorch, numpy or standalone. 19 | 20 | ✅ **Architecture-Specific Optimizations**: The kernels are optimized with SIMD intrinsics for different CPU architectures, including **AMD64** (SSE4.2, AVX2, AVX512F) and **ARM64** (Neon). 21 | 22 | ✅ **Thread Pool**: Reuses threads for minimal overhead. 23 | 24 | ✅ **Flexible Rounding Modes**: Supports both **nearest** and **stochastic** rounding modes. 25 | 26 | ✅ **C99 API**: Provides a C99 API for C projects or foreign language bindings (see `quant.h`). 27 | 28 | ✅ **Store Operators:** Supports multiple store modes (SET, ADD) during dequantization — useful for ring-reduction operations. 29 | 30 | ✅ **Quantization Parameters:** Efficient SIMD-parallel computation of quantization scale and zero point from input data. 31 | 32 | # Benchmarks 33 | 34 | ## Benchmark 35 | 36 | The benchmarks were run on a variety of hardware. We benchmark against PyTorch’s [**torch.quantize_per_tensor**](https://pytorch.org/docs/stable/generated/torch.quantize_per_tensor.html) and **[torch.ao.quantization.fx._decomposed.quantize_per_tensor**.](https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py) Each benchmark quantized float32 to uint8 across **1000 runs**. The number of elements and other details can be seen in the [benchmark code](https://github.com/PrimeIntellect-ai/quantization-kernels/blob/main/python/benchmark/benchmark.py). 37 | 38 | ### Benchmark 1 (AMD EPYC 9654, 360 vCPUs) 39 | 40 | 1000 runs with numel 27264000
41 | CPU: AMD EPYC 9654 96-Core Processor, Runtime: AVX512-F
42 | Memory: 1485 GB
43 | Linux: 6.8.0-57-generic
44 | 45 | ![bench1.png](media/bench1.png) 46 | **Torch FX Quant** refers to **[torch.ao.quantization.fx._decomposed.quantize_per_tensor](https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py),** 47 | **Torch Builtin Quant** to [**torch.quantize_per_tensor](https://pytorch.org/docs/stable/generated/torch.quantize_per_tensor.html)** and **Fast Quant** to **pi-quant’s [piquant.quantize_torch](https://github.com/PrimeIntellect-ai/piquant/blob/4bcf6ebc69bf9b44f89b13965f010a1d025a59f6/python/src/piquant/_torch.py#L52).** 48 | 49 | ### Benchmark 2 (AMD EPYC 7742, 128 vCPUs) 50 | 51 | 1000 runs with numel 27264000
52 | CPU: AMD EPYC 7742 64-Core Processor, Runtime: AVX2
53 | Memory: 528 GB
54 | Linux: 6.8.0-1023-nvidia
55 | ![bench2.png](media/bench2.png) 56 | 57 | ### Benchmark 3 (Apple M3 Pro) 58 | 59 | 1000 runs with numel 27264000
60 | CPU: Apple M3 Pro, Runtime: Neon
61 | Memory: 18 GB
62 | OSX: 15.4 (24E248)
63 | ![bench3.png](media/bench3.png) -------------------------------------------------------------------------------- /python/src/piquant/torch.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | 5 | from . import * 6 | 7 | import torch 8 | 9 | _TORCH_DTYPE_MAP: dict[torch.dtype, DataType] = { 10 | torch.float32: DataType.F32, 11 | torch.bfloat16: DataType.BF16, 12 | torch.quint2x4: DataType.UINT2, 13 | torch.quint4x2: DataType.UINT4, 14 | torch.quint8: DataType.UINT8, 15 | torch.uint8: DataType.UINT8, 16 | } 17 | 18 | _QUANT_TYPES: set[torch.dtype] = { 19 | torch.quint2x4, 20 | torch.quint4x2, 21 | torch.quint8, 22 | torch.uint8, 23 | } 24 | 25 | _DEQUANT_TYPES: set[torch.dtype] = { 26 | torch.float32, 27 | torch.bfloat16, 28 | } 29 | 30 | _ROUND_MODES: dict[str, RoundMode] = { 31 | 'nearest': RoundMode.NEAREST, 32 | 'stochastic': RoundMode.STOCHASTIC, 33 | } 34 | 35 | _REDUCE_OPS: dict[str, ReduceOp] = { 36 | 'set': ReduceOp.SET, 37 | 'add': ReduceOp.ADD, 38 | } 39 | 40 | def torch_to_piquant_dtype(dtype: torch.dtype) -> DataType: 41 | if dtype not in _TORCH_DTYPE_MAP: 42 | raise ValueError(f'Unsupported quant_dtype: {dtype}') 43 | return _TORCH_DTYPE_MAP[dtype] 44 | 45 | 46 | def piquant_to_torch_dtype(dtype: DataType) -> torch.dtype: 47 | for dtype, piquant_dtype in _TORCH_DTYPE_MAP.items(): 48 | if piquant_dtype == dtype: 49 | return dtype 50 | raise ValueError(f'Unsupported quantized dtype: {dtype}') 51 | 52 | 53 | def compute_quant_params( 54 | tensor: torch.Tensor, 55 | *, 56 | dtype: torch.dtype, 57 | ctx: Context = Context.get() 58 | ) -> Tuple[float, int]: 59 | assert dtype in _QUANT_TYPES, f'Unsupported quantized dtype: {dtype}. Must be one of {list(_QUANT_TYPES)}' 60 | 61 | if not tensor.is_contiguous(): 62 | tensor = tensor.contiguous() 63 | 64 | if tensor.dtype == torch.bfloat16: 65 | return ctx.compute_quant_params_ptr_bfloat16(tensor.data_ptr(), torch_to_piquant_dtype(dtype), tensor.numel()) 66 | else: 67 | return ctx.compute_quant_params_ptr_float32(tensor.data_ptr(), torch_to_piquant_dtype(dtype), tensor.numel()) 68 | 69 | 70 | def quantize( 71 | tensor: torch.Tensor, 72 | *, 73 | scale: float, 74 | zero_point: int, 75 | dtype: torch.dtype, 76 | round_mode: str = 'nearest', 77 | ctx: Context = Context.get(), 78 | ) -> torch.Tensor: 79 | assert dtype in _QUANT_TYPES, f'Unsupported quantized dtype: {dtype}. Must be one of {list(_QUANT_TYPES)}' 80 | 81 | if not tensor.is_contiguous(): 82 | tensor = tensor.contiguous() 83 | 84 | dtype_in = torch_to_piquant_dtype(tensor.dtype) 85 | dtype_out = torch_to_piquant_dtype(dtype) 86 | 87 | out = torch.empty(tensor.shape, dtype=dtype) 88 | 89 | ctx.quantize_ptr( 90 | tensor.data_ptr(), 91 | dtype_in, 92 | out.data_ptr(), 93 | dtype_out, 94 | numel=tensor.numel(), 95 | scale=scale, 96 | zero_point=zero_point, 97 | round_mode=_ROUND_MODES[round_mode], 98 | ) 99 | return out 100 | 101 | 102 | def dequantize( 103 | tensor: torch.Tensor, 104 | *, 105 | scale: float, 106 | zero_point: int, 107 | dtype: torch.dtype, 108 | reduce_op: str = 'set', 109 | ctx: Context = Context.get(), 110 | ) -> torch.Tensor: 111 | if dtype not in _DEQUANT_TYPES: 112 | raise ValueError(f'Unsupported dequantized dtype: {dtype}. Must be one of {list(_DEQUANT_TYPES)}') 113 | 114 | if not tensor.is_contiguous(): 115 | tensor = tensor.contiguous() 116 | 117 | out = torch.empty(tensor.shape, dtype=dtype) 118 | 119 | ctx.dequantize_ptr( 120 | tensor.data_ptr(), 121 | torch_to_piquant_dtype(tensor.dtype), 122 | out.data_ptr(), 123 | torch_to_piquant_dtype(out.dtype), 124 | numel=tensor.numel(), 125 | scale=scale, 126 | zero_point=zero_point, 127 | reduce_op=_REDUCE_OPS[reduce_op], 128 | ) 129 | return out 130 | -------------------------------------------------------------------------------- /.github/workflows/cmake-multi-platform.yml: -------------------------------------------------------------------------------- 1 | name: Multi-Platform CMake Build 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "*" ] 8 | 9 | jobs: 10 | build: 11 | runs-on: ${{ matrix.os }} 12 | 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | os: [ ubuntu-latest, macos-latest, ubuntu-22.04 ] 17 | build_type: [ Release ] 18 | c_compiler: [ gcc, clang ] 19 | cpp_compiler: [ g++, clang++ ] 20 | include: 21 | - os: ubuntu-latest 22 | c_compiler: gcc 23 | cpp_compiler: g++ 24 | - os: ubuntu-latest 25 | c_compiler: clang 26 | cpp_compiler: clang++ 27 | - os: ubuntu-22.04 28 | c_compiler: gcc 29 | cpp_compiler: g++ 30 | - os: ubuntu-22.04 31 | c_compiler: clang 32 | cpp_compiler: clang++ 33 | - os: macos-latest 34 | c_compiler: clang 35 | cpp_compiler: clang++ 36 | 37 | steps: 38 | - uses: actions/checkout@v4 39 | with: 40 | submodules: recursive 41 | 42 | - name: Set reusable strings 43 | id: strings 44 | shell: bash 45 | run: | 46 | echo "build-output-dir=${{ github.workspace }}/build" >> "$GITHUB_OUTPUT" 47 | 48 | - name: Configure CMake 49 | run: > 50 | cmake -B ${{ steps.strings.outputs.build-output-dir }} 51 | -DCMAKE_CXX_COMPILER=${{ matrix.cpp_compiler }} 52 | -DCMAKE_C_COMPILER=${{ matrix.c_compiler }} 53 | -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} 54 | -S ${{ github.workspace }} 55 | 56 | - name: Build C++ Runtime 57 | run: cmake --build ${{ steps.strings.outputs.build-output-dir }} --config ${{ matrix.build_type }} -j4 58 | 59 | - name: Run C++ Runtime Tests 60 | working-directory: ${{ steps.strings.outputs.build-output-dir }} 61 | run: ctest --build-config ${{ matrix.build_type }} --verbose 62 | 63 | - name: Setup Python 64 | uses: actions/setup-python@v5 65 | with: 66 | python-version: '3.12' 67 | 68 | - name: Create Virtual Environment 69 | if: runner.os != 'windows' 70 | shell: bash 71 | run: | 72 | cd ${{ github.workspace }}/python/ 73 | python -m venv venv 74 | 75 | - name: Build Python wheel 76 | if: runner.os != 'windows' 77 | shell: bash 78 | run: | 79 | cd ${{ github.workspace }}/python/ 80 | ${{ github.workspace }}/python/venv/bin/python -m pip wheel --verbose -w dist . 81 | 82 | - name: Install dev dependencies 83 | if: runner.os != 'windows' 84 | shell: bash 85 | run: | 86 | cd ${{ github.workspace }}/python/ 87 | ${{ github.workspace }}/python/venv/bin/python -m pip install .[dev] 88 | 89 | - name: Install Python wheel 90 | if: runner.os != 'windows' 91 | shell: bash 92 | run: | 93 | cd ${{ github.workspace }}/python/ 94 | ${{ github.workspace }}/python/venv/bin/python -m pip install ${{ github.workspace }}/python/dist/*.whl 95 | 96 | - name: Run Python Tests 97 | if: runner.os != 'windows' 98 | shell: bash 99 | run: | 100 | cd ${{ github.workspace }}/python/ 101 | ${{ github.workspace }}/python/venv/bin/python -m pytest tests/* 102 | 103 | windows: 104 | runs-on: windows-latest 105 | steps: 106 | - uses: actions/checkout@v4 107 | with: 108 | submodules: recursive 109 | 110 | - name: Configure CMake (MSVC) 111 | run: cmake -S ${{ github.workspace }} -B ${{ github.workspace }}\build -G "Visual Studio 17 2022" -A x64 -DQUANT_COMPILE_STATIC=ON 112 | 113 | - name: Build C++ Runtime (MSVC) 114 | run: cmake --build ${{ github.workspace }}\build --config Release 115 | 116 | - name: Run C++ Runtime Tests (MSVC) 117 | run: | 118 | cd ${{ github.workspace }}\build 119 | ctest --build-config Release --verbose -------------------------------------------------------------------------------- /src/capi.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "piquant_internal.hpp" 4 | 5 | #include 6 | 7 | using namespace piquant; 8 | 9 | static_assert(static_cast(dtype::f32) == PIQUANT_DTYPE_F32); 10 | static_assert(static_cast(dtype::bf16) == PIQUANT_DTYPE_BF16); 11 | static_assert(static_cast(dtype::uint2) == PIQUANT_DTYPE_UINT2); 12 | static_assert(static_cast(dtype::uint4) == PIQUANT_DTYPE_UINT4); 13 | static_assert(static_cast(dtype::uint8) == PIQUANT_DTYPE_UINT8); 14 | 15 | struct piquant_context_t final { 16 | context* ctx {}; 17 | }; 18 | 19 | extern "C" auto piquant_context_create(const std::size_t num_threads) -> piquant_context_t* { 20 | auto* ctx {new context{num_threads}}; // todo 21 | return std::bit_cast(ctx); 22 | } 23 | 24 | extern "C" auto piquant_context_destroy(piquant_context_t* ctx) -> void { 25 | delete std::bit_cast(ctx); 26 | } 27 | 28 | extern "C" auto piquant_quantize( 29 | piquant_context_t* ctx, 30 | const void* in, 31 | piquant_dtype_t dtype_in, 32 | void* out, 33 | piquant_dtype_t dtype_out, 34 | size_t numel, 35 | fp32_t scale, 36 | int64_t zero_point, 37 | piquant_round_mode_t mode 38 | ) -> void { 39 | const auto& dti {dtype_info_of(static_cast(dtype_in))}; 40 | const auto& dto {dtype_info_of(static_cast(dtype_out))}; 41 | std::size_t in_bytes {numel*dti.stride}; 42 | std::size_t out_bytes {dto.bit_size == 8 ? numel*dto.stride : packed_numel(numel, dto)*dto.stride}; 43 | std::span in_span {static_cast(in), in_bytes}; 44 | std::span out_span {static_cast(out), out_bytes}; 45 | std::bit_cast(ctx)->quantize( 46 | in_span, 47 | static_cast(dtype_in), 48 | out_span, 49 | static_cast(dtype_out), 50 | scale, 51 | zero_point, 52 | static_cast(mode) 53 | ); 54 | } 55 | 56 | extern "C" auto piquant_dequantize( 57 | piquant_context_t* ctx, 58 | const void* in, 59 | piquant_dtype_t dtype_in, 60 | void* out, 61 | piquant_dtype_t dtype_out, 62 | size_t numel, 63 | fp32_t scale, 64 | int64_t zero_point, 65 | piquant_reduce_op_t op 66 | ) -> void { 67 | const auto& dti {dtype_info_of(static_cast(dtype_in))}; 68 | const auto& dto {dtype_info_of(static_cast(dtype_out))}; 69 | std::size_t in_bytes {dti.bit_size == 8 ? numel*dti.stride : packed_numel(numel, dti)*dti.stride}; 70 | std::size_t out_bytes {numel*dto.stride}; 71 | std::span in_span {static_cast(in), in_bytes}; 72 | std::span out_span {static_cast(out), out_bytes}; 73 | std::bit_cast(ctx)->dequantize( 74 | in_span, 75 | static_cast(dtype_in), 76 | out_span, 77 | static_cast(dtype_out), 78 | scale, 79 | zero_point, 80 | static_cast(op) 81 | ); 82 | } 83 | 84 | extern "C" auto piquant_compute_quant_params_float32(piquant_context_t* ctx, const fp32_t* const x, const std::size_t n, const piquant_dtype_t target_quant_dtype, fp32_t* const out_scale, int64_t* const out_zero_point) -> void { 85 | const auto [scale, zero_point] { 86 | std::bit_cast(ctx)->compute_quant_config_from_data( 87 | std::span{x, n}, 88 | static_cast(target_quant_dtype) 89 | ) 90 | }; 91 | *out_scale = scale; 92 | *out_zero_point = zero_point; 93 | } 94 | 95 | extern "C" auto piquant_compute_quant_params_bfloat16(piquant_context_t* ctx, const uint16_t* const x, const std::size_t n, const piquant_dtype_t target_quant_dtype, fp32_t* const out_scale, int64_t* const out_zero_point) -> void { 96 | const auto [scale, zero_point] { 97 | std::bit_cast(ctx)->compute_quant_config_from_data( 98 | std::span{reinterpret_cast(x), n}, 99 | static_cast(target_quant_dtype) 100 | ) 101 | }; 102 | *out_scale = scale; 103 | *out_zero_point = zero_point; 104 | } 105 | -------------------------------------------------------------------------------- /python/.gitignore: -------------------------------------------------------------------------------- 1 | uv.lock 2 | .vscode/* 3 | logs/* 4 | wandb/* 5 | datasets/* 6 | 7 | benchmark/benchmark.png 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # _C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | cover/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | .pybuilder/ 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | # For a library or package, you might want to ignore these files since the code is 95 | # intended to run in multiple environments; otherwise, check them in: 96 | # .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # poetry 106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 107 | # This is especially recommended for binary packages to ensure reproducibility, and is more 108 | # commonly ignored for libraries. 109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 110 | #poetry.lock 111 | 112 | # pdm 113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 114 | #pdm.lock 115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 116 | # in version control. 117 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 118 | .pdm.toml 119 | .pdm-python 120 | .pdm-build/ 121 | 122 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 123 | __pypackages__/ 124 | 125 | # Celery stuff 126 | celerybeat-schedule 127 | celerybeat.pid 128 | 129 | # SageMath parsed files 130 | *.sage.py 131 | 132 | # Environments 133 | .env 134 | .venv 135 | env/ 136 | venv/ 137 | ENV/ 138 | env.bak/ 139 | venv.bak/ 140 | 141 | # Spyder project settings 142 | .spyderproject 143 | .spyproject 144 | 145 | # Rope project settings 146 | .ropeproject 147 | 148 | # mkdocs documentation 149 | /site 150 | 151 | # mypy 152 | .mypy_cache/ 153 | .dmypy.json 154 | dmypy.json 155 | 156 | # Pyre type checker 157 | .pyre/ 158 | 159 | # pytype static type analyzer 160 | .pytype/ 161 | 162 | # Cython debug symbols 163 | cython_debug/ 164 | 165 | # PyCharm 166 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 167 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 168 | # and can be added to the global gitignore or merged into this file. For a more nuclear 169 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 170 | #.idea/ 171 | 172 | # Aider 173 | .aider* 174 | 175 | # Files created while testing 176 | debug_I2_zero_band 177 | -------------------------------------------------------------------------------- /test/dequant.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | 14 | constexpr std::size_t iters {10}; 15 | 16 | using namespace piquant; 17 | 18 | #define test_dequant(ti, to, rnd, reduce) \ 19 | TEST(dequantize, dequantize_##ti##_to_##to##_##rnd##_##reduce) { \ 20 | std::mt19937 gen {0x9032002}; \ 21 | std::uniform_real_distribution dist {-1.0, 1.0}; \ 22 | const auto adjusted_epsilon {std::is_same_v ? 2.0f : std::is_same_v ? 0.2f : 0.05}; \ 23 | for (std::size_t n {}; n < iters; ++n) { \ 24 | std::size_t numel {std::uniform_int_distribution{5000, 1'5000}(gen)}; \ 25 | std::size_t numel_out {std::is_same_v ? (numel+3)>>2 : std::is_same_v ? (numel+1)>>1 : numel}; \ 26 | \ 27 | std::vector data_in {}; \ 28 | std::vector quantized {}; \ 29 | data_in.resize(numel); \ 30 | quantized.resize(numel_out); \ 31 | std::ranges::generate(data_in, [&] { return dist(gen); }); \ 32 | piquant::context ctx {4}; \ 33 | auto [scale, zero_point] {ctx.compute_quant_config_from_data(data_in, dtype_traits::type_code)}; \ 34 | ctx.quantize_generic(data_in, quantized, scale, zero_point, piquant::round_mode::rnd); \ 35 | std::vector dequantized {}; \ 36 | dequantized.resize(numel); \ 37 | ti prev {piquant::reduce_op::reduce == piquant::reduce_op::add ? dist(gen) : 0.0f}; \ 38 | std::ranges::fill(dequantized, prev); \ 39 | ctx.dequantize_generic(quantized, dequantized, scale, zero_point, piquant::reduce_op::reduce); \ 40 | for (std::size_t i {}; i < numel; ++i) { \ 41 | const auto a {static_cast(data_in[i])}; \ 42 | const auto b {static_cast(dequantized[i]-prev)}; \ 43 | const auto delta {std::abs(a - b)}; \ 44 | bool is_near {delta <= adjusted_epsilon}; \ 45 | if (!is_near) { \ 46 | std::cout << "Mismatch at index " << i << ": " << a << " != " << b << std::endl; \ 47 | std::cout << "Numel in: " << numel << " Numel out: " << numel_out << std::endl; \ 48 | std::cout << "Delta: " << delta << " ZP: " << zero_point << " Scale: " << scale << std::endl; \ 49 | std::cout << "Zero point: " << zero_point << " Scale: " << scale << std::endl; \ 50 | std::cout << "IN: ["; \ 51 | for (std::size_t j {}; j < numel; ++j) { \ 52 | std::cout << static_cast(data_in[j]) << ", "; \ 53 | } \ 54 | std::cout << "]" << std::endl; \ 55 | std::cout << "OT: ["; \ 56 | for (std::size_t j {}; j < numel; ++j) { \ 57 | std::cout << static_cast(dequantized[j]) << ", "; \ 58 | } \ 59 | std::cout << "]" << std::endl; \ 60 | ASSERT_TRUE(is_near); \ 61 | } \ 62 | } \ 63 | } \ 64 | } 65 | 66 | test_dequant(fp32_t, uint2_t, nearest, set) 67 | test_dequant(fp32_t, uint2_t, stochastic, set) 68 | test_dequant(fp32_t, uint2_t, nearest, add) 69 | test_dequant(fp32_t, uint2_t, stochastic, add) 70 | test_dequant(fp32_t, uint4_t, nearest, set) 71 | test_dequant(fp32_t, uint4_t, stochastic, set) 72 | test_dequant(fp32_t, uint4_t, nearest, add) 73 | test_dequant(fp32_t, uint4_t, stochastic, add) 74 | test_dequant(fp32_t, uint8_t, nearest, set) 75 | test_dequant(fp32_t, uint8_t, stochastic, set) 76 | test_dequant(fp32_t, uint8_t, nearest, add) 77 | test_dequant(fp32_t, uint8_t, stochastic, add) 78 | test_dequant(bfp16_t, uint2_t, nearest, set) 79 | test_dequant(bfp16_t, uint2_t, stochastic, set) 80 | test_dequant(bfp16_t, uint2_t, nearest, add) 81 | test_dequant(bfp16_t, uint2_t, stochastic, add) 82 | test_dequant(bfp16_t, uint4_t, nearest, set) 83 | test_dequant(bfp16_t, uint4_t, stochastic, set) 84 | test_dequant(bfp16_t, uint4_t, nearest, add) 85 | test_dequant(bfp16_t, uint4_t, stochastic, add) 86 | test_dequant(bfp16_t, uint8_t, nearest, set) 87 | test_dequant(bfp16_t, uint8_t, stochastic, set) 88 | test_dequant(bfp16_t, uint8_t, nearest, add) 89 | test_dequant(bfp16_t, uint8_t, stochastic, add) 90 | -------------------------------------------------------------------------------- /test/naive.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include "piquant.hpp" 14 | #include "../src/piquant_internal.hpp" 15 | 16 | // Xorshift 128 plus PRNG (scalar) used for stochastic rounding. 17 | // Generates a canonical float ∈ [0, 1) using a 64-bit state. 18 | struct xs128p_state final { 19 | std::uint64_t p1 {}; 20 | std::uint64_t p2 {}; 21 | 22 | constexpr xs128p_state(std::uint64_t p1, std::uint64_t p2) noexcept : p1{p1}, p2{p2} {} 23 | 24 | [[nodiscard]] auto operator ()() noexcept -> std::uint64_t { 25 | std::uint64_t s1 {p1}; 26 | std::uint64_t s0 {p2}; 27 | p1 = s0; 28 | s1 ^= s1<<23; 29 | p2 = s1^s0^(s1>>18)^(s0>>5); 30 | return p2 + s0; 31 | } 32 | 33 | [[nodiscard]] auto canonical() noexcept -> float { 34 | static constexpr auto bias_scale {1.0f/static_cast(0x800000)}; 35 | std::uint64_t y {~0u & (*this)()}; 36 | return (bias_scale*(static_cast(y>>9) + 0.5f)); 37 | } 38 | }; 39 | 40 | static constinit xs128p_state s_sprng {0x123456789abcdef0, 0x0fedcba987654321}; 41 | 42 | [[nodiscard]] inline auto xs32_canonical() noexcept -> float { 43 | return s_sprng.canonical(); 44 | } 45 | 46 | [[nodiscard]] static constexpr auto pack_nibbles(piquant::uint4_t x, piquant::uint4_t y) noexcept -> piquant::uint4_t { 47 | auto xi {x.bits}; 48 | auto yi {y.bits}; 49 | return xi&15 | ((yi&15)<<4); 50 | } 51 | 52 | template requires requires { 53 | requires piquant::is_float_type; 54 | requires piquant::is_quant_type; 55 | } 56 | auto quantize_naive( 57 | std::span x, 58 | std::span o, 59 | float scale, 60 | std::int64_t zero_point 61 | ) noexcept -> void { /* Original implementation */ 62 | float inv_scale {1.0f / scale}; 63 | auto Q{[&](const IN x) noexcept -> OUT { 64 | if constexpr (RND == piquant::round_mode::nearest) { 65 | float rnd {std::round(static_cast(x) * inv_scale)}; 66 | auto integral {static_cast(rnd) + zero_point}; 67 | return static_cast(std::clamp(integral, piquant::dtype_limits::min, piquant::dtype_limits::max)); 68 | } else { 69 | float rnd {x * inv_scale}; 70 | float dec {std::abs(rnd - std::trunc(rnd))}; 71 | float xi {xs32_canonical()}; 72 | float adj {xi < dec ? 1.0f : 0.0f}; 73 | if (rnd < 0.0f) adj = -1.0f * adj; 74 | rnd = std::trunc(rnd) + adj; 75 | const auto integral {static_cast(rnd) + zero_point}; 76 | return static_cast(std::clamp(integral, piquant::dtype_limits::min, piquant::dtype_limits::max)); 77 | } 78 | }}; 79 | if constexpr (std::is_same_v) { 80 | std::size_t numel {x.size()}; 81 | std::size_t i {}; 82 | for (i=0; i+1 < numel; i += 2) { 83 | IN a {x[i]}; 84 | IN b {x[i+1]}; 85 | o[i>>1] = pack_nibbles(Q(a), Q(b)); 86 | } 87 | if (numel & 1) { // Handle odd numel 88 | o[i>>1] = pack_nibbles(Q(x[numel-1]), OUT{0}); 89 | o[i>>1].bits &= 15; 90 | } 91 | } else { 92 | for (std::int64_t i {}; i < x.size(); ++i) { 93 | o[i] = Q(x[i]); 94 | } 95 | } 96 | } 97 | 98 | template requires std::is_floating_point_v 99 | [[nodiscard]] auto compute_quant_config_from_data_naive(const T* p, std::int64_t numel, std::int64_t tmax) -> std::pair { 100 | if (!numel) [[unlikely]] return {0.0, 0.0}; 101 | auto mean {static_cast(std::accumulate(p, p+numel, 0.0) / static_cast(numel))}; 102 | auto sq_delta {static_cast(std::transform_reduce( 103 | p, p+numel, 104 | 0.0, 105 | std::plus{}, 106 | [mean](const T value) noexcept -> T { 107 | T delta {value - mean}; 108 | return delta*delta; 109 | } 110 | ))}; 111 | auto std {static_cast(std::sqrt(sq_delta / static_cast(numel-1)))}; 112 | auto scale {static_cast(12.0*std/static_cast(tmax))}; 113 | std::int64_t zp {(tmax>>1) - static_cast(std::round(mean/scale))}; 114 | return {scale, zp}; 115 | } 116 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pi-quant: Prime Intellect Fast Quantization Library 2 | ![logo.png](https://raw.githubusercontent.com/PrimeIntellect-ai/pi-quant/main/media/logo.png) 3 | 4 | ## Overview 5 | 6 | **Fast, multithreaded CPU quantization kernels** with various rounding modes, outperforming PyTorch’s built-in quantization routines by **more than 2 times** on all tested hardware. 7 | The kernels are optimized with SIMD intrinsics for different CPU architectures, including **AMD64** (SSE4.2, AVX2, AVX512F) and **ARM64** (Neon). The most optimal kernel is selected at runtime using runtime CPU detection. 8 | 9 | ## What is Quantization? 10 | 11 | Quantization is the process of mapping continuous values into a finite, discrete set of values. In machine learning and signal processing, it is commonly used to **reduce the precision of numerical data**, lowering memory usage and improving computational efficiency while maintaining acceptable accuracy. 12 | 13 | ## Features 14 | 15 | ✅ **Parallel De/Quantization**: Efficiently quantizes and de-quantizes data using multiple threads. 16 | 17 | ✅ **Rich Datatype Support:** Provides f32, f64 ↔ (u)int4/8/16/32/64. 18 | 19 | ✅ **Modern Python API:** Use the library from Python with PyTorch, numpy or standalone. 20 | 21 | ✅ **Architecture-Specific Optimizations**: The kernels are optimized with SIMD intrinsics for different CPU architectures, including **AMD64** (SSE4.2, AVX2, AVX512F) and **ARM64** (Neon). 22 | 23 | ✅ **Thread Pool**: Reuses threads for minimal overhead. 24 | 25 | ✅ **Flexible Rounding Modes**: Supports both **nearest** and **stochastic** rounding modes. 26 | 27 | ✅ **C99 API**: Provides a C99 API for C projects or foreign language bindings (see `quant.h`). 28 | 29 | ✅ **Store Operators:** Supports multiple store modes (SET, ADD) during dequantization — useful for ring-reduction operations. 30 | 31 | ✅ **Quantization Parameters:** Efficient SIMD-parallel computation of quantization scale and zero point from input data. 32 | 33 | ## Installation 34 | 35 | To install pi-quant from PyPI, run the following command: 36 | ```bash 37 | pip install pypiquant 38 | ``` 39 | 40 | ## Examples 41 | piquant is torch compatible. Here are some examples of how to use it with PyTorch: 42 | 43 | ```python 44 | import torch 45 | import piquant 46 | 47 | # Quantize and back with: bfloat16 -> uint4 -> bfloat16 48 | # In torch, quint4x2 means two 4-bit quantized integers per byte. 49 | tensor = torch.rand(1000, dtype=torch.bfloat16, device='cpu') 50 | 51 | # Compute quantization parameters for uint4 (needed for quantization and dequantization) 52 | scale, zero_point = piquant.torch.compute_quant_params(tensor, dtype=torch.quint4x2) 53 | 54 | # Quantize the tensor to uint4 55 | quantized = piquant.torch.quantize(tensor, scale=scale, zero_point=zero_point, dtype=torch.quint4x2) 56 | 57 | # Dequantize back to bfloat16 58 | dequantized = piquant.torch.dequantize(quantized, scale=scale, zero_point=zero_point, dtype=torch.bfloat16) 59 | 60 | # Check if the dequantized tensor is close to the original tensor 61 | assert torch.allclose(dequantized, tensor, atol=scale*0.5 + 1e-3), "Dequantization did not match original tensor" 62 | 63 | # Print parts of original and dequantized tensors for verification 64 | print("Original tensor (first 10 elements):", tensor[:10].tolist()) 65 | print("Dequant tensor (first 10 elements):", dequantized[:10].tolist()) 66 | ``` 67 | 68 | ## Benchmark 69 | 70 | The benchmarks were run on a variety of hardware. We benchmark against PyTorch’s [**torch.quantize_per_tensor**](https://pytorch.org/docs/stable/generated/torch.quantize_per_tensor.html) and **[torch.ao.quantization.fx._decomposed.quantize_per_tensor**.](https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py) Each benchmark quantized float32 to uint8 across **1000 runs**. The number of elements and other details can be seen in the [benchmark code](https://github.com/PrimeIntellect-ai/quantization-kernels/blob/main/python/benchmark/benchmark.py). 71 | 72 | ### Benchmark 1 (AMD EPYC 9654, 360 vCPUs) 73 | 74 | 1000 runs with numel 27264000
75 | CPU: AMD EPYC 9654 96-Core Processor, Runtime: AVX512-F
76 | Memory: 1485 GB
77 | Linux: 6.8.0-57-generic
78 | 79 | ![bench1.png](https://raw.githubusercontent.com/PrimeIntellect-ai/pi-quant/main/media/bench1.png) 80 | **Torch FX Quant** refers to **[torch.ao.quantization.fx._decomposed.quantize_per_tensor](https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py),** 81 | **Torch Builtin Quant** to **[torch.quantize_per_tensor](https://pytorch.org/docs/stable/generated/torch.quantize_per_tensor.html)** and **Fast Quant** to **pi-quant’s [piquant.quantize_torch](https://github.com/PrimeIntellect-ai/piquant/blob/4bcf6ebc69bf9b44f89b13965f010a1d025a59f6/python/src/piquant/_torch.py#L52).** 82 | 83 | ### Benchmark 2 (AMD EPYC 7742, 128 vCPUs) 84 | 85 | 1000 runs with numel 27264000
86 | CPU: AMD EPYC 7742 64-Core Processor, Runtime: AVX2
87 | Memory: 528 GB
88 | Linux: 6.8.0-1023-nvidia
89 | ![bench2.png](https://raw.githubusercontent.com/PrimeIntellect-ai/pi-quant/main/media/bench2.png) 90 | 91 | ### Benchmark 3 (Apple M3 Pro) 92 | 93 | 1000 runs with numel 27264000
94 | CPU: Apple M3 Pro, Runtime: Neon
95 | Memory: 18 GB
96 | OSX: 15.4 (24E248)
97 | ![bench3.png](https://raw.githubusercontent.com/PrimeIntellect-ai/pi-quant/main/media/bench3.png) 98 | -------------------------------------------------------------------------------- /python/src/piquant/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | __version__ = '0.1.0' 4 | __author__ = 'Mario Sieg' 5 | __email__ = 'mario.sieg.64@gmail.com' 6 | __author_email__ = 'mario.sieg.64@gmail.com' 7 | 8 | import importlib.util 9 | 10 | import weakref 11 | import multiprocessing 12 | 13 | from enum import Enum, unique 14 | from typing import Union, Tuple 15 | from functools import lru_cache 16 | 17 | from piquant._bootstrap import ffi, C 18 | 19 | 20 | @unique 21 | class RoundMode(Enum): 22 | NEAREST = C.PIQUANT_NEAREST 23 | STOCHASTIC = C.PIQUANT_STOCHASTIC 24 | 25 | 26 | @unique 27 | class ReduceOp(Enum): 28 | SET = C.PIQUANT_REDUCE_OP_SET 29 | ADD = C.PIQUANT_REDUCE_OP_ADD 30 | 31 | 32 | @unique 33 | class DataType(Enum): 34 | F32 = C.PIQUANT_DTYPE_F32 35 | BF16 = C.PIQUANT_DTYPE_BF16 36 | UINT2 = C.PIQUANT_DTYPE_UINT2 37 | UINT4 = C.PIQUANT_DTYPE_UINT4 38 | UINT8 = C.PIQUANT_DTYPE_UINT8 39 | 40 | @property 41 | def bit_size(self) -> int: 42 | _BIT_SIZES = { 43 | self.F32: 32, 44 | self.BF16: 16, 45 | self.UINT2: 2, 46 | self.UINT4: 4, 47 | self.UINT8: 8, 48 | } 49 | return _BIT_SIZES[self] 50 | 51 | @property 52 | def is_quantized(self) -> bool: 53 | return self in (self.UINT2, self.UINT4, self.UINT8) 54 | 55 | @property 56 | def is_dequantized(self) -> bool: 57 | return self in (self.F32, self.BF16) 58 | 59 | @property 60 | def stride(self) -> int: 61 | return min(8, self.bit_size) >> 3 62 | 63 | 64 | class Context: 65 | def __init__(self, num_threads: Union[int, None] = None) -> None: 66 | """Initialize a quantization context with a given number of threads. If num_threads is None, the number of threads is set to the number of available CPUs minus 1.""" 67 | if num_threads is None: 68 | num_threads = max(multiprocessing.cpu_count() - 1, 1) 69 | self._num_threads = num_threads 70 | self._ctx = C.piquant_context_create(self._num_threads) 71 | self._finalizer = weakref.finalize(self, C.piquant_context_destroy, self._ctx) 72 | 73 | @staticmethod 74 | @lru_cache(maxsize=1) 75 | def get() -> Context: 76 | """ 77 | Default context for quantization. 78 | This is a singleton that is used to avoid creating multiple contexts. 79 | """ 80 | return Context() 81 | 82 | def quantize_ptr( 83 | self, 84 | ptr_in: int, 85 | dtype_in: DataType, 86 | ptr_out: int, 87 | dtype_out: DataType, 88 | numel: int, 89 | scale: float, 90 | zero_point: int, 91 | round_mode: RoundMode, 92 | ) -> None: 93 | assert dtype_in.is_dequantized, f'Input dtype must be a dequantized type, but is: {dtype_in}' 94 | assert dtype_out.is_quantized, f'Output dtype must be a quantized type, but is: {dtype_out}' 95 | assert ptr_in != 0, 'Input arr pointer must not be NULL' 96 | assert ptr_out != 0, 'Output arr pointer must not be NULL' 97 | ptr_in: ffi.CData = ffi.cast('const void*', ptr_in) 98 | ptr_out: ffi.CData = ffi.cast('void*', ptr_out) 99 | C.piquant_quantize( 100 | self._ctx, ptr_in, dtype_in.value, ptr_out, dtype_out.value, numel, scale, zero_point, round_mode.value 101 | ) 102 | 103 | def dequantize_ptr( 104 | self, 105 | ptr_in: int, 106 | dtype_in: DataType, 107 | ptr_out: int, 108 | dtype_out: DataType, 109 | numel: int, 110 | scale: float, 111 | zero_point: int, 112 | reduce_op: ReduceOp, 113 | ) -> None: 114 | assert dtype_in.is_quantized, f'Input dtype must be a quantized type, but is: {dtype_in}' 115 | assert dtype_out.is_dequantized, f'Output dtype must be a dequantized type, but is: {dtype_out}' 116 | assert ptr_in != 0, 'Input arr pointer must not be NULL' 117 | assert ptr_out != 0, 'Output arr pointer must not be NULL' 118 | ptr_in: ffi.CData = ffi.cast('const void*', ptr_in) 119 | ptr_out: ffi.CData = ffi.cast('void*', ptr_out) 120 | C.piquant_dequantize( 121 | self._ctx, ptr_in, dtype_in.value, ptr_out, dtype_out.value, numel, scale, zero_point, reduce_op.value 122 | ) 123 | 124 | def compute_quant_params_ptr_float32(self, ptr: int, target_quant_dtype: DataType, numel: int) -> Tuple[float, int]: 125 | assert target_quant_dtype.is_quantized, f'Target dtype must be a quantized type, but is: {target_quant_dtype}' 126 | assert ptr != 0, 'Input arr pointer must not be NULL' 127 | ptr: ffi.CData = ffi.cast('const float*', ptr) 128 | scale: ffi.CData = ffi.new('float*') 129 | zero_point: ffi.CData = ffi.new('int64_t*') 130 | C.piquant_compute_quant_params_float32(self._ctx, ptr, numel, target_quant_dtype.value, scale, zero_point) 131 | return scale[0], zero_point[0] 132 | 133 | def compute_quant_params_ptr_bfloat16( 134 | self, ptr: int, target_quant_dtype: DataType, numel: int 135 | ) -> Tuple[float, int]: 136 | assert target_quant_dtype.is_quantized, f'Target dtype must be a quantized type, but is: {target_quant_dtype}' 137 | assert ptr != 0, 'Input arr pointer must not be NULL' 138 | ptr: ffi.CData = ffi.cast('const uint16_t*', ptr) 139 | scale: ffi.CData = ffi.new('float*') 140 | zero_point: ffi.CData = ffi.new('int64_t*') 141 | C.piquant_compute_quant_params_bfloat16(self._ctx, ptr, numel, target_quant_dtype.value, scale, zero_point) 142 | return scale[0], zero_point[0] 143 | 144 | 145 | if importlib.util.find_spec('torch') is not None: 146 | try: 147 | from . import torch 148 | except ImportError: 149 | pass 150 | -------------------------------------------------------------------------------- /src/kernels/dequantize.inl: -------------------------------------------------------------------------------- 1 | // This inline file is directly included into the kernels.inl file, which is cloned (recompiled) in multiple compilation units for different CPU architectures. 2 | // ! Make sure all functions are static, to make them local to the compilation unit. 3 | 4 | #include "../piquant_internal.hpp" 5 | 6 | using namespace piquant; 7 | 8 | template requires is_quant_type && is_float_type 9 | [[nodiscard]] static auto dequant_step(fp32_t scale, std::int64_t zp, const In x) noexcept -> Out { 10 | return static_cast(static_cast(x) - zp)*scale; 11 | } 12 | 13 | template requires std::is_same_v && is_float_type 14 | static auto PIQUANT_HOT dequant_uint4( 15 | const In* x, 16 | Out* o, 17 | std::int64_t numel, 18 | fp32_t scale, 19 | std::int64_t zp 20 | ) noexcept -> void { 21 | std::int64_t i{}; 22 | for (std::int64_t j {}; i+1 < numel; i += 2, ++j) { 23 | auto p {x[j].bits}; 24 | auto qa {p & 15}; 25 | auto qb {p >> 4}; 26 | if constexpr (ReduceOp == reduce_op::set) { 27 | o[i] = dequant_step(scale, zp, qa); 28 | o[i+1] = dequant_step(scale, zp, qb); 29 | } else if constexpr (ReduceOp == reduce_op::add) { 30 | o[i] += dequant_step(scale, zp, qa); 31 | o[i+1] += dequant_step(scale, zp, qb); 32 | } 33 | } 34 | if (numel & 1) { 35 | auto qa {x[i>>1].bits & 15}; 36 | Out r = dequant_step(scale, zp, qa); 37 | if constexpr (ReduceOp == reduce_op::set) o[numel-1] = r; 38 | else if constexpr (ReduceOp == reduce_op::add) o[numel-1] += r; 39 | } 40 | } 41 | 42 | template requires std::is_same_v && is_float_type 43 | static auto PIQUANT_HOT dequant_uint2( 44 | const In* x, 45 | Out* o, 46 | std::int64_t numel, 47 | fp32_t scale, 48 | std::int64_t zp 49 | ) noexcept -> void { 50 | std::int64_t i {}; 51 | std::int64_t j {}; 52 | for (; i+3 < numel; i += 4, ++j) { 53 | auto p {x[j].bits}; 54 | auto qa {p & 3}; 55 | auto qb {p>>2 & 3}; 56 | auto qc {p>>4 & 3}; 57 | auto qd {p>>6 & 3}; 58 | if constexpr (ReduceOp == reduce_op::set) { 59 | o[i] = dequant_step(scale, zp, qa); 60 | o[i+1] = dequant_step(scale, zp, qb); 61 | o[i+2] = dequant_step(scale, zp, qc); 62 | o[i+3] = dequant_step(scale, zp, qd); 63 | } else if constexpr (ReduceOp == reduce_op::add) { 64 | o[i] += dequant_step(scale, zp, qa); 65 | o[i+1] += dequant_step(scale, zp, qb); 66 | o[i+2] += dequant_step(scale, zp, qc); 67 | o[i+3] += dequant_step(scale, zp, qd); 68 | } else { 69 | static_assert(ReduceOp == reduce_op::set || ReduceOp == reduce_op::add, "Invalid reduce operation"); 70 | } 71 | } 72 | auto p {x[i>>2].bits}; 73 | switch (numel&3) { 74 | case 1: 75 | o[i] = dequant_step(scale, zp, p&3); 76 | break; 77 | case 2: 78 | o[i] = dequant_step(scale, zp, p&3); 79 | o[i+1] = dequant_step(scale, zp, (p>>2)&3); 80 | break; 81 | case 3: 82 | o[i] = dequant_step(scale, zp, p&3); 83 | o[i+1] = dequant_step(scale, zp, (p>>2)&3); 84 | o[i+2] = dequant_step(scale, zp, (p>>4)&3); 85 | break; 86 | } 87 | } 88 | 89 | template requires is_quant_type && is_float_type 90 | static auto PIQUANT_HOT dequant_generic( 91 | const void* in, 92 | void* out, 93 | std::int64_t numel, 94 | fp32_t scale, 95 | std::int64_t zp 96 | ) noexcept -> void { 97 | const auto* PIQUANT_RESTRICT x {static_cast(in)}; 98 | auto* PIQUANT_RESTRICT o {static_cast(out)}; 99 | 100 | // Use SIMD optimized kernels for some dtype permutations 101 | if constexpr (std::is_same_v && std::is_same_v) { 102 | dequant_uint8_to_f32(static_cast(in), static_cast(out), numel, scale, static_cast(zp)); 103 | return; 104 | } 105 | if constexpr (std::is_same_v && std::is_same_v) { 106 | dequant_uint4_to_f32(static_cast(in), static_cast(out), numel, scale, static_cast(zp)); 107 | return; 108 | } 109 | if constexpr (std::is_same_v && std::is_same_v) { 110 | dequant_uint8_to_bf16(static_cast(in), static_cast(out), numel, scale, static_cast(zp)); 111 | return; 112 | } 113 | if constexpr (std::is_same_v && std::is_same_v) { 114 | dequant_uint4_to_bf16(static_cast(in), static_cast(out), numel, scale, static_cast(zp)); 115 | return; 116 | } 117 | if constexpr (std::is_same_v && std::is_same_v) { 118 | dequant_uint2_to_bf16(static_cast(in), static_cast(out), numel, scale, static_cast(zp)); 119 | return; 120 | } 121 | 122 | if constexpr (std::is_same_v) { // Special case for int4 123 | dequant_uint4(x, o, numel, scale, zp); 124 | return; 125 | } 126 | 127 | if constexpr (std::is_same_v) { // Special case for int2 128 | dequant_uint2(x, o, numel, scale, zp); 129 | return; 130 | } 131 | 132 | // Generic case for other quantized types 133 | if constexpr (ReduceOp == reduce_op::set) { 134 | for (std::int64_t i {}; i < numel; ++i) 135 | o[i] = dequant_step(scale, zp, x[i]); 136 | } else if constexpr (ReduceOp == reduce_op::add) { 137 | for (std::int64_t i {}; i < numel; ++i) 138 | o[i] += dequant_step(scale, zp, x[i]); 139 | } 140 | } -------------------------------------------------------------------------------- /src/kernels/quantize.inl: -------------------------------------------------------------------------------- 1 | // This inline file is directly included into the kernels.inl file, which is cloned (recompiled) in multiple compilation units for different CPU architectures. 2 | // ! Make sure all functions are static, to make them local to the compilation unit. 3 | 4 | #include "../piquant_internal.hpp" 5 | 6 | using namespace piquant; 7 | 8 | template requires is_float_type && is_quant_type 9 | [[nodiscard]] static auto PIQUANT_AINLINE quant_step_scalar_stochastic(In x, fp32_t rnd_threshold, fp32_t inv_scale, std::int64_t zp) noexcept -> Out { 10 | fp32_t rnd {static_cast(x) * inv_scale}; 11 | fp32_t dec {std::abs(rnd - std::trunc(rnd))}; 12 | fp32_t adj {rnd_threshold < dec ? 1.0f : 0.0f}; 13 | if (rnd < 0.0f) adj = -1.0f * adj; 14 | rnd = std::trunc(rnd) + adj; 15 | auto integral {static_cast(rnd) + zp}; 16 | const auto min = dtype_limits::min; 17 | const auto max = dtype_limits::max; 18 | return static_cast(std::clamp(integral, min, max)); 19 | } 20 | 21 | template requires is_float_type && is_quant_type 22 | [[nodiscard]] static auto PIQUANT_AINLINE quant_step_scalar_nearest(In x, fp32_t inv_scale, std::int64_t zp) noexcept -> Out { 23 | fp32_t rnd {std::round(static_cast(x) * inv_scale)}; 24 | auto integral {static_cast(rnd) + zp}; 25 | return static_cast(std::clamp(integral, dtype_limits::min, dtype_limits::max)); 26 | } 27 | 28 | template requires is_float_type && is_quant_type 29 | [[nodiscard]] static auto PIQUANT_AINLINE quant_step_scalar(In x, fp32_t rnd_threshold, fp32_t inv_scale, std::int64_t zp) noexcept -> Out { 30 | if constexpr (RoundMode == round_mode::stochastic) 31 | return quant_step_scalar_stochastic(x, rnd_threshold, inv_scale, zp); 32 | else 33 | return quant_step_scalar_nearest(x, inv_scale, zp); 34 | } 35 | 36 | template requires is_float_type && is_quant_type 37 | [[nodiscard]] static auto PIQUANT_AINLINE quant_step_packed(In a, In b, fp32_t rnd_threshold, fp32_t inv_scale, std::int64_t zp) noexcept -> Out { 38 | auto qa {quant_step_scalar(a, rnd_threshold, inv_scale, zp).bits}; 39 | auto qb {quant_step_scalar(b, rnd_threshold, inv_scale, zp).bits}; 40 | return qa & 15 | (qb & 15)<<4; 41 | } 42 | 43 | template requires is_float_type && is_quant_type 44 | [[nodiscard]] static auto PIQUANT_AINLINE quant_step_packed(In a, In b, In c, In d, fp32_t rnd_threshold, fp32_t inv_scale, std::int64_t zp) noexcept -> Out { 45 | auto qa {quant_step_scalar(a, rnd_threshold, inv_scale, zp).bits}; 46 | auto qb {quant_step_scalar(b, rnd_threshold, inv_scale, zp).bits}; 47 | auto qc {quant_step_scalar(c, rnd_threshold, inv_scale, zp).bits}; 48 | auto qd {quant_step_scalar(d, rnd_threshold, inv_scale, zp).bits}; 49 | return qa & 3 | (qb & 3)<<2 | (qc & 3)<<4 | (qd & 3)<<6; 50 | } 51 | 52 | template requires is_float_type && std::is_same_v 53 | static auto PIQUANT_HOT quant_uint4( 54 | const In* PIQUANT_RESTRICT x, 55 | Out* PIQUANT_RESTRICT o, 56 | std::int64_t numel, 57 | fp32_t rnd_threshold, 58 | fp32_t inv_scale, 59 | std::int64_t zp 60 | ) noexcept -> void { 61 | std::int64_t i {}; 62 | for (; i+1 < numel; i += 2) { 63 | In a {x[i]}; 64 | In b {x[i+1]}; 65 | o[i>>1] = quant_step_packed(a, b, rnd_threshold, inv_scale, zp); 66 | } 67 | if (numel & 1) { 68 | o[i>>1] = quant_step_packed(x[numel-1], 0, rnd_threshold, inv_scale, zp); 69 | o[i>>1].bits &= 15; 70 | } 71 | } 72 | 73 | template requires is_float_type && std::is_same_v 74 | static auto PIQUANT_HOT quant_uint2( 75 | const In* PIQUANT_RESTRICT x, 76 | Out* PIQUANT_RESTRICT o, 77 | std::int64_t numel, 78 | fp32_t rnd_threshold, 79 | fp32_t inv_scale, 80 | std::int64_t zp 81 | ) noexcept -> void { 82 | std::int64_t i {}; 83 | for (; i+3 < numel; i += 4) { 84 | In a {x[i]}; 85 | In b {x[i+1]}; 86 | In c {x[i+2]}; 87 | In d {x[i+3]}; 88 | o[i>>2] = quant_step_packed(a, b, c, d, rnd_threshold, inv_scale, zp); 89 | } 90 | if (numel & 3) { /* Handle 1-, 2- or 3-value tail */ 91 | typename Out::packed_storage p {}; 92 | switch (numel & 3) { 93 | case 3: p |= (quant_step_scalar(x[i+2], rnd_threshold, inv_scale, zp).bits&3) << 4; 94 | case 2: p |= (quant_step_scalar(x[i+1], rnd_threshold, inv_scale, zp).bits&3) << 2; 95 | case 1: p |= (quant_step_scalar(x[i], rnd_threshold, inv_scale, zp).bits&3); 96 | } 97 | o[i>>2] = p; 98 | } 99 | } 100 | 101 | template requires is_float_type && is_quant_type 102 | static auto PIQUANT_HOT quant_generic( 103 | const void* in, 104 | void* out, 105 | std::int64_t numel, 106 | fp32_t scale, 107 | fp32_t rnd_threshold, 108 | std::int64_t zp 109 | ) noexcept -> void { 110 | // Use SIMD optimized kernels for some dtype permutations 111 | if constexpr (std::is_same_v && std::is_same_v && RoundMode == round_mode::nearest) { 112 | quant_f32_to_uint8_nearest(static_cast(in), static_cast(out), numel, scale, zp); 113 | return; 114 | } 115 | if constexpr (std::is_same_v && std::is_same_v && RoundMode == round_mode::nearest) { 116 | quant_f32_to_uint4_nearest(static_cast(in), static_cast(out), numel, scale, zp); 117 | return; 118 | } 119 | if constexpr (std::is_same_v && std::is_same_v && RoundMode == round_mode::nearest) { 120 | quant_bf16_to_uint8_nearest(static_cast(in), static_cast(out), numel, scale, zp); 121 | return; 122 | } 123 | if constexpr (std::is_same_v && std::is_same_v && RoundMode == round_mode::nearest) { 124 | quant_bf16_to_uint4_nearest(static_cast(in), static_cast(out), numel, scale, zp); 125 | return; 126 | } 127 | if constexpr (std::is_same_v && std::is_same_v && RoundMode == round_mode::nearest) { 128 | quant_bf16_to_uint2_nearest(static_cast(in), static_cast(out), numel, scale, zp); 129 | return; 130 | } 131 | 132 | const auto* PIQUANT_RESTRICT x {static_cast(in)}; 133 | auto* PIQUANT_RESTRICT o {static_cast(out)}; 134 | fp32_t inv_scale {1.0f / scale}; // We multiply by reciprocal 135 | 136 | if constexpr (std::is_same_v) { // Special case for int4 137 | quant_uint4(x, o, numel, rnd_threshold, inv_scale, zp); 138 | return; 139 | } 140 | 141 | if constexpr (std::is_same_v) { // Special case for int2 142 | quant_uint2(x, o, numel, rnd_threshold, inv_scale, zp); 143 | return; 144 | } 145 | 146 | // Generic quantization for other dtypes 147 | for (std::int64_t i=0; i < numel; ++i) 148 | o[i] = quant_step_scalar(x[i], rnd_threshold, inv_scale, zp); 149 | } -------------------------------------------------------------------------------- /src/kernels/kernels.inl: -------------------------------------------------------------------------------- 1 | // This inline file is directly included into the kernels.inl file, which is cloned (recompiled) in multiple compilation units for different CPU architectures. 2 | // ! Make sure all functions are static, to make them local to the compilation unit. 3 | 4 | #ifndef QUANT_KERNEL_IMPL 5 | #error "Kernel impl is not defined" 6 | #endif 7 | #include 8 | #include 9 | #include "../piquant_internal.hpp" 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | namespace piquant { 17 | struct kernel_registry; 18 | } 19 | 20 | #define concat(a, b) a ## b 21 | #define impl_namespace(a, b) piquant::concat(a, _impl) 22 | 23 | namespace impl_namespace(QUANT_KERNEL_IMPL, _) { 24 | 25 | // Include order matters, implementations are cloned per specialized compilation unit 26 | #include "kernels_specialized.inl" 27 | #include "quantize.inl" 28 | #include "dequantize.inl" 29 | 30 | template 31 | static auto PIQUANT_HOT requant_generic( 32 | const void* in, 33 | void* out, 34 | std::int64_t numel, 35 | fp32_t scale, 36 | fp32_t rnd_threshold, 37 | std::int64_t zp 38 | ) noexcept -> void { 39 | const auto* PIQUANT_RESTRICT x {static_cast(in)}; 40 | auto* PIQUANT_RESTRICT o {static_cast(out)}; 41 | fp32_t inv_scale {1.0f / scale}; 42 | if constexpr (ReduceOp == reduce_op::set) { 43 | for (std::int64_t i {}; i < numel; ++i) 44 | o[i] = dequant_step(scale, zp, quant_step_scalar(x[i], rnd_threshold, inv_scale, zp)); 45 | return; 46 | } 47 | if constexpr (ReduceOp == reduce_op::add) { 48 | for (std::int64_t i {}; i < numel; ++i) 49 | o[i] += dequant_step(scale, zp, quant_step_scalar(x[i], rnd_threshold, inv_scale, zp)); 50 | return; 51 | } 52 | } 53 | }; 54 | 55 | namespace piquant { 56 | using quant_fn = auto (*)(const void*, void*, std::int64_t, fp32_t, fp32_t, std::int64_t) noexcept -> void; 57 | using dequant_fn = auto (*)(const void*, void*, std::int64_t, fp32_t, std::int64_t) noexcept -> void; 58 | 59 | template 60 | [[nodiscard]] consteval auto quant_entry() noexcept -> quant_fn { return &impl_namespace(QUANT_KERNEL_IMPL, _)::quant_generic; } 61 | 62 | template 63 | [[nodiscard]] consteval auto dequant_entry() noexcept -> dequant_fn { return &impl_namespace(QUANT_KERNEL_IMPL, _)::dequant_generic; } 64 | 65 | template 66 | [[nodiscard]] consteval auto requant_entry() noexcept -> quant_fn { return &impl_namespace(QUANT_KERNEL_IMPL, _)::requant_generic; } 67 | 68 | template struct type_set {}; 69 | 70 | template struct type_set_size; 71 | template struct type_set_size> : std::integral_constant {}; 72 | 73 | using quant_types = type_set; 74 | 75 | using fp_types = type_set; 76 | 77 | template struct make_quant_row; 78 | template 79 | struct make_quant_row> { 80 | static constexpr std::array::value + sizeof...(Dst)> value = 81 | {nullptr, nullptr, quant_entry()...}; 82 | }; 83 | 84 | template struct make_dequant_row; 85 | template 86 | struct make_dequant_row> { 87 | static constexpr std::array::value + sizeof...(Dst)> value = 88 | {nullptr, nullptr, dequant_entry()...}; 89 | }; 90 | 91 | template struct make_requant_row; 92 | template 93 | struct make_requant_row> { 94 | static constexpr std::array::value + sizeof...(Dst)> value = 95 | {nullptr, nullptr, requant_entry()...}; 96 | }; 97 | 98 | template struct make_requant_block; 99 | 100 | template 101 | struct make_requant_block> { 102 | static constexpr std::array::value + type_set_size::value>, sizeof...(Src)> value { 103 | make_requant_row::value... 104 | }; 105 | }; 106 | 107 | // 3D Dispatch table for quantization kernels. Order matters. 108 | static constexpr std::array quant_functions { 109 | std::array { 110 | std::array { 111 | make_quant_row::value, 112 | make_quant_row::value, 113 | }, 114 | }, 115 | std::array { 116 | std::array { 117 | make_quant_row::value, 118 | make_quant_row::value 119 | }, 120 | } 121 | }; 122 | 123 | // 3D Dispatch table for dequantization kernels. Order matters. 124 | static constexpr std::array dequant_functions { 125 | std::array { 126 | std::array { 127 | make_dequant_row::value, 128 | make_dequant_row::value 129 | }, 130 | }, 131 | std::array { 132 | std::array { 133 | make_dequant_row::value, 134 | make_dequant_row::value 135 | }, 136 | } 137 | }; 138 | 139 | // 4D Dispatch table for requantization kernels. Order matters. 140 | static constexpr std::array requant_functions { 141 | std::array{ 142 | make_requant_block::value, 143 | make_requant_block::value 144 | }, 145 | std::array{ 146 | make_requant_block::value, 147 | make_requant_block::value 148 | } 149 | }; 150 | 151 | static void dispatch_quantize(const void* in, void* out, std::int64_t range, const context::quant_descriptor& desc) { 152 | const dtype_info& dt_in {dtype_info_of(desc.dt_in)}; 153 | const dtype_info& dt_out {dtype_info_of(desc.dt_out)}; 154 | piquant_assert2(!(dt_in.flags & dtype_flags::is_quant)); 155 | piquant_assert2(dt_out.flags & dtype_flags::is_quant); 156 | const auto& stubs_round_mode {quant_functions[static_cast(desc.rounding)]}; 157 | const auto& stubs_dtype_fp {stubs_round_mode[static_cast(desc.dt_in)]}; 158 | auto* kernel {stubs_dtype_fp[static_cast(desc.dt_out)]}; 159 | piquant_assert(kernel != nullptr, "invalid quantization types: %s -> %s", dtype_info_of(desc.dt_in).name, dtype_info_of(desc.dt_out).name); 160 | (*kernel)(in, out, range, desc.scale, desc.rnd_threshold, desc.zero_point); 161 | } 162 | 163 | static void dispatch_dequantize(const void* in, void* out, std::int64_t range, const context::quant_descriptor& desc) { 164 | const dtype_info& dt_in {dtype_info_of(desc.dt_in)}; 165 | const dtype_info& dt_out {dtype_info_of(desc.dt_out)}; 166 | piquant_assert2(dt_in.flags & dtype_flags::is_quant); 167 | piquant_assert2(!(dt_out.flags & dtype_flags::is_quant)); 168 | const auto& stubs_reduce_mode {dequant_functions[static_cast(desc.reducing)]}; 169 | const auto& stubs_dtype_fp {stubs_reduce_mode[static_cast(desc.dt_out)]}; 170 | auto* kernel {stubs_dtype_fp[static_cast(desc.dt_in)]}; 171 | piquant_assert(kernel != nullptr, "invalid dequantization types: %s -> %s", dtype_info_of(desc.dt_in).name, dtype_info_of(desc.dt_out).name); 172 | (*kernel)(in, out, range, desc.scale, desc.zero_point); 173 | } 174 | 175 | static void dispatch_requantize(const void* in, void* out, std::int64_t range, const context::quant_descriptor& desc) { 176 | using enum dtype; 177 | const dtype_info& dt_in {dtype_info_of(desc.dt_in)}; 178 | const dtype_info& dt_out {dtype_info_of(desc.dt_out)}; 179 | piquant_assert2(!(dt_in.flags & dtype_flags::is_quant)); 180 | piquant_assert2(dt_out.flags & dtype_flags::is_quant); 181 | const auto& stubs_round_mode {requant_functions[static_cast(desc.rounding)]}; 182 | const auto& stubs_reduce_op {stubs_round_mode[static_cast(desc.reducing)]}; 183 | const auto& stubs_fp {stubs_reduce_op[static_cast(desc.dt_in)]}; 184 | auto* kernel {stubs_fp[static_cast(desc.dt_out)]}; 185 | piquant_assert(kernel != nullptr, "invalid requantization types: %s -> %s", dtype_info_of(desc.dt_in).name, dtype_info_of(desc.dt_out).name); 186 | (*kernel)(in, out, range, desc.scale, desc.rnd_threshold, desc.zero_point); 187 | } 188 | 189 | static auto PIQUANT_HOT quantize_dispatch(const void* in, void* out, std::int64_t range, const context::quant_descriptor& desc) noexcept -> void { 190 | switch (desc.type) { 191 | case context::command_type::quant: dispatch_quantize(in, out, range, desc); return; 192 | case context::command_type::dequant: dispatch_dequantize(in, out, range, desc); return; 193 | case context::command_type::quant_dequant: dispatch_requantize(in, out, range, desc); return; 194 | default: panic("invalid quantization command type: %d", static_cast(desc.type)); 195 | } 196 | } 197 | 198 | auto QUANT_KERNEL_IMPL() noexcept -> kernel_registry { 199 | return kernel_registry { 200 | .quant_kernel = &quantize_dispatch, 201 | .find_min_max_float32 = &impl_namespace(QUANT_KERNEL_IMPL, _)::find_min_max_f32, 202 | .find_min_max_bfloat16 = &impl_namespace(QUANT_KERNEL_IMPL, _)::find_min_max_bf16, 203 | }; 204 | } 205 | } 206 | -------------------------------------------------------------------------------- /test/quant.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | #include "naive.hpp" 14 | 15 | constexpr std::size_t iters {10}; 16 | constexpr std::int32_t stochastic_epsilon {1}; 17 | 18 | using namespace piquant; 19 | 20 | template 21 | [[nodiscard]] constexpr auto unpack_nibble(T val, bool is_signed) noexcept -> std::int32_t { 22 | const uint8_t raw = (static_cast(val.bits) >> (IDX<<2)) & 0xF; 23 | if (!is_signed) { 24 | return raw; 25 | } 26 | return (raw & 0x8) ? static_cast(raw) - 16 : static_cast(raw); 27 | } 28 | 29 | #define test_quant(ti, to, rnd) \ 30 | TEST(quantize, quantize_##ti##_to_##to##_##rnd) { \ 31 | std::mt19937 gen {0x9032002}; \ 32 | std::uniform_real_distribution dist {-1.0, 1.0}; \ 33 | \ 34 | for (std::size_t n {}; n < iters; ++n) { \ 35 | fp32_t scale {std::uniform_real_distribution{0.1, 1.0}(gen)}; \ 36 | std::int32_t zero_point {std::is_same_v ? std::uniform_int_distribution{-8, 7}(gen) : \ 37 | std::uniform_int_distribution{-128, 127}(gen)}; \ 38 | std::size_t numel {std::uniform_int_distribution{5000, 1'5000}(gen)}; \ 39 | std::size_t numel_out {std::is_same_v ? (numel+1)>>1 : numel}; \ 40 | \ 41 | std::vector data_in {}; \ 42 | std::vector data_out_naive {}; \ 43 | std::vector data_out {}; \ 44 | data_in.resize(numel); \ 45 | data_out.resize(numel_out); \ 46 | data_out_naive.resize(numel_out); \ 47 | std::ranges::generate(data_in, [&] { return dist(gen); }); \ 48 | quantize_naive(data_in, data_out_naive, scale, zero_point); \ 49 | piquant::context ctx {std::max(1u, 4u)}; \ 50 | ctx.quantize_generic(data_in, data_out, scale, zero_point, piquant::round_mode::rnd); \ 51 | for (std::size_t i {}; i < numel_out; ++i) { \ 52 | bool eq {eq = data_out[i] == data_out_naive[i]}; \ 53 | eq |= std::abs(static_cast(data_out[i]) - static_cast(data_out_naive[i])) <= stochastic_epsilon; \ 54 | if (!eq) { \ 55 | std::cout << "Mismatch at index " << i << ": " << static_cast(data_out[i]) << " != " << static_cast(data_out_naive[i]) << std::endl; \ 56 | std::cout << "Input: " << static_cast(data_in[i]) << std::endl; \ 57 | ASSERT_TRUE(eq); \ 58 | } \ 59 | } \ 60 | } \ 61 | } 62 | 63 | #define test_quant_int4(ti, to, rnd, is_stochastic, is_signed) \ 64 | TEST(quantize, quantize_##ti##_to_##to##_##rnd) { \ 65 | std::mt19937 gen {0x9032002}; \ 66 | std::uniform_real_distribution dist {-1.0, 1.0}; \ 67 | \ 68 | for (std::size_t n {}; n < iters; ++n) { \ 69 | std::cout << "Iteration " << n << std::endl; \ 70 | fp32_t scale {std::uniform_real_distribution{0.1, 1.0}(gen)}; \ 71 | std::size_t numel {std::uniform_int_distribution{5000, 1'5000}(gen)}; \ 72 | std::size_t numel_out {std::is_same_v ? (numel+1)>>1 : numel}; \ 73 | std::int32_t zero_point {std::is_same_v ? std::uniform_int_distribution{-8, 7}(gen) : \ 74 | std::uniform_int_distribution{-128, 127}(gen)}; \ 75 | \ 76 | std::vector data_in {}; \ 77 | std::vector data_out_naive {}; \ 78 | std::vector data_out {}; \ 79 | data_in.resize(numel); \ 80 | data_out.resize(numel_out); \ 81 | data_out_naive.resize(numel_out); \ 82 | std::ranges::generate(data_in, [&] { return dist(gen); }); \ 83 | quantize_naive(data_in, data_out_naive, scale, zero_point); \ 84 | piquant::context ctx {std::max(1u, std::thread::hardware_concurrency())}; \ 85 | ctx.quantize_generic(data_in, data_out, scale, zero_point, piquant::round_mode::rnd); \ 86 | for (std::size_t i {}; i < numel_out; ++i) { \ 87 | bool eq {eq = data_out[i] == data_out_naive[i]}; \ 88 | std::int32_t a {unpack_nibble<0>(data_out[i], is_signed)}; \ 89 | std::int32_t b {unpack_nibble<1>(data_out[i], is_signed)}; \ 90 | std::int32_t a_naive {unpack_nibble<0>(data_out_naive[i], is_signed)}; \ 91 | std::int32_t b_naive {unpack_nibble<1>(data_out_naive[i], is_signed)}; \ 92 | if (is_stochastic) { \ 93 | eq |= std::abs(a - a_naive) <= stochastic_epsilon; \ 94 | eq |= std::abs(b - b_naive) <= stochastic_epsilon; \ 95 | } else { \ 96 | eq = eq && (a == a_naive) && (b == b_naive); \ 97 | } \ 98 | if (!eq) { \ 99 | std::cout << "Mismatch at index " << i << ": " << "(" << a << ", " << b << ") != (" << a_naive << ", " << b_naive << ") -> " << (std::int32_t)(data_out[i].bits) << " != " << (std::int32_t)(data_out_naive[i].bits) << std::endl; \ 100 | std::cout << "Input: " << static_cast(data_in[i]) << std::endl; \ 101 | std::cout << "Data in: ["; \ 102 | for (std::size_t j {}; j < numel; ++j) { \ 103 | std::cout << static_cast(data_in[j]) << " "; \ 104 | } \ 105 | std::cout << "]" << std::endl; \ 106 | std::cout << "Data out (f): ["; \ 107 | for (std::size_t j {}; j < numel_out; ++j) { \ 108 | std::cout << static_cast(data_out[j].bits) << " "; \ 109 | } \ 110 | std::cout << "]" << std::endl; \ 111 | std::cout << "Data out (n): ["; \ 112 | for (std::size_t j {}; j < numel_out; ++j) { \ 113 | std::cout << static_cast(data_out_naive[j].bits) << " "; \ 114 | } \ 115 | std::cout << "]" << std::endl; \ 116 | std::cout << "Num el out: " << numel_out << std::endl; \ 117 | std::cout << "Num el in: " << numel << std::endl; \ 118 | ASSERT_TRUE(eq); \ 119 | } \ 120 | } \ 121 | } \ 122 | } 123 | 124 | #define test_quant_int2(ti, to, rnd, is_stochastic, is_signed) \ 125 | TEST(quantize, quantize_##ti##_to_##to##_##rnd) { \ 126 | std::mt19937 gen {0x9032002}; \ 127 | std::uniform_real_distribution dist {-1.0, 1.0}; \ 128 | \ 129 | for (std::size_t n {}; n < iters; ++n) { \ 130 | std::cout << "Iteration " << n << std::endl; \ 131 | fp32_t scale {std::uniform_real_distribution{0.1, 1.0}(gen)}; \ 132 | std::size_t numel {std::uniform_int_distribution{5000, 1'5000}(gen)}; \ 133 | std::size_t numel_out {std::is_same_v ? (numel+3)>>2 : numel}; \ 134 | std::int32_t zero_point {std::is_same_v ? std::uniform_int_distribution{-1, 2}(gen) : \ 135 | std::uniform_int_distribution{-128, 127}(gen)}; \ 136 | \ 137 | std::vector data_in {}; \ 138 | std::vector data_out_naive {}; \ 139 | std::vector data_out {}; \ 140 | data_in.resize(numel); \ 141 | data_out.resize(numel_out); \ 142 | data_out_naive.resize(numel_out); \ 143 | std::ranges::generate(data_in, [&] { return dist(gen); }); \ 144 | quantize_naive(data_in, data_out_naive, scale, zero_point); \ 145 | piquant::context ctx {std::max(1u, std::thread::hardware_concurrency())}; \ 146 | ctx.quantize_generic(data_in, data_out, scale, zero_point, piquant::round_mode::rnd); \ 147 | for (std::size_t i {}; i < numel_out; ++i) { \ 148 | bool eq {eq = data_out[i] == data_out_naive[i]}; \ 149 | std::int32_t a {unpack_nibble<0>(data_out[i], is_signed)}; \ 150 | std::int32_t b {unpack_nibble<1>(data_out[i], is_signed)}; \ 151 | std::int32_t a_naive {unpack_nibble<0>(data_out_naive[i], is_signed)}; \ 152 | std::int32_t b_naive {unpack_nibble<1>(data_out_naive[i], is_signed)}; \ 153 | if (is_stochastic) { \ 154 | eq |= std::abs(a - a_naive) <= stochastic_epsilon; \ 155 | eq |= std::abs(b - b_naive) <= stochastic_epsilon; \ 156 | } else { \ 157 | eq = eq && (a == a_naive) && (b == b_naive); \ 158 | } \ 159 | if (!eq) { \ 160 | std::cout << "Mismatch at index " << i << ": " << "(" << a << ", " << b << ") != (" << a_naive << ", " << b_naive << ") -> " << (std::int32_t)(data_out[i].bits) << " != " << (std::int32_t)(data_out_naive[i].bits) << std::endl; \ 161 | std::cout << "Input: " << static_cast(data_in[i]) << std::endl; \ 162 | std::cout << "Data in: ["; \ 163 | for (std::size_t j {}; j < numel; ++j) { \ 164 | std::cout << static_cast(data_in[j]) << " "; \ 165 | } \ 166 | std::cout << "]" << std::endl; \ 167 | std::cout << "Data out (f): ["; \ 168 | for (std::size_t j {}; j < numel_out; ++j) { \ 169 | std::cout << static_cast(data_out[j].bits) << " "; \ 170 | } \ 171 | std::cout << "]" << std::endl; \ 172 | std::cout << "Data out (n): ["; \ 173 | for (std::size_t j {}; j < numel_out; ++j) { \ 174 | std::cout << static_cast(data_out_naive[j].bits) << " "; \ 175 | } \ 176 | std::cout << "]" << std::endl; \ 177 | std::cout << "Num el out: " << numel_out << std::endl; \ 178 | std::cout << "Num el in: " << numel << std::endl; \ 179 | ASSERT_TRUE(eq); \ 180 | } \ 181 | } \ 182 | } \ 183 | } 184 | 185 | //test_quant_int2(fp32_t, uint2_t, nearest, false, false) 186 | //test_quant_int2(fp32_t, uint2_t, stochastic, true, false) 187 | //test_quant_int2(bfp16_t, uint2_t, nearest, false, false) 188 | //test_quant_int2(bfp16_t, uint2_t, stochastic, true, false) 189 | test_quant_int4(fp32_t, uint4_t, nearest, false, false) 190 | test_quant_int4(fp32_t, uint4_t, stochastic, true, false) 191 | test_quant_int4(bfp16_t, uint4_t, nearest, false, false) 192 | test_quant_int4(bfp16_t, uint4_t, stochastic, true, false) 193 | test_quant(fp32_t, uint8_t, nearest) 194 | test_quant(fp32_t, uint8_t, stochastic) 195 | test_quant(bfp16_t, uint8_t, nearest) 196 | test_quant(bfp16_t, uint8_t, stochastic) 197 | 198 | TEST(quantize, requantize_float_to_uint8_identity_data) { 199 | std::random_device rd {}; 200 | std::mt19937 gen {rd()}; 201 | std::size_t numel {std::uniform_int_distribution{5000, 1'5000}(gen)}; 202 | std::size_t numel_out {numel}; 203 | std::vector data_in {}; 204 | std::vector quantized {}; 205 | data_in.resize(numel); 206 | quantized.resize(numel_out); 207 | std::ranges::fill(data_in, 42.0f); 208 | context ctx {std::max(1u, std::thread::hardware_concurrency())}; 209 | auto [scale, zero_point] {ctx.compute_quant_config_from_data(data_in, dtype_traits::type_code)}; 210 | ctx.quantize_generic(data_in, quantized, scale, zero_point, round_mode::nearest); 211 | std::vector dequantized {}; 212 | dequantized.resize(numel); 213 | ctx.dequantize_generic(quantized, dequantized, scale, zero_point, reduce_op::add); 214 | for (std::size_t i {}; i < numel; ++i) { 215 | ASSERT_NEAR(data_in[i], dequantized[i], 1e-6f); 216 | } 217 | } 218 | -------------------------------------------------------------------------------- /include/piquant.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #ifdef _MSC_VER 11 | #ifdef QUANT_BUILD_SHARED 12 | #define QUANT_EXPORT __declspec(dllexport) 13 | #else 14 | #define QUANT_EXPORT 15 | #endif 16 | #else 17 | #define QUANT_EXPORT __attribute__((visibility("default"))) 18 | #endif 19 | 20 | namespace piquant { 21 | enum class round_mode { 22 | nearest, 23 | stochastic, 24 | 25 | count_ 26 | }; 27 | 28 | enum class reduce_op { 29 | set, // output[i] = dequantize(input[i]) 30 | add, // output[i] += qdeuantize(input[i]) 31 | 32 | count_ 33 | }; 34 | 35 | // All supported data types for quantization and dequantization. Order matters. 36 | enum class dtype { 37 | f32 = 0, 38 | bf16, 39 | 40 | uint2, // 2-bit unsigned int 41 | uint4, // 4-bit unsigned int 42 | uint8, // 8-bit unsigned int (uint8_t) 43 | 44 | count_ 45 | }; 46 | static_assert(static_cast>(dtype::count_) <= 0xff); 47 | static_assert(static_cast>(dtype::f32) == 0); 48 | static_assert(static_cast>(dtype::bf16) == 1); 49 | 50 | struct uint2_t final { 51 | using packed_storage = std::uint8_t; 52 | packed_storage bits; 53 | 54 | constexpr uint2_t() noexcept : bits{} {} 55 | constexpr uint2_t(int u8) noexcept : bits{static_cast(u8)} {} 56 | constexpr auto operator == (uint2_t rhs) const noexcept -> bool { return bits == rhs.bits; } 57 | constexpr auto operator != (uint2_t rhs) const noexcept -> bool { return !(*this == rhs); } 58 | constexpr auto operator == (packed_storage rhs) const noexcept -> bool { return bits == rhs; } 59 | constexpr auto operator != (packed_storage rhs) const noexcept -> bool { return !(*this == rhs); } 60 | constexpr explicit operator std::uint8_t() const noexcept { return bits; } 61 | constexpr explicit operator std::int64_t() const noexcept { return bits; } 62 | }; 63 | 64 | struct uint4_t final { 65 | using packed_storage = std::uint8_t; 66 | packed_storage bits; 67 | 68 | constexpr uint4_t() noexcept : bits {} {} 69 | constexpr uint4_t(int u8) noexcept : bits {static_cast(u8)} {} 70 | constexpr auto operator == (uint4_t rhs) const noexcept -> bool { return bits == rhs.bits; } 71 | constexpr auto operator != (uint4_t rhs) const noexcept -> bool { return !(*this == rhs); } 72 | constexpr auto operator == (packed_storage rhs) const noexcept -> bool { return bits == rhs; } 73 | constexpr auto operator != (packed_storage rhs) const noexcept -> bool { return !(*this == rhs); } 74 | constexpr explicit operator std::uint8_t() const noexcept { return bits; } 75 | constexpr explicit operator std::int64_t() const noexcept { return bits; } 76 | }; 77 | 78 | using fp32_t = float; // IEEE 754 binary 32 79 | 80 | // Google Brain Float 16 81 | struct bfp16_t final { 82 | using packed_storage = std::uint16_t; 83 | packed_storage bits; 84 | 85 | constexpr bfp16_t() noexcept : bits {} {} 86 | constexpr bfp16_t(fp32_t s) noexcept { 87 | auto u32 {std::bit_cast(s)}; 88 | if ((u32 & 0x7fffffff) > 0x7f800000) bits = u32>>16|64; // Force quiet NaN 89 | else bits = (u32 + (0x7fff + ((u32>>16)&1)))>>16; 90 | } 91 | constexpr auto operator == (bfp16_t rhs) const noexcept -> bool { return bits == rhs.bits; } 92 | constexpr auto operator != (bfp16_t rhs) const noexcept -> bool { return !(*this == rhs); } 93 | constexpr auto operator == (packed_storage rhs) const noexcept -> bool { return bits == rhs; } 94 | constexpr auto operator != (packed_storage rhs) const noexcept -> bool { return !(*this == rhs); } 95 | constexpr explicit operator fp32_t() const noexcept { return std::bit_cast(static_cast(bits)<<16); } 96 | 97 | constexpr auto operator + (bfp16_t rhs) const noexcept -> bfp16_t { 98 | return {static_cast(*this) + static_cast(rhs)}; 99 | } 100 | constexpr auto operator += (bfp16_t rhs) noexcept -> bfp16_t& { 101 | *this = *this + rhs; 102 | return *this; 103 | } 104 | constexpr auto operator - (bfp16_t rhs) const noexcept -> bfp16_t { 105 | return {static_cast(*this) - static_cast(rhs)}; 106 | } 107 | constexpr auto operator -= (bfp16_t rhs) noexcept -> bfp16_t& { 108 | *this = *this - rhs; 109 | return *this; 110 | } 111 | constexpr auto operator * (bfp16_t rhs) const noexcept -> bfp16_t { 112 | return {static_cast(*this) * static_cast(rhs)}; 113 | } 114 | constexpr auto operator *= (bfp16_t rhs) noexcept -> bfp16_t& { 115 | *this = *this * rhs; 116 | return *this; 117 | } 118 | constexpr auto operator / (bfp16_t rhs) const noexcept -> bfp16_t { 119 | return {static_cast(*this) / static_cast(rhs)}; 120 | } 121 | constexpr auto operator /= (bfp16_t rhs) noexcept -> bfp16_t& { 122 | *this = *this / rhs; 123 | return *this; 124 | } 125 | }; 126 | 127 | static_assert(sizeof(uint2_t) == 1); 128 | static_assert(sizeof(uint4_t) == 1); 129 | static_assert(sizeof(bfp16_t) == 2); 130 | 131 | struct dtype_flags final { 132 | enum $ { 133 | none = 0, 134 | is_quant = 1<<0, 135 | is_float = 1<<1, 136 | is_int = 1<<2, 137 | is_signed = 1<<3, 138 | is_packed = 1<<4, 139 | }; 140 | }; 141 | 142 | struct dtype_info final { 143 | std::string_view name; 144 | std::size_t stride; 145 | std::size_t bit_size; 146 | std::underlying_type_t flags; 147 | }; 148 | 149 | constexpr std::array dtype_infos { 150 | dtype_info{.name="f32", .stride=sizeof(fp32_t), .bit_size=8*sizeof(fp32_t), .flags=dtype_flags::is_float|dtype_flags::is_signed}, // f32 151 | dtype_info{.name="bf16", .stride=sizeof(bfp16_t), .bit_size=8*sizeof(bfp16_t), .flags=dtype_flags::is_float|dtype_flags::is_signed}, 152 | dtype_info{.name="uint2", .stride=sizeof(std::uint8_t), .bit_size=2, .flags=dtype_flags::is_quant|dtype_flags::is_int|dtype_flags::is_packed}, // uint2 153 | dtype_info{.name="uint4", .stride=sizeof(std::uint8_t), .bit_size=4, .flags=dtype_flags::is_quant|dtype_flags::is_int|dtype_flags::is_packed}, // uint4 154 | dtype_info{.name="uint8", .stride=sizeof(std::uint8_t), .bit_size=8, .flags=dtype_flags::is_quant|dtype_flags::is_int}, // uint8 155 | }; 156 | static_assert([]() -> bool { 157 | for (auto&& info : dtype_infos) { 158 | if (!info.bit_size || info.bit_size & (info.bit_size-1)) return false; // bit_size must be a power of two 159 | if (!((info.flags & dtype_flags::is_float) ^ (info.flags & dtype_flags::is_int))) return false; // Either is_fp32_t or is_int must be set, but not both 160 | } 161 | return true; 162 | }()); 163 | [[nodiscard]] constexpr auto dtype_info_of(dtype dtype) noexcept -> const dtype_info& { return dtype_infos[static_cast(dtype)]; } 164 | 165 | template struct dtype_limits final {}; 166 | 167 | template<> struct dtype_limits final { 168 | static constexpr fp32_t min {-std::numeric_limits::max()}; // Referes to the smallest, normal, finite number, so it's like std::numeric_limits::lowest() 169 | static constexpr fp32_t max {std::numeric_limits::max()}; 170 | }; 171 | template<> struct dtype_limits final { 172 | static constexpr bfp16_t min {0xFF7F}; // Referes to the smallest, normal, finite number, so it's like std::numeric_limits::lowest() 173 | static constexpr bfp16_t max {0x7F7F}; 174 | }; 175 | template<> struct dtype_limits final { 176 | static constexpr std::uint8_t min {0}; 177 | static constexpr std::uint8_t max {3}; 178 | }; 179 | template<> struct dtype_limits final { 180 | static constexpr std::uint8_t min {0}; 181 | static constexpr std::uint8_t max {15}; 182 | }; 183 | template<> struct dtype_limits final { 184 | static constexpr std::uint8_t min {0}; 185 | static constexpr std::uint8_t max {255}; 186 | }; 187 | 188 | template concept is_float_type = std::is_floating_point_v || std::is_same_v; 189 | template concept is_quant_type = std::is_integral_v || std::is_same_v || std::is_same_v;; 190 | template concept is_dtype = is_float_type || is_quant_type; 191 | template requires is_dtype struct dtype_traits final {}; 192 | 193 | template <> struct dtype_traits { static constexpr dtype type_code {dtype::f32}; }; 194 | template <> struct dtype_traits { static constexpr dtype type_code {dtype::bf16}; }; 195 | template <> struct dtype_traits { static constexpr dtype type_code {dtype::uint2}; }; 196 | template <> struct dtype_traits { static constexpr dtype type_code {dtype::uint4}; }; 197 | template <> struct dtype_traits { static constexpr dtype type_code {dtype::uint8}; }; 198 | 199 | class QUANT_EXPORT context final { 200 | public: 201 | explicit context(std::size_t num_threads); 202 | context(const context&) = delete; 203 | context(context&&) = delete; 204 | auto operator=(const context&) -> context& = delete; 205 | auto operator=(context&&) -> context& = delete; 206 | ~context(); 207 | 208 | auto quantize( 209 | std::span in, 210 | dtype dtype_in, 211 | std::span out, 212 | dtype dtype_out, 213 | fp32_t scale, 214 | std::int64_t zero_point, 215 | round_mode mode 216 | ) const -> void; 217 | 218 | template requires requires { 219 | requires is_dtype; 220 | requires is_dtype; 221 | dtype_info_of(dtype_traits::type_code).flags & dtype_flags::is_quant; 222 | !(dtype_info_of(dtype_traits::type_code).flags & dtype_flags::is_quant); 223 | } 224 | auto quantize_generic( 225 | std::span in, 226 | std::span out, 227 | fp32_t scale, 228 | std::int64_t zero_point, 229 | round_mode mode 230 | ) -> void { 231 | quantize( 232 | {reinterpret_cast(in.data()), in.size_bytes()}, 233 | dtype_traits::type_code, 234 | {reinterpret_cast(out.data()), out.size_bytes()}, 235 | dtype_traits::type_code, 236 | scale, 237 | zero_point, 238 | mode 239 | ); 240 | } 241 | 242 | auto dequantize( 243 | std::span in, 244 | dtype dtype_in, 245 | std::span out, 246 | dtype dtype_out, 247 | fp32_t scale, 248 | std::int64_t zero_point, 249 | reduce_op op 250 | ) const -> void; 251 | 252 | template requires requires { 253 | requires is_dtype; 254 | requires is_dtype; 255 | !(dtype_info_of(dtype_traits::type_code).flags & dtype_flags::is_quant); 256 | dtype_info_of(dtype_traits::type_code).flags & dtype_flags::is_quant; 257 | } 258 | auto dequantize_generic( 259 | std::span in, 260 | std::span out, 261 | fp32_t scale, 262 | std::int64_t zero_point, 263 | reduce_op op 264 | ) -> void { 265 | dequantize( 266 | {reinterpret_cast(in.data()), in.size_bytes()}, 267 | dtype_traits::type_code, 268 | {reinterpret_cast(out.data()), out.size_bytes()}, 269 | dtype_traits::type_code, 270 | scale, 271 | zero_point, 272 | op 273 | ); 274 | } 275 | 276 | auto quantize_dequantize_fused( 277 | std::span in, 278 | dtype dtype_in_out, 279 | std::span out, 280 | dtype quant_type, 281 | fp32_t scale, 282 | std::int64_t zero_point, 283 | round_mode mode, 284 | reduce_op op 285 | ) const -> void; 286 | 287 | template requires requires { 288 | requires is_dtype; 289 | !(dtype_info_of(dtype_traits::type_code).flags & dtype_flags::is_quant); 290 | dtype_info_of(dtype_traits::type_code).flags & dtype_flags::is_quant; 291 | } 292 | auto quantize_dequantize_fused_generic( 293 | std::span in, 294 | std::span out, 295 | fp32_t scale, 296 | std::int64_t zero_point, 297 | round_mode mode, 298 | reduce_op op 299 | ) -> void { 300 | quantize_dequantize_fused( 301 | {reinterpret_cast(in.data()), in.size_bytes()}, 302 | dtype_traits::type_code, 303 | {reinterpret_cast(out.data()), out.size_bytes()}, 304 | dtype_traits::type_code, 305 | scale, 306 | zero_point, 307 | mode, 308 | op 309 | ); 310 | } 311 | 312 | [[nodiscard]] auto compute_quant_config_from_data(std::span x, dtype quant_dst_dtype) const -> std::pair; 313 | [[nodiscard]] auto compute_quant_config_from_data(std::span x, dtype quant_dst_dtype) const -> std::pair; 314 | 315 | class pimpl; 316 | 317 | enum class command_type { 318 | quant, // out[i] = quantize(in[i]) 319 | dequant, // out[i] = dequantize(in[i]) 320 | quant_dequant // out[i] = dequantize(quantize(in[i]))) 321 | }; 322 | 323 | struct quant_descriptor final { 324 | command_type type {command_type::quant}; 325 | const std::byte* in {}; 326 | std::byte* out {}; 327 | std::int64_t numel {}; 328 | fp32_t scale {}; 329 | std::int64_t zero_point {}; 330 | dtype dt_in {}; 331 | dtype dt_out {}; 332 | round_mode rounding {}; 333 | reduce_op reducing {}; 334 | fp32_t rnd_threshold {}; 335 | }; 336 | 337 | private: 338 | std::shared_ptr m_pimpl; 339 | }; 340 | } 341 | -------------------------------------------------------------------------------- /src/piquant.cpp: -------------------------------------------------------------------------------- 1 | #include "piquant.hpp" 2 | #include "piquant_internal.hpp" 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | namespace piquant { 19 | #define decl_quant_kernel_installer_fn(impl) \ 20 | [[nodiscard]] extern auto impl() noexcept -> kernel_registry 21 | 22 | decl_quant_kernel_installer_fn(install_quant_generic); 23 | 24 | #ifdef __x86_64__ 25 | #include 26 | 27 | [[nodiscard]] static auto check_sse42_support() noexcept -> bool { 28 | int info[4] = {-1}; 29 | __cpuid(0, info[0], info[1], info[2], info[3]); 30 | if (info[0] < 1) return false; 31 | __cpuid(1, info[0], info[1], info[2], info[3]); 32 | return (info[2] & (1<<20)) != 0; 33 | } 34 | 35 | [[nodiscard]] static auto check_avx2_support() noexcept -> bool { 36 | int info[4] = {-1}; 37 | __cpuid(0, info[0], info[1], info[2], info[3]); 38 | if (info[0] < 7) return false; 39 | __cpuid(1, info[0], info[1], info[2], info[3]); 40 | if ((info[2] & 0x38081001) != 0x38081001) return false; 41 | __cpuid_count(7, 0, info[0], info[1], info[2], info[3]); 42 | if ((info[1] & 0x20) != 0x20) return false; 43 | std::uint32_t lo, hi; 44 | asm volatile("xgetbv\n\t" : "=a" (lo), "=d" (hi) : "c" (0)); 45 | return ((static_cast(lo)|(static_cast(hi) << 32)) & 6) == 6; 46 | } 47 | 48 | [[nodiscard]] static auto check_avx512f_support() noexcept -> bool { 49 | int info[4] = {-1}; 50 | __cpuid(0, info[0], info[1], info[2], info[3]); 51 | if (info[0] < 7) return false; 52 | __cpuid(1, info[0], info[1], info[2], info[3]); 53 | if ((info[2] & 0x8000000) == 0 || (info[2] & 0x10000000) == 0) return false; 54 | __cpuid_count(7, 0, info[0], info[1], info[2], info[3]); 55 | if ((info[1] & 0x10000) == 0) return false; 56 | std::uint32_t lo, hi; 57 | asm volatile("xgetbv\n\t" : "=a" (lo), "=d" (hi) : "c" (0)); 58 | return ((static_cast(lo)|(static_cast(hi)<<32))&0xe0) == 0xe0; 59 | } 60 | 61 | [[nodiscard]] static auto check_avx512f_bf16_support() noexcept -> bool { 62 | int info[4] = {-1}; 63 | __cpuid(0, info[0], info[1], info[2], info[3]); 64 | if (info[0] < 7) return false; 65 | __cpuid(1, info[0], info[1], info[2], info[3]); 66 | if ((info[2] & (1<<27|1<<28)) != (1<<27|1<<28)) return false; 67 | __cpuid_count(7, 0, info[0], info[1], info[2], info[3]); 68 | if (!(info[1] & 1<<16)) return false; 69 | __cpuid_count(7, 1, info[0], info[1], info[2], info[3]); 70 | if (!(info[0] & 1<<5)) return false; 71 | std::uint32_t lo, hi; 72 | asm volatile("xgetbv" : "=a"(lo), "=d"(hi) : "c"(0)); 73 | return ((static_cast(hi)<<32|lo)&0xe0) == 0xe0; 74 | } 75 | 76 | decl_quant_kernel_installer_fn(install_quant_amd64_sse42); 77 | decl_quant_kernel_installer_fn(install_quant_amd64_avx2); 78 | decl_quant_kernel_installer_fn(install_quant_amd64_avx512f); 79 | decl_quant_kernel_installer_fn(install_quant_amd64_avx512f_bf16); 80 | 81 | #endif 82 | 83 | #undef decl_kernel_pair 84 | 85 | template 86 | struct overloads final : T... { using T::operator()...; }; 87 | 88 | auto panic(const char* msg, ...) -> void { 89 | std::va_list args; 90 | va_start(args, msg); 91 | std::array tmp {}; 92 | int delta{std::snprintf(tmp.data(), sizeof(tmp), "%s", "\x1b[31m")}; 93 | delta += std::vsnprintf(tmp.data()+delta, sizeof(tmp)-delta, msg, args); 94 | std::snprintf(tmp.data()+delta, sizeof(tmp)-delta, "%s", "\x1b[0m"); 95 | std::cerr << tmp.data() << std::endl; 96 | va_end(args); 97 | std::abort(); 98 | } 99 | 100 | static constexpr std::size_t cache_line { 101 | #ifdef __cpp_lib_hardware_interference_size 102 | std::hardware_destructive_interference_size 103 | #else 104 | 64 105 | #endif 106 | }; 107 | 108 | struct partition { 109 | std::int64_t ti {}; // thread index 110 | std::int64_t tc {}; // thread count 111 | }; 112 | 113 | class context::pimpl final { 114 | public: 115 | explicit pimpl(std::size_t num_threads); 116 | pimpl(const pimpl&) = delete; 117 | pimpl(pimpl&&) = delete; 118 | auto operator = (const pimpl&) -> pimpl& = delete; 119 | auto operator = (pimpl&&) -> pimpl& = delete; 120 | ~pimpl(); 121 | 122 | kernel_registry registry {}; 123 | std::size_t num_threads; 124 | pi::threadpool::ThreadPool m_pool; 125 | 126 | auto operator ()(const quant_descriptor& base_desc) const -> void; // Quant/Dequant dispatcher 127 | auto operator ()(std::span x, dtype quant_dst_type) -> std::pair; // Quant config dispatcher 128 | auto operator ()(std::span x, dtype quant_dst_type) -> std::pair; // Quant config dispatcher 129 | auto job_entry(partition& pl, const quant_descriptor& cmd) const -> void; 130 | }; 131 | 132 | auto context::pimpl::job_entry(partition& pl, const quant_descriptor& cmd) const -> void { 133 | const std::int64_t tc {std::max(std::int64_t{1}, pl.tc)}; 134 | const std::int64_t ti {pl.ti}; 135 | const auto partition_row {[&] () noexcept -> std::optional> { 136 | std::int64_t bs_in {static_cast(dtype_info_of(cmd.dt_in).bit_size)}; 137 | std::int64_t bs_out {static_cast(dtype_info_of(cmd.dt_out).bit_size)}; 138 | std::int64_t packed_bits {8}; 139 | switch (cmd.type) { 140 | case command_type::quant: packed_bits = bs_out; break; 141 | case command_type::dequant: packed_bits = bs_in; break; 142 | case command_type::quant_dequant: packed_bits = std::max(bs_in, bs_out); break; 143 | default: break; 144 | } 145 | std::int64_t pack_elems {packed_bits < 8 ? 8/packed_bits : 1}; 146 | std::int64_t n {cmd.numel}; 147 | std::int64_t tcm {std::max(1, pl.tc)}; 148 | std::int64_t t {pl.ti}; 149 | std::int64_t raw_begin {n*t / tcm}; 150 | std::int64_t raw_end {n*(t+1) / tcm}; 151 | const auto align_down {[&](std::int64_t v) noexcept -> std::int64_t { 152 | return pack_elems == 1 ? v : v - (v % pack_elems); 153 | }}; 154 | std::int64_t begin {pack_elems == 1 ? raw_begin : align_down(raw_begin)}; 155 | std::int64_t end {t+1 == tcm || pack_elems == 1 ? raw_end : align_down(raw_end)}; 156 | if (begin >= end) [[unlikely]] return {}; 157 | return {{begin, begin, end - begin}}; 158 | }}; 159 | const auto dispatch_quant {[&](const std::int64_t oa, const std::int64_t ob, const std::int64_t range, const quant_descriptor& cmd) noexcept -> void { 160 | auto* const kernel {®istry.quant_kernel}; 161 | piquant_assert2(kernel != nullptr); 162 | const auto si {dtype_info_of(cmd.dt_in).bit_size}; 163 | const auto so {cmd.type == command_type::quant_dequant ? si : dtype_info_of(cmd.dt_out).bit_size}; 164 | (*kernel)( 165 | cmd.in + (si*oa) / 8, 166 | cmd.out + (so*ob) / 8, 167 | range, 168 | cmd 169 | ); 170 | }}; 171 | 172 | if (const auto partition {partition_row()}; partition) [[likely]] { 173 | const auto [oa, ob, n] {*partition}; 174 | dispatch_quant(oa, ob, n, cmd); 175 | } 176 | } 177 | 178 | context::pimpl::pimpl(const std::size_t num_threads) : num_threads(num_threads), 179 | m_pool{static_cast(num_threads), 64} { 180 | registry = install_quant_generic(); 181 | m_pool.startup(); 182 | #ifdef __x86_64__ 183 | if (check_avx512f_bf16_support()) registry = install_quant_amd64_avx512f_bf16(); 184 | else if (check_avx512f_support()) registry = install_quant_amd64_avx512f(); 185 | else if (check_avx2_support()) registry = install_quant_amd64_avx2(); 186 | else if (check_sse42_support()) registry = install_quant_amd64_sse42(); 187 | #endif 188 | } 189 | 190 | context::pimpl::~pimpl() { 191 | m_pool.shutdown(); 192 | } 193 | 194 | static thread_local std::random_device rd; 195 | static thread_local std::mt19937_64 rng {rd()}; 196 | 197 | auto context::pimpl::operator()(const quant_descriptor& base_desc) const -> void { 198 | quant_descriptor desc {base_desc}; 199 | if (desc.rounding == round_mode::stochastic) { // Set random threshold for stochastic rounding on every invocation 200 | desc.rnd_threshold = std::uniform_real_distribution{0.f, 1.f}(rng); 201 | } 202 | const size_t num_threads {this->num_threads}; 203 | const pi::threadpool::MultiTaskResult jobs_future = m_pool.scheduleSequence(0u, num_threads, [this, &desc, num_threads](const std::size_t ti) { 204 | partition pl { 205 | .ti = static_cast(ti), 206 | .tc = static_cast(num_threads) 207 | }; 208 | job_entry(pl, desc); 209 | }); 210 | jobs_future.join(); 211 | } 212 | 213 | [[nodiscard]] static auto compute_type_max(dtype dt) noexcept -> std::uint64_t { 214 | dtype_info info {dtype_info_of(dt)}; 215 | piquant_assert(info.flags & dtype_flags::is_quant && info.flags & dtype_flags::is_int, "type %s is not a quantization type", info.name.data()); 216 | std::size_t width {dtype_info_of(dt).bit_size}; 217 | piquant_assert(width > 0 && width <= 64, "invalid width %zu for type %s", width, info.name.data()); 218 | if (info.flags & dtype_flags::is_signed) --width; 219 | return (1ull< requires is_float_type 223 | static auto compute_quant_config( 224 | pi::threadpool::ThreadPool& pool, 225 | F&& kernel, 226 | std::span x, 227 | dtype quant_dst_type 228 | ) -> std::pair { 229 | const auto* base {x.data()}; 230 | auto callback {[base, &kernel](std::size_t start, std::size_t end) -> std::array { 231 | std::size_t numel {end - start}; 232 | if (numel <= 0) return {0.0, 0.0}; 233 | std::span x {base + start, numel}; 234 | return std::invoke(kernel, x); 235 | }}; 236 | pi::threadpool::MultiTaskResult jobs_future {pool.scheduleBlocks(0u, x.size(), callback)}; 237 | jobs_future.join(); 238 | double r_min {std::numeric_limits::max()}; 239 | double r_max {std::numeric_limits::lowest()}; 240 | for (std::size_t i {}; i < jobs_future.size(); ++i) { 241 | auto [min, max] {jobs_future.get(i)}; 242 | r_min = std::min(r_min, static_cast(min)); 243 | r_max = std::max(r_max, static_cast(max)); 244 | } 245 | std::uint64_t type_max {compute_type_max(quant_dst_type)}; 246 | std::int64_t type_min {0}; 247 | if (dtype_info_of(quant_dst_type).flags & dtype_flags::is_signed) 248 | type_min = -static_cast(type_max) - 1; 249 | if (r_max == r_min) [[unlikely]] { 250 | auto mid {static_cast((type_max + type_min) >> 1)}; 251 | return {1.0f, mid}; 252 | } 253 | double q_min {static_cast(type_min)}; 254 | double q_max {static_cast(type_max)}; 255 | double scale = (r_max - r_min)/(q_max - q_min); 256 | double zero_point = q_min - r_min/scale; 257 | zero_point = std::max(std::min(static_cast(static_cast(std::round(zero_point))), q_max), q_min); 258 | return {scale, zero_point}; 259 | } 260 | 261 | auto context::pimpl::operator()(std::span x, dtype quant_dst_type) -> std::pair { 262 | auto& kernel {(*registry.find_min_max_float32)}; 263 | return compute_quant_config(m_pool, kernel, x, quant_dst_type); 264 | } 265 | 266 | auto context::pimpl::operator()(std::span x, dtype quant_dst_type) -> std::pair { 267 | auto& kernel {(*registry.find_min_max_bfloat16)}; 268 | return compute_quant_config(m_pool, kernel, x, quant_dst_type); 269 | } 270 | 271 | context::context(std::size_t num_threads) { 272 | m_pimpl = std::make_shared(num_threads); 273 | } 274 | 275 | context::~context() = default; 276 | 277 | auto context::quantize( 278 | const std::span in, 279 | const dtype dtype_in, 280 | const std::span out, 281 | const dtype dtype_out, 282 | const fp32_t scale, 283 | const std::int64_t zero_point, 284 | const round_mode mode 285 | ) const -> void { 286 | const auto& dti {dtype_info_of(dtype_in)}; 287 | const auto& dto {dtype_info_of(dtype_out)}; 288 | piquant_assert(!(dti.flags & dtype_flags::is_quant), "input dtype (%s) must be a dequantized type", dti.name.data()); 289 | piquant_assert(dto.flags & dtype_flags::is_quant, "output dtype (%s) must be a quantized type", dto.name.data()); 290 | std::size_t ne_in {in.size() / dti.stride}; 291 | std::size_t expected_out_bytes {dto.bit_size == 8 ? ne_in*dto.stride : packed_numel(ne_in, dto)*dto.stride}; 292 | piquant_assert(out.size() == expected_out_bytes, 293 | "quantize: expected output buffer to hold %zu byte(s) for %zu element(s) " 294 | "of %s (bit_size=%u), but got %zu", 295 | expected_out_bytes, ne_in, dto.name.data(), dto.bit_size, out.size()); 296 | quant_descriptor info { 297 | .type = command_type::quant, 298 | .in = in.data(), 299 | .out = out.data(), 300 | .numel = static_cast(ne_in), 301 | .scale = scale, 302 | .zero_point = zero_point, 303 | .dt_in = dtype_in, 304 | .dt_out = dtype_out, 305 | .rounding = mode 306 | }; 307 | (*this->m_pimpl)(info); 308 | } 309 | 310 | auto context::dequantize( 311 | const std::span in, 312 | const dtype dtype_in, 313 | const std::span out, 314 | const dtype dtype_out, 315 | const fp32_t scale, 316 | const std::int64_t zero_point, 317 | const reduce_op op 318 | ) const -> void { 319 | const auto& dti {dtype_info_of(dtype_in)}; 320 | const auto& dto {dtype_info_of(dtype_out)}; 321 | piquant_assert(dti.flags & dtype_flags::is_quant, "input dtype (%s) must be a quantized type", dto.name.data()); 322 | piquant_assert(!(dto.flags & dtype_flags::is_quant), "output dtype (%s) must be a dequantized type", dti.name.data()); 323 | std::size_t ne_out {out.size() / dto.stride}; 324 | std::size_t min_in_bytes {packed_numel(ne_out, dti) * dti.stride}; 325 | piquant_assert(in.size() == min_in_bytes, 326 | "dequantize: need %zu byte(s) of %s for %zu element(s), but got %zu", 327 | min_in_bytes, dti.name.data(), ne_out, in.size()); 328 | quant_descriptor info { 329 | .type = command_type::dequant, 330 | .in = in.data(), 331 | .out = out.data(), 332 | .numel = static_cast(ne_out), 333 | .scale = scale, 334 | .zero_point = zero_point, 335 | .dt_in = dtype_in, 336 | .dt_out = dtype_out, 337 | .reducing = op 338 | }; 339 | (*this->m_pimpl)(info); 340 | } 341 | 342 | auto context::quantize_dequantize_fused( 343 | const std::span in, 344 | const dtype dtype_in_out, 345 | const std::span out, 346 | const dtype quant_type, 347 | const fp32_t scale, 348 | const std::int64_t zero_point, 349 | const round_mode mode, 350 | const reduce_op op 351 | ) const -> void { 352 | const auto& dti{dtype_info_of(dtype_in_out)}; 353 | piquant_assert(!(dti.flags & dtype_flags::is_quant), "input dtype must be a dequantized type"); 354 | piquant_assert(dtype_info_of(quant_type).flags & dtype_flags::is_quant, "quant dtype must be a quantized type"); 355 | piquant_assert(in.size() == out.size(), "input and output spans must have the same length, but %zu != %zu", in.size(), out.size()); 356 | quant_descriptor info { 357 | .type = command_type::quant_dequant, 358 | .in = in.data(), 359 | .out = out.data(), 360 | .numel = static_cast(in.size()/dti.stride), 361 | .scale = scale, 362 | .zero_point = zero_point, 363 | .dt_in = dtype_in_out, 364 | .dt_out = quant_type, 365 | .rounding = mode, 366 | .reducing = op 367 | }; 368 | (*this->m_pimpl)(info); 369 | } 370 | 371 | auto context::compute_quant_config_from_data(std::span x, dtype quant_dst_dtype) const -> std::pair { 372 | auto result {(*this->m_pimpl)(x, quant_dst_dtype)}; 373 | piquant_assert(!std::isnan(result.first) && result.first >= 0.0f, "scale must be positive"); 374 | return result; 375 | } 376 | 377 | auto context::compute_quant_config_from_data(std::span x, dtype quant_dst_dtype) const -> std::pair { 378 | auto result {(*this->m_pimpl)(x, quant_dst_dtype)}; 379 | piquant_assert(!std::isnan(result.first) && result.first >= 0.0f, "scale must be positive"); 380 | return result; 381 | } 382 | } 383 | --------------------------------------------------------------------------------