├── .bazelrc ├── .clang-tidy ├── .github ├── metrics │ ├── __init__.py │ ├── datatypes.py │ ├── github.py │ ├── reporters.py │ ├── requirements.txt │ └── scrape.py ├── scripts │ ├── fbgemm_build.bash │ ├── fbgemm_gpu_benchmarks.bash │ ├── fbgemm_gpu_build.bash │ ├── fbgemm_gpu_docs.bash │ ├── fbgemm_gpu_install.bash │ ├── fbgemm_gpu_lint.bash │ ├── fbgemm_gpu_postbuild.bash │ ├── fbgemm_gpu_test.bash │ ├── filter_nova_matrix.py │ ├── nova_dir.bash │ ├── nova_postscript.bash │ ├── nova_prescript.bash │ ├── setup_env.bash │ ├── test_torchrec.bash │ ├── utils_base.bash │ ├── utils_build.bash │ ├── utils_conda.bash │ ├── utils_cuda.bash │ ├── utils_pip.bash │ ├── utils_pytorch.bash │ ├── utils_rocm.bash │ ├── utils_system.bash │ ├── utils_torchrec.bash │ └── utils_triton.bash └── workflows │ ├── build_wheels_genai_linux_aarch64.yml │ ├── build_wheels_genai_linux_x86.yml │ ├── build_wheels_linux_aarch64.yml │ ├── build_wheels_linux_x86.yml │ ├── fbgemm_ci.yml │ ├── fbgemm_gpu_benchmark_cpu.yml │ ├── fbgemm_gpu_benchmark_cuda.yml │ ├── fbgemm_gpu_benchmark_rocm.yml │ ├── fbgemm_gpu_ci_cpu.yml │ ├── fbgemm_gpu_ci_cuda.yml │ ├── fbgemm_gpu_ci_genai_generic_infra.yml │ ├── fbgemm_gpu_ci_rocm.yml │ ├── fbgemm_gpu_docs.yml │ ├── fbgemm_gpu_lint.yml │ ├── fbgemm_gpu_pip.yml │ ├── fbgemm_gpu_release_cpu.yml │ ├── fbgemm_gpu_release_cuda.yml │ └── fbgemm_gpu_release_genai.yml ├── .gitignore ├── .gitmodules ├── BUILD.bazel ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MODULE.bazel ├── README.md ├── WORKSPACE.bazel ├── bench ├── AlignedVec.h ├── BenchUtils.cc ├── BenchUtils.h ├── CMakeLists.txt ├── ConvUnifiedBenchmark.cc ├── ConvertBenchmark.cc ├── Depthwise3DBenchmark.cc ├── DepthwiseBenchmark.cc ├── EmbeddingIndexRemappingBenchmark.cc ├── EmbeddingQuantizeBenchmark.cc ├── EmbeddingQuantizeFloatToFloatOrHalfBenchmark.cc ├── EmbeddingSpMDM8BitBenchmark.cc ├── EmbeddingSpMDMBenchmark.cc ├── EmbeddingSpMDMNBit2Benchmark.cc ├── EmbeddingSpMDMNBitBenchmark.cc ├── EmbeddingSpMDMNBitRowWiseSparseBenchmark.cc ├── FP16Benchmark.cc ├── FP32Benchmark.cc ├── GEMMsBenchmark.cc ├── GEMMsTunableBenchmark.cc ├── GroupwiseConvRequantizeBenchmark.cc ├── I64Benchmark.cc ├── I8SpmdmBenchmark.cc ├── Im2ColFusedRequantizeBenchmark.cc ├── PackedFloatInOutBenchmark.cc ├── PackedRequantizeAcc16Benchmark.cc ├── PackedRequantizeAcc32Benchmark.cc ├── RequantizeBenchmark.cc ├── RowOffsetBenchmark.cc ├── RowwiseAdagradBenchmark.cc ├── RowwiseAdagradFusedBenchmark.cc ├── SparseAdagradBenchmark.cc ├── SparseDenseMMFP32Benchmark.cc ├── SparseDenseMMInt8Benchmark.cc └── TransposeBenchmark.cc ├── cmake └── modules │ ├── CudaSetup.cmake │ ├── CxxCompilerSetup.cmake │ ├── FindAVX.cmake │ ├── FindGnuH2fIeee.cmake │ ├── FindMKL.cmake │ ├── FindSphinx.cmake │ ├── GpuCppLibrary.cmake │ ├── PyTorchSetup.cmake │ ├── RocmSetup.cmake │ └── Utilities.cmake ├── defs.bzl ├── docs ├── CMakeLists.txt ├── Doxyfile.in ├── conf.py ├── index.rst └── requirements.txt ├── external ├── asmjit.BUILD └── cpuinfo.BUILD ├── fbgemm_gpu ├── CMakeLists.txt ├── FbgemmGpu.cmake ├── README.md ├── bench │ ├── README.md │ ├── batched_unary_embeddings_benchmark.py │ ├── bench_utils.py │ ├── histogram_binning_calibration_benchmark.py │ ├── jagged_tensor_benchmark.py │ ├── merge_embeddings_benchmark.py │ ├── quantize_ops_benchmark.py │ ├── sparse_ops_benchmark.py │ ├── stride_gemm_benchmark.py │ ├── tbe │ │ ├── README.md │ │ ├── batch_benchmark_run.py │ │ ├── run_tbe_benchmark.py │ │ ├── split_table_batched_embeddings_benchmark.py │ │ ├── tbe_cache_benchmark.py │ │ ├── tbe_inference_benchmark.py │ │ ├── tbe_ssd_benchmark.py │ │ ├── tbe_training_benchmark.py │ │ └── tbe_utils_benchmark.py │ └── verify_fp16_stochastic_benchmark.cu ├── cmake │ ├── Asmjit.cmake │ ├── Fbgemm.cmake │ ├── Hip.cmake │ ├── TbeInference.cmake │ ├── TbeTraining.cmake │ └── tbe_sources.py ├── codegen │ ├── genscript │ │ ├── __init__.py │ │ ├── common.py │ │ ├── generate_backward_split.py │ │ ├── generate_embedding_optimizer.py │ │ ├── generate_forward_quantized.py │ │ ├── generate_forward_split.py │ │ ├── generate_index_select.py │ │ ├── jinja_environment.py │ │ ├── optimizer_args.py │ │ ├── optimizers.py │ │ ├── scripts_argsparse.py │ │ └── torch_type_utils.py │ ├── inference │ │ ├── embedding_forward_quantized_cpu_template.cpp │ │ ├── embedding_forward_quantized_host.cpp │ │ ├── embedding_forward_quantized_host_cpu.cpp │ │ ├── embedding_forward_quantized_split_lookup.cu │ │ ├── embedding_forward_quantized_split_nbit_host_template.cu │ │ └── embedding_forward_quantized_split_nbit_kernel_template.cu │ ├── training │ │ ├── backward │ │ │ ├── embedding_backward_dense_host_cpu.cpp │ │ │ ├── embedding_backward_split_cpu_approx_template.cpp │ │ │ ├── embedding_backward_split_cpu_template.cpp │ │ │ ├── embedding_backward_split_device_kernel_template.cuh │ │ │ ├── embedding_backward_split_grad_template.cu │ │ │ ├── embedding_backward_split_host_cpu_template.cpp │ │ │ ├── embedding_backward_split_host_template.cpp │ │ │ ├── embedding_backward_split_indice_weights_template.cu │ │ │ ├── embedding_backward_split_kernel_cta_template.cu │ │ │ ├── embedding_backward_split_kernel_warp_template.cu │ │ │ ├── embedding_backward_split_meta_template.cpp │ │ │ ├── embedding_backward_split_template.cu │ │ │ └── rocm │ │ │ │ └── embedding_backward_split_device_kernel_template.hip │ │ ├── embedding_ops_placeholder.cpp │ │ ├── forward │ │ │ ├── embedding_forward_split_cpu.cpp │ │ │ ├── embedding_forward_split_kernel_nobag_small_template.cu │ │ │ ├── embedding_forward_split_kernel_template.cu │ │ │ ├── embedding_forward_split_kernel_v2_template.cu │ │ │ ├── embedding_forward_split_meta_template.cpp │ │ │ └── embedding_forward_split_template.cu │ │ ├── index_select │ │ │ ├── batch_index_select_dim0_cpu_host.cpp │ │ │ ├── batch_index_select_dim0_host.cpp │ │ │ └── batch_index_select_dim0_ops.cpp │ │ ├── optimizer │ │ │ ├── embedding_optimizer_split_device_kernel_template.cuh │ │ │ ├── embedding_optimizer_split_host_template.cpp │ │ │ ├── embedding_optimizer_split_kernel_template.cu │ │ │ └── embedding_optimizer_split_template.cu │ │ ├── pt2 │ │ │ ├── embedding_split_host_pt2_autograd_template.cpp │ │ │ ├── embedding_split_host_pt2_cpu_wrapper_template.cpp │ │ │ ├── embedding_split_host_pt2_cuda_wrapper_template.cpp │ │ │ ├── pt2_arg_utils_template.h │ │ │ └── pt2_autograd_utils.cpp │ │ └── python │ │ │ ├── __init__.template │ │ │ ├── lookup_args.template │ │ │ ├── optimizer_args.py │ │ │ ├── split_embedding_codegen_lookup_invoker.template │ │ │ └── split_embedding_optimizer_codegen.template │ └── utils │ │ ├── embedding_bounds_check_host.cpp │ │ ├── embedding_bounds_check_host_cpu.cpp │ │ ├── embedding_bounds_check_v1.cu │ │ └── embedding_bounds_check_v2.cu ├── docs │ ├── Doxyfile.in │ ├── Makefile │ ├── README.md │ ├── requirements.txt │ └── src │ │ ├── conf.py │ │ ├── fbgemm │ │ ├── cpp-api │ │ │ ├── QuantUtils.rst │ │ │ └── tbe_cpu_autovec.rst │ │ ├── development │ │ │ └── BuildInstructions.rst │ │ └── index.rst │ │ ├── fbgemm_genai │ │ ├── development │ │ │ ├── BuildInstructions.rst │ │ │ ├── InstallationInstructions.rst │ │ │ └── TestInstructions.rst │ │ └── index.rst │ │ ├── fbgemm_gpu │ │ ├── cpp-api │ │ │ ├── embedding_ops.rst │ │ │ ├── experimental_ops.rst │ │ │ ├── feature_gates.rst │ │ │ ├── input_combine.rst │ │ │ ├── jagged_tensor_ops.rst │ │ │ ├── layout_transform_ops.rst │ │ │ ├── memory_utils.rst │ │ │ ├── merge_pooled_embeddings.rst │ │ │ ├── quantize_ops.rst │ │ │ ├── sparse_ops.rst │ │ │ ├── split_table_batched_embeddings.rst │ │ │ └── ssd_embedding_ops.rst │ │ ├── development │ │ │ ├── BuildInstructions.rst │ │ │ ├── FeatureGates.rst │ │ │ ├── InstallationInstructions.rst │ │ │ └── TestInstructions.rst │ │ ├── index.rst │ │ ├── overview │ │ │ └── jagged-tensor-ops │ │ │ │ ├── JaggedTensorConversion1.png │ │ │ │ ├── JaggedTensorConversion2.png │ │ │ │ ├── JaggedTensorConversion3.png │ │ │ │ ├── JaggedTensorExample.png │ │ │ │ └── JaggedTensorOps.rst │ │ ├── python-api │ │ │ ├── feature_gates.rst │ │ │ ├── jagged_tensor_ops.rst │ │ │ ├── pooled_embedding_modules.rst │ │ │ ├── pooled_embedding_ops.rst │ │ │ ├── quantize_ops.rst │ │ │ ├── sparse_ops.rst │ │ │ ├── tbe_ops_inference.rst │ │ │ └── tbe_ops_training.rst │ │ └── stable-api │ │ │ └── python_api.rst │ │ ├── general │ │ ├── ContactUs.rst │ │ ├── Contributing.rst │ │ ├── License.rst │ │ ├── Releases.rst │ │ ├── documentation │ │ │ ├── Cpp.rst │ │ │ ├── ExampleGraph.dot │ │ │ ├── Overview.rst │ │ │ ├── Python.rst │ │ │ └── Sphinx.rst │ │ └── index.rst │ │ ├── index.rst │ │ └── nitpick.ignore ├── experimental │ ├── example │ │ ├── CMakeLists.txt │ │ ├── example │ │ │ ├── __init__.py │ │ │ └── utils.py │ │ ├── src │ │ │ ├── cutlass_sgemm_nn.cu │ │ │ ├── example_nccl.cpp │ │ │ └── example_ops.cpp │ │ └── test │ │ │ ├── __init__.py │ │ │ ├── add_tensors_float_test.py │ │ │ ├── sgemm_float_test.py │ │ │ └── triton_example_test.py │ ├── gemm │ │ ├── CMakeLists.txt │ │ ├── test │ │ │ ├── __init__.py │ │ │ ├── fp4_quantize_test.py │ │ │ ├── fp8_gemm_benchmark.py │ │ │ ├── fp8_gemm_test.py │ │ │ └── grouped_gemm_test.py │ │ └── triton_gemm │ │ │ ├── __init__.py │ │ │ ├── fp4_quantize.py │ │ │ ├── fp8_gemm.py │ │ │ ├── grouped_gemm.py │ │ │ ├── matmul_perf_model.py │ │ │ └── utils.py │ ├── gen_ai │ │ ├── CMakeLists.txt │ │ ├── README.md │ │ ├── bench │ │ │ ├── __init__.py │ │ │ ├── ck_bf16_bench.py │ │ │ ├── comm_bench.py │ │ │ ├── gather_scatter_bench.py │ │ │ ├── quantize_bench.py │ │ │ └── quantize_ops.py │ │ ├── gen_ai │ │ │ ├── __init__.py │ │ │ ├── moe │ │ │ │ ├── README.md │ │ │ │ ├── __init__.py │ │ │ │ ├── activation.py │ │ │ │ ├── gather_scatter.py │ │ │ │ ├── layers.py │ │ │ │ └── shuffling.py │ │ │ └── quantize.py │ │ ├── src │ │ │ ├── attention │ │ │ │ ├── attention.cpp │ │ │ │ └── gqa_attn_splitk.cu │ │ │ ├── coalesce │ │ │ │ ├── coalesce.cpp │ │ │ │ ├── coalesce.cu │ │ │ │ └── coalesce.h │ │ │ ├── comm │ │ │ │ ├── car.cpp │ │ │ │ └── car.cu │ │ │ ├── gather_scatter │ │ │ │ ├── gather_scatter.cpp │ │ │ │ └── gather_scatter.cu │ │ │ ├── gemm │ │ │ │ ├── ck_extensions.hip │ │ │ │ └── gemm.cpp │ │ │ ├── kv_cache │ │ │ │ ├── kv_cache.cpp │ │ │ │ ├── kv_cache.cu │ │ │ │ └── kv_cache.h │ │ │ ├── moe │ │ │ │ ├── index_shuffling.cpp │ │ │ │ └── index_shuffling.cu │ │ │ └── quantize │ │ │ │ ├── ck_extensions │ │ │ │ ├── bf16_grouped │ │ │ │ │ ├── bf16_grouped_gemm.hip │ │ │ │ │ └── kernels │ │ │ │ │ │ ├── bf16_grouped_128x16x32x128_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── bf16_grouped_128x16x32x64_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip │ │ │ │ │ │ ├── bf16_grouped_128x16x32x64_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip │ │ │ │ │ │ ├── bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v1.hip │ │ │ │ │ │ ├── bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2.hip │ │ │ │ │ │ ├── bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_intrawave_v2.hip │ │ │ │ │ │ ├── bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip │ │ │ │ │ │ ├── bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip │ │ │ │ │ │ ├── bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip │ │ │ │ │ │ ├── bf16_grouped_128x16x96x64_16x16_1x3_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip │ │ │ │ │ │ ├── bf16_grouped_128x32x16x64_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── bf16_grouped_128x32x64x128_32x32_1x1_16x8x1_16x8x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── bf16_grouped_128x32x64x128_32x32_1x1_16x8x1_16x8x1_1x16x1x8_8x8x1_1x1_intrawave_v1.hip │ │ │ │ │ │ ├── bf16_grouped_128x32x96x128_16x16_2x3_16x8x1_16x8x1_1x32x1x4_8x8x1_2x1_intrawave_v2.hip │ │ │ │ │ │ ├── bf16_grouped_128x64x128x64_32x32_2x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── bf16_grouped_128x64x96x64_16x16_4x3_8x16x1_8x16x1_1x32x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── bf16_grouped_256x128x128x128_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── bf16_grouped_256x128x128x64_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip │ │ │ │ │ │ ├── bf16_grouped_256x128x128x64_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── bf16_grouped_256x128x224x64_16x16_4x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── bf16_grouped_256x128x256x64_32x32_4x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── bf16_grouped_256x128x96x64_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v1.hip │ │ │ │ │ │ ├── bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v2.hip │ │ │ │ │ │ ├── bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── bf16_grouped_256x224x256x32_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── bf16_grouped_256x256x128x32_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip │ │ │ │ │ │ ├── bf16_grouped_256x256x160x64_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── bf16_grouped_256x256x192x64_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── bf16_grouped_256x256x224x64_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── bf16_grouped_256x256x256x64_32x32_4x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── bf16_grouped_256x32x128x128_16x16_1x4_16x16x1_16x16x1_1x32x1x8_8x8x1_1x2_intrawave_v2.hip │ │ │ │ │ │ ├── bf16_grouped_256x32x224x64_16x16_1x7_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v1.hip │ │ │ │ │ │ ├── bf16_grouped_256x32x96x64_16x16_1x3_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v1.hip │ │ │ │ │ │ ├── bf16_grouped_256x32x96x64_16x16_1x3_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── bf16_grouped_256x64x128x128_32x32_2x1_16x16x1_16x16x1_1x16x1x16_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── bf16_grouped_256x64x192x128_16x16_4x3_16x16x1_16x16x1_1x32x1x8_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── bf16_grouped_256x64x96x64_16x16_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── bf16_grouped_64x16x16x128_16x16_1x1_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── bf16_grouped_64x16x16x128_16x16_1x1_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip │ │ │ │ │ │ ├── bf16_grouped_64x16x16x64_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── bf16_grouped_64x16x32x128_16x16_1x2_16x4x1_16x4x1_1x16x1x4_8x8x1_1x2_intrawave_v2.hip │ │ │ │ │ │ ├── bf16_grouped_64x16x48x128_16x16_1x3_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip │ │ │ │ │ │ ├── bf16_grouped_64x16x64x128_16x16_1x4_16x4x1_16x4x1_1x16x1x4_8x8x1_1x2_intrawave_v2.hip │ │ │ │ │ │ ├── bf16_grouped_common.h │ │ │ │ │ │ └── bf16_grouped_kernel_manifest.h │ │ │ │ ├── ck_utility.hip │ │ │ │ ├── fp8_blockwise_gemm.hip │ │ │ │ ├── fp8_rowwise │ │ │ │ │ ├── fp8_rowwise_gemm.hip │ │ │ │ │ └── kernels │ │ │ │ │ │ ├── fp8_rowwise_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_4_split_k.hip │ │ │ │ │ │ ├── fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_8_split_k.hip │ │ │ │ │ │ ├── fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2_8_split_k.hip │ │ │ │ │ │ ├── fp8_rowwise_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_interwave_v2_2_split_k.hip │ │ │ │ │ │ ├── fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2_2_split_k.hip │ │ │ │ │ │ ├── fp8_rowwise_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2_16.split_k │ │ │ │ │ │ ├── fp8_rowwise_128x32x16x256_16x16_1x1_16x8x1_16x8x1_1x32x1x4_4x4x1_1x1_interwave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_128x32x16x512_16x16_1x1_32x4x1_32x4x1_1x32x1x4_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_128x32x16x512_16x16_1x1_32x4x1_32x4x1_1x32x1x4_4x4x1_1x1_intrawave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2_16.split_k │ │ │ │ │ │ ├── fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_16.split_k │ │ │ │ │ │ ├── fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_256x128x128x128_16x16_4x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5.hip │ │ │ │ │ │ ├── fp8_rowwise_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip │ │ │ │ │ │ ├── fp8_rowwise_256x128x160x128_16x16_4x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x128x256x128_32x32_2x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x128x64x256_32x32_2x1_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x128x96x128_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x160x128x128_16x16_5x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x160x256x128_16x16_5x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x160x96x128_16x16_5x3_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x16x64x128_16x16_1x1_16x16x1_8x32x1_1x16x1x16_4x4x1_1x1_intrawave_v2_8_split_k.hip │ │ │ │ │ │ ├── fp8_rowwise_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x192x128x128_16x16_6x4_8x32x1_8x32x1_1x32x1x8_8x8x1_2x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x192x192x128_16x16_6x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x192x224x128_16x16_6x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x192x256x128_16x16_6x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x192x256x128_16x16_6x8_8x32x1_8x32x1_1x32x1x8_8x8x1_2x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x224x160x128_16x16_7x5_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x224x192x128_16x16_7x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x256x128x128_32x32_4x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x256x192x128_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip │ │ │ │ │ │ ├── fp8_rowwise_256x256x96x128_16x16_8x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x256x96x128_32x32_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x32x128x256_32x32_1x1_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x64x128x128_32x32_1x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x64x16x512_16x16_1x1_32x8x1_32x8x1_1x64x1x4_4x4x1_1x1_intrawave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_256x64x192x128_32x32_1x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x64x192x256_32x32_1x3_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x64x256x128_32x32_1x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x64x96x256_16x16_2x3_16x16x1_16x16x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x80x128x256_16x16_5x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_256x96x128x128_16x16_3x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1_16.split_k │ │ │ │ │ │ ├── fp8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_common.h │ │ │ │ │ │ └── fp8_rowwise_kernel_manifest.h │ │ │ │ ├── fp8_rowwise_batched │ │ │ │ │ ├── fp8_rowwise_batched_gemm.hip │ │ │ │ │ └── kernels │ │ │ │ │ │ ├── fp8_rowwise_batched_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x128x256x128_32x32_2x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x32x128x256_32x32_1x1_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x64x192x256_32x32_1x3_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4_1x1_interwave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_batched_common.h │ │ │ │ │ │ └── fp8_rowwise_batched_kernel_manifest.h │ │ │ │ ├── fp8_rowwise_grouped │ │ │ │ │ ├── fp8_rowwise_grouped_gemm.hip │ │ │ │ │ └── kernels │ │ │ │ │ │ ├── fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_128x16x64x256_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_128x16x64x256_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_128x16x64x256_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_intrawave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_128x16x96x256_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_128x32x64x256_16x16_1x4_16x8x1_16x8x1_1x32x1x4_8x8x1_1x2_intrawave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_128x32x64x256_32x32_1x1_16x8x1_16x8x1_1x16x1x8_8x8x1_1x1_interwave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_128x64x64x256_32x32_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_128x64x64x256_32x32_2x1_16x8x1_16x8x1_1x16x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x128x224x128_16x16_4x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x128x256x128_32x32_4x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x128x96x128_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x16x128x256_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x16x128x256_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x16x128x256_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x16x64x256_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x16x64x256_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_intrawave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x16x64x256_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x192x96x128_16x16_6x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x256x160x128_32x32_2x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x256x192x128_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x256x256x128_32x32_4x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x256x256x128_32x32_8x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x32x128x128_16x16_1x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x32x160x128_16x16_1x5_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x32x160x128_16x16_1x5_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x32x256x128_16x16_1x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x32x32x512_16x16_1x1_32x8x1_32x8x1_1x32x1x8_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x32x32x512_16x16_1x1_32x8x1_32x8x1_1x32x1x8_4x4x1_1x1_intrawave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x32x64x512_16x16_2x1_32x8x1_32x8x1_1x32x1x8_8x8x1_2x1_intrawave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x64x128x256_32x32_2x1_16x16x1_16x16x1_1x16x1x16_8x8x1_1x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x64x160x128_16x16_2x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_256x64x192x128_16x16_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_2x1_intrawave_v3.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_64x16x32x256_16x16_1x2_16x4x1_16x4x1_1x16x1x4_8x8x1_1x2_intrawave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_64x16x64x256_16x16_1x4_16x4x1_16x4x1_1x16x1x4_8x8x1_1x2_interwave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_64x16x64x256_16x16_1x4_16x4x1_16x4x1_1x16x1x4_8x8x1_1x2_intrawave_v1.hip │ │ │ │ │ │ ├── fp8_rowwise_grouped_common.h │ │ │ │ │ │ └── fp8_rowwise_grouped_kernel_manifest.h │ │ │ │ ├── fp8_tensorwise_gemm.hip │ │ │ │ └── fused_moe │ │ │ │ │ ├── CMakeLists.txt │ │ │ │ │ ├── fused_moe.hpp │ │ │ │ │ ├── fused_moe_kernel.hip │ │ │ │ │ ├── fused_moe_op.cpp │ │ │ │ │ ├── fused_moegemm.hpp │ │ │ │ │ ├── fused_moesorting.hpp │ │ │ │ │ ├── instances │ │ │ │ │ ├── fused_moe_api.hip │ │ │ │ │ ├── fused_moegemm_api.hip │ │ │ │ │ ├── fused_moegemm_api_internal.hpp │ │ │ │ │ ├── fused_moegemm_api_traits.hpp │ │ │ │ │ ├── fused_moegemm_bf16_m32.hip │ │ │ │ │ ├── fused_moegemm_fp16_m32.hip │ │ │ │ │ └── fused_moesorting_api.hip │ │ │ │ │ ├── main.hip │ │ │ │ │ └── run.py │ │ │ │ ├── cublas_utils.h │ │ │ │ ├── cutlass_extensions │ │ │ │ ├── bf16bf16bf16_grouped.cu │ │ │ │ ├── bf16bf16bf16_grouped │ │ │ │ │ ├── bf16bf16bf16_grouped_128_128_128_1_1_1_f.cu │ │ │ │ │ ├── bf16bf16bf16_grouped_128_128_128_2_1_1_t.cu │ │ │ │ │ ├── bf16bf16bf16_grouped_128_16_128_1_1_1_f.cu │ │ │ │ │ ├── bf16bf16bf16_grouped_128_16_128_2_1_1_f.cu │ │ │ │ │ ├── bf16bf16bf16_grouped_128_256_128_1_1_1_f.cu │ │ │ │ │ ├── bf16bf16bf16_grouped_128_256_128_2_1_1_f.cu │ │ │ │ │ ├── bf16bf16bf16_grouped_128_32_128_1_1_1_f.cu │ │ │ │ │ ├── bf16bf16bf16_grouped_128_32_128_2_1_1_f.cu │ │ │ │ │ ├── bf16bf16bf16_grouped_128_32_128_2_1_1_t.cu │ │ │ │ │ ├── bf16bf16bf16_grouped_128_64_128_1_1_1_f.cu │ │ │ │ │ ├── bf16bf16bf16_grouped_128_64_128_2_1_1_t.cu │ │ │ │ │ ├── bf16bf16bf16_grouped_256_128_128_2_1_1_f.cu │ │ │ │ │ ├── bf16bf16bf16_grouped_common.cuh │ │ │ │ │ └── bf16bf16bf16_grouped_manifest.cuh │ │ │ │ ├── bf16i4bf16.cu │ │ │ │ ├── bf16i4bf16_rowwise_batched.cu │ │ │ │ ├── bf16i4bf16_shuffled_grouped.cu │ │ │ │ ├── f4f4bf16.cu │ │ │ │ ├── f4f4bf16 │ │ │ │ │ ├── f4f4bf16_128_128_4_1_1_f.cu │ │ │ │ │ ├── f4f4bf16_128_128_4_1_1_t.cu │ │ │ │ │ ├── f4f4bf16_128_192_2_2_1_f.cu │ │ │ │ │ ├── f4f4bf16_128_192_2_2_1_t.cu │ │ │ │ │ ├── f4f4bf16_128_256_2_1_1_f.cu │ │ │ │ │ ├── f4f4bf16_128_256_2_1_1_t.cu │ │ │ │ │ ├── f4f4bf16_256_128_2_2_1_f.cu │ │ │ │ │ ├── f4f4bf16_256_128_2_2_1_t.cu │ │ │ │ │ ├── f4f4bf16_256_128_2_4_1_f.cu │ │ │ │ │ ├── f4f4bf16_256_128_2_4_1_t.cu │ │ │ │ │ ├── f4f4bf16_256_192_2_2_1_f.cu │ │ │ │ │ ├── f4f4bf16_256_192_2_2_1_t.cu │ │ │ │ │ ├── f4f4bf16_256_192_2_4_1_f.cu │ │ │ │ │ ├── f4f4bf16_256_192_2_4_1_t.cu │ │ │ │ │ ├── f4f4bf16_256_192_4_1_1_f.cu │ │ │ │ │ ├── f4f4bf16_256_192_4_1_1_t.cu │ │ │ │ │ ├── f4f4bf16_256_256_2_1_1_f.cu │ │ │ │ │ ├── f4f4bf16_256_256_2_1_1_t.cu │ │ │ │ │ ├── f4f4bf16_256_256_2_2_1_f.cu │ │ │ │ │ ├── f4f4bf16_256_256_2_2_1_t.cu │ │ │ │ │ ├── f4f4bf16_256_256_2_4_1_f.cu │ │ │ │ │ ├── f4f4bf16_256_256_2_4_1_t.cu │ │ │ │ │ ├── f4f4bf16_256_256_4_1_1_f.cu │ │ │ │ │ ├── f4f4bf16_256_256_4_1_1_t.cu │ │ │ │ │ ├── f4f4bf16_common.cuh │ │ │ │ │ └── f4f4bf16_manifest.cuh │ │ │ │ ├── f4f4bf16_grouped.cu │ │ │ │ ├── f8f8bf16.cu │ │ │ │ ├── f8f8bf16_blockwise.cu │ │ │ │ ├── f8f8bf16_cublas.cu │ │ │ │ ├── f8f8bf16_lite.cu │ │ │ │ ├── f8f8bf16_rowwise.cu │ │ │ │ ├── f8f8bf16_rowwise │ │ │ │ │ ├── f8f8bf16_rowwise_128_128_128_2_1_1_9_t_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_128_128_128_2_2_1_10_f_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_128_256_128_1_2_1_10_f_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t.cu │ │ │ │ │ ├── f8f8bf16_rowwise_128_256_128_4_4_1_9_f_t.cu │ │ │ │ │ ├── f8f8bf16_rowwise_128_32_128_1_1_1_10_f_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_128_64_128_1_1_1_10_f_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_64_256_128_2_1_1_9_f_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_common.cuh │ │ │ │ │ └── f8f8bf16_rowwise_manifest.cuh │ │ │ │ ├── f8f8bf16_rowwise_batched.cu │ │ │ │ ├── f8f8bf16_rowwise_batched │ │ │ │ │ ├── f8f8bf16_rowwise_batched_128_128_128_1_2_1_10_t.cu │ │ │ │ │ ├── f8f8bf16_rowwise_batched_128_128_128_1_2_1_9_t.cu │ │ │ │ │ ├── f8f8bf16_rowwise_batched_128_128_128_2_1_1_10_t.cu │ │ │ │ │ ├── f8f8bf16_rowwise_batched_128_128_128_2_1_1_9_t.cu │ │ │ │ │ ├── f8f8bf16_rowwise_batched_64_128_128_1_2_1_10_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_batched_64_128_128_1_2_1_9_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_batched_64_128_128_2_1_1_10_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_batched_64_128_128_2_1_1_9_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_batched_common.cuh │ │ │ │ │ └── f8f8bf16_rowwise_batched_manifest.cuh │ │ │ │ ├── f8f8bf16_rowwise_grouped.cu │ │ │ │ ├── f8f8bf16_rowwise_grouped │ │ │ │ │ ├── f8f8bf16_rowwise_grouped_128_128_128_1_1_1_9_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_grouped_128_16_128_1_1_1_9_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_grouped_128_256_128_2_1_1_9_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_grouped_128_32_128_1_1_1_9_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_grouped_128_64_128_1_1_1_9_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_grouped_256_128_128_2_1_1_9_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_grouped_common.cuh │ │ │ │ │ └── f8f8bf16_rowwise_grouped_manifest.cuh │ │ │ │ ├── f8f8bf16_rowwise_grouped_sm100 │ │ │ │ │ ├── f8f8bf16_rowwise_grouped_128_128_128_2_1_1_10_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_grouped_128_32_128_2_1_1_10_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_grouped_128_64_128_2_1_1_10_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_grouped_256_128_128_2_1_1_10_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_grouped_256_256_128_2_1_1_10_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_grouped_256_32_128_2_1_1_10_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_grouped_256_64_128_2_1_1_10_f.cu │ │ │ │ │ ├── f8f8bf16_rowwise_grouped_common.cuh │ │ │ │ │ └── f8f8bf16_rowwise_grouped_manifest.cuh │ │ │ │ ├── f8f8bf16_tensorwise.cu │ │ │ │ ├── f8i4bf16_rowwise.cu │ │ │ │ ├── f8i4bf16_shuffled.cu │ │ │ │ ├── f8i4bf16_shuffled_grouped.cu │ │ │ │ ├── i8i8bf16.cu │ │ │ │ ├── i8i8bf16_dynamic.cu │ │ │ │ ├── include │ │ │ │ │ ├── fp8_blockwise_cutlass_helpers.h │ │ │ │ │ ├── kernel_mode.h │ │ │ │ │ └── threadblock.h │ │ │ │ └── mixed_dtype_utils.cu │ │ │ │ ├── fast_gemv │ │ │ │ ├── bf16_fast_gemv.cu │ │ │ │ ├── bf16fp8bf16_fast_gemv.cu │ │ │ │ ├── fp8fp8bf16_fast_gemv.cu │ │ │ │ ├── include │ │ │ │ │ ├── common_utils.h │ │ │ │ │ ├── fast_gemv.cu │ │ │ │ │ ├── fast_gemv.cuh │ │ │ │ │ └── utility.cuh │ │ │ │ └── sweep_utils.py │ │ │ │ ├── quantize.cpp │ │ │ │ └── quantize.cu │ │ └── test │ │ │ ├── attention │ │ │ └── gqa_test.py │ │ │ ├── coalesce │ │ │ └── coalesce_test.py │ │ │ ├── comm │ │ │ └── multi_gpu_car_test.py │ │ │ ├── gather_scatter │ │ │ └── gather_scatter_test.py │ │ │ ├── kv_cache │ │ │ ├── kv_cache_test.py │ │ │ └── rope_padded.py │ │ │ ├── moe │ │ │ ├── __init__.py │ │ │ ├── activation_test.py │ │ │ ├── gather_scatter_test.py │ │ │ ├── layers_test.py │ │ │ ├── parallelism.py │ │ │ ├── shuffling_test.py │ │ │ └── utils.py │ │ │ └── quantize │ │ │ └── quantize_test.py │ └── hstu │ │ ├── CMakeLists.txt │ │ ├── LICENSE │ │ ├── README.md │ │ ├── context_causal_target.png │ │ ├── deltaq_causal.png │ │ ├── deltaq_local.png │ │ ├── hstu │ │ ├── __init__.py │ │ └── cuda_hstu_attention.py │ │ ├── src │ │ ├── generate_kernels.py │ │ ├── hstu_ampere │ │ │ ├── block_info.h │ │ │ ├── hstu.h │ │ │ ├── hstu_bwd.h │ │ │ ├── hstu_fwd.h │ │ │ ├── hstu_ops_gpu.cpp │ │ │ ├── instantiations │ │ │ │ └── .gitkeep │ │ │ ├── kernel_traits.h │ │ │ ├── static_switch.h │ │ │ └── utils.h │ │ └── hstu_hopper │ │ │ ├── epilogue_bwd_sm90_tma.hpp │ │ │ ├── epilogue_fwd_sm90_tma.hpp │ │ │ ├── hstu.h │ │ │ ├── hstu_bwd_kernel.h │ │ │ ├── hstu_bwd_launch_template.h │ │ │ ├── hstu_bwd_postprocess_kernel.h │ │ │ ├── hstu_fwd_kernel.h │ │ │ ├── hstu_fwd_launch_template.h │ │ │ ├── hstu_ops_gpu.cpp │ │ │ ├── instantiations │ │ │ └── .gitkeep │ │ │ ├── kernel_traits.h │ │ │ ├── mainloop_bwd_sm90_tma_gmma_ws.hpp │ │ │ ├── mainloop_fwd_sm90_tma_gmma_ws.hpp │ │ │ ├── named_barrier.hpp │ │ │ ├── seq_len.h │ │ │ ├── static_switch.h │ │ │ ├── tile_scheduler.hpp │ │ │ ├── tile_scheduler_bwd.hpp │ │ │ └── utils.h │ │ └── test │ │ └── hstu_test.py ├── fbgemm_gpu │ ├── __init__.py │ ├── batched_unary_embeddings_ops.py │ ├── config │ │ ├── __init__.py │ │ └── feature_list.py │ ├── docs │ │ ├── __init__.py │ │ ├── common.py │ │ ├── examples.py │ │ ├── jagged_tensor_ops.py │ │ ├── merge_pooled_embedding_ops.py │ │ ├── permute_pooled_embedding_ops.py │ │ ├── quantize_ops.py │ │ └── sparse_ops.py │ ├── enums.py │ ├── metrics.py │ ├── permute_pooled_embedding_modules.py │ ├── permute_pooled_embedding_modules_split.py │ ├── quantize │ │ ├── __init__.py │ │ └── quantize_ops.py │ ├── quantize_comm.py │ ├── quantize_utils.py │ ├── runtime_monitor.py │ ├── sll │ │ ├── __init__.py │ │ ├── cpu │ │ │ ├── __init__.py │ │ │ └── cpu_sll.py │ │ ├── meta │ │ │ ├── __init__.py │ │ │ └── meta_sll.py │ │ └── triton │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── triton_dense_jagged_cat_jagged_out.py │ │ │ ├── triton_jagged2_to_padded_dense.py │ │ │ ├── triton_jagged_bmm.py │ │ │ ├── triton_jagged_bmm_jagged_out.py │ │ │ ├── triton_jagged_dense_elementwise_add.py │ │ │ ├── triton_jagged_dense_elementwise_mul_jagged_out.py │ │ │ ├── triton_jagged_dense_flash_attention.py │ │ │ ├── triton_jagged_flash_attention_basic.py │ │ │ ├── triton_jagged_self_substraction_jagged_out.py │ │ │ ├── triton_jagged_softmax.py │ │ │ └── triton_multi_head_jagged_flash_attention.py │ ├── sparse_ops.py │ ├── split_embedding_configs.py │ ├── split_embedding_inference_converter.py │ ├── split_embedding_optimizer_ops.py │ ├── split_embedding_utils.py │ ├── split_table_batched_embeddings_ops.py │ ├── split_table_batched_embeddings_ops_common.py │ ├── split_table_batched_embeddings_ops_inference.py │ ├── split_table_batched_embeddings_ops_training.py │ ├── split_table_batched_embeddings_ops_training_common.py │ ├── ssd_split_table_batched_embeddings_ops.py │ ├── tbe │ │ ├── __init__.py │ │ ├── bench │ │ │ ├── __init__.py │ │ │ ├── bench_config.py │ │ │ ├── bench_runs.py │ │ │ ├── eeg_cli.py │ │ │ ├── embedding_ops_common_config.py │ │ │ ├── eval_compression.py │ │ │ ├── reporter.py │ │ │ ├── tbe_data_config.py │ │ │ ├── tbe_data_config_loader.py │ │ │ ├── tbe_data_config_param_models.py │ │ │ └── utils.py │ │ ├── cache │ │ │ ├── __init__.py │ │ │ └── split_embeddings_cache_ops.py │ │ ├── ssd │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── inference.py │ │ │ ├── training.py │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ └── partially_materialized_tensor.py │ │ ├── stats │ │ │ ├── __init__.py │ │ │ └── bench_params_reporter.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── offsets.py │ │ │ ├── quantize.py │ │ │ └── requests.py │ ├── tbe_input_multiplexer.py │ ├── triton │ │ ├── __init__.py │ │ ├── common.py │ │ ├── jagged │ │ │ ├── __init__.py │ │ │ └── triton_jagged_tensor_ops.py │ │ ├── quantize.py │ │ └── quantize_ref.py │ ├── utils │ │ ├── __init__.py │ │ ├── filestore.py │ │ ├── loader.py │ │ └── torch_library.py │ └── uvm.py ├── include │ └── fbgemm_gpu │ │ ├── config │ │ └── feature_gates.h │ │ ├── cumem_utils.h │ │ ├── embedding_backward_template_helpers.cuh │ │ ├── embedding_common.h │ │ ├── embedding_forward_split_cpu.h │ │ ├── embedding_forward_template_helpers.cuh │ │ ├── embedding_inplace_update.h │ │ ├── input_combine.h │ │ ├── intraining_embedding_pruning.h │ │ ├── layout_transform_ops.cuh │ │ ├── merge_pooled_embeddings.h │ │ ├── permute_multi_embedding_function.h │ │ ├── permute_pooled_embedding_function.h │ │ ├── permute_pooled_embedding_ops.h │ │ ├── permute_pooled_embedding_ops_split.h │ │ ├── permute_pooled_embs_function.h │ │ ├── permute_pooled_embs_function_split.h │ │ ├── quantize_ops.cuh │ │ ├── quantize_ops_utils.h │ │ ├── rocm │ │ ├── cdna_guard.h │ │ └── split_embeddings_common.h │ │ ├── sparse_ops.cuh │ │ ├── sparse_ops.h │ │ ├── split_embeddings_cache │ │ ├── cachelib_cache.h │ │ └── kv_db_cpp_utils.h │ │ ├── split_embeddings_cache_cuda.cuh │ │ ├── split_embeddings_utils.cuh │ │ ├── split_embeddings_utils.h │ │ └── utils │ │ ├── assert_macros.h │ │ ├── barrier_isolation.cuh │ │ ├── bench_utils.cuh │ │ ├── binary_search_range.cuh │ │ ├── binary_search_range.h │ │ ├── bitonic_sort.cuh │ │ ├── cpu_utils.h │ │ ├── cub_namespace_postfix.cuh │ │ ├── cub_namespace_prefix.cuh │ │ ├── cuda_block_count.h │ │ ├── cuda_prelude.cuh │ │ ├── device_cache_flusher.cuh │ │ ├── device_properties.cuh │ │ ├── dispatch_macros.h │ │ ├── embedding_bounds_check_common.cuh │ │ ├── enum_utils.h │ │ ├── find_qparams.cuh │ │ ├── fixed_divisor.cuh │ │ ├── float.cuh │ │ ├── host_device_buffer_pair.cuh │ │ ├── inclusive_sum_scan.cuh │ │ ├── kernel_execution_timer.cuh │ │ ├── kernel_launcher.cuh │ │ ├── log2.h │ │ ├── ops_utils.h │ │ ├── pt2_autograd_utils.h │ │ ├── rocm │ │ ├── half2.h │ │ ├── stochastic_rounding.h │ │ ├── vec2.h │ │ └── weight_row.h │ │ ├── shared_memory.cuh │ │ ├── source_context.h │ │ ├── stochastic_rounding.cuh │ │ ├── tensor_accessor.h │ │ ├── tensor_accessor_builder.h │ │ ├── tensor_utils.h │ │ ├── topology_utils.h │ │ ├── types.h │ │ ├── vec4.cuh │ │ ├── vec4acc.cuh │ │ ├── vec_quant.cuh │ │ ├── vecn.cuh │ │ └── weight_row.cuh ├── requirements.txt ├── requirements_genai.txt ├── setup.py ├── src │ ├── config │ │ └── feature_gates.cpp │ ├── docs │ │ └── example_code.cpp │ ├── dram_kv_embedding_cache │ │ ├── SynchronizedShardedMap.h │ │ ├── dram_kv_embedding_cache.h │ │ ├── dram_kv_embedding_cache_wrapper.h │ │ ├── fixed_block_pool.h │ │ ├── store_value.h │ │ └── store_value_utils.h │ ├── embedding_inplace_ops │ │ ├── embedding_inplace_update.cu │ │ ├── embedding_inplace_update_cpu.cpp │ │ ├── embedding_inplace_update_gpu.cpp │ │ └── embedding_inplace_update_test.cpp │ ├── histogram_binning_calibration_ops.cu │ ├── input_combine_ops │ │ ├── input_combine.cu │ │ ├── input_combine_cpu.cpp │ │ └── input_combine_gpu.cpp │ ├── intraining_embedding_pruning_ops │ │ ├── intraining_embedding_pruning.cu │ │ └── intraining_embedding_pruning_gpu.cpp │ ├── jagged_tensor_ops │ │ ├── batched_dense_vec_jagged_2d_mul_backward.cu │ │ ├── batched_dense_vec_jagged_2d_mul_forward.cu │ │ ├── common.cuh │ │ ├── common.h │ │ ├── dense_to_jagged_forward.cu │ │ ├── jagged_dense_bmm_forward.cu │ │ ├── jagged_dense_dense_elementwise_add_jagged_output_forward.cu │ │ ├── jagged_dense_elementwise_mul_backward.cu │ │ ├── jagged_dense_elementwise_mul_forward.cu │ │ ├── jagged_index_add_2d_forward.cu │ │ ├── jagged_index_select_2d_forward.cu │ │ ├── jagged_jagged_bmm_forward.cu │ │ ├── jagged_softmax_backward.cu │ │ ├── jagged_softmax_forward.cu │ │ ├── jagged_tensor_ops.cu │ │ ├── jagged_tensor_ops_autograd.cpp │ │ ├── jagged_tensor_ops_cpu.cpp │ │ ├── jagged_tensor_ops_meta.cpp │ │ ├── jagged_to_padded_dense_backward.cu │ │ ├── jagged_to_padded_dense_forward.cu │ │ ├── jagged_unique_indices.cu │ │ ├── keyed_jagged_index_select_dim1.cu │ │ ├── stacked_jagged_1d_to_dense.cu │ │ └── stacked_jagged_2d_to_dense.cu │ ├── layout_transform_ops │ │ ├── layout_transform_ops.cu │ │ ├── layout_transform_ops_cpu.cpp │ │ └── layout_transform_ops_gpu.cpp │ ├── memory_utils │ │ ├── common.cuh │ │ ├── common.h │ │ ├── memory_utils.cpp │ │ ├── memory_utils.cu │ │ ├── memory_utils_ops.cpp │ │ └── memory_utils_ops.cu │ ├── merge_pooled_embedding_ops │ │ ├── merge_pooled_embedding_ops_cpu.cpp │ │ └── merge_pooled_embedding_ops_gpu.cpp │ ├── metric_ops │ │ ├── metric_ops.cu │ │ ├── metric_ops.h │ │ └── metric_ops_host.cpp │ ├── permute_multi_embedding_ops │ │ ├── permute_multi_embedding_function.cpp │ │ ├── permute_multi_embedding_ops.cu │ │ └── permute_multi_embedding_ops_cpu.cpp │ ├── permute_pooled_embedding_ops │ │ ├── permute_pooled_embedding_function.cpp │ │ ├── permute_pooled_embedding_ops.cu │ │ ├── permute_pooled_embedding_ops_cpu.cpp │ │ ├── permute_pooled_embedding_ops_gpu.cpp │ │ ├── permute_pooled_embedding_ops_split.cu │ │ ├── permute_pooled_embedding_ops_split_cpu.cpp │ │ └── permute_pooled_embedding_ops_split_gpu.cpp │ ├── placeholder.cpp │ ├── ps_split_embeddings_cache │ │ ├── ps_split_table_batched_embeddings.cpp │ │ └── ps_table_batched_embeddings.h │ ├── quantize_ops │ │ ├── common.cuh │ │ ├── mx │ │ │ ├── LICENSE │ │ │ └── common.cuh │ │ ├── mx_common.cuh │ │ ├── quantize_bfloat16.cu │ │ ├── quantize_fp8_rowwise.cu │ │ ├── quantize_fused_8bit_rowwise.cu │ │ ├── quantize_fused_nbit_rowwise.cu │ │ ├── quantize_hfp8.cu │ │ ├── quantize_msfp.cu │ │ ├── quantize_mx.cu │ │ ├── quantize_mx.cuh │ │ ├── quantize_ops_cpu.cpp │ │ ├── quantize_ops_gpu.cpp │ │ ├── quantize_ops_meta.cpp │ │ └── quantize_padded_fp8_rowwise.cu │ ├── sparse_ops │ │ ├── common.cuh │ │ ├── common.h │ │ ├── sparse_async_batched_cumsum.cpp │ │ ├── sparse_async_batched_cumsum.cu │ │ ├── sparse_async_cumsum.cpp │ │ ├── sparse_async_cumsum.cu │ │ ├── sparse_batched_unary_embeddings.cu │ │ ├── sparse_block_bucketize_features.cu │ │ ├── sparse_bucketize_features.cu │ │ ├── sparse_compute_frequency_sequence.cu │ │ ├── sparse_expand_into_jagged_permute.cu │ │ ├── sparse_group_index.cu │ │ ├── sparse_index_add.cu │ │ ├── sparse_index_select.cu │ │ ├── sparse_invert_permute.cu │ │ ├── sparse_ops_cpu.cpp │ │ ├── sparse_ops_gpu.cpp │ │ ├── sparse_ops_meta.cpp │ │ ├── sparse_pack_segments_backward.cu │ │ ├── sparse_pack_segments_forward.cu │ │ ├── sparse_permute102.cu │ │ ├── sparse_permute_1d.cu │ │ ├── sparse_permute_2d.cu │ │ ├── sparse_permute_embeddings.cu │ │ ├── sparse_range.cu │ │ ├── sparse_reorder_batched_ad.cu │ │ ├── sparse_segment_sum_csr.cu │ │ └── sparse_zipf.cu │ ├── split_embeddings_cache │ │ ├── cachelib_cache.cpp │ │ ├── common.cuh │ │ ├── common.h │ │ ├── kv_db_cpp_utils.cpp │ │ ├── lfu_cache_find.cu │ │ ├── lfu_cache_populate.cu │ │ ├── lfu_cache_populate_byte.cpp │ │ ├── lfu_cache_populate_byte.cu │ │ ├── linearize_cache_indices.cpp │ │ ├── linearize_cache_indices.cu │ │ ├── lru_cache_find.cu │ │ ├── lru_cache_populate.cu │ │ ├── lru_cache_populate_byte.cpp │ │ ├── lru_cache_populate_byte.cu │ │ ├── lxu_cache.cpp │ │ ├── lxu_cache.cu │ │ ├── reset_weight_momentum.cu │ │ ├── split_embeddings_cache_ops.cpp │ │ └── split_embeddings_cache_ops.cu │ ├── split_embeddings_utils │ │ ├── generate_vbe_metadata.cu │ │ ├── get_infos_metadata.cu │ │ ├── radix_sort_pairs.cu │ │ ├── split_embeddings_utils.cpp │ │ ├── split_embeddings_utils_cpu.cpp │ │ ├── split_embeddings_utils_meta.cpp │ │ └── transpose_embedding_input.cu │ ├── ssd_split_embeddings_cache │ │ ├── embedding_rocksdb_wrapper.h │ │ ├── initializer.h │ │ ├── kv_db_cuda_utils.cpp │ │ ├── kv_db_cuda_utils.h │ │ ├── kv_db_table_batched_embeddings.cpp │ │ ├── kv_db_table_batched_embeddings.h │ │ ├── kv_tensor_wrapper.h │ │ ├── kv_tensor_wrapper_cpu.cpp │ │ ├── ssd_scratch_pad_indices_queue.cpp │ │ ├── ssd_split_embeddings_cache_cuda.cu │ │ ├── ssd_split_table_batched_embeddings.cpp │ │ ├── ssd_table_batched_embeddings.h │ │ └── test │ │ │ └── ssd_table_batched_embeddings_test.cpp │ ├── tbe │ │ └── eeg │ │ │ ├── eeg_models.cpp │ │ │ ├── eeg_models.h │ │ │ ├── eeg_utils.cpp │ │ │ ├── eeg_utils.h │ │ │ ├── indices_estimator.cpp │ │ │ ├── indices_estimator.h │ │ │ ├── indices_estimator_ops.cpp │ │ │ ├── indices_generator.cpp │ │ │ ├── indices_generator.h │ │ │ └── indices_generator_ops.cpp │ └── topology_utils.cpp └── test │ ├── batched_unary_embeddings_test.py │ ├── combine │ ├── __init__.py │ ├── common.py │ ├── empty_weights_test.py │ ├── failures_dict.json │ └── input_combine_test.py │ ├── config │ ├── feature_gate_cpp_test.cpp │ └── feature_gate_test.py │ ├── dram_kv_embedding_cache │ ├── fixed_block_pool_test.cpp │ ├── sharded_map_test.cpp │ └── store_value_utils_test.cpp │ ├── failures_dict.json │ ├── failures_dict_fast.json │ ├── jagged │ ├── 1d_to_dense_test.py │ ├── 2d_to_dense_test.py │ ├── __init__.py │ ├── batched_dense_vec_jagged_2d_mul_test.py │ ├── common.py │ ├── dense_bmm_test.py │ ├── dense_dense_elementwise_add_test.py │ ├── dense_to_jagged_test.py │ ├── elementwise_binary_test.py │ ├── expand_into_jagged_permute_test.py │ ├── failures_dict.json │ ├── jagged_index_select_2d_test.py │ ├── jagged_to_padded_dense_test.py │ ├── keyed_jagged_index_select_test.py │ ├── misc_ops_test.py │ ├── slice_test.py │ └── unique_indices_test.py │ ├── layout_transform_ops_test.py │ ├── lint │ ├── check_meta_header.py │ └── flake8_problem_matcher.json │ ├── merge_pooled_embeddings_test.py │ ├── metric_ops_test.py │ ├── permute │ ├── __init__.py │ ├── common.py │ ├── failures_dict.json │ └── permute_pooled_embedding_test.py │ ├── quantize │ ├── __init__.py │ ├── bfloat16_test.py │ ├── comm_codec_test.py │ ├── common.py │ ├── failures_dict.json │ ├── failures_dict_fast.json │ ├── fp8_rowwise_test.py │ ├── fused_8bit_rowwise_test.py │ ├── fused_nbit_rowwise_test.py │ ├── hfp8_test.py │ ├── mixed_dim_int8_test.py │ ├── msfp_test.py │ ├── mx │ │ ├── LICENSE │ │ └── common.py │ └── mx4_test.py │ ├── release │ ├── __init__.py │ ├── example.json │ ├── stable_ops_v1.json │ ├── stable_release_test.py │ └── utils.py │ ├── runtime_monitor_test.py │ ├── sll │ ├── __init__.py │ ├── array_jagged_bmm_jagged_out_test.py │ ├── common.py │ ├── dense_jagged_cat_jagged_out_test.py │ ├── jagged2_to_padded_dense_test.py │ ├── jagged_dense_bmm_test.py │ ├── jagged_dense_elementwise_add_test.py │ ├── jagged_dense_elementwise_mul_jagged_out_test.py │ ├── jagged_dense_flash_attention_test.py │ ├── jagged_flash_attention_basic_test.py │ ├── jagged_jagged_bmm_jagged_out_test.py │ ├── jagged_jagged_bmm_test.py │ ├── jagged_self_substraction_jagged_out_test.py │ ├── jagged_softmax_test.py │ └── multi_head_jagged_flash_attention_test.py │ ├── sparse │ ├── __init__.py │ ├── block_bucketize_test.py │ ├── common.py │ ├── cumsum_test.py │ ├── failures_dict.json │ ├── histogram_binning_calibration_test.py │ ├── index_select_test.py │ ├── misc_ops_test.py │ ├── pack_segments_test.py │ ├── permute_embeddings_test.py │ ├── permute_indices_test.py │ ├── permute_sparse_features_test.py │ ├── reorder_batched_test.py │ ├── tensor_assert_test.cpp │ └── utils_test.cpp │ ├── tbe │ ├── __init__.py │ ├── bench │ │ ├── tbe_data_config_loader_test.py │ │ └── tbe_data_config_models_test.py │ ├── cache │ │ ├── __init__.py │ │ ├── cache_common.py │ │ ├── cache_overflow_test.py │ │ ├── cache_test.py │ │ ├── failures_dict_fast.json │ │ ├── linearize_cache_indices_test.py │ │ └── lxu_cache_test.py │ ├── common.py │ ├── dram_kv │ │ ├── __init__.py │ │ ├── dram_kv_test.py │ │ └── failures_dict_fast.json │ ├── eeg │ │ ├── eeg_utils_test.cpp │ │ ├── tbe_indices_estimator_test.py │ │ └── tbe_indices_generator_test.py │ ├── inference │ │ ├── __init__.py │ │ ├── common.py │ │ ├── failures_dict_fast.json │ │ ├── inference_converter_test.py │ │ ├── nbit_cache_test.py │ │ ├── nbit_forward_autovec_test.py │ │ ├── nbit_forward_test.py │ │ └── nbit_split_embeddings_test.py │ ├── ssd │ │ ├── __init__.py │ │ ├── embedding_cache │ │ │ └── rocksdb_embedding_cache_test.cpp │ │ ├── failures_dict_fast.json │ │ ├── kv_backend_test.py │ │ ├── kv_tensor_wrapper_test.py │ │ ├── ssd_split_tbe_inference_test.py │ │ ├── ssd_split_tbe_training_test.py │ │ └── ssd_utils_test.py │ ├── stats │ │ ├── __init__.py │ │ ├── failures_dict_fast.json │ │ └── tbe_bench_params_reporter_test.py │ ├── training │ │ ├── __init__.py │ │ ├── backward_adagrad_common.py │ │ ├── backward_adagrad_global_weight_decay_test.py │ │ ├── backward_adagrad_large_dim_test.py │ │ ├── backward_adagrad_test.py │ │ ├── backward_dense_test.py │ │ ├── backward_none_test.py │ │ ├── backward_optimizers_test.py │ │ ├── backward_sgd_test.py │ │ ├── failures_dict_fast.json │ │ ├── forward_backward_int32_overflow_test.py │ │ └── forward_test.py │ └── utils │ │ ├── __init__.py │ │ ├── cpu_kernel_test.cpp │ │ ├── failures_dict_fast.json │ │ ├── generate_vbe_metadata_test.py │ │ ├── split_embeddings_test.py │ │ └── split_embeddings_utils_test.py │ ├── test_utils.py │ ├── utils │ ├── filestore_test.py │ ├── kernel_launcher_test.cu │ ├── stochastic_rounding_test.cu │ ├── tensor_accessor2_test.cu │ ├── tensor_accessor_builder_test.cu │ ├── tensor_accessor_builder_with_memcheck_test.cu │ ├── tensor_accessor_test.cu │ ├── tensor_accessor_with_memcheck_test.cu │ └── weight_row_test.cu │ └── uvm │ ├── cache_miss_emulate_test.cpp │ ├── copy_test.py │ ├── ops_load_test.py │ └── uvm_test.py ├── include └── fbgemm │ ├── ConvUtils.h │ ├── Fbgemm.h │ ├── FbgemmBuild.h │ ├── FbgemmConvert.h │ ├── FbgemmEmbedding.h │ ├── FbgemmFP16.h │ ├── FbgemmFP32.h │ ├── FbgemmFPCommon.h │ ├── FbgemmI64.h │ ├── FbgemmI8DepthwiseAvx2.h │ ├── FbgemmI8DirectconvAvx2.h │ ├── FbgemmI8Spmdm.h │ ├── FbgemmPackMatrixB.h │ ├── FbgemmSparse.h │ ├── FloatConversion.h │ ├── OutputProcessing-inl.h │ ├── PackingTraits-inl.h │ ├── QuantUtils.h │ ├── QuantUtilsAvx2.h │ ├── QuantUtilsAvx512.h │ ├── QuantUtilsNeon.h │ ├── SimdUtils.h │ ├── Types.h │ ├── Utils.h │ ├── UtilsAvx2.h │ ├── spmmUtils.h │ └── spmmUtilsAvx2.h ├── netlify.toml ├── src ├── CodeCache.h ├── CodeGenHelpers.h ├── DirectConv.h ├── EmbeddingSpMDM.cc ├── EmbeddingSpMDMAutovec.cc ├── EmbeddingSpMDMAutovec.h ├── EmbeddingSpMDMAvx2.cc ├── EmbeddingSpMDMAvx512.cc ├── EmbeddingSpMDMNBit.cc ├── ExecuteKernel.cc ├── ExecuteKernel.h ├── ExecuteKernelGeneric.h ├── ExecuteKernelU8S8.cc ├── ExecuteKernelU8S8.h ├── Fbgemm.cc ├── FbgemmBfloat16Convert.cc ├── FbgemmBfloat16ConvertAvx2.cc ├── FbgemmBfloat16ConvertAvx512.cc ├── FbgemmConv.cc ├── FbgemmFP16.cc ├── FbgemmFP16UKernelsAvx2.cc ├── FbgemmFP16UKernelsAvx2.h ├── FbgemmFP16UKernelsAvx512.cc ├── FbgemmFP16UKernelsAvx512.h ├── FbgemmFP16UKernelsAvx512_256.cc ├── FbgemmFP16UKernelsAvx512_256.h ├── FbgemmFP16UKernelsIntrinsicAvx2.cc ├── FbgemmFP16UKernelsIntrinsicAvx512.cc ├── FbgemmFP16UKernelsIntrinsicAvx512_256.cc ├── FbgemmFP16UKernelsSve128.cc ├── FbgemmFP16UKernelsSve128.h ├── FbgemmFPCommon.cc ├── FbgemmFloat16Convert.cc ├── FbgemmFloat16ConvertAvx2.cc ├── FbgemmFloat16ConvertAvx512.cc ├── FbgemmFloat16ConvertSVE.cc ├── FbgemmI64.cc ├── FbgemmI8Depthwise2DAvx2-inl.h ├── FbgemmI8Depthwise3DAvx2.cc ├── FbgemmI8DepthwiseAvx2-inl.h ├── FbgemmI8DepthwiseAvx2.cc ├── FbgemmI8DepthwisePerChannelQuantAvx2.cc ├── FbgemmI8Spmdm.cc ├── FbgemmPackMatrixB.cc ├── FbgemmSparseDense.cc ├── FbgemmSparseDenseAvx2.cc ├── FbgemmSparseDenseAvx512.cc ├── FbgemmSparseDenseInt8Avx2.cc ├── FbgemmSparseDenseInt8Avx512.cc ├── FbgemmSparseDenseVectorInt8Avx512.cc ├── GenerateI8Depthwise.cc ├── GenerateI8Depthwise.h ├── GenerateKernel.cc ├── GenerateKernel.h ├── GenerateKernelDirectConvU8S8S32ACC32.cc ├── GenerateKernelU8S8S32ACC16.cc ├── GenerateKernelU8S8S32ACC16Avx512.cc ├── GenerateKernelU8S8S32ACC16Avx512VNNI.cc ├── GenerateKernelU8S8S32ACC32.cc ├── GenerateKernelU8S8S32ACC32Avx512VNNI.cc ├── GroupwiseConv.cc ├── GroupwiseConv.h ├── GroupwiseConvAcc32Avx2.cc ├── GroupwiseConvAcc32Avx512.cc ├── InlineAsmDefines.h ├── KleidiAIFP16UKernelsNeon.cc ├── KleidiAIFP16UKernelsNeon.h ├── MaskAvx2.h ├── OptimizedKernelsAvx2.cc ├── OptimizedKernelsAvx2.h ├── PackAMatrix.cc ├── PackAWithIm2Col.cc ├── PackAWithQuantRowOffset.cc ├── PackAWithRowOffset.cc ├── PackBMatrix.cc ├── PackDepthwiseConvMatrixAvx2.cc ├── PackMatrix.cc ├── PackWeightMatrixForGConv.cc ├── PackWeightsForConv.cc ├── PackWeightsForDirectConv.cc ├── QuantUtils.cc ├── QuantUtilsAvx2.cc ├── QuantUtilsAvx512.cc ├── QuantUtilsNeon.cc ├── RefImplementations.cc ├── RefImplementations.h ├── RowWiseSparseAdagradFused.cc ├── SparseAdagrad.cc ├── TransposeUtils.cc ├── TransposeUtils.h ├── TransposeUtilsAvx2.h ├── TransposeUtilsNeon.h ├── TransposeUtilsSve.h ├── Utils.cc ├── UtilsAvx2.cc ├── UtilsAvx512.cc ├── UtilsNeon.cc ├── UtilsSve.cc ├── codegen_fp16fp32.cc ├── fp32 │ ├── FbgemmFP32.cc │ ├── FbgemmFP32UKernelsAvx2.cc │ ├── FbgemmFP32UKernelsAvx2.h │ ├── FbgemmFP32UKernelsAvx512.cc │ ├── FbgemmFP32UKernelsAvx512.h │ ├── FbgemmFP32UKernelsAvx512_256.cc │ ├── FbgemmFP32UKernelsAvx512_256.h │ ├── KleidiAIFP32UKernelsNeon.cc │ └── KleidiAIFP32UKernelsNeon.h ├── spmmUtils.cc └── spmmUtilsAvx2.cc └── test ├── Bfloat16ConvertTest.cc ├── CMakeLists.txt ├── EmbeddingSpMDM8BitTest.cc ├── EmbeddingSpMDMNBitTest.cc ├── EmbeddingSpMDMTest.cc ├── EmbeddingSpMDMTestUtils.cc ├── EmbeddingSpMDMTestUtils.h ├── FBGemmFPTest.h ├── FP16Test.cc ├── FP32Test.cc ├── Float16ConvertTest.cc ├── GConvTest.cc ├── I64Test.cc ├── I8DepthwiseTest.cc ├── I8DirectconvTest.cc ├── I8SpmdmTest.cc ├── Im2ColFusedRequantizeTest.cc ├── PackedRequantizeAcc16Test.cc ├── PackedRequantizeTest.cc ├── QuantUtilsTest.cc ├── QuantizationHelpers.cc ├── QuantizationHelpers.h ├── RadixSortTest.cc ├── RequantizeOnlyTest.cc ├── RowWiseSparseAdagradFusedTest.cc ├── SparseAdagradTest.cc ├── SparseDenseMMFP32Test.cc ├── SparseDenseMMInt8Test.cc ├── SparsePackUnpackTest.cc ├── TestUtils.cc ├── TestUtils.h ├── TransposeTest.cc ├── TransposedRequantizeTest.cc └── UniConvTest.cc /.clang-tidy: -------------------------------------------------------------------------------- 1 | # The configuration file is in a YAML format, 2 | # so the document starts with (---) and ends with (...) 3 | --- 4 | # Get options for config files in parent directories, 5 | # but override them if there's a conflict. 6 | InheritParentConfig: true 7 | Checks: ' 8 | bugprone-argument-comment, 9 | ' 10 | CheckOptions: 11 | - key: facebook-cuda-safe-api-call-check.HandlerName 12 | # This is PyTorch's handler; you may need to define your own 13 | value: C10_CUDA_CHECK 14 | - key: facebook-cuda-safe-kernel-call-check.HandlerName 15 | # This is PyTorch's handler; you may need to define your own 16 | value: C10_CUDA_KERNEL_LAUNCH_CHECK 17 | ... 18 | -------------------------------------------------------------------------------- /.github/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /.github/metrics/requirements.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | click 8 | python-dateutil 9 | ratelimit 10 | requests 11 | -------------------------------------------------------------------------------- /.github/scripts/fbgemm_gpu_postbuild.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | echo "################################################################################" 9 | echo "[CMAKE] Running post-build script ..." 10 | 11 | TARGET=$1 12 | SET_RPATH_TO_ORIGIN=$2 13 | echo "Target file: ${TARGET}" 14 | 15 | # Set or remove RPATHs for the .SO 16 | # https://github.com/pytorch/FBGEMM/issues/3098 17 | # https://github.com/NixOS/patchelf/issues/453 18 | if [ "${SET_RPATH_TO_ORIGIN}" != "" ]; then 19 | echo "Resetting RPATH to \$ORIGIN ..." 20 | patchelf --force-rpath --set-rpath "\$ORIGIN" "${TARGET}" || exit 1 21 | else 22 | echo "Removing all RPATHs ..." 23 | patchelf --remove-rpath "${TARGET}" || exit 1 24 | fi 25 | 26 | readelf -d "${TARGET}" | grep -i rpath 27 | echo "################################################################################" 28 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "external/asmjit"] 2 | path = external/asmjit 3 | url = https://github.com/asmjit/asmjit.git 4 | [submodule "external/cpuinfo"] 5 | path = external/cpuinfo 6 | url = https://github.com/pytorch/cpuinfo 7 | [submodule "external/googletest"] 8 | path = external/googletest 9 | url = https://github.com/google/googletest 10 | [submodule "external/hipify_torch"] 11 | path = external/hipify_torch 12 | url = https://github.com/ROCmSoftwarePlatform/hipify_torch.git 13 | # TODO Using a private copy of cutlass is a temporary mitigation to enable grouped gemm. 14 | # Go back to main cutlass when possible. 15 | [submodule "external/cutlass"] 16 | path = external/cutlass 17 | url = https://github.com/jwfromm/cutlass 18 | branch = FBGEMM 19 | [submodule "external/composable_kernel"] 20 | path = external/composable_kernel 21 | url = https://github.com/jwfromm/composable_kernel.git 22 | branch = FBGEMM 23 | [submodule "external/json"] 24 | path = external/json 25 | url = https://github.com/nlohmann/json.git 26 | -------------------------------------------------------------------------------- /MODULE.bazel: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | module(name = "fbgemm") 7 | 8 | bazel_dep(name = "bazel_skylib", version = "1.7.1") 9 | -------------------------------------------------------------------------------- /cmake/modules/FindGnuH2fIeee.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | ################################################################################ 9 | # Finds and sets GNU_FH2_IEEE compilation flags 10 | ################################################################################ 11 | 12 | INCLUDE(CheckCXXSourceCompiles) 13 | 14 | CHECK_CXX_SOURCE_COMPILES(" 15 | #include 16 | int main() { 17 | float f = 1.0f; 18 | uint16_t h = __gnu_f2h_ieee(f); 19 | return 0; 20 | } 21 | " HAVE_GNU_F2H_IEEE) 22 | -------------------------------------------------------------------------------- /cmake/modules/FindSphinx.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Search sphinx-build 8 | find_program(SPHINX_EXECUTABLE 9 | NAMES sphinx-build 10 | DOC "Path to sphinx-build executable") 11 | 12 | include(FindPackageHandleStandardArgs) 13 | 14 | find_package_handle_standard_args(Sphinx 15 | "Failed to find sphinx-build executable" 16 | SPHINX_EXECUTABLE) 17 | -------------------------------------------------------------------------------- /cmake/modules/RocmSetup.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | include(${CMAKE_CURRENT_SOURCE_DIR}/../cmake/modules/Utilities.cmake) 8 | 9 | 10 | ################################################################################ 11 | # ROCm and HIPify Setup 12 | ################################################################################ 13 | 14 | if(FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_ROCM) 15 | # Load CMake modules 16 | list(APPEND CMAKE_MODULE_PATH 17 | "${PROJECT_SOURCE_DIR}/cmake" 18 | "${THIRDPARTY}/hipify_torch/cmake") 19 | include(Hip) 20 | include(Hipify) 21 | 22 | # Configure compiler for HIP 23 | list(APPEND HIP_HCC_FLAGS 24 | " \"-Wno-#pragma-messages\" " 25 | " \"-Wno-#warnings\" " 26 | -fclang-abi-compat=17 27 | -Wno-cuda-compat 28 | -Wno-deprecated-declarations 29 | -Wno-format 30 | -Wno-ignored-attributes 31 | -Wno-unused-result) 32 | 33 | BLOCK_PRINT( 34 | "HIP found: ${HIP_FOUND}" 35 | "HIPCC compiler flags:" 36 | "" 37 | "${HIP_HCC_FLAGS}" 38 | ) 39 | endif() 40 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. FBGEMM documentation master file, created by 2 | sphinx-quickstart on Wed Apr 24 15:19:01 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to FBGEMM's documentation! 7 | ======================================= 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | :ref:`genindex` 14 | 15 | Docs 16 | ==== 17 | 18 | .. doxygennamespace:: fbgemm 19 | :members: 20 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | breathe 8 | sphinx_rtd_theme 9 | -------------------------------------------------------------------------------- /external/asmjit.BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_cc//cc:defs.bzl", "cc_library") 2 | 3 | cc_library( 4 | name = "asmjit", 5 | srcs = glob([ 6 | "src/asmjit/core/*.cpp", 7 | "src/asmjit/x86/*.cpp", 8 | "src/asmjit/arm/*.cpp", 9 | ]), 10 | hdrs = glob([ 11 | "src/asmjit/x86/*.h", 12 | "src/asmjit/core/*.h", 13 | "src/asmjit/*.h", 14 | "src/asmjit/arm/*.h", 15 | ]), 16 | copts = [ 17 | "-DASMJIT_STATIC", 18 | "-fno-tree-vectorize", 19 | "-fmerge-all-constants", 20 | "-DTH_BLAS_MKL", 21 | ], 22 | includes = [ 23 | "asmjit/", 24 | "src/", 25 | ], 26 | linkstatic = True, 27 | visibility = ["//visibility:public"], 28 | ) 29 | -------------------------------------------------------------------------------- /external/cpuinfo.BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_cc//cc:defs.bzl", "cc_library") 2 | 3 | cc_library( 4 | name = "cpuinfo", 5 | srcs = glob( 6 | [ 7 | "src/*.c", 8 | "src/linux/*.c", 9 | "src/x86/*.c", 10 | "src/x86/cache/*.c", 11 | "src/x86/linux/*.c", 12 | ], 13 | exclude = [ 14 | "src/x86/mockcpuid.c", 15 | "src/linux/mockfile.c", 16 | ], 17 | ), 18 | hdrs = glob([ 19 | "include/*.h", 20 | "src/*.h", 21 | "src/cpuinfo/*.h", 22 | "src/include/*.h", 23 | "src/x86/*.h", 24 | "src/x86/linux/*.h", 25 | "src/linux/*.h", 26 | ]), 27 | copts = [ 28 | "-DCPUINFO_LOG_LEVEL=2", 29 | "-DTH_BLAS_MKL", 30 | "-D_GNU_SOURCE=1", 31 | ], 32 | includes = [ 33 | "include", 34 | "src", 35 | ], 36 | linkstatic = True, 37 | visibility = ["//visibility:public"], 38 | ) 39 | -------------------------------------------------------------------------------- /fbgemm_gpu/bench/README.md: -------------------------------------------------------------------------------- 1 | ### Benchmarks 2 | 3 | ## TorchRec FusedTableBatchedEmbeddingBags 4 | 5 | [Torchrec](https://pytorch.org/torchrec/) uses fbgemm_gpu embedding and embedding bag implementations for Fused, Batched, Quantized versions of embedding and embeddingbag (in addition to other kernels). 6 | They have run benchmarks on FusedEmbeddingBagCollection, which is implemented with fbgemm_gpu's [`SplitTableBatchedEmbeddingBagsCodegen`](https://github.com/pytorch/FBGEMM/blob/253b8842eeb2b33e65f7e2a7cfb79923b0e46bd7/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py#L171). They benchmark utilizing UVM and UVM-caching. 7 | The [results](https://github.com/pytorch/torchrec/tree/main/benchmarks) show between 13x and 23x usecase in DLRM embedding sizes. 8 | -------------------------------------------------------------------------------- /fbgemm_gpu/cmake/Asmjit.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | ################################################################################ 8 | # Asmjit Sources 9 | ################################################################################ 10 | 11 | file(GLOB_RECURSE asmjit_sources 12 | "${CMAKE_CURRENT_SOURCE_DIR}/../external/asmjit/src/asmjit/*/*.cpp") 13 | 14 | 15 | ################################################################################ 16 | # Build Intermediate Target (Static) 17 | ################################################################################ 18 | 19 | gpu_cpp_library( 20 | PREFIX 21 | asmjit 22 | TYPE 23 | SHARED 24 | INCLUDE_DIRS 25 | ${fbgemm_sources_include_directories} 26 | OTHER_SRCS 27 | ${asmjit_sources} 28 | DESTINATION 29 | fbgemm_gpu) 30 | -------------------------------------------------------------------------------- /fbgemm_gpu/codegen/genscript/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | -------------------------------------------------------------------------------- /fbgemm_gpu/codegen/genscript/scripts_argsparse.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | # flake8: noqa F401 9 | 10 | import argparse 11 | from typing import List 12 | 13 | ################################################################################ 14 | # Parse Codegen Scripts' Arguments 15 | ################################################################################ 16 | 17 | parser = argparse.ArgumentParser() 18 | # By default the source template files are in the same folder as this script: 19 | # The install dir is by default the same as the current folder. 20 | parser.add_argument( 21 | "--install_dir", default=".", help="Output directory for generated source files" 22 | ) 23 | parser.add_argument("--opensource", action="store_false", dest="is_fbcode") 24 | parser.add_argument("--is_rocm", action="store_true") 25 | 26 | args: argparse.Namespace 27 | _: List[str] 28 | args, _ = parser.parse_known_args() 29 | 30 | print(f"[ARGS PARSE] Parsed arguments: {args}") 31 | -------------------------------------------------------------------------------- /fbgemm_gpu/codegen/training/embedding_ops_placeholder.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | /* 10 | This is placeholder code to force compilation and generation of an 11 | `libdeeplearning_fbgemm_fbgemm_gpu_codegen_embedding_ops.so` file, which 12 | allows downstream PyTorch code to contlinue loading the `embedding_ops“ 13 | and `embedding_ops_cpu` (now-)shim targets correctly. 14 | */ 15 | namespace fbgemm_gpu {} 16 | -------------------------------------------------------------------------------- /fbgemm_gpu/codegen/training/python/optimizer_args.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from typing import NamedTuple 11 | 12 | import torch 13 | from torch import nn 14 | 15 | 16 | class SplitEmbeddingOptimizerParams(NamedTuple): 17 | weights_dev: nn.Parameter 18 | # TODO: Enable weights_uvm and weights_lxu_cache support 19 | # weights_uvm: nn.Parameter 20 | # weights_lxu_cache: nn.Parameter 21 | 22 | 23 | class SplitEmbeddingArgs(NamedTuple): 24 | weights_placements: torch.Tensor 25 | weights_offsets: torch.Tensor 26 | max_D: int 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/README.md: -------------------------------------------------------------------------------- 1 | # FBGEMM_GPU Documentation 2 | 3 | [![FBGEMM_GPU Docs CI](https://github.com/pytorch/FBGEMM/actions/workflows/fbgemm_gpu_docs.yml/badge.svg)](https://github.com/pytorch/FBGEMM/actions/workflows/fbgemm_gpu_docs.yml) 4 | 5 | This is the repo for the FBGEMM_GPU project's documentation. Please visit 6 | [this page](src/general/DocsInstructions.rst) for more detailed information. 7 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Note: Sphinx 7+ currently runs into an `Undefinederror("'style' is undefined")` 8 | # We should eventually move to Sphinx 7+ to resolve 9 | # https://github.com/sphinx-doc/sphinx/issues/1514 10 | sphinx<7 11 | 12 | # PyTorch Theme 13 | -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 14 | 15 | breathe 16 | bs4 17 | docutils<0.20,>=0.18.1 18 | lxml 19 | myst-parser 20 | sphinx-lint 21 | sphinx-serve 22 | six 23 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm/cpp-api/QuantUtils.rst: -------------------------------------------------------------------------------- 1 | Quantization Utilities 2 | ====================== 3 | 4 | Reference Implementation Methods 5 | -------------------------------- 6 | 7 | .. doxygengroup:: fbgemm-quant-utils-generic 8 | :content-only: 9 | 10 | AVX-2 Implementation Methods 11 | ---------------------------- 12 | 13 | .. doxygengroup:: fbgemm-quant-utils-avx2 14 | :content-only: 15 | 16 | AVX-512 Implementation Methods 17 | ------------------------------ 18 | 19 | .. doxygengroup:: fbgemm-quant-utils-avx512 20 | :content-only: 21 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm/cpp-api/tbe_cpu_autovec.rst: -------------------------------------------------------------------------------- 1 | TBE CPU Autovectorization 2 | ========================= 3 | 4 | FP8/16/32 Autovec Implementation Methods 5 | ---------------------------------------- 6 | 7 | .. doxygengroup:: tbe-cpu-autovec 8 | :content-only: 9 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm/index.rst: -------------------------------------------------------------------------------- 1 | .. _fbgemm.main: 2 | 3 | FBGEMM 4 | ====== 5 | 6 | **FBGEMM** (Facebook GEneral Matrix Multiplication) is a low-precision, 7 | high-performance matrix-matrix multiplications and convolution library for 8 | server-side inference. This library is used as a backend of 9 | `PyTorch `__ 10 | quantized operators on x86 machines. 11 | 12 | .. _fbgemm.toc.development: 13 | 14 | .. toctree:: 15 | :maxdepth: 2 16 | :caption: FBGEMM Development 17 | 18 | development/BuildInstructions 19 | 20 | .. _fbgemm.toc.api.cpp: 21 | 22 | .. toctree:: 23 | :maxdepth: 2 24 | :caption: FBGEMM C++ API 25 | 26 | cpp-api/QuantUtils 27 | cpp-api/tbe_cpu_autovec 28 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_genai/index.rst: -------------------------------------------------------------------------------- 1 | .. _fbgemm-genai.main: 2 | 3 | FBGEMM GenAI 4 | ============ 5 | 6 | **FBGEMM GenAI** (FBGEMM Generative AI Kernels Library) is a collection of PyTorch 7 | GPU operator libraries that are designed for generative AI applications, such as 8 | FP8 row-wise quantization and collective communications. 9 | 10 | .. _fbgemm-genai.toc.development: 11 | 12 | .. toctree:: 13 | :maxdepth: 2 14 | :caption: FBGEMM GenAI Development 15 | 16 | development/BuildInstructions 17 | development/InstallationInstructions 18 | development/TestInstructions 19 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/cpp-api/embedding_ops.rst: -------------------------------------------------------------------------------- 1 | Embedding Operators 2 | =================== 3 | 4 | CUDA Operators 5 | -------------- 6 | .. doxygengroup:: embedding-cuda 7 | :content-only: 8 | 9 | CPU Operators 10 | ------------- 11 | .. doxygengroup:: embedding-cpu 12 | :content-only: 13 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/cpp-api/experimental_ops.rst: -------------------------------------------------------------------------------- 1 | Experimental Operators 2 | ====================== 3 | 4 | Attention Operators 5 | ------------------- 6 | .. doxygengroup:: experimental-gen-ai-attention 7 | :content-only: 8 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/cpp-api/feature_gates.rst: -------------------------------------------------------------------------------- 1 | .. _fbgemm-gpu.dev.config.cpp: 2 | 3 | Feature Gates (C++) 4 | =================== 5 | 6 | .. doxygengroup:: fbgemm-gpu-config 7 | :content-only: 8 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/cpp-api/input_combine.rst: -------------------------------------------------------------------------------- 1 | Combine Input Operators 2 | ======================= 3 | 4 | .. doxygengroup:: input-combine 5 | :content-only: 6 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/cpp-api/jagged_tensor_ops.rst: -------------------------------------------------------------------------------- 1 | Jagged Tensor Operators 2 | ======================= 3 | 4 | Jagged Tensor solves the issue when rows in dimension are of 5 | different length. This often occurs in sparse feature inputs 6 | in recommender systems, as well as natural language processing 7 | system batched inputs. 8 | 9 | CUDA Operators 10 | -------------- 11 | 12 | .. doxygengroup:: jagged-tensor-ops-cuda 13 | :content-only: 14 | 15 | CPU Operators 16 | ------------- 17 | 18 | .. doxygengroup:: jagged-tensor-ops-cpu 19 | :content-only: 20 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/cpp-api/layout_transform_ops.rst: -------------------------------------------------------------------------------- 1 | Layout Transformation Operators 2 | =============================== 3 | 4 | CUDA Operators 5 | -------------- 6 | 7 | .. doxygengroup:: layout-transform-cuda 8 | :content-only: 9 | 10 | CPU Operators 11 | ------------- 12 | 13 | .. doxygengroup:: layout-transform-cpu 14 | :content-only: 15 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/cpp-api/memory_utils.rst: -------------------------------------------------------------------------------- 1 | CUDA Memory Operators 2 | ===================== 3 | 4 | .. doxygengroup:: cumem-utils 5 | :content-only: 6 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/cpp-api/merge_pooled_embeddings.rst: -------------------------------------------------------------------------------- 1 | Pooled Embeddings Operators 2 | =========================== 3 | 4 | This section includes CUDA and CPU operators for various 5 | operations with pooled embeddings, including merge and 6 | permutation operators. 7 | 8 | Merge Operators 9 | ---------------- 10 | .. doxygengroup:: merge-pooled-emb 11 | :content-only: 12 | 13 | Permutation Operators 14 | --------------------- 15 | 16 | .. doxygengroup:: permute-pooled-embs-gpu 17 | :content-only: 18 | 19 | .. doxygengroup:: permute-pooled-embs-cpu 20 | :content-only: 21 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/cpp-api/quantize_ops.rst: -------------------------------------------------------------------------------- 1 | Quantization Operators 2 | =========================== 3 | 4 | Quantization is a model optimization technique to reduce the size of a large 5 | model in order to achieve better storage performance with a small loss in 6 | accuracy. 7 | 8 | CUDA Operators 9 | -------------- 10 | 11 | .. doxygengroup:: quantize-ops-cuda 12 | :content-only: 13 | 14 | CPU Operators 15 | ------------- 16 | 17 | .. doxygengroup:: quantize-data-cpu 18 | :content-only: 19 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/cpp-api/sparse_ops.rst: -------------------------------------------------------------------------------- 1 | Sparse Data Operators 2 | ===================== 3 | 4 | CUDA Operators 5 | -------------------------- 6 | 7 | .. doxygengroup:: sparse-data-cuda 8 | :content-only: 9 | 10 | CPU Operators 11 | -------------------------- 12 | 13 | .. doxygengroup:: sparse-data-cpu 14 | :content-only: 15 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/cpp-api/split_table_batched_embeddings.rst: -------------------------------------------------------------------------------- 1 | Table Batched Embedding Operators 2 | ================================= 3 | 4 | .. doxygengroup:: table-batched-embed-cuda 5 | :content-only: 6 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/cpp-api/ssd_embedding_ops.rst: -------------------------------------------------------------------------------- 1 | SSD Embedding Operators 2 | ======================= 3 | 4 | CUDA Operators 5 | -------------- 6 | .. doxygengroup:: embedding-ssd 7 | :content-only: 8 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/overview/jagged-tensor-ops/JaggedTensorConversion1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/FBGEMM/ee0264c59fc6a403100c18f9c676f787c84ccf5b/fbgemm_gpu/docs/src/fbgemm_gpu/overview/jagged-tensor-ops/JaggedTensorConversion1.png -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/overview/jagged-tensor-ops/JaggedTensorConversion2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/FBGEMM/ee0264c59fc6a403100c18f9c676f787c84ccf5b/fbgemm_gpu/docs/src/fbgemm_gpu/overview/jagged-tensor-ops/JaggedTensorConversion2.png -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/overview/jagged-tensor-ops/JaggedTensorConversion3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/FBGEMM/ee0264c59fc6a403100c18f9c676f787c84ccf5b/fbgemm_gpu/docs/src/fbgemm_gpu/overview/jagged-tensor-ops/JaggedTensorConversion3.png -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/overview/jagged-tensor-ops/JaggedTensorExample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/FBGEMM/ee0264c59fc6a403100c18f9c676f787c84ccf5b/fbgemm_gpu/docs/src/fbgemm_gpu/overview/jagged-tensor-ops/JaggedTensorExample.png -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/python-api/feature_gates.rst: -------------------------------------------------------------------------------- 1 | .. _fbgemm-gpu.dev.config.python: 2 | 3 | Feature Gates (Python) 4 | ====================== 5 | 6 | .. automodule:: fbgemm_gpu 7 | 8 | Stable API 9 | ---------- 10 | 11 | .. autoclass:: fbgemm_gpu.config.FeatureGateName 12 | 13 | .. autoclass:: fbgemm_gpu.config.FeatureGate 14 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/python-api/jagged_tensor_ops.rst: -------------------------------------------------------------------------------- 1 | Jagged Tensor Operators 2 | ======================= 3 | 4 | .. automodule:: fbgemm_gpu 5 | 6 | .. _jagged-tensor-ops-stable-api: 7 | 8 | Stable API 9 | ---------- 10 | 11 | .. autofunction:: torch.ops.fbgemm.jagged_to_padded_dense 12 | 13 | Other API 14 | --------- 15 | 16 | .. autofunction:: torch.ops.fbgemm.jagged_2d_to_dense 17 | 18 | .. autofunction:: torch.ops.fbgemm.jagged_1d_to_dense 19 | 20 | .. autofunction:: torch.ops.fbgemm.dense_to_jagged 21 | 22 | .. autofunction:: torch.ops.fbgemm.jagged_dense_elementwise_add 23 | 24 | .. autofunction:: torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output 25 | 26 | .. autofunction:: torch.ops.fbgemm.jagged_dense_dense_elementwise_add_jagged_output 27 | 28 | .. autofunction:: torch.ops.fbgemm.jagged_dense_elementwise_mul 29 | 30 | .. autofunction:: torch.ops.fbgemm.batched_dense_vec_jagged_2d_mul 31 | 32 | .. autofunction:: torch.ops.fbgemm.stacked_jagged_1d_to_dense 33 | 34 | .. autofunction:: torch.ops.fbgemm.stacked_jagged_2d_to_dense 35 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/python-api/pooled_embedding_modules.rst: -------------------------------------------------------------------------------- 1 | Pooled Embedding Modules 2 | ======================== 3 | 4 | .. automodule:: fbgemm_gpu 5 | 6 | .. _pooled-embedding-modules-stable-api: 7 | 8 | Stable API 9 | ---------- 10 | 11 | .. autoclass:: fbgemm_gpu.permute_pooled_embedding_modules.PermutePooledEmbeddings 12 | :members: __call__ 13 | 14 | Other API 15 | --------- 16 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/python-api/pooled_embedding_ops.rst: -------------------------------------------------------------------------------- 1 | Pooled Embedding Operators 2 | ========================== 3 | 4 | .. automodule:: fbgemm_gpu 5 | 6 | .. _pooled-embedding-operators-stable-api: 7 | 8 | Stable API 9 | ---------- 10 | 11 | .. autofunction:: torch.ops.fbgemm.merge_pooled_embeddings 12 | 13 | .. autofunction:: torch.ops.fbgemm.permute_pooled_embs 14 | 15 | Other API 16 | --------- 17 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/python-api/quantize_ops.rst: -------------------------------------------------------------------------------- 1 | Quantization Operators 2 | ====================== 3 | 4 | .. automodule:: fbgemm_gpu 5 | 6 | .. _quantize-ops-stable-api: 7 | 8 | Stable API 9 | ---------- 10 | 11 | .. autofunction:: torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf 12 | 13 | Other API 14 | --------- 15 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/python-api/sparse_ops.rst: -------------------------------------------------------------------------------- 1 | Sparse Operators 2 | ================ 3 | 4 | .. automodule:: fbgemm_gpu 5 | 6 | .. _sparse-ops-stable-api: 7 | 8 | Stable API 9 | ---------- 10 | 11 | .. autofunction:: torch.ops.fbgemm.permute_2D_sparse_data 12 | 13 | .. autofunction:: torch.ops.fbgemm.permute_1D_sparse_data 14 | 15 | .. autofunction:: torch.ops.fbgemm.expand_into_jagged_permute 16 | 17 | .. autofunction:: torch.ops.fbgemm.asynchronous_complete_cumsum 18 | 19 | .. autofunction:: torch.ops.fbgemm.offsets_range 20 | 21 | .. autofunction:: torch.ops.fbgemm.segment_sum_csr 22 | 23 | .. autofunction:: torch.ops.fbgemm.keyed_jagged_index_select_dim1 24 | 25 | .. autofunction:: torch.ops.fbgemm.block_bucketize_sparse_features 26 | 27 | Other API 28 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/python-api/tbe_ops_inference.rst: -------------------------------------------------------------------------------- 1 | Table Batched Embedding (TBE) Inference Module 2 | ============================================== 3 | 4 | .. _tbe-ops-inference-stable-api: 5 | 6 | Stable API 7 | ---------- 8 | 9 | .. autoclass:: fbgemm_gpu.split_table_batched_embeddings_ops_inference.IntNBitTableBatchedEmbeddingBagsCodegen 10 | :members: forward, 11 | fill_random_weights, 12 | assign_embedding_weights, 13 | split_embedding_weights, 14 | split_embedding_weights_with_scale_bias, 15 | recompute_module_buffers 16 | 17 | Other API 18 | --------- 19 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/fbgemm_gpu/python-api/tbe_ops_training.rst: -------------------------------------------------------------------------------- 1 | Table Batched Embedding (TBE) Training Module 2 | ============================================= 3 | 4 | .. _tbe-ops-training-stable-api: 5 | 6 | Stable API 7 | ---------- 8 | 9 | .. autoclass:: fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen 10 | :members: forward, 11 | split_embedding_weights, 12 | split_optimizer_states, 13 | set_learning_rate, 14 | update_hyper_parameters, 15 | set_optimizer_step 16 | 17 | Other API 18 | --------- 19 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/general/ContactUs.rst: -------------------------------------------------------------------------------- 1 | Contact Us 2 | ========== 3 | 4 | GitHub 5 | ------ 6 | 7 | * `GitHub Issues `__: Use this to 8 | file questions, issues, and feature requests concerning FBGEMM and/or 9 | FBGEMM_GPU. 10 | 11 | * `GitHub Discussions `__: Use 12 | this to kick off longer discussions regarding FBGEMM and/or FBGEMM_GPU. 13 | 14 | Slack 15 | ----- 16 | 17 | For both FBGEMM and FBGEMM_GPU, feel free to reach out to us on the ``#fbgemm`` 18 | channel in `Pytorch Slack `__. 19 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/general/Contributing.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../../../../CONTRIBUTING.md 2 | :parser: myst_parser.sphinx_ 3 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/general/License.rst: -------------------------------------------------------------------------------- 1 | License 2 | ======= 3 | 4 | Both FBGEMM and FBGEMM_GPU are licensed under the 3-clause BSD License: 5 | 6 | .. literalinclude:: ../../../../LICENSE 7 | :language: text 8 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/general/documentation/ExampleGraph.dot: -------------------------------------------------------------------------------- 1 | digraph "sphinx-ext-graphviz" { 2 | size="6,8"; 3 | rankdir="LR"; 4 | graph [fontname="Verdana", fontsize="12"]; 5 | node [fontname="Verdana", fontsize="12"]; 6 | edge [fontname="Sans", fontsize="9"]; 7 | 8 | sphinx [label="Sphinx", shape="component", 9 | href="https://www.sphinx-doc.org/", 10 | target="_blank"]; 11 | dot [label="GraphViz", shape="component", 12 | href="https://www.graphviz.org/", 13 | target="_blank"]; 14 | docs [label="Docs (.rst)", shape="folder", 15 | fillcolor=green, style=filled]; 16 | svg_file [label="SVG Image", shape="note", fontcolor=white, 17 | fillcolor="#3333ff", style=filled]; 18 | html_files [label="HTML Files", shape="folder", 19 | fillcolor=yellow, style=filled]; 20 | 21 | docs -> sphinx [label=" parse "]; 22 | sphinx -> dot [label=" call ", style=dashed, arrowhead=none]; 23 | dot -> svg_file [label=" draw "]; 24 | sphinx -> html_files [label=" render "]; 25 | svg_file -> html_files [style=dashed]; 26 | } 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/docs/src/general/index.rst: -------------------------------------------------------------------------------- 1 | General Info 2 | ============ 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | Releases 8 | Contributing 9 | documentation/Overview 10 | ContactUs 11 | License 12 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/example/example/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import os 11 | 12 | import torch 13 | 14 | try: 15 | # pyre-ignore[21] 16 | # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils 17 | from fbgemm_gpu import open_source 18 | 19 | # pyre-ignore[21] 20 | # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils 21 | from fbgemm_gpu.docs.version import __version__ # noqa: F401 22 | except Exception: 23 | open_source: bool = False 24 | 25 | # pyre-ignore[16] 26 | if open_source: 27 | torch.ops.load_library( 28 | os.path.join(os.path.dirname(__file__), "fbgemm_gpu_experimental_example_py.so") 29 | ) 30 | else: 31 | torch.ops.load_library( 32 | "//deeplearning/fbgemm/fbgemm_gpu/experimental/example:example_ops_cuda" 33 | ) 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/example/example/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import torch 11 | 12 | 13 | def add_tensors(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 14 | return torch.ops.fbgemm.add_tensors_float(a, b) 15 | 16 | 17 | def sgemm( 18 | alpha: float, TA: torch.Tensor, TB: torch.Tensor, beta: float, TC: torch.Tensor 19 | ) -> torch.Tensor: 20 | return torch.ops.fbgemm.sgemm_float(alpha, TA, TB, beta, TC) 21 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/example/src/example_nccl.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | #include 11 | 12 | namespace fbgemm_gpu::experimental { 13 | 14 | void example_nccl_code() { 15 | ncclComm_t comms[4]; 16 | int devs[4] = {0, 1, 2, 3}; 17 | ncclCommInitAll(comms, 4, devs); 18 | 19 | for (const auto i : c10::irange(4)) { 20 | ncclCommDestroy(comms[i]); 21 | } 22 | } 23 | 24 | } // namespace fbgemm_gpu::experimental 25 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/example/src/example_ops.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | #include 11 | 12 | namespace fbgemm_gpu::experimental { 13 | 14 | at::Tensor add_tensors_float(const at::Tensor& a, const at::Tensor& b) { 15 | return a.to(at::kFloat) + b.to(at::kFloat); 16 | } 17 | 18 | TORCH_LIBRARY_FRAGMENT(fbgemm, m) { 19 | m.def("add_tensors_float(Tensor a, Tensor b) -> Tensor"); 20 | } 21 | 22 | TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { 23 | m.impl( 24 | "add_tensors_float", 25 | torch::dispatch( 26 | c10::DispatchKey::CPU, 27 | TORCH_FN(fbgemm_gpu::experimental::add_tensors_float))); 28 | } 29 | 30 | } // namespace fbgemm_gpu::experimental 31 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/example/test/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from typing import Tuple 10 | 11 | import torch 12 | 13 | 14 | gpu_unavailable: Tuple[bool, str] = ( 15 | not torch.cuda.is_available() or torch.cuda.device_count() == 0, 16 | "CUDA is not available or no GPUs detected", 17 | ) 18 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/example/test/add_tensors_float_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import unittest 10 | 11 | import torch 12 | from fbgemm_gpu.experimental.example import utils 13 | 14 | 15 | class AddTensorsFloatTest(unittest.TestCase): 16 | def test_add_tensors_float(self) -> None: 17 | a = torch.tensor([1, 2, 3]) 18 | b = torch.tensor([4, 5, 6]) 19 | expected = torch.tensor([5, 7, 9], dtype=torch.float) 20 | c = utils.add_tensors(a, b) 21 | torch.testing.assert_close(c.cpu(), expected.cpu()) 22 | 23 | 24 | if __name__ == "__main__": 25 | unittest.main() 26 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/example/test/sgemm_float_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import unittest 10 | 11 | import torch 12 | from fbgemm_gpu.experimental.example import utils 13 | 14 | from . import gpu_unavailable 15 | 16 | 17 | class SgemmFloatTest(unittest.TestCase): 18 | @unittest.skipIf(*gpu_unavailable) 19 | def test_sgemm_float(self) -> None: 20 | alpha = 3.14 21 | beta = 2.71 22 | 23 | A = torch.rand(4, 3, dtype=torch.float, device="cuda") 24 | B = torch.rand(3, 5, dtype=torch.float, device="cuda") 25 | C = torch.rand(4, 5, dtype=torch.float, device="cuda") 26 | D = utils.sgemm(alpha, A, B, beta, C) 27 | 28 | expected = torch.add(alpha * torch.matmul(A, B), beta * C) 29 | torch.testing.assert_close(D.cpu(), expected.cpu()) 30 | 31 | 32 | if __name__ == "__main__": 33 | unittest.main() 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gemm/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | ################################################################################ 8 | # Target Sources 9 | ################################################################################ 10 | 11 | # Python sources 12 | file(GLOB_RECURSE experimental_triton_gemm_python_source_files 13 | triton_gemm/*.py) 14 | 15 | 16 | ################################################################################ 17 | # Install Python Files 18 | ################################################################################ 19 | 20 | add_to_package( 21 | DESTINATION fbgemm_gpu/experimental/gemm/triton_gemm 22 | FILES ${experimental_triton_gemm_python_source_files}) 23 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gemm/test/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | try: 11 | # pyre-ignore[21] 12 | # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils 13 | from fbgemm_gpu import open_source 14 | 15 | # pyre-ignore[21] 16 | # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils 17 | from fbgemm_gpu.docs.version import __version__ # noqa: F401 18 | except Exception: 19 | open_source: bool = False 20 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/bench/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | try: 9 | # pyre-ignore[21] 10 | # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils 11 | from fbgemm_gpu import open_source 12 | 13 | # pyre-ignore[21] 14 | # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils 15 | from fbgemm_gpu.docs.version import __version__ # noqa: F401 16 | except Exception: 17 | open_source: bool = False 18 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/gen_ai/moe/README.md: -------------------------------------------------------------------------------- 1 | # FBGEMM GenAI MoE Support 2 | 3 | MetaShuffling MoE kernel support in FBGEMM GenAI kernel library. 4 | 5 | # **Overview** 6 | 7 | Mixture-of-Experts (MoE) is a popular model architecture for large language models (LLMs). Although it reduces computation in training and inference by activating less parameters per token, it imposes additional challenges in achieving optimal computation efficiency with high memory and communication pressure, as well as the complexity to handle the dynamism and sparsity nature of the model. Here we introduce a new MoE inference solution, MetaShuffling, which enables us to efficiently deploy Llama 4 models for real scenario inference. 8 | 9 | [Technical design blog](https://pytorch.org/blog/metashuffling-accelerating-llama-4-moe-inference/). 10 | 11 | # **Updates** 12 | 13 | - 2025-05-01: Initial release of MetaShuffling MoE PyTorch examples. 14 | 15 | - 2025-04-17: Initial release of MetaShuffling MoE GPU kernels. 16 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/coalesce/coalesce.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | 14 | namespace fbgemm_gpu { 15 | std::vector coalesce_batches_cpu( 16 | const std::vector& input, 17 | const std::vector& output, 18 | const at::Tensor& old_bids, 19 | const at::Tensor& new_bids); 20 | 21 | std::vector coalesce_batches_gpu( 22 | const std::vector& input, 23 | const std::vector& output, 24 | const at::Tensor& old_bids, 25 | const at::Tensor& new_bids); 26 | 27 | } // namespace fbgemm_gpu 28 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | // A kernel that works well on small but not super tiny shapes. 19 | using DeviceGemmInstance = DeviceGemmHelper< 20 | 128, 21 | 128, 22 | 16, 23 | 128, 24 | 16, 25 | 16, 26 | 4, 27 | 1, 28 | S<8, 16, 1>, 29 | S<8, 16, 1>, 30 | S<1, 16, 1, 8>, 31 | S<2, 2, 1>, 32 | 1, 33 | 1, 34 | ck::BlockGemmPipelineScheduler::Interwave, 35 | ck::BlockGemmPipelineVersion::v2>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 128, 20 | 16, 21 | 32, 22 | 128, 23 | 16, 24 | 16, 25 | 1, 26 | 1, 27 | S<8, 16, 1>, 28 | S<8, 16, 1>, 29 | S<1, 16, 1, 8>, 30 | S<4, 4, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Interwave, 34 | ck::BlockGemmPipelineVersion::v2, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_4_split_k.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_4( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 128, 20 | 16, 21 | 32, 22 | 128, 23 | 16, 24 | 16, 25 | 1, 26 | 1, 27 | S<8, 16, 1>, 28 | S<8, 16, 1>, 29 | S<1, 16, 1, 8>, 30 | S<4, 4, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Interwave, 34 | ck::BlockGemmPipelineVersion::v2, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y, 4); 38 | } 39 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_8_split_k.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_8( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 128, 20 | 16, 21 | 32, 22 | 128, 23 | 16, 24 | 16, 25 | 1, 26 | 1, 27 | S<8, 16, 1>, 28 | S<8, 16, 1>, 29 | S<1, 16, 1, 8>, 30 | S<4, 4, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Interwave, 34 | ck::BlockGemmPipelineVersion::v2, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y, 8); 38 | } 39 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2_8_split_k.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2_8( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 128, 20 | 16, 21 | 32, 22 | 128, 23 | 16, 24 | 16, 25 | 1, 26 | 1, 27 | S<8, 16, 1>, 28 | S<8, 16, 1>, 29 | S<1, 16, 1, 8>, 30 | S<4, 4, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v2, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y, 8); 38 | } 39 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 128, 20 | 16, 21 | 32, 22 | 256, 23 | 16, 24 | 16, 25 | 1, 26 | 1, 27 | S<16, 8, 1>, 28 | S<16, 8, 1>, 29 | S<1, 16, 1, 8>, 30 | S<4, 4, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v1, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_interwave_v2_2_split_k.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_interwave_v2_2( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 128, 20 | 16, 21 | 32, 22 | 512, 23 | 16, 24 | 16, 25 | 1, 26 | 1, 27 | S<32, 4, 1>, 28 | S<32, 4, 1>, 29 | S<1, 16, 1, 8>, 30 | S<4, 4, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Interwave, 34 | ck::BlockGemmPipelineVersion::v2, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y, 2); 38 | } 39 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 128, 20 | 16, 21 | 32, 22 | 512, 23 | 16, 24 | 16, 25 | 1, 26 | 1, 27 | S<32, 4, 1>, 28 | S<32, 4, 1>, 29 | S<1, 16, 1, 8>, 30 | S<4, 4, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v2, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2_2_split_k.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2_2( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 128, 20 | 16, 21 | 32, 22 | 512, 23 | 16, 24 | 16, 25 | 1, 26 | 1, 27 | S<32, 4, 1>, 28 | S<32, 4, 1>, 29 | S<1, 16, 1, 8>, 30 | S<4, 4, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v2, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y, 2); 38 | } 39 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | // The smallest kernel we have available. Works well for memory bound shapes. 19 | using DeviceGemmInstance = DeviceGemmHelper< 20 | 128, 21 | 16, 22 | 32, 23 | 512, 24 | 16, 25 | 16, 26 | 1, 27 | 1, 28 | S<8, 16, 1>, 29 | S<8, 16, 1>, 30 | S<1, 16, 1, 8>, 31 | S<4, 4, 1>, 32 | 1, 33 | 1, 34 | ck::BlockGemmPipelineScheduler::Interwave, 35 | ck::BlockGemmPipelineVersion::v2>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 128, 20 | 32, 21 | 16, 22 | 128, 23 | 16, 24 | 16, 25 | 1, 26 | 1, 27 | S<8, 16, 1>, 28 | S<8, 16, 1>, 29 | S<1, 16, 1, 8>, 30 | S<2, 2, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Interwave, 34 | ck::BlockGemmPipelineVersion::v2, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x32x16x256_16x16_1x1_16x8x1_16x8x1_1x32x1x4_4x4x1_1x1_interwave_v1.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_128x32x16x256_16x16_1x1_16x8x1_16x8x1_1x32x1x4_4x4x1_1x1_interwave_v1( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 128, 20 | 32, 21 | 16, 22 | 256, 23 | 16, 24 | 16, 25 | 1, 26 | 1, 27 | S<16, 8, 1>, 28 | S<16, 8, 1>, 29 | S<1, 32, 1, 4>, 30 | S<4, 4, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Interwave, 34 | ck::BlockGemmPipelineVersion::v1, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x32x16x512_16x16_1x1_32x4x1_32x4x1_1x32x1x4_4x4x1_1x1_intrawave_v2.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_128x32x16x512_16x16_1x1_32x4x1_32x4x1_1x32x1x4_4x4x1_1x1_intrawave_v2( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 128, 20 | 32, 21 | 16, 22 | 512, 23 | 16, 24 | 16, 25 | 1, 26 | 1, 27 | S<32, 4, 1>, 28 | S<32, 4, 1>, 29 | S<1, 32, 1, 4>, 30 | S<4, 4, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v2, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 128, 20 | 32, 21 | 64, 22 | 128, 23 | 32, 24 | 32, 25 | 1, 26 | 1, 27 | S<8, 16, 1>, 28 | S<8, 16, 1>, 29 | S<1, 16, 1, 8>, 30 | S<8, 8, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Interwave, 34 | ck::BlockGemmPipelineVersion::v2, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 128, 20 | 64, 21 | 32, 22 | 128, 23 | 32, 24 | 32, 25 | 1, 26 | 1, 27 | S<8, 16, 1>, 28 | S<8, 16, 1>, 29 | S<1, 16, 1, 8>, 30 | S<4, 4, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Interwave, 34 | ck::BlockGemmPipelineVersion::v2, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 256, 20 | 128, 21 | 128, 22 | 256, 23 | 32, 24 | 32, 25 | 2, 26 | 2, 27 | S<16, 16, 1>, 28 | S<16, 16, 1>, 29 | S<1, 32, 1, 8>, 30 | S<8, 8, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v3, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 256, 20 | 128, 21 | 160, 22 | 128, 23 | 32, 24 | 32, 25 | 1, 26 | 5, 27 | S<8, 32, 1>, 28 | S<8, 32, 1>, 29 | S<1, 64, 1, 4>, 30 | S<8, 8, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v3, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 256, 20 | 128, 21 | 192, 22 | 128, 23 | 32, 24 | 32, 25 | 2, 26 | 3, 27 | S<8, 32, 1>, 28 | S<8, 32, 1>, 29 | S<1, 32, 1, 8>, 30 | S<8, 8, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v3, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x128x256x128_32x32_2x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_256x128x256x128_32x32_2x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 256, 20 | 128, 21 | 256, 22 | 128, 23 | 32, 24 | 32, 25 | 2, 26 | 4, 27 | S<8, 32, 1>, 28 | S<8, 32, 1>, 29 | S<1, 32, 1, 8>, 30 | S<8, 8, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v3, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 256, 20 | 128, 21 | 64, 22 | 128, 23 | 32, 24 | 32, 25 | 2, 26 | 1, 27 | S<8, 32, 1>, 28 | S<8, 32, 1>, 29 | S<1, 32, 1, 8>, 30 | S<8, 8, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v3, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 256, 20 | 128, 21 | 96, 22 | 256, 23 | 32, 24 | 32, 25 | 1, 26 | 3, 27 | S<16, 16, 1>, 28 | S<16, 16, 1>, 29 | S<1, 64, 1, 4>, 30 | S<8, 8, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v3, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 256, 20 | 256, 21 | 128, 22 | 128, 23 | 16, 24 | 16, 25 | 8, 26 | 4, 27 | S<8, 32, 1>, 28 | S<8, 32, 1>, 29 | S<1, 32, 1, 8>, 30 | S<8, 8, 1>, 31 | 1, 32 | 2, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v3, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x256x96x128_16x16_8x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_256x256x96x128_16x16_8x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 256, 20 | 256, 21 | 96, 22 | 128, 23 | 16, 24 | 16, 25 | 8, 26 | 3, 27 | S<8, 32, 1>, 28 | S<8, 32, 1>, 29 | S<1, 64, 1, 4>, 30 | S<8, 8, 1>, 31 | 2, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v3, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x256x96x128_32x32_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_256x256x96x128_32x32_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 256, 20 | 256, 21 | 96, 22 | 128, 23 | 32, 24 | 32, 25 | 2, 26 | 3, 27 | S<8, 32, 1>, 28 | S<8, 32, 1>, 29 | S<1, 64, 1, 4>, 30 | S<8, 8, 1>, 31 | 2, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v3, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x64x128x128_32x32_1x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_256x64x128x128_32x32_1x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 256, 20 | 64, 21 | 128, 22 | 128, 23 | 32, 24 | 32, 25 | 1, 26 | 2, 27 | S<8, 32, 1>, 28 | S<8, 32, 1>, 29 | S<1, 32, 1, 8>, 30 | S<8, 8, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v3, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x64x16x512_16x16_1x1_32x8x1_32x8x1_1x64x1x4_4x4x1_1x1_intrawave_v2.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_256x64x16x512_16x16_1x1_32x8x1_32x8x1_1x64x1x4_4x4x1_1x1_intrawave_v2( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 256, 20 | 64, 21 | 16, 22 | 512, 23 | 16, 24 | 16, 25 | 1, 26 | 1, 27 | S<32, 8, 1>, 28 | S<32, 8, 1>, 29 | S<1, 64, 1, 4>, 30 | S<4, 4, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v2, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x64x192x128_32x32_1x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_256x64x192x128_32x32_1x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 256, 20 | 64, 21 | 192, 22 | 128, 23 | 32, 24 | 32, 25 | 1, 26 | 3, 27 | S<8, 32, 1>, 28 | S<8, 32, 1>, 29 | S<1, 32, 1, 8>, 30 | S<8, 8, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v3, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x64x192x256_32x32_1x3_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_256x64x192x256_32x32_1x3_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 256, 20 | 64, 21 | 192, 22 | 256, 23 | 32, 24 | 32, 25 | 1, 26 | 3, 27 | S<16, 16, 1>, 28 | S<16, 16, 1>, 29 | S<1, 32, 1, 8>, 30 | S<8, 8, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v3, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x64x256x128_32x32_1x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_256x64x256x128_32x32_1x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 256, 20 | 64, 21 | 256, 22 | 128, 23 | 32, 24 | 32, 25 | 1, 26 | 4, 27 | S<8, 32, 1>, 28 | S<8, 32, 1>, 29 | S<1, 32, 1, 8>, 30 | S<8, 8, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v3, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 256, 20 | 64, 21 | 64, 22 | 128, 23 | 32, 24 | 32, 25 | 1, 26 | 1, 27 | S<8, 32, 1>, 28 | S<8, 32, 1>, 29 | S<1, 32, 1, 8>, 30 | S<8, 8, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v3, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x64x96x256_16x16_2x3_16x16x1_16x16x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_256x64x96x256_16x16_2x3_16x16x1_16x16x1_1x64x1x4_8x8x1_2x1_intrawave_v3( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 256, 20 | 64, 21 | 96, 22 | 256, 23 | 16, 24 | 16, 25 | 2, 26 | 3, 27 | S<16, 16, 1>, 28 | S<16, 16, 1>, 29 | S<1, 64, 1, 4>, 30 | S<8, 8, 1>, 31 | 2, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v3, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_intrawave_v1( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | using DeviceGemmInstance = DeviceGemmHelper< 19 | 64, 20 | 16, 21 | 16, 22 | 256, 23 | 16, 24 | 16, 25 | 1, 26 | 1, 27 | S<16, 4, 1>, 28 | S<16, 4, 1>, 29 | S<1, 16, 1, 4>, 30 | S<4, 4, 1>, 31 | 1, 32 | 1, 33 | ck::BlockGemmPipelineScheduler::Intrawave, 34 | ck::BlockGemmPipelineVersion::v1, 35 | ck::tensor_operation::device::GemmSpecialization::Default>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | // The smallest kernel we have available. Works well for memory bound shapes. 19 | using DeviceGemmInstance = DeviceGemmHelper< 20 | 64, 21 | 16, 22 | 16, 23 | 512, 24 | 16, 25 | 16, 26 | 1, 27 | 1, 28 | S<8, 8, 1>, 29 | S<8, 8, 1>, 30 | S<1, 16, 1, 4>, 31 | S<4, 4, 1>, 32 | 1, 33 | 1, 34 | ck::BlockGemmPipelineScheduler::Interwave, 35 | ck::BlockGemmPipelineVersion::v2>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fp8_rowwise_common.h" 10 | 11 | at::Tensor 12 | fp8_rowwise_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2( 13 | at::Tensor XQ, 14 | at::Tensor WQ, 15 | at::Tensor x_scale, 16 | at::Tensor w_scale, 17 | at::Tensor Y) { 18 | // The smallest kernel we have available. Works well for memory bound shapes. 19 | using DeviceGemmInstance = DeviceGemmHelper< 20 | 64, 21 | 16, 22 | 16, 23 | 64, 24 | 16, 25 | 16, 26 | 1, 27 | 1, 28 | S<4, 16, 1>, 29 | S<4, 16, 1>, 30 | S<1, 16, 1, 4>, 31 | S<4, 4, 1>, 32 | 1, 33 | 1, 34 | ck::BlockGemmPipelineScheduler::Interwave, 35 | ck::BlockGemmPipelineVersion::v2>; 36 | // Run kernel instance. 37 | return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); 38 | } 39 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moesorting.hpp: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. 3 | 4 | #pragma once 5 | #include 6 | #include "ck_tile/core.hpp" 7 | #include "ck_tile/host.hpp" 8 | #include "ck_tile/ops/fused_moe.hpp" 9 | 10 | struct fused_moesorting_trait { 11 | std::string index_type; 12 | std::string weight_type; // currently always float 13 | bool local_expert_masking; // if mask experts as local expert 14 | }; 15 | 16 | struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs {}; 17 | 18 | float fused_moesorting( 19 | fused_moesorting_trait t, 20 | fused_moesorting_args a, 21 | ck_tile::stream_config s); 22 | 23 | int moe_sorting_get_workspace_size(int tokens, int num_experts); 24 | float moe_sorting_mp( 25 | fused_moesorting_trait t, 26 | fused_moesorting_args a, 27 | ck_tile::stream_config s); 28 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cublas_utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #define CUBLAS_WORKSPACE_SIZE 4194304 16 | 17 | inline void checkCublasStatus(cublasStatus_t status) { 18 | if (status != CUBLAS_STATUS_SUCCESS) { 19 | printf("cuBLAS API failed with status %d\n", status); 20 | throw std::logic_error("cuBLAS API failed"); 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_2_1_1_t.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "bf16bf16bf16_grouped_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor bf16bf16bf16_grouped_128_32_128_2_1_1_t( 14 | at::Tensor X, // BF16 15 | at::Tensor W, // BF16 16 | at::Tensor output, 17 | std::optional zero_start_index_M, 18 | std::optional M_sizes) { 19 | return bf16bf16bf16_grouped_impl( 20 | X, W, output, zero_start_index_M, M_sizes); 21 | } 22 | 23 | at::Tensor bf16bf16bf16_grouped_128_32_128_2_1_1_t( 24 | at::TensorList X, // BF16 25 | at::TensorList W, // BF16 26 | at::Tensor output, 27 | std::optional zero_start_index_M, 28 | std::optional M_sizes) { 29 | return bf16bf16bf16_grouped_impl( 30 | X, W, output, zero_start_index_M, M_sizes); 31 | } 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_4_1_1_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_128_128_4_1_1_f( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::nv_float4_t, 24 | 128, 25 | 128, 26 | 4, 27 | 1, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_4_1_1_t.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_128_128_4_1_1_t( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::mx_float4_t, 24 | 128, 25 | 128, 26 | 4, 27 | 1, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_2_2_1_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_128_192_2_2_1_f( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::nv_float4_t, 24 | 128, 25 | 192, 26 | 2, 27 | 2, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_2_2_1_t.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_128_192_2_2_1_t( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::mx_float4_t, 24 | 128, 25 | 192, 26 | 2, 27 | 2, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_256_2_1_1_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_128_256_2_1_1_f( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::nv_float4_t, 24 | 128, 25 | 256, 26 | 2, 27 | 1, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_256_2_1_1_t.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_128_256_2_1_1_t( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::mx_float4_t, 24 | 128, 25 | 256, 26 | 2, 27 | 1, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_2_1_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_256_128_2_2_1_f( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::nv_float4_t, 24 | 256, 25 | 128, 26 | 2, 27 | 2, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_2_1_t.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_256_128_2_2_1_t( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::mx_float4_t, 24 | 256, 25 | 128, 26 | 2, 27 | 2, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_4_1_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_256_128_2_4_1_f( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::nv_float4_t, 24 | 256, 25 | 128, 26 | 2, 27 | 4, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_128_2_4_1_t.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_256_128_2_4_1_t( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::mx_float4_t, 24 | 256, 25 | 128, 26 | 2, 27 | 4, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_2_1_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_256_192_2_2_1_f( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::nv_float4_t, 24 | 256, 25 | 192, 26 | 2, 27 | 2, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_2_1_t.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_256_192_2_2_1_t( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::mx_float4_t, 24 | 256, 25 | 192, 26 | 2, 27 | 2, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_4_1_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_256_192_2_4_1_f( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::nv_float4_t, 24 | 256, 25 | 192, 26 | 2, 27 | 4, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_2_4_1_t.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_256_192_2_4_1_t( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::mx_float4_t, 24 | 256, 25 | 192, 26 | 2, 27 | 4, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_4_1_1_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_256_192_4_1_1_f( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::nv_float4_t, 24 | 256, 25 | 192, 26 | 4, 27 | 1, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_192_4_1_1_t.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_256_192_4_1_1_t( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::mx_float4_t, 24 | 256, 25 | 192, 26 | 4, 27 | 1, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_1_1_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_256_256_2_1_1_f( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::nv_float4_t, 24 | 256, 25 | 256, 26 | 2, 27 | 1, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_1_1_t.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_256_256_2_1_1_t( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::mx_float4_t, 24 | 256, 25 | 256, 26 | 2, 27 | 1, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_2_1_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_256_256_2_2_1_f( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::nv_float4_t, 24 | 256, 25 | 256, 26 | 2, 27 | 2, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_2_1_t.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_256_256_2_2_1_t( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::mx_float4_t, 24 | 256, 25 | 256, 26 | 2, 27 | 2, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_4_1_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_256_256_2_4_1_f( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::nv_float4_t, 24 | 256, 25 | 256, 26 | 2, 27 | 4, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_2_4_1_t.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_256_256_2_4_1_t( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::mx_float4_t, 24 | 256, 25 | 256, 26 | 2, 27 | 4, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_4_1_1_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_256_256_4_1_1_f( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::nv_float4_t, 24 | 256, 25 | 256, 26 | 4, 27 | 1, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_256_256_4_1_1_t.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f4f4bf16_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) 14 | 15 | at::Tensor f4f4bf16_256_256_4_1_1_t( 16 | at::Tensor XQ, // FP4 17 | at::Tensor WQ, // FP4 18 | at::Tensor x_scale, 19 | at::Tensor w_scale, 20 | std::optional global_scale = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return _f4f4bf16< 23 | cutlass::mx_float4_t, 24 | 256, 25 | 256, 26 | 4, 27 | 1, 28 | 1>(XQ, WQ, x_scale, w_scale, global_scale); 29 | } 30 | 31 | #endif 32 | 33 | } // namespace fbgemm_gpu 34 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/f8f8bf16_rowwise_128_128_128_2_1_1_9_t_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_128_128_128_2_1_1_9_t_f( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_wrapper<128, 128, 128, 2, 1, 1, 9, true, false>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/f8f8bf16_rowwise_128_128_128_2_2_1_10_f_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_128_128_128_2_2_1_10_f_f( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_wrapper<128, 128, 128, 2, 2, 1, 10, false, false>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/f8f8bf16_rowwise_128_256_128_1_2_1_10_f_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_128_256_128_1_2_1_10_f_f( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_wrapper<128, 256, 128, 1, 2, 1, 10, false, false>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_128_256_128_2_1_1_10_f_f( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_wrapper<128, 256, 128, 2, 1, 1, 10, false, false>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_128_256_128_2_1_1_9_f_t( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_wrapper<128, 256, 128, 2, 1, 1, 9, false, true>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/f8f8bf16_rowwise_128_256_128_4_4_1_9_f_t.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_128_256_128_4_4_1_9_f_t( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_wrapper<128, 256, 128, 4, 4, 1, 9, false, true>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/f8f8bf16_rowwise_128_32_128_1_1_1_10_f_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_128_32_128_1_1_1_10_f_f( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_wrapper<128, 32, 128, 1, 1, 1, 10, false, false>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/f8f8bf16_rowwise_128_64_128_1_1_1_10_f_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_128_64_128_1_1_1_10_f_f( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_wrapper<128, 64, 128, 1, 1, 1, 10, false, false>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_64_128_128_1_1_1_9_f_f( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_wrapper<64, 128, 128, 1, 1, 1, 9, false, false>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_64_16_128_1_1_1_9_f_f( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_wrapper<64, 16, 128, 1, 1, 1, 9, false, false>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_64_256_128_1_1_1_9_f_f( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_wrapper<64, 256, 128, 1, 1, 1, 9, false, false>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/f8f8bf16_rowwise_64_256_128_2_1_1_9_f_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_64_256_128_2_1_1_9_f_f( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_wrapper<64, 256, 128, 2, 1, 1, 9, false, false>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_64_32_128_2_1_1_9_f_f( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_wrapper<64, 32, 128, 2, 1, 1, 9, false, false>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_64_64_128_2_1_1_9_f_f( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_wrapper<64, 64, 128, 2, 1, 1, 9, false, false>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_128_128_128_1_2_1_10_t.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_batched_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_batched_128_128_128_1_2_1_10_t( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_batched_wrapper<128, 128, 128, 1, 2, 1, 10, true>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_128_128_128_1_2_1_9_t.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_batched_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_batched_128_128_128_1_2_1_9_t( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_batched_wrapper<128, 128, 128, 1, 2, 1, 9, true>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_128_128_128_2_1_1_10_t.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_batched_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_batched_128_128_128_2_1_1_10_t( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_batched_wrapper<128, 128, 128, 2, 1, 1, 10, true>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_128_128_128_2_1_1_9_t.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_batched_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_batched_128_128_128_2_1_1_9_t( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_batched_wrapper<128, 128, 128, 2, 1, 1, 9, true>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_64_128_128_1_2_1_10_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_batched_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_batched_64_128_128_1_2_1_10_f( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_batched_wrapper<64, 128, 128, 1, 2, 1, 10, false>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_64_128_128_1_2_1_9_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_batched_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_batched_64_128_128_1_2_1_9_f( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_batched_wrapper<64, 128, 128, 1, 2, 1, 9, false>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_64_128_128_2_1_1_10_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_batched_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_batched_64_128_128_2_1_1_10_f( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_batched_wrapper<64, 128, 128, 2, 1, 1, 10, false>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_64_128_128_2_1_1_9_f.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "f8f8bf16_rowwise_batched_common.cuh" 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor f8f8bf16_rowwise_batched_64_128_128_2_1_1_9_f( 14 | at::Tensor XQ, 15 | at::Tensor WQ, 16 | at::Tensor x_scale, 17 | at::Tensor w_scale, 18 | bool use_fast_accum = true, 19 | std::optional bias = std::nullopt, 20 | std::optional output = std::nullopt) { 21 | // Dispatch this kernel to the correct underlying implementation. 22 | return f8f8bf16_rowwise_batched_wrapper<64, 128, 128, 2, 1, 1, 9, false>( 23 | XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/gen_ai/test/moe/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/hstu/context_causal_target.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/FBGEMM/ee0264c59fc6a403100c18f9c676f787c84ccf5b/fbgemm_gpu/experimental/hstu/context_causal_target.png -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/hstu/deltaq_causal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/FBGEMM/ee0264c59fc6a403100c18f9c676f787c84ccf5b/fbgemm_gpu/experimental/hstu/deltaq_causal.png -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/hstu/deltaq_local.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/FBGEMM/ee0264c59fc6a403100c18f9c676f787c84ccf5b/fbgemm_gpu/experimental/hstu/deltaq_local.png -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/hstu/hstu/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2024, NVIDIA Corporation & AFFILIATES. 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | # pyre-strict 10 | 11 | import os 12 | 13 | import torch 14 | 15 | try: 16 | # pyre-ignore[21] 17 | # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils 18 | from fbgemm_gpu import open_source 19 | 20 | # pyre-ignore[21] 21 | # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils 22 | from fbgemm_gpu.docs.version import __version__ # noqa: F401 23 | except Exception: 24 | open_source: bool = False 25 | 26 | # pyre-ignore[16] 27 | if open_source: 28 | torch.ops.load_library( 29 | os.path.join(os.path.dirname(__file__), "fbgemm_gpu_experimental_hstu.so") 30 | ) 31 | torch.classes.load_library( 32 | os.path.join(os.path.dirname(__file__), "fbgemm_gpu_experimental_hstu.so") 33 | ) 34 | else: 35 | torch.ops.load_library( 36 | "//deeplearning/fbgemm/fbgemm_gpu/experimental/hstu:hstu_ops" 37 | ) 38 | -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/hstu/src/hstu_ampere/instantiations/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/FBGEMM/ee0264c59fc6a403100c18f9c676f787c84ccf5b/fbgemm_gpu/experimental/hstu/src/hstu_ampere/instantiations/.gitkeep -------------------------------------------------------------------------------- /fbgemm_gpu/experimental/hstu/src/hstu_hopper/instantiations/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/FBGEMM/ee0264c59fc6a403100c18f9c676f787c84ccf5b/fbgemm_gpu/experimental/hstu/src/hstu_hopper/instantiations/.gitkeep -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/config/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from .feature_list import FeatureGate, FeatureGateName # noqa F401 10 | -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/docs/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # Trigger the manual addition of docstrings to pybind11-generated operators 9 | try: 10 | from . import ( # noqa: F401 11 | jagged_tensor_ops, 12 | merge_pooled_embedding_ops, 13 | permute_pooled_embedding_ops, 14 | quantize_ops, 15 | sparse_ops, 16 | ) 17 | except Exception: 18 | pass 19 | -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/docs/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | def add_docs(method, docstr: str): 9 | method.__doc__ = docstr 10 | -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/enums.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import enum 11 | import typing 12 | from typing import Any, Callable, List, Tuple 13 | 14 | 15 | # Create enums in given namespace with information from query_op 16 | def create_enums( 17 | namespace: typing.Dict[str, Any], 18 | query_op: Callable[[], List[Tuple[str, List[Tuple[str, int]]]]], 19 | ) -> None: 20 | for enum_name, items in query_op(): 21 | # Create matching python enumeration 22 | # pyre-fixme[19]: Expected 1 positional argument. 23 | new_enum = enum.Enum(enum_name, items) 24 | # and store it in the module 25 | namespace[enum_name] = new_enum 26 | -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/quantize/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from fbgemm_gpu.quantize.quantize_ops import dequantize_mx, quantize_mx # noqa F401 10 | from fbgemm_gpu.utils import TorchLibraryFragment 11 | 12 | lib = TorchLibraryFragment("fbgemm") 13 | 14 | lib.define( 15 | """quantize_mx( 16 | Tensor input, 17 | int scale_bits, 18 | int elem_ebits, 19 | int elem_mbits, 20 | float elem_max_norm, 21 | int mx_group_size, 22 | int? rounding_mode = None 23 | ) -> Tensor 24 | """ 25 | ) 26 | 27 | lib.define( 28 | """dequantize_mx( 29 | Tensor input, 30 | int mx_group_size 31 | ) -> Tensor 32 | """ 33 | ) 34 | 35 | lib.register( 36 | "quantize_mx", 37 | {"CUDA": quantize_mx, "CPU": quantize_mx}, 38 | ) 39 | 40 | lib.register( 41 | "dequantize_mx", 42 | {"CUDA": dequantize_mx, "CPU": dequantize_mx}, 43 | ) 44 | -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/sll/meta/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from fbgemm_gpu.sll.meta.meta_sll import ( # noqa F401 11 | meta_array_jagged_bmm_jagged_out, 12 | meta_jagged2_softmax, 13 | meta_jagged_dense_elementwise_mul_jagged_out, 14 | meta_jagged_jagged_bmm_jagged_out, 15 | meta_jagged_self_substraction_jagged_out, 16 | ) 17 | 18 | # pyre-ignore[5] 19 | op_registrations = { 20 | "sll_jagged_self_substraction_jagged_out": { 21 | "Meta": meta_jagged_self_substraction_jagged_out, 22 | }, 23 | "sll_jagged_dense_elementwise_mul_jagged_out": { 24 | "Meta": meta_jagged_dense_elementwise_mul_jagged_out, 25 | }, 26 | "sll_jagged2_softmax": { 27 | "AutogradMeta": meta_jagged2_softmax, 28 | }, 29 | "sll_array_jagged_bmm_jagged_out": { 30 | "AutogradMeta": meta_array_jagged_bmm_jagged_out, 31 | }, 32 | "sll_jagged_jagged_bmm_jagged_out": { 33 | "AutogradMeta": meta_jagged_jagged_bmm_jagged_out, 34 | }, 35 | } 36 | -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/sll/triton/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-unsafe 8 | 9 | import torch 10 | 11 | 12 | def next_power_of_two(N: int) -> int: 13 | if N > 4096: 14 | raise Exception(f"{N} is too large that is not supported yet") 15 | 16 | if N > 2048: 17 | return 4096 18 | elif N > 1024: 19 | return 2048 20 | elif N > 512: 21 | return 1024 22 | elif N > 256: 23 | return 512 24 | elif N > 128: 25 | return 256 26 | elif N > 64: 27 | return 128 28 | elif N > 32: 29 | return 64 30 | else: 31 | return 32 32 | 33 | 34 | def expect_contiguous(x: torch.Tensor) -> torch.Tensor: 35 | if not x.is_contiguous(): 36 | return x.contiguous() 37 | else: 38 | return x 39 | -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/split_embedding_optimizer_ops.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | # flake8: noqa F401 11 | 12 | # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_optimizer_codegen 13 | from fbgemm_gpu.split_embedding_optimizer_codegen.optimizer_args import ( 14 | SplitEmbeddingArgs, 15 | SplitEmbeddingOptimizerParams, 16 | ) 17 | 18 | # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_optimizer_codegen 19 | from fbgemm_gpu.split_embedding_optimizer_codegen.split_embedding_optimizer_rowwise_adagrad import ( 20 | SplitEmbeddingRowwiseAdagrad, 21 | ) 22 | -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/split_embedding_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import warnings 10 | 11 | from fbgemm_gpu.tbe.utils import ( # noqa: F401 12 | b_indices, # noqa: F401 13 | fake_quantize_embs, # noqa: F401 14 | generate_requests, # noqa: F401 15 | get_device, # noqa: F401 16 | get_table_batched_offsets_from_dense, # noqa: F401 17 | quantize_embs, # noqa: F401 18 | round_up, # noqa: F401 19 | TBERequest, # noqa: F401 20 | to_device, # noqa: F401 21 | ) 22 | 23 | warnings.warn( # noqa: B028 24 | f"""\033[93m 25 | The Python module {__name__} is now DEPRECATED and will be removed in the 26 | future. Users should import fbgemm_gpu.tbe.utils into their scripts instead. 27 | \033[0m""", 28 | DeprecationWarning, 29 | ) 30 | -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | # pyre-ignore-all-errors[56] 10 | 11 | import warnings 12 | 13 | from fbgemm_gpu.tbe.ssd import ( # noqa: F401 14 | ASSOC, # noqa: F401 15 | SSDIntNBitTableBatchedEmbeddingBags, # noqa: F401 16 | SSDTableBatchedEmbeddingBags, # noqa: F401 17 | ) 18 | 19 | 20 | warnings.warn( # noqa: B028 21 | f"""\033[93m 22 | The Python module {__name__} is now DEPRECATED and will be removed in the 23 | future. Users should import fbgemm_gpu.tbe.ssd into their scripts instead. 24 | \033[0m""", 25 | DeprecationWarning, 26 | ) 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/tbe/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/tbe/bench/reporter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | 11 | import logging 12 | from dataclasses import dataclass 13 | 14 | haveAIBench = False 15 | try: 16 | from aibench_observer.utils.observer import emitMetric 17 | 18 | haveAIBench = True 19 | except Exception: 20 | haveAIBench = False 21 | 22 | 23 | @dataclass 24 | class BenchmarkReporter: 25 | report: bool 26 | logger: logging.Logger = logging.getLogger() 27 | 28 | # pyre-ignore[3] 29 | def __post_init__(self): 30 | self.logger.setLevel(logging.INFO) 31 | 32 | # pyre-ignore[2] 33 | def emit_metric(self, **kwargs) -> None: 34 | if self.report and haveAIBench: 35 | self.logger.info(emitMetric(**kwargs)) 36 | -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/tbe/cache/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-unsafe 9 | 10 | from .split_embeddings_cache_ops import get_unique_indices_v2 # noqa: F401 11 | -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/tbe/ssd/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | # Load the prelude 11 | from .common import ASSOC # noqa: F401 12 | 13 | # Load the inference and training ops 14 | from .inference import SSDIntNBitTableBatchedEmbeddingBags # noqa: F401 15 | from .training import SSDTableBatchedEmbeddingBags # noqa: F401 16 | -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/tbe/ssd/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ..common import ASSOC # noqa: F401 8 | -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/tbe/stats/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from .bench_params_reporter import TBEBenchmarkParamsReporter # noqa F401 11 | -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/tbe/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-unsafe 9 | 10 | from .common import get_device, round_up, to_device # noqa: F401 11 | from .offsets import b_indices, get_table_batched_offsets_from_dense # noqa: F401 12 | from .quantize import dequantize_embs, fake_quantize_embs, quantize_embs # noqa: F401 13 | from .requests import generate_requests, TBERequest # noqa: F401 14 | -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/triton/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-unsafe 9 | 10 | # Attempt to import triton kernels, fallback to reference if we cannot. 11 | from .common import RoundingMode # noqa 12 | 13 | try: 14 | from .quantize import ( 15 | triton_dequantize_mx4 as dequantize_mx4, 16 | triton_quantize_mx4 as quantize_mx4, 17 | ) 18 | except ImportError: 19 | from .quantize_ref import ( # noqa: F401, E402 20 | py_dequantize_mx4 as dequantize_mx4, 21 | py_quantize_mx4 as quantize_mx4, 22 | ) 23 | -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/triton/jagged/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-unsafe 9 | -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-unsafe 9 | 10 | from .filestore import FileStore # noqa F401 11 | from .torch_library import TorchLibraryFragment # noqa F401 12 | -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/utils/loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | # pyre-ignore-all-errors[56] 10 | 11 | from typing import Optional 12 | 13 | import torch 14 | 15 | 16 | def load_torch_module( 17 | unified_path: str, cuda_path: Optional[str] = None, hip_path: Optional[str] = None 18 | ) -> None: 19 | try: 20 | torch.ops.load_library(unified_path) 21 | except Exception: 22 | if torch.version.hip: 23 | if not hip_path: 24 | hip_path = f"{unified_path}_hip" 25 | torch.ops.load_library(hip_path) 26 | else: 27 | if not cuda_path: 28 | cuda_path = f"{unified_path}_cuda" 29 | torch.ops.load_library(cuda_path) 30 | 31 | 32 | def load_torch_module_bc(new_path: str, old_path: str) -> None: 33 | try: 34 | torch.ops.load_library(new_path) 35 | except Exception: 36 | torch.ops.load_library(old_path) 37 | -------------------------------------------------------------------------------- /fbgemm_gpu/fbgemm_gpu/uvm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from enum import Enum 11 | from typing import Optional 12 | 13 | import torch 14 | 15 | from fbgemm_gpu.enums import create_enums 16 | 17 | try: 18 | # pyre-ignore[21] 19 | from fbgemm_gpu import open_source # noqa: F401 20 | except Exception: 21 | torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:cumem_utils") 22 | 23 | # Import all uvm enums from c++ library 24 | create_enums(globals(), torch.ops.fbgemm.fbgemm_gpu_uvm_enum_query) 25 | 26 | 27 | def cudaMemAdvise( 28 | t: torch.Tensor, 29 | advice: Enum, 30 | ) -> None: 31 | torch.ops.fbgemm.cuda_mem_advise(t, advice.value) 32 | 33 | 34 | def cudaMemPrefetchAsync( 35 | t: torch.Tensor, 36 | device_t: Optional[torch.Tensor] = None, 37 | ) -> None: 38 | torch.ops.fbgemm.cuda_mem_prefetch_async(t, device_t) 39 | -------------------------------------------------------------------------------- /fbgemm_gpu/include/fbgemm_gpu/merge_pooled_embeddings.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | 13 | namespace fbgemm_gpu { 14 | /// @defgroup merge-pooled-emb Merge Operators 15 | 16 | ///@ingroup merge-pooled-emb 17 | std::vector all_to_one_device( 18 | std::vector inputTensors, 19 | at::Device target_device); 20 | 21 | } // namespace fbgemm_gpu 22 | -------------------------------------------------------------------------------- /fbgemm_gpu/include/fbgemm_gpu/permute_pooled_embs_function.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | -------------------------------------------------------------------------------- /fbgemm_gpu/include/fbgemm_gpu/sparse_ops.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #ifdef USE_ROCM 12 | #define HIPCUB_ARCH 1 13 | #endif 14 | 15 | #include 16 | #include 17 | #include 18 | 19 | // clang-format off 20 | #include "fbgemm_gpu/utils/cub_namespace_prefix.cuh" 21 | #include 22 | #include "fbgemm_gpu/utils/cub_namespace_postfix.cuh" 23 | // clang-format on 24 | -------------------------------------------------------------------------------- /fbgemm_gpu/include/fbgemm_gpu/utils/assert_macros.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #define FBGEMM_KERNEL_ERROR_CHECK(CODE, COND, ERROR_VAL) \ 10 | if (!(COND)) { \ 11 | error_code = CODE; \ 12 | error_value = ERROR_VAL; \ 13 | goto kernel_error_handler; \ 14 | } 15 | 16 | #define FBGEMM_KERNEL_ERROR_THROW(CODE, COND, MSG, ...) \ 17 | if (error_code == CODE) { \ 18 | printf("CUDA Kernel Assertion: " #COND " " #MSG "\n", __VA_ARGS__); \ 19 | CUDA_KERNEL_ASSERT(false && "Please search for 'CUDA Kernel Assertion'"); \ 20 | } 21 | -------------------------------------------------------------------------------- /fbgemm_gpu/include/fbgemm_gpu/utils/cub_namespace_postfix.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #undef FBGEMM_GPU_CUB_NS_PREFIX 10 | 11 | #ifdef FBGEMM_CUB_USE_NAMESPACE 12 | 13 | #undef CUB_NS_PREFIX 14 | #undef CUB_NS_POSTFIX 15 | 16 | #include // for CUDA_VERSION 17 | #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 18 | #include 19 | #else 20 | #define CUB_VERSION 0 21 | #endif 22 | 23 | // PR https://github.com/NVIDIA/cub/pull/350 introduced breaking change. 24 | // When the CUB_NS_[PRE|POST]FIX macros are set, 25 | // CUB_NS_QUALIFIER must also be defined to the fully qualified CUB namespace 26 | #if CUB_VERSION >= 101301 27 | #undef CUB_NS_QUALIFIER 28 | #endif 29 | 30 | #define FBGEMM_GPU_CUB_NS_PREFIX fbgemm_gpu:: 31 | 32 | #else 33 | 34 | #define FBGEMM_GPU_CUB_NS_PREFIX 35 | 36 | #endif 37 | -------------------------------------------------------------------------------- /fbgemm_gpu/include/fbgemm_gpu/utils/embedding_bounds_check_common.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" 10 | #include "fbgemm_gpu/utils/kernel_launcher.cuh" 11 | #include "fbgemm_gpu/utils/tensor_accessor_builder.h" 12 | 13 | #include 14 | #include 15 | 16 | using Tensor = at::Tensor; 17 | using namespace fbgemm_gpu; 18 | 19 | template 20 | __device__ void adjust_offset_kernel( 21 | index_t& indices_start, 22 | index_t& indices_end, 23 | const index_t num_indices, 24 | index_t* const offset_acc_start, 25 | index_t* const offset_acc_end) { 26 | indices_start = 27 | std::max(static_cast(0), std::min(indices_start, num_indices)); 28 | indices_end = std::max(indices_start, std::min(indices_end, num_indices)); 29 | *offset_acc_start = indices_start; 30 | *offset_acc_end = indices_end; 31 | } 32 | -------------------------------------------------------------------------------- /fbgemm_gpu/include/fbgemm_gpu/utils/log2.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | template 12 | struct log2_calc_ { 13 | enum { value = log2_calc_<(x >> 1)>::value + 1 }; 14 | }; 15 | template <> 16 | struct log2_calc_<0> { 17 | enum { value = 0 }; 18 | }; 19 | 20 | template 21 | struct log2_calc { 22 | enum { value = log2_calc_<(x - 1)>::value }; 23 | }; 24 | 25 | #if 0 26 | template <> 27 | struct log2_calc<0> { enum { value = 0 }; }; 28 | template <> 29 | struct log2_calc<1> { enum { value = 0 }; }; 30 | #endif 31 | -------------------------------------------------------------------------------- /fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | #include 11 | 12 | using Tensor = at::Tensor; 13 | 14 | namespace fbgemm_gpu { 15 | 16 | //////////////////////////////////////////////////////////////////////////////// 17 | // Helper Functions 18 | //////////////////////////////////////////////////////////////////////////////// 19 | 20 | Tensor reshape_vbe_output( 21 | const Tensor& grad_output, 22 | const int64_t max_B, 23 | const Tensor& B_offsets_rank_per_feature, 24 | const Tensor& D_offsets); 25 | 26 | template 27 | Tensor reshape_vbe_offsets( 28 | const Tensor& offsets, 29 | const Tensor& B_offsets_rank_per_feature, 30 | const int64_t max_B, 31 | const int32_t T); 32 | } // namespace fbgemm_gpu 33 | -------------------------------------------------------------------------------- /fbgemm_gpu/include/fbgemm_gpu/utils/topology_utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | 13 | using Node = int64_t; 14 | using Links = int64_t; 15 | template 16 | using AdjacencyMatrix = std::function; 17 | 18 | namespace fbgemm_gpu { 19 | AdjacencyMatrix get_nvlink_matrix(); 20 | } // namespace fbgemm_gpu 21 | -------------------------------------------------------------------------------- /fbgemm_gpu/include/fbgemm_gpu/utils/types.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | namespace fbgemm_gpu { 12 | 13 | using fint32 = union fint32 { 14 | uint32_t I; 15 | float F; 16 | }; 17 | 18 | inline int64_t div_up(int64_t val, int64_t unit) { 19 | return (val + unit - 1) / unit; 20 | } 21 | 22 | } // namespace fbgemm_gpu 23 | -------------------------------------------------------------------------------- /fbgemm_gpu/requirements.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # NOTES: 8 | # 9 | # - A fixed version of mpmath is needed to work around an AttributeError; see: 10 | # * https://github.com/nod-ai/SHARK/issues/2095 11 | # * https://github.com/jianyicheng/mase-docker/pull/9 12 | 13 | backports.tarfile 14 | build 15 | cmake 16 | click 17 | hypothesis 18 | jinja2 19 | mpmath==1.3.0 20 | ninja 21 | numpy>=2.0.2 22 | pyre-extensions 23 | pyyaml 24 | scikit-build 25 | setuptools 26 | setuptools_git_versioning 27 | tabulate 28 | patchelf 29 | fairscale 30 | -------------------------------------------------------------------------------- /fbgemm_gpu/requirements_genai.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # NOTES: 8 | # 9 | # - A fixed version of mpmath is needed to work around an AttributeError; see: 10 | # * https://github.com/nod-ai/SHARK/issues/2095 11 | # * https://github.com/jianyicheng/mase-docker/pull/9 12 | 13 | # Requirements for GENAI build variant 14 | 15 | backports.tarfile 16 | build 17 | cmake 18 | click 19 | hypothesis 20 | jinja2 21 | mpmath==1.3.0 22 | ninja 23 | numpy 24 | pyre-extensions 25 | pyyaml 26 | scikit-build 27 | setuptools 28 | setuptools_git_versioning 29 | tabulate 30 | patchelf 31 | fairscale 32 | -------------------------------------------------------------------------------- /fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include "fbgemm_gpu/embedding_inplace_update.h" 14 | #include "fbgemm_gpu/utils/ops_utils.h" 15 | 16 | TORCH_LIBRARY_FRAGMENT(fbgemm, m) { 17 | DISPATCH_TO_CUDA( 18 | "emb_inplace_update", fbgemm_gpu::embedding_inplace_update_cuda); 19 | DISPATCH_TO_CUDA( 20 | "pruned_array_lookup_from_row_idx", 21 | fbgemm_gpu::pruned_array_lookup_from_row_idx_cuda); 22 | } 23 | -------------------------------------------------------------------------------- /fbgemm_gpu/src/jagged_tensor_ops/stacked_jagged_1d_to_dense.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/FBGEMM/ee0264c59fc6a403100c18f9c676f787c84ccf5b/fbgemm_gpu/src/jagged_tensor_ops/stacked_jagged_1d_to_dense.cu -------------------------------------------------------------------------------- /fbgemm_gpu/src/jagged_tensor_ops/stacked_jagged_2d_to_dense.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/FBGEMM/ee0264c59fc6a403100c18f9c676f787c84ccf5b/fbgemm_gpu/src/jagged_tensor_ops/stacked_jagged_2d_to_dense.cu -------------------------------------------------------------------------------- /fbgemm_gpu/src/layout_transform_ops/layout_transform_ops_gpu.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | #include 11 | #include 12 | #include "fbgemm_gpu/sparse_ops.h" 13 | #include "fbgemm_gpu/utils/ops_utils.h" 14 | 15 | TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { 16 | DISPATCH_TO_CUDA( 17 | "recat_embedding_grad_output_mixed_D_batch", 18 | fbgemm_gpu::recat_embedding_grad_output_mixed_D_batch_cuda); 19 | DISPATCH_TO_CUDA( 20 | "recat_embedding_grad_output_mixed_D", 21 | fbgemm_gpu::recat_embedding_grad_output_mixed_D_cuda); 22 | DISPATCH_TO_CUDA( 23 | "recat_embedding_grad_output", 24 | fbgemm_gpu::recat_embedding_grad_output_cuda); 25 | } 26 | -------------------------------------------------------------------------------- /fbgemm_gpu/src/memory_utils/common.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #include "common.h" 20 | #include "fbgemm_gpu/cumem_utils.h" 21 | #include "fbgemm_gpu/utils/enum_utils.h" 22 | 23 | namespace fbgemm_gpu { 24 | 25 | FBGEMM_GPU_ENUM_CREATE_TAG(uvm) 26 | 27 | } // namespace fbgemm_gpu 28 | -------------------------------------------------------------------------------- /fbgemm_gpu/src/memory_utils/common.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | 13 | using Tensor = at::Tensor; 14 | 15 | namespace fbgemm_gpu { 16 | 17 | Tensor new_unified_tensor_cpu( 18 | const Tensor& self, 19 | const std::vector& sizes, 20 | bool is_host_mapped); 21 | 22 | } // namespace fbgemm_gpu 23 | -------------------------------------------------------------------------------- /fbgemm_gpu/src/memory_utils/memory_utils.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "common.h" 10 | #include "fbgemm_gpu/cumem_utils.h" 11 | 12 | using Tensor = at::Tensor; 13 | 14 | namespace fbgemm_gpu { 15 | 16 | Tensor new_managed_tensor_meta( 17 | const Tensor& self, 18 | const std::vector& sizes) { 19 | return at::empty(sizes, self.options()); 20 | } 21 | 22 | Tensor new_unified_tensor_meta( 23 | const Tensor& self, 24 | const std::vector& sizes, 25 | bool /*is_host_mapped*/) { 26 | return at::empty(sizes, self.options()); 27 | } 28 | 29 | Tensor new_unified_tensor_cpu( 30 | const Tensor& self, 31 | const std::vector& sizes, 32 | bool is_host_mapped) { 33 | return at::empty({0}, self.options()); 34 | } 35 | 36 | } // namespace fbgemm_gpu 37 | -------------------------------------------------------------------------------- /fbgemm_gpu/src/memory_utils/memory_utils_ops.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | #include "common.h" 11 | #include "fbgemm_gpu/cumem_utils.h" 12 | #include "fbgemm_gpu/utils/ops_utils.h" 13 | 14 | using Tensor = at::Tensor; 15 | 16 | namespace fbgemm_gpu { 17 | 18 | TORCH_LIBRARY_FRAGMENT(fbgemm, m) { 19 | m.def("new_managed_tensor(Tensor self, int[] sizes) -> Tensor"); 20 | m.def("new_host_mapped_tensor(Tensor self, int[] sizes) -> Tensor"); 21 | m.def("new_vanilla_managed_tensor(Tensor self, int[] sizes) -> Tensor"); 22 | m.def( 23 | "new_unified_tensor(Tensor self, int[] sizes, bool is_host_mapped) -> Tensor"); 24 | 25 | DISPATCH_TO_CPU("new_unified_tensor", new_unified_tensor_cpu); 26 | DISPATCH_TO_META("new_managed_tensor", new_managed_tensor_meta); 27 | DISPATCH_TO_META("new_unified_tensor", new_unified_tensor_meta); 28 | } 29 | 30 | } // namespace fbgemm_gpu 31 | -------------------------------------------------------------------------------- /fbgemm_gpu/src/metric_ops/metric_ops.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | 11 | namespace fbgemm_gpu { 12 | 13 | at::Tensor batch_auc( 14 | const int64_t num_tasks, 15 | const at::Tensor& indices, 16 | const at::Tensor& labels, 17 | const at::Tensor& weights); 18 | 19 | } // namespace fbgemm_gpu 20 | -------------------------------------------------------------------------------- /fbgemm_gpu/src/metric_ops/metric_ops_host.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | #include 11 | 12 | #include "fbgemm_gpu/utils/ops_utils.h" 13 | #include "metric_ops.h" 14 | 15 | namespace fbgemm_gpu { 16 | 17 | TORCH_LIBRARY_FRAGMENT(fbgemm, m) { 18 | m.def( 19 | "batch_auc(int num_tasks, Tensor indices, Tensor laebls, Tensor weights) -> Tensor"); 20 | } 21 | 22 | TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { 23 | DISPATCH_TO_CUDA("batch_auc", fbgemm_gpu::batch_auc); 24 | } 25 | 26 | } // namespace fbgemm_gpu 27 | -------------------------------------------------------------------------------- /fbgemm_gpu/src/placeholder.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | /// This is a placeholder source file that is used to force compilation and 10 | /// generation of an .SO file. 11 | namespace fbgemm_gpu {} 12 | -------------------------------------------------------------------------------- /fbgemm_gpu/src/quantize_ops/mx/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /fbgemm_gpu/src/sparse_ops/common.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | 11 | using Tensor = at::Tensor; 12 | 13 | namespace fbgemm_gpu { 14 | 15 | namespace { 16 | inline Tensor native_empty_like(const Tensor& self) { 17 | return at::native::empty_like( 18 | self, 19 | c10::optTypeMetaToScalarType(self.options().dtype_opt()), 20 | self.options().layout_opt(), 21 | self.options().device_opt(), 22 | self.options().pinned_memory_opt(), 23 | std::nullopt); 24 | } 25 | 26 | } // namespace 27 | 28 | }; // namespace fbgemm_gpu 29 | -------------------------------------------------------------------------------- /fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "common.h" 10 | 11 | using Tensor = at::Tensor; 12 | using namespace fbgemm_gpu; 13 | 14 | namespace fbgemm_gpu { 15 | 16 | DLL_PUBLIC void lfu_cache_populate_byte_cpu( 17 | Tensor weights, 18 | Tensor cache_hash_size_cumsum, 19 | int64_t total_cache_hash_size, 20 | Tensor cache_index_table_map, 21 | Tensor weights_offsets, 22 | Tensor weights_tys, 23 | Tensor D_offsets, 24 | Tensor linear_cache_indices, 25 | Tensor lxu_cache_state, 26 | Tensor lxu_cache_weights, 27 | Tensor lfu_state, 28 | int64_t row_alignment) { 29 | return; 30 | } 31 | 32 | } // namespace fbgemm_gpu 33 | -------------------------------------------------------------------------------- /fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" // @manual 10 | #include "fbgemm_gpu/split_embeddings_utils.cuh" // @manual 11 | #include "fbgemm_gpu/utils/ops_utils.h" // @manual 12 | 13 | using Tensor = at::Tensor; 14 | using namespace fbgemm_gpu; 15 | 16 | DLL_PUBLIC std::tuple 17 | get_infos_metadata(Tensor unused, int64_t B, int64_t T) { 18 | return get_info_B_num_bits_from_T(T, B); 19 | } 20 | -------------------------------------------------------------------------------- /fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fbgemm_gpu/split_embeddings_utils.cuh" // @manual 10 | #include 11 | #include 12 | #include "fbgemm_gpu/utils/ops_utils.h" 13 | 14 | using Tensor = at::Tensor; 15 | using namespace fbgemm_gpu; 16 | 17 | TORCH_LIBRARY_FRAGMENT(fbgemm, m) { 18 | DISPATCH_TO_CUDA("transpose_embedding_input", transpose_embedding_input); 19 | DISPATCH_TO_CUDA("get_infos_metadata", get_infos_metadata); 20 | DISPATCH_TO_CUDA("generate_vbe_metadata", generate_vbe_metadata); 21 | } 22 | -------------------------------------------------------------------------------- /fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_cuda_utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | #include 11 | #include 12 | 13 | namespace kv_db_utils { 14 | 15 | /// @ingroup embedding-ssd 16 | /// 17 | /// @brief A callback function for `cudaStreamAddCallback` 18 | /// 19 | /// A common callback function for `cudaStreamAddCallback`, i.e., 20 | /// `cudaStreamCallback_t callback`. This function casts `functor` 21 | /// into a void function, invokes it and then delete it (the deletion 22 | /// occurs in another thread) 23 | /// 24 | /// @param stream CUDA stream that `cudaStreamAddCallback` operates on 25 | /// @param status CUDA status 26 | /// @param functor A functor that will be called 27 | /// 28 | /// @return None 29 | void cuda_callback_func(cudaStream_t stream, cudaError_t status, void* functor); 30 | 31 | }; // namespace kv_db_utils 32 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/combine/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/jagged/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/lint/flake8_problem_matcher.json: -------------------------------------------------------------------------------- 1 | { 2 | "problemMatcher": [ 3 | { 4 | "owner": "flake8", 5 | "severity": "error", 6 | "pattern": [ 7 | { 8 | "regexp": "^([^:]+):(\\d+):(\\d+):\\s+(.*)$", 9 | "file": 1, 10 | "line": 2, 11 | "column": 3, 12 | "message": 4 13 | } 14 | ] 15 | } 16 | ] 17 | } 18 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/permute/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/permute/failures_dict.json: -------------------------------------------------------------------------------- 1 | { 2 | "_description": "This is a dict containing failures for tests autogenerated by generate_opcheck_tests. For more details, please see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit", 3 | "_version": 1, 4 | "data": { 5 | "fbgemm::permute_pooled_embs": {}, 6 | "fbgemm::permute_pooled_embs_auto_grad": {} 7 | } 8 | } 9 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/quantize/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from .common import ( # noqa F401 10 | fused_rowwise_8bit_dequantize_reference, 11 | fused_rowwise_8bit_dequantize_reference_half, 12 | fused_rowwise_8bit_quantize_reference, 13 | fused_rowwise_nbit_quantize_dequantize_reference, 14 | fused_rowwise_nbit_quantize_reference, 15 | ) 16 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/quantize/mx/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/release/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/release/example.json: -------------------------------------------------------------------------------- 1 | { 2 | "_description": "This is a dict containing example schemas. The schema of future releases need to be backward and forward compatible. For more details, please see https://docs.google.com/document/d/18I0lSkyHHqJ5BY30bx8YhpQHAMOg25nAFV2zeO8PIGk/edit#heading=h.y00l3f1ht5u1", 3 | "_version": 1, 4 | "data": { 5 | "mx4_to_fp32": 6 | "mx4_to_fp32(Tensor tensor, int group_size=32, bool use_triton=True, int ebits=2, int mbits=1) -> Tensor", 7 | "merge_pooled_embeddings": 8 | "merge_pooled_embeddings(Tensor[] pooled_embeddings, int uncat_dim_size, Device target_device, int cat_dim=1) -> Tensor", 9 | "dummy_func": 10 | "dummy_func(str var1, int var2) -> ()" 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/sll/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/sll/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | # pyre-ignore-all-errors[56] 9 | 10 | import fbgemm_gpu 11 | import fbgemm_gpu.sll 12 | import torch 13 | 14 | # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. 15 | open_source: bool = getattr(fbgemm_gpu, "open_source", False) 16 | 17 | if not open_source: 18 | torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") 19 | 20 | 21 | def clone_tensor(data: torch.Tensor) -> torch.Tensor: 22 | if data.requires_grad: 23 | return data.detach().clone().requires_grad_() 24 | return data.detach().clone() 25 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/sparse/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/tbe/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/tbe/cache/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/tbe/dram_kv/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/tbe/dram_kv/failures_dict_fast.json: -------------------------------------------------------------------------------- 1 | { 2 | } 3 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/tbe/inference/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/tbe/ssd/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/tbe/ssd/failures_dict_fast.json: -------------------------------------------------------------------------------- 1 | { 2 | } 3 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/tbe/stats/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/tbe/stats/failures_dict_fast.json: -------------------------------------------------------------------------------- 1 | { 2 | } 3 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/tbe/training/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | -------------------------------------------------------------------------------- /fbgemm_gpu/test/tbe/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | -------------------------------------------------------------------------------- /include/fbgemm/FbgemmI64.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | 13 | #include "fbgemm/Utils.h" 14 | 15 | namespace fbgemm { 16 | 17 | FBGEMM_API void cblas_gemm_i64_i64acc( 18 | matrix_op_t transa, 19 | matrix_op_t transb, 20 | int M, 21 | int N, 22 | int K, 23 | const std::int64_t* A, 24 | int lda, 25 | const std::int64_t* B, 26 | int ldb, 27 | bool accumulate, 28 | std::int64_t* C, 29 | int ldc); 30 | 31 | } // namespace fbgemm 32 | -------------------------------------------------------------------------------- /include/fbgemm/QuantUtilsAvx512.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include "./FbgemmBuild.h" // @manual 13 | #include "./UtilsAvx2.h" // @manual 14 | 15 | /// @defgroup fbgemm-quant-utils-avx512 Quantization Utilities (AVX512) 16 | /// 17 | 18 | namespace fbgemm { 19 | 20 | /// @ingroup fbgemm-quant-utils-avx512 21 | /// 22 | /// Requantize with AVX512. 23 | template < 24 | bool A_SYMMETRIC, 25 | bool B_SYMMETRIC, 26 | QuantizationGranularity Q_GRAN, 27 | bool HAS_BIAS, 28 | bool FUSE_RELU, 29 | int C_PER_G, 30 | typename BIAS_TYPE = std::int32_t> 31 | FBGEMM_API void requantizeOutputProcessingGConvAvx512( 32 | std::uint8_t* out, 33 | const std::int32_t* inp, 34 | const block_type_t& block, 35 | int ld_out, 36 | int ld_in, 37 | const requantizationParams_t& r); 38 | } // namespace fbgemm 39 | -------------------------------------------------------------------------------- /include/fbgemm/QuantUtilsNeon.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #ifdef __aarch64__ 12 | 13 | #include 14 | #include "./FbgemmBuild.h" // @manual 15 | 16 | /// @defgroup fbgemm-quant-utils-avx2 Quantization Utilities (AVX2) 17 | /// 18 | 19 | namespace fbgemm { 20 | 21 | //////////////////////////////////////////////////////////////////////////////// 22 | // Utility functions 23 | //////////////////////////////////////////////////////////////////////////////// 24 | 25 | template 26 | void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon( 27 | const std::uint8_t* input, 28 | size_t input_rows, 29 | int input_columns, 30 | OutputType* output); 31 | 32 | } // namespace fbgemm 33 | 34 | #endif // __aarch64__ 35 | -------------------------------------------------------------------------------- /include/fbgemm/Types.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | 13 | namespace fbgemm { 14 | 15 | using float16 = std::uint16_t; 16 | using bfloat16 = std::uint16_t; 17 | 18 | inline int64_t round_up(int64_t val, int64_t unit) { 19 | return (val + unit - 1) / unit * unit; 20 | } 21 | 22 | inline int64_t div_up(int64_t val, int64_t unit) { 23 | return (val + unit - 1) / unit; 24 | } 25 | 26 | } // namespace fbgemm 27 | -------------------------------------------------------------------------------- /netlify.toml: -------------------------------------------------------------------------------- 1 | [build] 2 | base = "fbgemm_gpu/docs" 3 | 4 | # Unconditionally rebuild the docs 5 | # https://docs.netlify.com/configure-builds/ignore-builds/ 6 | ignore = "/bin/false" 7 | 8 | [context.deploy-preview] 9 | publish = "build/html" 10 | command = """ 11 | # Load scripts 12 | export BUILD_ENV=build_docs 13 | . ../../.github/scripts/setup_env.bash 14 | 15 | # Print system info 16 | print_exec uname -a 17 | print_exec ldd --version 18 | 19 | # Set up Conda environment 20 | setup_miniconda $HOME/miniconda 21 | create_conda_environment $BUILD_ENV 3.13 22 | 23 | # Install tools 24 | install_cxx_compiler $BUILD_ENV 25 | install_build_tools $BUILD_ENV 26 | install_docs_tools $BUILD_ENV 27 | install_pytorch_pip $BUILD_ENV nightly cpu 28 | 29 | # Build the code 30 | cd .. 31 | prepare_fbgemm_gpu_build $BUILD_ENV 32 | build_fbgemm_gpu_install $BUILD_ENV docs 33 | 34 | # Build the docs 35 | cd docs 36 | build_fbgemm_gpu_docs $BUILD_ENV 37 | """ 38 | -------------------------------------------------------------------------------- /src/ExecuteKernel.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | namespace fbgemm {} // namespace fbgemm 10 | -------------------------------------------------------------------------------- /src/ExecuteKernel.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | #include 11 | #include "./ExecuteKernelGeneric.h" // @manual 12 | #include "./ExecuteKernelU8S8.h" // @manual 13 | #include "fbgemm/Fbgemm.h" 14 | -------------------------------------------------------------------------------- /src/FbgemmFP16UKernelsAvx2.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | #include 11 | #include "fbgemm/FbgemmBuild.h" 12 | #include "fbgemm/FbgemmFPCommon.h" 13 | #include "fbgemm/Types.h" 14 | 15 | namespace fbgemm { 16 | 17 | using GemmParamsFP16 = GemmParams; 18 | 19 | void NOINLINE gemmkernel_1x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp); 20 | void NOINLINE gemmkernel_2x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp); 21 | void NOINLINE gemmkernel_3x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp); 22 | void NOINLINE gemmkernel_4x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp); 23 | void NOINLINE gemmkernel_5x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp); 24 | void NOINLINE gemmkernel_6x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp); 25 | 26 | } // namespace fbgemm 27 | -------------------------------------------------------------------------------- /src/FbgemmFP16UKernelsAvx512_256.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | #include 11 | #include "fbgemm/FbgemmBuild.h" 12 | #include "fbgemm/FbgemmFPCommon.h" 13 | #include "fbgemm/Types.h" 14 | 15 | namespace fbgemm { 16 | 17 | using GemmParamsFP16 = GemmParams; 18 | 19 | void NOINLINE gemmkernel_7x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp); 20 | void NOINLINE gemmkernel_8x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp); 21 | void NOINLINE gemmkernel_9x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp); 22 | void NOINLINE gemmkernel_10x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp); 23 | void NOINLINE gemmkernel_11x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp); 24 | void NOINLINE gemmkernel_12x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp); 25 | void NOINLINE gemmkernel_13x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp); 26 | void NOINLINE gemmkernel_14x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp); 27 | 28 | } // namespace fbgemm 29 | -------------------------------------------------------------------------------- /src/FbgemmFP16UKernelsSve128.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | #include 11 | #include "fbgemm/FbgemmBuild.h" 12 | #include "fbgemm/FbgemmFPCommon.h" 13 | #include "fbgemm/Types.h" 14 | 15 | namespace fbgemm { 16 | 17 | using GemmParamsFP16 = GemmParams; 18 | 19 | void NOINLINE gemmkernel_1x2_Sve128_fp16_fA0fB0fC0(GemmParamsFP16* gp); 20 | void NOINLINE gemmkernel_2x2_Sve128_fp16_fA0fB0fC0(GemmParamsFP16* gp); 21 | void NOINLINE gemmkernel_3x2_Sve128_fp16_fA0fB0fC0(GemmParamsFP16* gp); 22 | void NOINLINE gemmkernel_4x2_Sve128_fp16_fA0fB0fC0(GemmParamsFP16* gp); 23 | void NOINLINE gemmkernel_5x2_Sve128_fp16_fA0fB0fC0(GemmParamsFP16* gp); 24 | void NOINLINE gemmkernel_6x2_Sve128_fp16_fA0fB0fC0(GemmParamsFP16* gp); 25 | 26 | } // namespace fbgemm 27 | -------------------------------------------------------------------------------- /src/GenerateI8Depthwise.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | 14 | namespace fbgemm { 15 | 16 | class GenI8Depthwise { 17 | public: 18 | using jit_kernel_signature = void (*)( 19 | const std::uint8_t* a, 20 | const std::int8_t* b, 21 | std::int32_t* c, 22 | std::int32_t* a_sum, // row_wise sum of A 23 | int h, 24 | int w, 25 | int ic, // the number of input channels == the number of groups 26 | const int* mask, 27 | int A_zero_point); 28 | 29 | jit_kernel_signature getOrCreate( 30 | int D, // dimension 31 | std::array F, // filter size (K_T, K_H, K_W) 32 | int oc_per_g, // the number of output channels per group 33 | bool compute_a_sum, 34 | int remainder, // the number of channels in the remainder loop 35 | int prev_skip, 36 | int next_skip, 37 | int top_skip, 38 | int bottom_skip, 39 | int left_skip, 40 | int right_skip); 41 | }; 42 | 43 | } // namespace fbgemm 44 | -------------------------------------------------------------------------------- /src/GenerateKernel.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "./GenerateKernel.h" // @manual 10 | 11 | namespace fbgemm { 12 | 13 | namespace x86 = asmjit::x86; 14 | 15 | /** 16 | * Generate instructions for initializing the C registers to 0 in 32-bit 17 | * Accumulation kernel. 18 | */ 19 | void initCRegs(x86::Emitter* a, int rowRegs, int colRegs) { 20 | using CRegs = x86::Xmm; 21 | // Take advantage of implicit zeroing out 22 | // i.e., zero out xmm and ymm will be zeroed out too 23 | for (int i = 0; i < rowRegs; ++i) { 24 | for (int j = 0; j < colRegs; ++j) { 25 | a->vpxor( 26 | CRegs(i * colRegs + j), 27 | CRegs(i * colRegs + j), 28 | CRegs(i * colRegs + j)); 29 | } 30 | } 31 | } 32 | 33 | } // namespace fbgemm 34 | -------------------------------------------------------------------------------- /src/KleidiAIFP16UKernelsNeon.h: -------------------------------------------------------------------------------- 1 | /* 2 | * @lint-ignore-every LICENSELINT 3 | * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliate 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | */ 6 | #ifdef FBGEMM_ENABLE_KLEIDIAI 7 | 8 | #pragma once 9 | #include 10 | #include "fbgemm/FbgemmBuild.h" 11 | #include "fbgemm/FbgemmFPCommon.h" 12 | #include "fbgemm/Types.h" 13 | 14 | namespace kleidiai { 15 | 16 | using GemmParamsFP16 = fbgemm::GemmParams; 17 | 18 | void NOINLINE gemmkernel_1x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp); 19 | void NOINLINE gemmkernel_2x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp); 20 | void NOINLINE gemmkernel_3x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp); 21 | void NOINLINE gemmkernel_4x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp); 22 | void NOINLINE gemmkernel_5x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp); 23 | void NOINLINE gemmkernel_6x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp); 24 | void NOINLINE gemmkernel_7x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp); 25 | void NOINLINE gemmkernel_8x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp); 26 | 27 | } // namespace kleidiai 28 | 29 | #endif 30 | -------------------------------------------------------------------------------- /src/OptimizedKernelsAvx2.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include // for std::int32_t 12 | #include "fbgemm/FbgemmBuild.h" 13 | 14 | namespace fbgemm { 15 | 16 | /** 17 | * @brief Sum a given vector. 18 | */ 19 | FBGEMM_API std::int32_t reduceAvx2(const std::uint8_t* A, int len); 20 | 21 | /** 22 | * @brief Transpose 8 rows from source matrix. 23 | */ 24 | void transpose_8rows( 25 | int N, 26 | const uint8_t* src, 27 | int ld_src, 28 | uint8_t* dst, 29 | int ld_dst); 30 | 31 | /** 32 | * @brief avx2 part of the spmdm code. 33 | */ 34 | void spmdmKernelAvx2( 35 | int N, 36 | const uint8_t* A_buffer, 37 | const int32_t* colptr, 38 | const int8_t* values, 39 | const int16_t* rowidx, 40 | int32_t* C_buffer); 41 | 42 | } // namespace fbgemm 43 | -------------------------------------------------------------------------------- /src/fp32/FbgemmFP32UKernelsAvx2.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | #include 11 | #include "fbgemm/FbgemmBuild.h" 12 | #include "fbgemm/FbgemmFPCommon.h" 13 | #include "fbgemm/Types.h" 14 | 15 | namespace fbgemm { 16 | 17 | using GemmParamsFP32 = GemmParams; 18 | 19 | void NOINLINE gemmkernel_1x2_Avx2_fp32_fA0fB0fC0(GemmParamsFP32* gp); 20 | void NOINLINE gemmkernel_2x2_Avx2_fp32_fA0fB0fC0(GemmParamsFP32* gp); 21 | void NOINLINE gemmkernel_3x2_Avx2_fp32_fA0fB0fC0(GemmParamsFP32* gp); 22 | void NOINLINE gemmkernel_4x2_Avx2_fp32_fA0fB0fC0(GemmParamsFP32* gp); 23 | void NOINLINE gemmkernel_5x2_Avx2_fp32_fA0fB0fC0(GemmParamsFP32* gp); 24 | void NOINLINE gemmkernel_6x2_Avx2_fp32_fA0fB0fC0(GemmParamsFP32* gp); 25 | 26 | } // namespace fbgemm 27 | -------------------------------------------------------------------------------- /src/fp32/FbgemmFP32UKernelsAvx512_256.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | #include 11 | #include "fbgemm/FbgemmBuild.h" 12 | #include "fbgemm/FbgemmFPCommon.h" 13 | #include "fbgemm/Types.h" 14 | 15 | namespace fbgemm { 16 | 17 | using GemmParamsFP32 = GemmParams; 18 | 19 | void NOINLINE gemmkernel_7x2_Avx512_256_fp32_fA0fB0fC0(GemmParamsFP32* gp); 20 | void NOINLINE gemmkernel_8x2_Avx512_256_fp32_fA0fB0fC0(GemmParamsFP32* gp); 21 | void NOINLINE gemmkernel_9x2_Avx512_256_fp32_fA0fB0fC0(GemmParamsFP32* gp); 22 | void NOINLINE gemmkernel_10x2_Avx512_256_fp32_fA0fB0fC0(GemmParamsFP32* gp); 23 | void NOINLINE gemmkernel_11x2_Avx512_256_fp32_fA0fB0fC0(GemmParamsFP32* gp); 24 | void NOINLINE gemmkernel_12x2_Avx512_256_fp32_fA0fB0fC0(GemmParamsFP32* gp); 25 | void NOINLINE gemmkernel_13x2_Avx512_256_fp32_fA0fB0fC0(GemmParamsFP32* gp); 26 | void NOINLINE gemmkernel_14x2_Avx512_256_fp32_fA0fB0fC0(GemmParamsFP32* gp); 27 | 28 | } // namespace fbgemm 29 | -------------------------------------------------------------------------------- /src/fp32/KleidiAIFP32UKernelsNeon.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #ifdef FBGEMM_ENABLE_KLEIDIAI 7 | 8 | #pragma once 9 | #include 10 | #include "fbgemm/FbgemmBuild.h" 11 | #include "fbgemm/FbgemmFPCommon.h" 12 | #include "fbgemm/Types.h" 13 | 14 | namespace kleidiai { 15 | 16 | using GemmParamsFP32 = fbgemm::GemmParams; 17 | 18 | void NOINLINE gemmkernel_1x2_Neon_fp32_fA0fB0fC0(GemmParamsFP32* gp); 19 | void NOINLINE gemmkernel_2x2_Neon_fp32_fA0fB0fC0(GemmParamsFP32* gp); 20 | void NOINLINE gemmkernel_3x2_Neon_fp32_fA0fB0fC0(GemmParamsFP32* gp); 21 | void NOINLINE gemmkernel_4x2_Neon_fp32_fA0fB0fC0(GemmParamsFP32* gp); 22 | void NOINLINE gemmkernel_5x2_Neon_fp32_fA0fB0fC0(GemmParamsFP32* gp); 23 | void NOINLINE gemmkernel_6x2_Neon_fp32_fA0fB0fC0(GemmParamsFP32* gp); 24 | 25 | } // namespace kleidiai 26 | 27 | #endif 28 | -------------------------------------------------------------------------------- /test/FP16Test.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | 11 | #include "./FBGemmFPTest.h" 12 | #include "fbgemm/FbgemmFP16.h" 13 | 14 | using FBGemmFP16Test = fbgemm::FBGemmFPTest; 15 | 16 | INSTANTIATE_TEST_CASE_P( 17 | InstantiationName, 18 | FBGemmFP16Test, 19 | ::testing::Values( 20 | std::pair( 21 | fbgemm::matrix_op_t::NoTranspose, fbgemm::matrix_op_t::NoTranspose), 22 | std::pair( 23 | fbgemm::matrix_op_t::NoTranspose, fbgemm::matrix_op_t::Transpose)/*, 24 | pair( 25 | matrix_op_t::Transpose, matrix_op_t::NoTranspose), 26 | pair( 27 | matrix_op_t::Transpose, matrix_op_t::Transpose)*/)); 28 | 29 | TEST_P(FBGemmFP16Test, Test) { 30 | TestRun(); 31 | } 32 | 33 | TEST_P(FBGemmFP16Test, Unpack) { 34 | UnpackTestRun(); 35 | } 36 | -------------------------------------------------------------------------------- /test/QuantizationHelpers.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | #include 11 | 12 | namespace fbgemm { 13 | 14 | /* 15 | * @brief Make sure we won't have overflows from vpmaddubsw instruction. 16 | */ 17 | template 18 | void avoidOverflow( 19 | int m, 20 | int n, 21 | int k, 22 | const uint8_t* Aint8, 23 | int lda, 24 | T* B, 25 | int ldb); 26 | 27 | template 28 | void avoidOverflow(int m, int n, int k, const uint8_t* Aint8, T* B); 29 | 30 | } // namespace fbgemm 31 | --------------------------------------------------------------------------------