├── .github ├── merge_rules.yaml ├── pytorch-probot.yml ├── scripts │ ├── github_utils.py │ ├── gitutils.py │ ├── label_utils.py │ ├── trymerge.py │ ├── trymerge_explainer.py │ └── validate_binaries.sh └── workflows │ ├── build_wheels_linux.yml │ ├── dashboard_perf_test.yml │ ├── doc_build.yml │ ├── float8_test.yml │ ├── nightly_smoke_test.yml │ ├── pr-label-check.yml │ ├── regression_test.yml │ ├── regression_test_rocm.yml │ ├── ruff_linter.yml │ ├── run_tutorials.yml │ ├── torchao_experimental_test.yml │ ├── trymerge.yml │ └── validate-binaries.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── CITATION.cff ├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── benchmarks ├── __init__.py ├── bench_galore_fused_kernels.py ├── benchmark_aq.py ├── benchmark_blockwise_scaled_linear_triton.py ├── benchmark_e2e_fp8_sparse_linear.py ├── benchmark_fp6.py ├── benchmark_gpu_sparsity.py ├── benchmark_hqq.py ├── benchmark_low_bit_adam.py ├── benchmark_marlin_qqq.py ├── benchmark_rowwise_scaled_linear_cutlass.py ├── benchmark_rowwise_scaled_linear_sparse_cutlass.py ├── benchmark_semi_sparse_training.py ├── benchmark_sparse_conversion_cutlass.py ├── benchmark_uintx.py ├── float8 │ ├── bench_linear_float8.py │ ├── bench_matmul.py │ ├── bench_padding.py │ ├── float8_roofline.py │ ├── profile_lowp_training.py │ ├── training │ │ ├── README.md │ │ ├── parse_torchtitan_logs.py │ │ └── torchtitan_benchmark.sh │ └── utils.py ├── fused_benchmark_utils.py ├── intmm.py ├── intmm_shapes.csv ├── microbenchmarks │ ├── README.md │ ├── __init__.py │ ├── benchmark_inference.py │ ├── benchmark_runner.py │ ├── profiler.py │ ├── test │ │ ├── __init__.py │ │ ├── benchmark_config.yml │ │ ├── test_benchmark_inference.py │ │ ├── test_benchmark_profiler.py │ │ ├── test_benchmark_runner.py │ │ └── test_utils.py │ └── utils.py ├── mx_formats │ └── cast_bench.py ├── print_config_shapes.py ├── quantized_training │ ├── benchmark_int8mm.py │ └── pretrain_llama2.py ├── sam_benchmark_results.csv └── sam_vit_b_shapes.csv ├── dev-requirements.txt ├── docs ├── Makefile ├── README.txt ├── requirements.txt ├── source │ ├── _static │ │ ├── css │ │ │ └── custom.css │ │ └── img │ │ │ ├── card-background.svg │ │ │ ├── generic-pytorch-logo.png │ │ │ └── pytorch-logo-dark.svg │ ├── _templates │ │ ├── autosummary │ │ │ ├── class.rst │ │ │ └── function.rst │ │ └── layout.html │ ├── api_ref_dtypes.rst │ ├── api_ref_intro.rst │ ├── api_ref_kernel.rst │ ├── api_ref_quantization.rst │ ├── api_ref_sparsity.rst │ ├── conf.py │ ├── contributor_guide.rst │ ├── custom_directives.py │ ├── dtypes.rst │ ├── index.rst │ ├── performant_kernels.rst │ ├── quantization.rst │ ├── quick_start.rst │ ├── serialization.rst │ ├── sparsity.rst │ ├── subclass_advanced.rst │ ├── subclass_basic.rst │ └── tutorials_source │ │ ├── README.txt │ │ └── template_tutorial.py └── static │ ├── microbenchmarking_process_diagram.png │ ├── microbenchmarks_code_flow_diagram.png │ ├── pruning_ecosystem_diagram.png │ ├── pruning_flow.png │ └── supported_sparsity_patterns.png ├── examples ├── README.md ├── sam2_amg_server │ ├── README.md │ ├── amg_example.py │ ├── annotate_with_rle.py │ ├── cli.py │ ├── cli_on_modal.py │ ├── compare_rle_lists.py │ ├── compile_export_utils.py │ ├── dog.jpg │ ├── dog_rle.json │ ├── example.html │ ├── generate_data.py │ ├── modal_experiments.sh │ ├── reproduce_experiments.py │ ├── requirements.txt │ ├── result.csv │ ├── result_batch_size_16.csv │ ├── result_batch_size_8.csv │ └── server.py └── sam2_vos_example │ ├── compile_export_utils.py │ ├── requirements.txt │ └── video_profile.py ├── packaging ├── env_var_script_linux.sh ├── post_build_script.sh ├── pre_build_script.sh ├── smoke_test.py └── vc_env_helper.bat ├── ruff.toml ├── scripts ├── check_copyright_header.py ├── clean_release_notes.py ├── convert_hf_checkpoint.py ├── create_weight_map.py ├── download.py ├── download_sam2_ckpts.sh ├── prepare.sh ├── quick_start.py ├── run_ruff_fix.sh └── upload_to_s3.py ├── setup.py ├── test ├── dtypes │ ├── ddp │ │ ├── check_ddp_nf4.py │ │ ├── ddp_nf4.py │ │ └── run_ddp_nf4_test.sh │ ├── test_affine_quantized.py │ ├── test_affine_quantized_float.py │ ├── test_affine_quantized_tensor_parallel.py │ ├── test_bitpacking.py │ ├── test_fbgemm_quantized.py │ ├── test_fbgemm_quantized_tensor.py │ ├── test_floatx.py │ ├── test_nf4.py │ ├── test_uint4.py │ └── test_uintx.py ├── float8 │ ├── test_base.py │ ├── test_compile.py │ ├── test_dtensor.py │ ├── test_dtensor.sh │ ├── test_everything.sh │ ├── test_float8_utils.py │ ├── test_fsdp.py │ ├── test_fsdp.sh │ ├── test_fsdp2 │ │ └── test_fsdp2.py │ ├── test_fsdp2_tp.py │ ├── test_fsdp_compile.py │ ├── test_fsdp_compile.sh │ └── test_numerics_integration.py ├── galore │ ├── README.md │ ├── memory_analysis_utils.py │ ├── model_configs.py │ ├── profile_memory_usage.py │ └── profiling_utils.py ├── hqq │ ├── test_hqq_affine.py │ ├── test_triton_mm.py │ └── test_triton_qkv_fused.py ├── integration │ ├── test_integration.py │ └── test_vllm.py ├── kernel │ ├── galore_test_utils.py │ ├── test_autotuner.py │ ├── test_fused_kernels.py │ └── test_galore_downproj.py ├── prototype │ ├── inductor │ │ └── test_int8_sdpa_fusion.py │ ├── module_swap_quantization │ │ ├── test_kmeans_codebook.py │ │ ├── test_llm_ptq_data_getter.py │ │ ├── test_module_swap.py │ │ ├── test_module_swap_quantization_utils.py │ │ ├── test_quantized_modules.py │ │ ├── test_quantizers.py │ │ └── test_range_setting_methods.py │ ├── mx_formats │ │ ├── test_kernels.py │ │ ├── test_mx_linear.py │ │ ├── test_mx_mm.py │ │ └── test_mx_tensor.py │ ├── scaled_grouped_mm │ │ ├── __init__.py │ │ ├── test_kernels.py │ │ └── test_scaled_grouped_mm.py │ ├── test_autoround.py │ ├── test_awq.py │ ├── test_blockwise_triton.py │ ├── test_codebook_quant.py │ ├── test_gguf_quant.py │ ├── test_mixed_precision.py │ ├── test_parametrization.py │ ├── test_paretoq.py │ ├── test_parq.py │ ├── test_quantized_training.py │ ├── test_scheduler.py │ ├── test_smoothquant.py │ ├── test_sparsifier.py │ ├── test_sparsity_utils.py │ ├── test_spinquant.py │ └── test_structured_sparsifier.py ├── quantization │ ├── pt2e │ │ ├── test_arm_inductor_quantizer.py │ │ ├── test_duplicate_dq.py │ │ ├── test_graph_utils.py │ │ ├── test_metadata_porting.py │ │ ├── test_numeric_debugger.py │ │ ├── test_quantize_pt2e.py │ │ ├── test_quantize_pt2e_qat.py │ │ ├── test_representation.py │ │ ├── test_x86inductor_fusion.py │ │ └── test_x86inductor_quantizer.py │ ├── test_config_serialization.py │ ├── test_galore_quant.py │ ├── test_gptq_mt.py │ ├── test_marlin_qqq.py │ ├── test_moe_quant.py │ ├── test_observer.py │ ├── test_qat.py │ ├── test_quant_api.py │ └── test_quant_primitives.py ├── smoke_tests │ └── smoke_tests.py ├── sparsity │ ├── test_activation24.py │ ├── test_fast_sparse_training.py │ ├── test_marlin.py │ ├── test_sparse_api.py │ ├── test_supermask.py │ └── test_wanda.py ├── test_ao_models.py ├── test_low_bit_optim.py ├── test_model_architecture.py ├── test_ops.py ├── test_ops_rowwise_scaled_linear_cutlass.py ├── test_ops_rowwise_scaled_linear_sparse_cutlass.py └── test_utils.py ├── torchao ├── __init__.py ├── _executorch_ops.py ├── _models │ ├── README.md │ ├── __init__.py │ ├── _eval.py │ ├── llama │ │ ├── .gitignore │ │ ├── README.md │ │ ├── __init__.py │ │ ├── benchmark_results.txt │ │ ├── benchmarks.sh │ │ ├── bsr_bench_results.txt │ │ ├── bsr_benchmarks.sh │ │ ├── demo_summarize.sh │ │ ├── eval.py │ │ ├── evals.sh │ │ ├── generate.py │ │ ├── model.py │ │ └── tokenizer.py │ ├── mixtral-moe │ │ ├── README.md │ │ ├── generate.py │ │ ├── model.py │ │ ├── run.sh │ │ └── scripts │ │ │ ├── convert_hf_checkpoint.py │ │ │ ├── download.py │ │ │ └── prepare.sh │ ├── sam │ │ ├── .gitignore │ │ ├── README.md │ │ ├── benchmark.sh │ │ ├── data.py │ │ ├── eval_combo.py │ │ ├── flash_4_configs.p │ │ ├── metrics.py │ │ ├── results.csv │ │ └── setup.sh │ ├── sam2 │ │ ├── __init__.py │ │ ├── automatic_mask_generator.py │ │ ├── build_sam.py │ │ ├── configs │ │ │ ├── sam2.1 │ │ │ │ ├── sam2.1_hiera_b+.yaml │ │ │ │ ├── sam2.1_hiera_l.yaml │ │ │ │ ├── sam2.1_hiera_s.yaml │ │ │ │ └── sam2.1_hiera_t.yaml │ │ │ ├── sam2.1_training │ │ │ │ └── sam2.1_hiera_b+_MOSE_finetune.yaml │ │ │ └── sam2 │ │ │ │ ├── sam2_hiera_b+.yaml │ │ │ │ ├── sam2_hiera_l.yaml │ │ │ │ ├── sam2_hiera_s.yaml │ │ │ │ └── sam2_hiera_t.yaml │ │ ├── csrc │ │ │ └── connected_components.cu │ │ ├── map_tensor.py │ │ ├── modeling │ │ │ ├── __init__.py │ │ │ ├── backbones │ │ │ │ ├── __init__.py │ │ │ │ ├── hieradet.py │ │ │ │ ├── image_encoder.py │ │ │ │ └── utils.py │ │ │ ├── memory_attention.py │ │ │ ├── memory_encoder.py │ │ │ ├── position_encoding.py │ │ │ ├── sam │ │ │ │ ├── __init__.py │ │ │ │ ├── mask_decoder.py │ │ │ │ ├── prompt_encoder.py │ │ │ │ └── transformer.py │ │ │ ├── sam2_base.py │ │ │ └── sam2_utils.py │ │ ├── sam2_hiera_b+.yaml │ │ ├── sam2_hiera_l.yaml │ │ ├── sam2_hiera_s.yaml │ │ ├── sam2_hiera_t.yaml │ │ ├── sam2_image_predictor.py │ │ ├── sam2_video_predictor.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── amg.py │ │ │ ├── misc.py │ │ │ └── transforms.py │ └── utils.py ├── core │ ├── __init__.py │ └── config.py ├── csrc │ ├── README.md │ ├── cpu │ │ └── int8_sdpa.cpp │ ├── cuda │ │ ├── activation24 │ │ │ ├── compute_sparse_tile.h │ │ │ ├── sparse24_metadata.h │ │ │ ├── sparse_gemm.cu │ │ │ ├── sparsify24.cu │ │ │ ├── static_sort.h │ │ │ └── warp_tensor.h │ │ ├── cutlass_extensions │ │ │ └── common.h │ │ ├── fp6_llm │ │ │ ├── README.md │ │ │ ├── configs.h │ │ │ ├── fp6_linear.cu │ │ │ ├── kernel_matmul.cuh │ │ │ ├── kernel_reduction.cuh │ │ │ ├── ptx_cp.async.cuh │ │ │ ├── ptx_mma.cuh │ │ │ ├── utils_core.cuh │ │ │ ├── utils_gmem.cuh │ │ │ └── utils_parallel_dequant.cuh │ │ ├── marlin_qqq │ │ │ ├── base.h │ │ │ ├── marlin_qqq_kernel.cu │ │ │ └── mem.h │ │ ├── mx_kernels │ │ │ └── mx_fp_cutlass_kernels.cu │ │ ├── rowwise_scaled_linear_cutlass │ │ │ ├── README.md │ │ │ ├── rowwise_scaled_linear_cutlass.cuh │ │ │ ├── rowwise_scaled_linear_cutlass_s4s4.cu │ │ │ └── rowwise_scaled_linear_cutlass_s8s4.cu │ │ ├── rowwise_scaled_linear_sparse_cutlass │ │ │ ├── rowwise_scaled_linear_sparse_cutlass.cuh │ │ │ ├── rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.cu │ │ │ ├── rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.h │ │ │ ├── rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.cu │ │ │ ├── rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.h │ │ │ ├── rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.cu │ │ │ ├── rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.h │ │ │ ├── rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.cu │ │ │ ├── rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.h │ │ │ └── rowwise_scaled_linear_sparse_cutlass_f8f8.cu │ │ ├── sparse_marlin │ │ │ ├── base.h │ │ │ ├── marlin_kernel_nm.cu │ │ │ ├── mem.h │ │ │ └── mma.h │ │ ├── tensor_core_tiled_layout │ │ │ └── tensor_core_tiled_layout.cu │ │ └── to_sparse_semi_structured_cutlass_sm9x │ │ │ ├── to_sparse_semi_structured_cutlass_sm9x.cuh │ │ │ └── to_sparse_semi_structured_cutlass_sm9x_f8.cu │ └── rocm │ │ └── swizzle │ │ └── swizzle.cpp ├── dtypes │ ├── README.md │ ├── __init__.py │ ├── _nf4tensor_api.py │ ├── affine_quantized_tensor.py │ ├── affine_quantized_tensor_ops.py │ ├── fbgemm_quantized_tensor.py │ ├── floatx │ │ ├── README.md │ │ ├── __init__.py │ │ ├── cutlass_semi_sparse_layout.py │ │ ├── float8_layout.py │ │ └── floatx_tensor_core_layout.py │ ├── nf4tensor.py │ ├── uintx │ │ ├── __init__.py │ │ ├── bitpacking.py │ │ ├── block_sparse_layout.py │ │ ├── cutlass_int4_packed_layout.py │ │ ├── gemlite_layout.py │ │ ├── int4_cpu_layout.py │ │ ├── int4_xpu_layout.py │ │ ├── marlin_qqq_tensor.py │ │ ├── marlin_sparse_layout.py │ │ ├── packed_linear_int8_dynamic_activation_intx_weight_layout.py │ │ ├── plain_layout.py │ │ ├── q_dq_layout.py │ │ ├── semi_sparse_layout.py │ │ ├── tensor_core_tiled_layout.py │ │ ├── uint4_layout.py │ │ └── uintx_layout.py │ └── utils.py ├── experimental │ ├── CMakeLists.txt │ ├── Utils.cmake │ ├── __init__.py │ ├── benchmark_infra │ │ ├── ios │ │ │ ├── Entitlements-Dev.plist │ │ │ ├── TorchAOBenchmark-Info.plist │ │ │ ├── main_empty.mm │ │ │ ├── output_redirect.h │ │ │ └── output_redirect.mm │ │ └── test │ │ │ └── test_bench.cpp │ ├── benchmarks │ │ └── cpu_memory_bw.cpp │ ├── build_torchao_ops.sh │ ├── docs │ │ └── readme.md │ ├── install_requirements.sh │ ├── kernels │ │ ├── cpu │ │ │ ├── aarch64 │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── benchmarks │ │ │ │ │ ├── CMakeLists.txt │ │ │ │ │ ├── benchmark_bitpacking.cpp │ │ │ │ │ ├── benchmark_linear.cpp │ │ │ │ │ ├── benchmark_quantization.cpp │ │ │ │ │ └── build_and_run_benchmarks.sh │ │ │ │ ├── bitpacking │ │ │ │ │ ├── bitpack.h │ │ │ │ │ ├── uint1.h │ │ │ │ │ ├── uint2.h │ │ │ │ │ ├── uint3.h │ │ │ │ │ ├── uint4.h │ │ │ │ │ ├── uint5.h │ │ │ │ │ ├── uint6.h │ │ │ │ │ └── uint7.h │ │ │ │ ├── embedding │ │ │ │ │ └── embedding.h │ │ │ │ ├── kleidi │ │ │ │ │ ├── kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h │ │ │ │ │ └── pack.h │ │ │ │ ├── linear │ │ │ │ │ └── channelwise_8bit_activation_groupwise_lowbit_weight │ │ │ │ │ │ ├── channelwise_8bit_activation_groupwise_lowbit_weight.h │ │ │ │ │ │ ├── kernel_1x1x32_f32_neondot-impl.h │ │ │ │ │ │ ├── kernel_1x4x16_f32_neondot-impl.h │ │ │ │ │ │ ├── kernel_1x8x16_f32_neondot-impl.h │ │ │ │ │ │ ├── pack_activations.h │ │ │ │ │ │ └── pack_weights.h │ │ │ │ ├── macro.h │ │ │ │ ├── matmul │ │ │ │ │ ├── channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h │ │ │ │ │ ├── channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h │ │ │ │ │ ├── channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h │ │ │ │ │ ├── fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h │ │ │ │ │ ├── fp32_a_input_channelwise_8bit_b_4x16x4_f32_impl.h │ │ │ │ │ ├── matmul.h │ │ │ │ │ └── matmul_utils.h │ │ │ │ ├── quantization │ │ │ │ │ ├── quantize.cpp │ │ │ │ │ └── quantize.h │ │ │ │ ├── reduction │ │ │ │ │ ├── compute_sum.cpp │ │ │ │ │ ├── find_min_and_max.cpp │ │ │ │ │ └── reduction.h │ │ │ │ ├── tests │ │ │ │ │ ├── CMakeLists.txt │ │ │ │ │ ├── build_and_run_tests.sh │ │ │ │ │ ├── test_bitpacking.cpp │ │ │ │ │ ├── test_embedding.cpp │ │ │ │ │ ├── test_linear.cpp │ │ │ │ │ ├── test_qmatmul.cpp │ │ │ │ │ ├── test_quantization.cpp │ │ │ │ │ ├── test_reduction.cpp │ │ │ │ │ ├── test_utils.h │ │ │ │ │ ├── test_utils_quantized_attention.h │ │ │ │ │ ├── test_valpacking.cpp │ │ │ │ │ └── test_weight_packing.cpp │ │ │ │ └── valpacking │ │ │ │ │ ├── interleave.cpp │ │ │ │ │ └── valpack.h │ │ │ ├── fallback │ │ │ │ └── matmul │ │ │ │ │ ├── channelwise_8bit_a_channelwise_8bit_b.h │ │ │ │ │ └── fp32_a_channelwise_8bit_b_fp32_c.h │ │ │ └── interface │ │ │ │ ├── quantized_matmul.h │ │ │ │ └── test_qmatmul_interface.cpp │ │ └── mps │ │ │ ├── codegen │ │ │ └── gen_metal_shader_lib.py │ │ │ ├── metal.yaml │ │ │ ├── metal │ │ │ ├── common.metal │ │ │ ├── int1mm.metal │ │ │ ├── int2mm_opt.metal │ │ │ ├── int3mm_opt.metal │ │ │ ├── int4mm_opt.metal │ │ │ ├── int5mm.metal │ │ │ ├── int6mm.metal │ │ │ ├── int7mm.metal │ │ │ └── qmv_fast.metal │ │ │ ├── src │ │ │ ├── MetalShaderLibrary.h │ │ │ ├── OperationUtils.h │ │ │ ├── OperationUtils.mm │ │ │ ├── common.h │ │ │ ├── dispatch.h │ │ │ ├── lowbit.h │ │ │ └── packing.h │ │ │ └── test │ │ │ ├── Makefile │ │ │ ├── bfloat16.h │ │ │ └── test_lowbit.mm │ ├── op_lib.py │ ├── op_lib_utils.py │ ├── ops │ │ ├── benchmarks │ │ │ ├── CMakeLists.txt │ │ │ ├── benchmark_linear_8bit_act_xbit_weight.cpp │ │ │ └── build_and_run_benchmarks.sh │ │ ├── embedding_xbit │ │ │ ├── CMakeLists.txt │ │ │ ├── op_embedding_xbit-impl.h │ │ │ ├── op_embedding_xbit_aten.cpp │ │ │ ├── op_embedding_xbit_executorch.cpp │ │ │ └── packed_weights_header.h │ │ ├── library.h │ │ ├── linear_8bit_act_xbit_weight │ │ │ ├── CMakeLists.txt │ │ │ ├── examples │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── Linear8BitActXBitWeightOperator.h │ │ │ │ ├── build_and_run_examples.sh │ │ │ │ ├── separate_function_wrappers.cpp │ │ │ │ └── stateful_class_wrapper.cpp │ │ │ ├── kernel_config.h │ │ │ ├── kernel_selector.h │ │ │ ├── linear_8bit_act_xbit_weight.cpp │ │ │ ├── linear_8bit_act_xbit_weight.h │ │ │ ├── op_linear_8bit_act_xbit_weight-impl.h │ │ │ ├── op_linear_8bit_act_xbit_weight_aten.cpp │ │ │ ├── op_linear_8bit_act_xbit_weight_executorch.cpp │ │ │ └── packed_weights_format.h │ │ ├── memory.h │ │ ├── mps │ │ │ ├── .gitignore │ │ │ ├── CMakeLists.txt │ │ │ ├── build.sh │ │ │ ├── linear_fp_act_xbit_weight_aten.mm │ │ │ ├── linear_fp_act_xbit_weight_executorch.mm │ │ │ ├── mps_op_lib.py │ │ │ └── test │ │ │ │ ├── test_lowbit.py │ │ │ │ └── test_quantizer.py │ │ ├── packed_weights_header.h │ │ ├── parallel-aten-impl.h │ │ ├── parallel-executorch-impl.h │ │ ├── parallel-openmp-impl.h │ │ ├── parallel-pthreadpool-impl.h │ │ ├── parallel-single_threaded-impl.h │ │ ├── parallel-test_dummy-impl.h │ │ ├── parallel.h │ │ └── tests │ │ │ ├── CMakeLists.txt │ │ │ ├── build_and_run_tests.sh │ │ │ ├── generate_tests.py │ │ │ └── test_linear_8bit_act_xbit_weight.cpp │ ├── packed_linear_int8_dynamic_activation_intx_weight_layout.py │ ├── q_dq_layout.py │ ├── quant_api.py │ ├── quant_passes.py │ ├── temp_build.py │ └── tests │ │ ├── test_embedding_xbit_quantizer.py │ │ ├── test_int8_dynamic_activation_intx_weight.py │ │ ├── test_load_libtorchao_ops.py │ │ └── test_quant_passes.py ├── float8 │ ├── README.md │ ├── __init__.py │ ├── config.py │ ├── distributed_utils.py │ ├── float8_linear.py │ ├── float8_linear_utils.py │ ├── float8_ops.py │ ├── float8_scaling_utils.py │ ├── float8_tensor.py │ ├── float8_tensor_parallel.py │ ├── float8_utils.py │ ├── fsdp_utils.py │ └── inference.py ├── kernel │ ├── README.md │ ├── __init__.py │ ├── autotuner.py │ ├── bsr_triton_ops.py │ ├── configs │ │ └── data_a100.pkl │ ├── intmm.py │ └── intmm_triton.py ├── ops.py ├── optim │ ├── README.md │ ├── __init__.py │ ├── adam.py │ ├── cpu_offload.py │ ├── quant_utils.py │ ├── subclass_4bit.py │ ├── subclass_8bit.py │ └── subclass_fp8.py ├── prototype │ ├── README.md │ ├── __init__.py │ ├── autoround │ │ ├── README.md │ │ ├── __init__.py │ │ ├── autoround_llm.py │ │ ├── core.py │ │ ├── eval_autoround.py │ │ ├── multi_tensor.py │ │ ├── requirements.txt │ │ ├── run_example.sh │ │ └── utils.py │ ├── awq │ │ ├── README.md │ │ ├── __init__.py │ │ ├── api.py │ │ ├── core.py │ │ └── example.py │ ├── blockwise_fp8 │ │ ├── README.md │ │ ├── __init__.py │ │ ├── blockwise_linear.py │ │ └── blockwise_quantization.py │ ├── common │ │ ├── __init__.py │ │ ├── profiling_tools.py │ │ └── triton │ │ │ ├── __init__.py │ │ │ ├── matmul.py │ │ │ └── matmul_perf_model.py │ ├── custom_fp_utils.py │ ├── float8nocompile │ │ ├── README.md │ │ ├── __init__.py │ │ ├── benchmark │ │ │ └── benchmark.py │ │ ├── examples │ │ │ └── example.py │ │ ├── float8nocompile_linear.py │ │ ├── float8nocompile_linear_test.py │ │ ├── float8nocompile_linear_utils.py │ │ ├── float8nocompile_loss_curves.png │ │ ├── float8nocompile_scaling_utils.py │ │ ├── kernels │ │ │ ├── __init__.py │ │ │ ├── fp8_dynamic_tensorwise.py │ │ │ └── fp8_dynamic_tensorwise_test.py │ │ └── test │ │ │ ├── fsdp_test.py │ │ │ └── train_test.py │ ├── galore │ │ ├── README.md │ │ ├── __init__.py │ │ ├── docs │ │ │ ├── README.md │ │ │ └── galore_adam8bit.md │ │ ├── kernels │ │ │ ├── __init__.py │ │ │ ├── adam_downproj_fused.py │ │ │ ├── adam_step.py │ │ │ ├── custom_autotune.py │ │ │ ├── matmul.py │ │ │ └── quant.py │ │ ├── optim │ │ │ ├── __init__.py │ │ │ └── galore_torch.py │ │ └── utils.py │ ├── hqq │ │ ├── README.md │ │ ├── __init__.py │ │ ├── example.py │ │ ├── hqq_tinygemm_linear.py │ │ ├── kernels.py │ │ └── mixed_mm.py │ ├── inductor │ │ ├── __init__.py │ │ ├── codegen │ │ │ ├── __init__.py │ │ │ ├── cpp_int8_sdpa_template.py │ │ │ └── utils.py │ │ ├── fx_passes │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ └── int8_sdpa_fusion.py │ │ └── int8_sdpa_lowering.py │ ├── moe_quant │ │ ├── README.md │ │ ├── __init__.py │ │ ├── llama4_quant.py │ │ ├── quantizable_moe_modules.py │ │ └── utils.py │ ├── mx_formats │ │ ├── README.md │ │ ├── __init__.py │ │ ├── benchmarks │ │ │ └── bench_qdq.py │ │ ├── config.py │ │ ├── constants.py │ │ ├── fp_format_spec.py │ │ ├── kernels.py │ │ ├── mx_funcs.py │ │ ├── mx_linear.py │ │ ├── mx_ops.py │ │ ├── mx_subclass.py │ │ ├── mx_tensor.py │ │ └── utils.py │ ├── paretoq │ │ ├── 1_run_train.sh │ │ ├── 2_run_eval.sh │ │ ├── README.md │ │ ├── __init__.py │ │ ├── main_result_234bit.jpg │ │ ├── main_result_scaling_law.jpg │ │ ├── main_result_ternary.jpg │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── configuration_llama.py │ │ │ ├── modeling_llama_quant.py │ │ │ └── utils_quant.py │ │ ├── requirement.txt │ │ ├── train.py │ │ └── utils │ │ │ ├── datautils.py │ │ │ ├── process_args.py │ │ │ └── utils.py │ ├── parq │ │ ├── README.md │ │ ├── __init__.py │ │ ├── optim │ │ │ ├── __init__.py │ │ │ ├── binarelax.py │ │ │ ├── parq.py │ │ │ ├── proxmap.py │ │ │ └── quantopt.py │ │ ├── quant │ │ │ ├── __init__.py │ │ │ ├── lsbq.py │ │ │ ├── quantizer.py │ │ │ ├── uniform.py │ │ │ └── uniform_torchao.py │ │ └── utils.py │ ├── quantization │ │ ├── __init__.py │ │ ├── autoquant_v2.py │ │ ├── codebook │ │ │ ├── __init__.py │ │ │ ├── codebook_ops.py │ │ │ └── codebook_quantized_tensor.py │ │ ├── gguf │ │ │ ├── __init__.py │ │ │ ├── api.py │ │ │ └── gguf_quantized_tensor.py │ │ ├── mixed_precision │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ └── scripts │ │ │ │ ├── BO_acc_modelsize.py │ │ │ │ ├── BO_acc_throughput.py │ │ │ │ ├── Llama3-8B_initial_samples.json │ │ │ │ ├── Llama3-8B_parameters.json │ │ │ │ ├── Mistral-7B_initial_samples.json │ │ │ │ ├── Mistral-7B_parameters.json │ │ │ │ ├── __init__.py │ │ │ │ ├── fit.py │ │ │ │ ├── hessian_grad.py │ │ │ │ ├── hessian_vhp.py │ │ │ │ ├── mp_quant_eval.py │ │ │ │ ├── naive_intNwo.py │ │ │ │ └── utils.py │ │ ├── module_swap │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── algorithms │ │ │ │ ├── __init__.py │ │ │ │ └── kmeans_codebook.py │ │ │ ├── data_getters │ │ │ │ ├── __init__.py │ │ │ │ ├── llm_ptq_data_getter.py │ │ │ │ └── ptq_data_getter.py │ │ │ ├── module_swap.py │ │ │ ├── quantized_modules.py │ │ │ ├── quantizers.py │ │ │ ├── range_setting_methods.py │ │ │ └── utils.py │ │ └── subgraph_utils │ │ │ ├── __init__.py │ │ │ └── extract_subgraphs.py │ ├── quantized_training │ │ ├── README.md │ │ ├── __init__.py │ │ ├── bitnet.py │ │ ├── int8.py │ │ ├── int8_mixed_precision.py │ │ └── int8_mm.py │ ├── scaled_grouped_mm │ │ ├── __init__.py │ │ ├── benchmarks │ │ │ ├── benchmark_kernels.py │ │ │ └── benchmark_scaled_grouped_mm.py │ │ ├── kernels │ │ │ ├── __init__.py │ │ │ └── jagged_float8_scales.py │ │ ├── scaled_grouped_mm.py │ │ └── utils.py │ ├── smoothquant │ │ ├── README.md │ │ ├── __init__.py │ │ ├── api.py │ │ ├── core.py │ │ └── example.py │ ├── sparsity │ │ ├── __init__.py │ │ ├── activation │ │ │ ├── __init__.py │ │ │ ├── srelu_linear.py │ │ │ └── utils.py │ │ ├── pruner │ │ │ ├── FPGM_pruner.py │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── base_structured_sparsifier.py │ │ │ ├── images │ │ │ │ ├── prune_1.png │ │ │ │ ├── prune_2.png │ │ │ │ ├── prune_3.png │ │ │ │ ├── prune_4.png │ │ │ │ ├── prune_5.png │ │ │ │ └── prune_6.png │ │ │ ├── lstm_saliency_pruner.py │ │ │ ├── match_utils.py │ │ │ ├── parametrization.py │ │ │ ├── prune_functions.py │ │ │ └── saliency_pruner.py │ │ ├── scheduler │ │ │ ├── __init__.py │ │ │ ├── base_scheduler.py │ │ │ ├── cubic_scheduler.py │ │ │ └── lambda_scheduler.py │ │ ├── sparsifier │ │ │ ├── __init__.py │ │ │ ├── base_sparsifier.py │ │ │ ├── nearly_diagonal_sparsifier.py │ │ │ ├── utils.py │ │ │ └── weight_norm_sparsifier.py │ │ └── superblock │ │ │ ├── .gitignore │ │ │ ├── README.md │ │ │ ├── TRAINING.md │ │ │ ├── __init__.py │ │ │ ├── benchmark.py │ │ │ ├── benchmark.sh │ │ │ ├── benchmark_results.txt │ │ │ ├── evaluate.py │ │ │ ├── evaluate.sh │ │ │ ├── evaluation_results.txt │ │ │ ├── train.py │ │ │ └── utils.py │ └── spinquant │ │ ├── README.md │ │ ├── __init__.py │ │ ├── _hadamard_matrices.py │ │ ├── hadamard_utils.py │ │ └── spinquant.py ├── quantization │ ├── GPTQ.py │ ├── GPTQ_MT.py │ ├── README.md │ ├── __init__.py │ ├── autoquant.py │ ├── dynamic_quant.py │ ├── granularity.py │ ├── linear_activation_quantized_tensor.py │ ├── linear_activation_scale.py │ ├── linear_activation_weight_observed_tensor.py │ ├── marlin_qqq │ │ ├── README.md │ │ ├── __init__.py │ │ └── utils.py │ ├── observer.py │ ├── prototype │ │ ├── __init__.py │ │ └── qat │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── _module_swap_api.py │ │ │ ├── affine_fake_quantized_tensor.py │ │ │ ├── api.py │ │ │ ├── embedding.py │ │ │ ├── fake_quantizer.py │ │ │ └── linear.py │ ├── pt2e │ │ ├── __init__.py │ │ ├── _affine_quantization.py │ │ ├── _numeric_debugger.py │ │ ├── constant_fold.py │ │ ├── convert.py │ │ ├── export_utils.py │ │ ├── fake_quantize.py │ │ ├── graph_utils.py │ │ ├── inductor_passes │ │ │ ├── __init__.py │ │ │ └── x86.py │ │ ├── lowering.py │ │ ├── observer.py │ │ ├── prepare.py │ │ ├── qat_utils.py │ │ ├── quantize_pt2e.py │ │ ├── quantizer │ │ │ ├── __init__.py │ │ │ ├── arm_inductor_quantizer.py │ │ │ ├── composable_quantizer.py │ │ │ ├── duplicate_dq_pass.py │ │ │ ├── embedding_quantizer.py │ │ │ ├── port_metadata_pass.py │ │ │ ├── quantizer.py │ │ │ ├── utils.py │ │ │ ├── x86_inductor_quantizer.py │ │ │ └── xpu_inductor_quantizer.py │ │ ├── reference_representation_rewrite.py │ │ └── utils.py │ ├── qat │ │ ├── README.md │ │ ├── __init__.py │ │ ├── affine_fake_quantized_tensor.py │ │ ├── api.py │ │ ├── embedding.py │ │ ├── fake_quantizer.py │ │ ├── images │ │ │ └── qat_diagram.png │ │ ├── linear.py │ │ └── utils.py │ ├── quant_api.py │ ├── quant_primitives.py │ ├── smoothquant.py │ ├── subclass.py │ ├── transform_module.py │ ├── unified.py │ ├── utils.py │ ├── weight_only.py │ └── weight_tensor_linear_activation_quantization.py ├── sparsity │ ├── README.md │ ├── __init__.py │ ├── blocksparse.py │ ├── marlin │ │ ├── README.md │ │ ├── __init__.py │ │ └── utils.py │ ├── sparse_api.py │ ├── supermask.py │ ├── training │ │ ├── README.md │ │ ├── __init__.py │ │ ├── autograd.py │ │ └── pointwise_ops.py │ ├── utils.py │ └── wanda.py ├── swizzle │ ├── __init__.py │ ├── swizzle_ops.py │ └── swizzle_tensor.py ├── testing │ ├── __init__.py │ ├── float8 │ │ ├── __init__.py │ │ ├── dtensor_utils.py │ │ ├── fsdp2_utils.py │ │ ├── roofline_utils.py │ │ └── test_utils.py │ ├── model_architectures.py │ ├── pt2e │ │ ├── __init__.py │ │ ├── _xnnpack_quantizer.py │ │ ├── _xnnpack_quantizer_utils.py │ │ └── utils.py │ └── utils.py └── utils.py ├── tutorials ├── add_an_op.py ├── calibration_flow │ ├── awq_like.py │ ├── gptq_like.py │ └── static_quant.py ├── developer_api_guide │ ├── __init__.py │ ├── export_to_executorch.py │ ├── my_dtype_tensor_subclass.py │ ├── my_trainable_tensor_subclass.py │ ├── print_op_and_shapes.py │ └── tensor_parallel.py ├── examples │ ├── logging_subclass.py │ ├── quantized_module_swap.py │ └── quantized_subclass.py ├── quantize_vit │ ├── bfloat16.json.gz │ ├── bfloat16_code.py │ ├── quant.json.gz │ ├── quant_code.py │ ├── run.sh │ ├── run_vit_b.py │ └── run_vit_b_quant.py └── run_all.sh └── version.txt /.github/merge_rules.yaml: -------------------------------------------------------------------------------- 1 | - name: superuser 2 | patterns: 3 | - '*' 4 | approved_by: 5 | - pytorch/metamates 6 | mandatory_checks_name: 7 | - Facebook CLA Check 8 | -------------------------------------------------------------------------------- /.github/pytorch-probot.yml: -------------------------------------------------------------------------------- 1 | mergebot: True 2 | ciflow_push_tags: 3 | - ciflow/benchmark 4 | - ciflow/tutorials 5 | - ciflow/rocm 6 | -------------------------------------------------------------------------------- /.github/scripts/validate_binaries.sh: -------------------------------------------------------------------------------- 1 | pip install ${PYTORCH_PIP_PREFIX} torchao --index-url ${PYTORCH_PIP_DOWNLOAD_URL} 2 | # Intial smoke test, tries importing torchao 3 | python ./test/smoke_tests/smoke_tests.py 4 | # Now we install dev-requirments and try to run the tests 5 | pip install -r dev-requirements.txt 6 | pytest test --verbose -s 7 | -------------------------------------------------------------------------------- /.github/workflows/nightly_smoke_test.yml: -------------------------------------------------------------------------------- 1 | name: PyTorch CUDA Nightly Smoke Test 2 | 3 | on: 4 | schedule: 5 | # 6 am PST every day 6 | - cron: "0 14 * * *" 7 | workflow_dispatch: 8 | 9 | concurrency: 10 | group: regression_test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} 11 | cancel-in-progress: true 12 | 13 | env: 14 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 15 | 16 | jobs: 17 | test: 18 | strategy: 19 | fail-fast: false 20 | matrix: 21 | include: 22 | - name: CUDA Nightly 23 | runs-on: linux.g5.12xlarge.nvidia.gpu 24 | torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu126' 25 | gpu-arch-type: "cuda" 26 | gpu-arch-version: "12.6" 27 | 28 | permissions: 29 | id-token: write 30 | contents: read 31 | uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main 32 | with: 33 | runner: ${{ matrix.runs-on }} 34 | gpu-arch-type: ${{ matrix.gpu-arch-type }} 35 | gpu-arch-version: ${{ matrix.gpu-arch-version }} 36 | submodules: recursive 37 | script: | 38 | python -m pip install --upgrade pip 39 | pip install ${{ matrix.torch-spec }} 40 | pip install -r dev-requirements.txt 41 | python setup.py install 42 | pytest test --verbose -s 43 | -------------------------------------------------------------------------------- /.github/workflows/pr-label-check.yml: -------------------------------------------------------------------------------- 1 | name: PR Label Check 2 | on: 3 | pull_request: 4 | types: [opened, labeled, unlabeled, synchronize] 5 | 6 | jobs: 7 | check-labels: 8 | name: Check PR Labels 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Check for Topic label 12 | run: | 13 | # Get the labels using GitHub API 14 | LABELS=$(curl -s -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \ 15 | "https://api.github.com/repos/${{ github.repository }}/pulls/${{ github.event.pull_request.number }}" \ 16 | | jq -r '.labels[].name') 17 | 18 | # Check if there are any labels 19 | if [ -z "$LABELS" ]; then 20 | echo "::error::This PR requires at least one topic label. Please add a topic from: https://github.com/pytorch/ao/labels?q=topic" 21 | exit 1 22 | fi 23 | 24 | # Check for Topic label 25 | if ! echo "$LABELS" | grep -i "topic:" > /dev/null; then 26 | echo "::error::This PR requires at least one label starting with 'topic:'. Available topics can be found at: https://github.com/pytorch/ao/labels?q=topic" 27 | exit 1 28 | fi 29 | 30 | echo "PR has required topic label" 31 | -------------------------------------------------------------------------------- /.github/workflows/regression_test_rocm.yml: -------------------------------------------------------------------------------- 1 | name: Run Regression Tests on ROCm 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | tags: 8 | - ciflow/rocm/* 9 | 10 | concurrency: 11 | group: regression_test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} 12 | cancel-in-progress: true 13 | 14 | env: 15 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 16 | 17 | jobs: 18 | test-nightly: 19 | strategy: 20 | fail-fast: false 21 | matrix: 22 | include: 23 | - name: ROCM Nightly 24 | runs-on: linux.rocm.gpu.mi300.2 25 | torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.3' 26 | gpu-arch-type: "rocm" 27 | gpu-arch-version: "6.3" 28 | 29 | permissions: 30 | id-token: write 31 | contents: read 32 | uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main 33 | with: 34 | timeout: 120 35 | no-sudo: ${{ matrix.gpu-arch-type == 'rocm' }} 36 | runner: ${{ matrix.runs-on }} 37 | gpu-arch-type: ${{ matrix.gpu-arch-type }} 38 | gpu-arch-version: ${{ matrix.gpu-arch-version }} 39 | submodules: recursive 40 | script: | 41 | conda create -n venv python=3.9 -y 42 | conda activate venv 43 | python -m pip install --upgrade pip 44 | pip install ${{ matrix.torch-spec }} 45 | pip install -r dev-requirements.txt 46 | pip install . 47 | export CONDA=$(dirname $(dirname $(which conda))) 48 | export LD_LIBRARY_PATH=$CONDA/lib/:$LD_LIBRARY_PATH 49 | pytest test --verbose -s 50 | -------------------------------------------------------------------------------- /.github/workflows/run_tutorials.yml: -------------------------------------------------------------------------------- 1 | name: Run tutorials 2 | 3 | on: 4 | push: 5 | tags: 6 | - ciflow/tutorials/* 7 | workflow_dispatch: 8 | 9 | jobs: 10 | run_tutorials: 11 | runs-on: linux.aws.a100 12 | strategy: 13 | matrix: 14 | torch-spec: 15 | - '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126' 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - name: Setup miniconda 20 | uses: pytorch/test-infra/.github/actions/setup-miniconda@main 21 | with: 22 | python-version: "3.9" 23 | 24 | - name: Run tutorials 25 | shell: bash 26 | run: | 27 | set -eux 28 | ${CONDA_RUN} python -m pip install --upgrade pip 29 | ${CONDA_RUN} pip install ${{ matrix.torch-spec }} 30 | ${CONDA_RUN} pip install -r dev-requirements.txt 31 | ${CONDA_RUN} pip install . 32 | cd tutorials 33 | ${CONDA_RUN} bash run_all.sh 34 | -------------------------------------------------------------------------------- /.github/workflows/validate-binaries.yml: -------------------------------------------------------------------------------- 1 | name: Validate binaries 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | channel: 7 | description: "Channel to use (nightly, test, release, all)" 8 | required: false 9 | type: string 10 | default: release 11 | ref: 12 | description: "Reference to checkout, defaults to empty" 13 | default: "" 14 | required: false 15 | type: string 16 | workflow_dispatch: 17 | inputs: 18 | channel: 19 | description: "Channel to use (nightly, test, release, all)" 20 | required: true 21 | type: choice 22 | options: 23 | - release 24 | - nightly 25 | - test 26 | - all 27 | ref: 28 | description: "Reference to checkout, defaults to empty" 29 | default: "" 30 | required: false 31 | type: string 32 | pytorch_version: 33 | description: "PyTorch version to validate (ie. 2.0, 2.2.2, etc.) - optional" 34 | default: "" 35 | required: false 36 | type: string 37 | jobs: 38 | validate-binaries: 39 | uses: pytorch/test-infra/.github/workflows/validate-domain-library.yml@main 40 | with: 41 | package_type: "wheel" 42 | version: ${{ inputs.version }} 43 | os: "linux" 44 | channel: ${{ inputs.channel }} 45 | repository: "pytorch/ao" 46 | with_cuda: "enable" 47 | with_rocm: "disable" 48 | smoke_test: "source ./.github/scripts/validate_binaries.sh" 49 | install_torch: true 50 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/cutlass"] 2 | path = third_party/cutlass 3 | url = https://github.com/NVIDIA/cutlass 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v5.0.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | - id: check-added-large-files 11 | 12 | - repo: https://github.com/astral-sh/ruff-pre-commit 13 | # Ruff version. 14 | rev: v0.11.6 15 | hooks: 16 | # Run the linter. 17 | - id: ruff 18 | args: 19 | - --fix 20 | - --select 21 | - F,I 22 | # Run the formatter. 23 | - id: ruff-format 24 | # Run isolated checks. 25 | - id: ruff 26 | alias: ruff-isolated 27 | args: 28 | - --isolated 29 | - --select 30 | - F821,F823,W191 31 | 32 | - repo: local 33 | hooks: 34 | - id: privacy-policy-check 35 | name: Check Privacy Policy Headers 36 | entry: python scripts/check_copyright_header.py 37 | language: python 38 | types_or: [python, c, c++, shell, text] 39 | files: \.(py|cu|h|cuh|sh|metal)$ 40 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: "torchao: PyTorch native quantization and sparsity for training and inference" 3 | message: "If you use this software, please cite it as below." 4 | type: software 5 | authors: 6 | - given-names: "torchao maintainers and contributors" 7 | url: "https//github.com/pytorch/torchao" 8 | license: "BSD-3-Clause" 9 | date-released: "2024-10-25" 10 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | msaroufim 2 | cpuhrsch 3 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Torchao 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Linting 16 | 17 | We use [ruff](https://beta.ruff.rs/docs/) for linting. 18 | 1. `pip install ruff==0.11.6` 19 | 2. `ruff check --fix` 20 | 3. `ruff format .` 21 | 22 | ## Contributor License Agreement ("CLA") 23 | In order to accept your pull request, we need you to submit a CLA. You only need 24 | to do this once to work on any of Meta's open source projects. 25 | 26 | Complete your CLA here: 27 | 28 | ## Issues 29 | We use GitHub issues to track public bugs. Please ensure your description is 30 | clear and has sufficient instructions to be able to reproduce the issue. 31 | 32 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 33 | disclosure of security bugs. In those cases, please go through the process 34 | outlined on that page and do not file a public issue. 35 | 36 | ## Coding Style 37 | * 2 spaces for indentation rather than tabs 38 | * 80 character line length 39 | * ... 40 | 41 | ## License 42 | By contributing to Torchao, you agree that your contributions will be licensed 43 | under the LICENSE file in the root directory of this source tree. 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 Meta 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 12 | -------------------------------------------------------------------------------- /benchmarks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/benchmarks/__init__.py -------------------------------------------------------------------------------- /benchmarks/float8/training/README.md: -------------------------------------------------------------------------------- 1 | # Float8 training benchmarking 2 | 3 | The `torchtitan_benchmark.sh` script in this directory can be used to launch a Llama3 8b training run with [torchtitan](https://github.com/pytorch/torchtitan) training run, and parse the logs to calculate the median tokens/sec and peak memory usage for you. 4 | 5 | ## Usage 6 | 7 | Example: `TORCHTITAN_ROOT=${HOME}/torchtitan FLOAT8_RECIPE_WITH_BEST_SETTINGS=rowwise ./torchtitan_benchmark.sh` 8 | 9 | Training parameters can be configured via environment variables. 10 | 11 | - Required: 12 | - `TORCHTITAN_ROOT`: Root directory of torchtitan in your local filesystem 13 | - Optional: 14 | - `FLOAT8_RECIPE_WITH_BEST_SETTINGS`: "rowwise" or "tensorwise". Applies float8 training with the specified scaling recipe, as well as additional training configs which are optimal for that scaling recipe. See `torchtitan_benchmark.sh` for more details. 15 | - `BATCH_SIZE`: Defaults to 1. 16 | - `STEPS`: Defaults to 100. 17 | - `EXTRA_ARGS`: Extra arguments to pass to torchtitan training script. See [torchtitan](https://github.com/pytorch/torchtitan) docs for the full list of options. 18 | 19 | **NOTE**: `torch.compile` and FSDP2 are always used. Other forms of parallelism supported in torchtitan are not yet supported in this script. 20 | -------------------------------------------------------------------------------- /benchmarks/microbenchmarks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/benchmarks/microbenchmarks/__init__.py -------------------------------------------------------------------------------- /benchmarks/microbenchmarks/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/benchmarks/microbenchmarks/test/__init__.py -------------------------------------------------------------------------------- /benchmarks/print_config_shapes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from torchao.kernel import autotuner 7 | 8 | configs = autotuner._load_best_configs() 9 | 10 | print("m,k,n") 11 | for k, v in configs.items(): 12 | a_shape = k[1] 13 | b_shape = k[4] 14 | M, K0 = a_shape 15 | K1, N = b_shape 16 | 17 | assert K0 == K1 18 | 19 | print(f"{M},{K0},{N}") 20 | -------------------------------------------------------------------------------- /benchmarks/quantized_training/benchmark_int8mm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import pandas as pd 7 | import torch 8 | from triton.testing import do_bench 9 | 10 | from torchao.prototype.quantized_training.int8_mm import int8_mm_dequant 11 | 12 | 13 | def bench_f(f, *args): 14 | return do_bench(lambda: f(*args), return_mode="median") 15 | 16 | 17 | shapes = [(sz, sz, sz) for sz in [1024, 2048, 4096]] 18 | 19 | # Llama-8B shapes 20 | shapes += [ 21 | # linear in attention 22 | (32_768, 4096, 4096), 23 | (4096, 4096, 32_768), 24 | # linear in feed-forward 25 | (32_768, 14_336, 4096), 26 | (32_768, 4096, 14_336), 27 | (14_336, 4096, 32_768), 28 | ] 29 | 30 | data = [] 31 | for M, N, K in shapes: 32 | print(f"{M=}, {N=}, {K=}") 33 | 34 | A_bf16 = torch.randn(M, K).bfloat16().cuda() 35 | B_bf16 = torch.randn(N, K).bfloat16().cuda() 36 | A_i8 = torch.randint(-128, 127, size=(M, K), dtype=torch.int8).cuda() 37 | B_i8 = torch.randint(-128, 127, size=(N, K), dtype=torch.int8).cuda() 38 | A_scale = torch.randn(M).bfloat16().cuda() 39 | B_scale = torch.randn(N).bfloat16().cuda() 40 | 41 | # benchmark F.linear() i.e. A @ B.T 42 | bf16_time = bench_f(torch.mm, A_bf16, B_bf16.T) 43 | i8_time = bench_f(torch._int_mm, A_i8, B_i8.T) 44 | i8_dequant_time = bench_f(int8_mm_dequant, A_i8, B_i8.T, A_scale, B_scale) 45 | 46 | sample = [M, N, K, bf16_time / i8_time, bf16_time / i8_dequant_time] 47 | data.append(sample) 48 | 49 | df = pd.DataFrame( 50 | data, columns=["M", "N", "K", "CuBLAS INT8 speedup", "Triton INT8 dequant speedup"] 51 | ) 52 | print(df.to_markdown()) 53 | -------------------------------------------------------------------------------- /benchmarks/sam_benchmark_results.csv: -------------------------------------------------------------------------------- 1 | ,block_only,batchsize,dtype,compile,qkv,proj,lin1,lin2,time,memory,img/s 2 | 0,False,32,torch.bfloat16,True,,,,,1457.0417301729321,28.280423936,21.96230851686177 3 | 1,False,32,torch.bfloat16,True,quant,quant,quant,quant,1318.5919532552361,28.261341696,24.268311300551254 4 | 2,False,32,torch.bfloat16,True,quant+sparse (cusparselt),quant,quant+sparse (cutlass),quant+sparse (cutlass),1253.1237555667758,28.18694656,25.536184960061433 5 | 3,False,32,torch.bfloat16,True,quant+sparse (cutlass),quant+sparse (cutlass),quant+sparse (cutlass),quant+sparse (cutlass),1290.4946617782116,27.837008896,24.796693041648258 6 | -------------------------------------------------------------------------------- /benchmarks/sam_vit_b_shapes.csv: -------------------------------------------------------------------------------- 1 | m,k,n 2 | 32768,3072,768 3 | 32768,768,2304 4 | 32768,768,3072 5 | 32768,768,768 6 | 39200,768,2304 7 | 39200,768,768 8 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | # Test utilities 2 | pytest 3 | unittest-xml-reporting 4 | parameterized 5 | packaging 6 | transformers 7 | hypothesis # Avoid test derandomization warning 8 | sentencepiece # for gpt-fast tokenizer 9 | expecttest 10 | 11 | # For prototype features and benchmarks 12 | bitsandbytes # needed for testing triton quant / dequant ops for 8-bit optimizers 13 | matplotlib 14 | pandas 15 | fire # QOL for commandline scripts 16 | tabulate # QOL for printing tables to stdout 17 | tiktoken 18 | blobfile 19 | lm_eval 20 | # sam 21 | diskcache 22 | pycocotools 23 | tqdm 24 | importlib_metadata 25 | 26 | # Custom CUDA Extensions 27 | ninja 28 | 29 | # CPU kernels 30 | cmake<4.0.0,>=3.19.0 31 | 32 | # Linting 33 | ruff==0.11.6 34 | pre-commit 35 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | ifneq ($(EXAMPLES_PATTERN),) 5 | EXAMPLES_PATTERN_OPTS := -D sphinx_gallery_conf.filename_pattern="$(EXAMPLES_PATTERN)" 6 | endif 7 | 8 | # You can set these variables from the command line. 9 | 10 | # TODO: Revert this when have docs on pytorch.org/ao 11 | # SPHINXOPTS = -W -j auto $(EXAMPLES_PATTERN_OPTS) 12 | # SPHINXOPTS = -WT -j auto --keep-going # enable later when the files are included in the doc build 13 | 14 | 15 | SPHINXBUILD = sphinx-build 16 | SPHINXPROJ = torchao 17 | SOURCEDIR = source 18 | BUILDDIR = build 19 | 20 | # Put it first so that "make" without argument is like "make help". 21 | help: 22 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 23 | 24 | html-noplot: # Avoids running the gallery examples, which may take time 25 | $(SPHINXBUILD) -D plot_gallery=0 -b html "${SOURCEDIR}" "$(BUILDDIR)"/html 26 | @echo 27 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 28 | 29 | clean: 30 | rm -rf $(BUILDDIR)/* 31 | rm -rf $(SOURCEDIR)/generated_examples/ # sphinx-gallery 32 | rm -rf $(SOURCEDIR)/gen_modules/ # sphinx-gallery 33 | rm -rf $(SOURCEDIR)/sg_execution_times.rst # sphinx-gallery 34 | rm -rf $(SOURCEDIR)/generated/ # autosummary 35 | 36 | .PHONY: help Makefile docset 37 | 38 | # Catch-all target: route all unknown targets to Sphinx using the new 39 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 40 | %: Makefile 41 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 42 | -------------------------------------------------------------------------------- /docs/README.txt: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ========= 3 | 4 | template_tutorial.py 5 | Template Tutorial 6 | tutorials/template_tutorial.html 7 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx-gallery>0.11 2 | sphinx==5.0.0 3 | sphinx_design 4 | sphinx_copybutton 5 | sphinx-tabs 6 | matplotlib 7 | -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 8 | -------------------------------------------------------------------------------- /docs/source/_static/img/card-background.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /docs/source/_static/img/generic-pytorch-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/docs/source/_static/img/generic-pytorch-logo.png -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | {{ name | underline}} 7 | 8 | .. autoclass:: {{ name }} 9 | :members: 10 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/function.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | {{ name | underline}} 7 | 8 | .. autofunction:: {{ name }} 9 | -------------------------------------------------------------------------------- /docs/source/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | 3 | {% block sidebartitle %} 4 |
5 | {{ version }} ▼ 6 |
7 | {% include "searchbox.html" %} 8 | {% endblock %} 9 | 10 | 11 | {% block footer %} 12 | 13 | 14 | 17 | 18 | {{ super() }} 19 | 34 | {% endblock %} 35 | -------------------------------------------------------------------------------- /docs/source/api_ref_dtypes.rst: -------------------------------------------------------------------------------- 1 | .. _api_dtypes: 2 | 3 | ================ 4 | torchao.dtypes 5 | ================ 6 | 7 | .. currentmodule:: torchao.dtypes 8 | 9 | Layouts and Tensor Subclasses 10 | ----------------------------- 11 | .. autosummary:: 12 | :toctree: generated/ 13 | :nosignatures: 14 | 15 | NF4Tensor 16 | AffineQuantizedTensor 17 | Layout 18 | PlainLayout 19 | SemiSparseLayout 20 | TensorCoreTiledLayout 21 | Float8Layout 22 | FloatxTensor 23 | FloatxTensorCoreLayout 24 | MarlinSparseLayout 25 | BlockSparseLayout 26 | UintxLayout 27 | MarlinQQQTensor 28 | MarlinQQQLayout 29 | Int4CPULayout 30 | CutlassInt4PackedLayout 31 | CutlassSemiSparseLayout 32 | 33 | Quantization techniques 34 | ----------------------- 35 | .. autosummary:: 36 | :toctree: generated/ 37 | :nosignatures: 38 | 39 | to_affine_quantized_intx 40 | to_affine_quantized_intx_static 41 | to_affine_quantized_fpx 42 | to_affine_quantized_floatx 43 | to_affine_quantized_floatx_static 44 | to_marlinqqq_quantized_intx 45 | to_nf4 46 | .. 47 | _NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring 48 | of torchao.dtypes.nf4tensor.NF4Tensor.dequantize_scalers:6:Unexpected indentation. 49 | -------------------------------------------------------------------------------- /docs/source/api_ref_intro.rst: -------------------------------------------------------------------------------- 1 | ``torchao`` API Reference 2 | ========================= 3 | 4 | This section introduces the torchao API reference. Dive into the details of how torchao integrates with PyTorch to optimize your machine learning models. 5 | 6 | .. toctree:: 7 | :glob: 8 | :maxdepth: 1 9 | :caption: Python API Reference 10 | 11 | api_ref_dtypes 12 | api_ref_quantization 13 | api_ref_sparsity 14 | -------------------------------------------------------------------------------- /docs/source/api_ref_kernel.rst: -------------------------------------------------------------------------------- 1 | .. _api_kernel: 2 | 3 | ================ 4 | torchao.kernel 5 | ================ 6 | 7 | .. currentmodule:: torchao.kernel 8 | 9 | .. autosummary:: 10 | :toctree: generated/ 11 | :nosignatures: 12 | 13 | TBA 14 | -------------------------------------------------------------------------------- /docs/source/api_ref_sparsity.rst: -------------------------------------------------------------------------------- 1 | .. _api_sparsity: 2 | 3 | ================ 4 | torchao.sparsity 5 | ================ 6 | .. automodule:: torchao.sparsity 7 | .. currentmodule:: torchao.sparsity 8 | 9 | .. autosummary:: 10 | :toctree: generated/ 11 | :nosignatures: 12 | 13 | sparsify_ 14 | semi_sparse_weight 15 | int8_dynamic_activation_int8_semi_sparse_weight 16 | apply_fake_sparsity 17 | WandaSparsifier 18 | PerChannelNormObserver 19 | -------------------------------------------------------------------------------- /docs/source/dtypes.rst: -------------------------------------------------------------------------------- 1 | Dtypes 2 | ====== 3 | 4 | TBA 5 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to the torchao Documentation 2 | ==================================== 3 | 4 | `torchao `__ is a library for custom data types and optimizations. 5 | Quantize and sparsify weights, gradients, optimizers, and activations for inference and training 6 | using native PyTorch. Please checkout torchao `README `__ 7 | for an overall introduction to the library and recent highlight and updates. 8 | 9 | .. toctree:: 10 | :glob: 11 | :maxdepth: 1 12 | :caption: Getting Started 13 | 14 | quick_start 15 | 16 | .. toctree:: 17 | :glob: 18 | :maxdepth: 1 19 | :caption: Developer Notes 20 | 21 | quantization 22 | sparsity 23 | contributor_guide 24 | 25 | .. toctree:: 26 | :glob: 27 | :maxdepth: 1 28 | :caption: API Reference 29 | 30 | api_ref_dtypes 31 | api_ref_quantization 32 | api_ref_sparsity 33 | 34 | .. toctree:: 35 | :glob: 36 | :maxdepth: 1 37 | :caption: Tutorials 38 | 39 | serialization 40 | subclass_basic 41 | subclass_advanced 42 | -------------------------------------------------------------------------------- /docs/source/performant_kernels.rst: -------------------------------------------------------------------------------- 1 | Performant Kernels 2 | ================== 3 | 4 | TBA 5 | -------------------------------------------------------------------------------- /docs/source/subclass_advanced.rst: -------------------------------------------------------------------------------- 1 | Writing Your Own Quantized Tensor (advanced) 2 | -------------------------------------------- 3 | 4 | Coming soon! 5 | -------------------------------------------------------------------------------- /docs/source/tutorials_source/README.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/docs/source/tutorials_source/README.txt -------------------------------------------------------------------------------- /docs/source/tutorials_source/template_tutorial.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -*- coding: utf-8 -*- 7 | 8 | """ 9 | Template Tutorial 10 | ================= 11 | 12 | **Author:** `FirstName LastName `_ 13 | 14 | **What you will learn** 15 | 16 | * Item 1 17 | * Item 2 18 | * Item 3 19 | 20 | **Prerequisites** 21 | 22 | * PyTorch v2.0.0 23 | * GPU ??? 24 | * Other items 3 25 | 26 | """ 27 | 28 | ######################################################################### 29 | # Overview 30 | # -------- 31 | # 32 | # Describe Why is this topic important? Add Links to relevant research papers. 33 | # 34 | # This tutorial walks you through the process of.... 35 | # 36 | # Steps 37 | # ----- 38 | # 39 | # Example code (the output below is generated automatically): 40 | # 41 | import torch 42 | 43 | x = torch.rand(5, 3) 44 | print(x) 45 | 46 | ###################################################################### 47 | # (Optional) Additional Exercises 48 | # ------------------------------- 49 | # 50 | # Add additional practice exercises for users to test their knowledge. 51 | # Example: `NLP from Scratch `__. 52 | # 53 | 54 | ###################################################################### 55 | # Conclusion 56 | # ---------- 57 | # 58 | # Summarize the steps and concepts covered. Highlight key takeaways. 59 | # 60 | # Further Reading 61 | # --------------- 62 | # 63 | # * Link1 64 | # * Link2 65 | -------------------------------------------------------------------------------- /docs/static/microbenchmarking_process_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/docs/static/microbenchmarking_process_diagram.png -------------------------------------------------------------------------------- /docs/static/microbenchmarks_code_flow_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/docs/static/microbenchmarks_code_flow_diagram.png -------------------------------------------------------------------------------- /docs/static/pruning_ecosystem_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/docs/static/pruning_ecosystem_diagram.png -------------------------------------------------------------------------------- /docs/static/pruning_flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/docs/static/pruning_flow.png -------------------------------------------------------------------------------- /docs/static/supported_sparsity_patterns.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/docs/static/supported_sparsity_patterns.png -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | Various example scripts and applications of torchao and PyTorch in general. 4 | -------------------------------------------------------------------------------- /examples/sam2_amg_server/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/examples/sam2_amg_server/dog.jpg -------------------------------------------------------------------------------- /examples/sam2_amg_server/requirements.txt: -------------------------------------------------------------------------------- 1 | uvicorn 2 | fire 3 | fastapi 4 | opencv-python 5 | matplotlib 6 | hydra-core 7 | tqdm 8 | iopath 9 | python-multipart 10 | requests 11 | scipy 12 | pandas 13 | tabulate 14 | -------------------------------------------------------------------------------- /examples/sam2_vos_example/requirements.txt: -------------------------------------------------------------------------------- 1 | requests 2 | fire 3 | -------------------------------------------------------------------------------- /packaging/env_var_script_linux.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # This file is sourced into the environment before building a pip wheel. It 8 | # should typically only contain shell variable assignments. Be sure to export 9 | # any variables so that subprocesses will see them. 10 | if [[ ${CHANNEL:-nightly} == "nightly" ]]; then 11 | export TORCHAO_NIGHTLY=1 12 | fi 13 | 14 | # Set ARCH list so that we can build fp16 with SM75+, the logic is copied from 15 | # pytorch/builder 16 | TORCH_CUDA_ARCH_LIST="8.0;8.6" 17 | if [[ ${CU_VERSION:-} == "cu124" ]]; then 18 | TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0" 19 | fi 20 | -------------------------------------------------------------------------------- /packaging/post_build_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -eux 9 | 10 | # Prepare manywheel, only for CUDA. 11 | # The wheel is a pure python wheel for other platforms. 12 | if [[ "$CU_VERSION" == cu* ]]; then 13 | WHEEL_NAME=$(ls dist/) 14 | 15 | pushd dist 16 | manylinux_plat=manylinux_2_28_x86_64 17 | auditwheel repair --plat "$manylinux_plat" -w . \ 18 | --exclude libtorch.so \ 19 | --exclude libtorch_python.so \ 20 | --exclude libtorch_cuda.so \ 21 | --exclude libtorch_cpu.so \ 22 | --exclude libc10.so \ 23 | --exclude libc10_cuda.so \ 24 | --exclude libcudart.so.12 \ 25 | --exclude libcudart.so.11.0 \ 26 | "${WHEEL_NAME}" 27 | 28 | ls -lah . 29 | # Clean up the linux_x86_64 wheel 30 | rm "${WHEEL_NAME}" 31 | popd 32 | fi 33 | 34 | MANYWHEEL_NAME=$(ls dist/) 35 | # Try to install the new wheel 36 | pip install "dist/${MANYWHEEL_NAME}" 37 | python -c "import torchao" 38 | -------------------------------------------------------------------------------- /packaging/pre_build_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -eux 9 | 10 | echo "This script is run before building torchao binaries" 11 | 12 | python -m pip install --upgrade pip 13 | if [ -z ${PYTORCH_VERSION:-} ]; then 14 | PYTORCH_DEP="torch" 15 | else 16 | PYTORCH_DEP="torch==$PYTORCH_VERSION" 17 | fi 18 | pip install $PYTORCH_DEP 19 | 20 | pip install setuptools wheel twine auditwheel 21 | -------------------------------------------------------------------------------- /packaging/smoke_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torchao 9 | 10 | 11 | def main(): 12 | """ 13 | Run torchao binary smoke tests like importing and performing simple ops 14 | """ 15 | print(dir(torchao)) 16 | 17 | 18 | if __name__ == "__main__": 19 | main() 20 | -------------------------------------------------------------------------------- /packaging/vc_env_helper.bat: -------------------------------------------------------------------------------- 1 | @echo on 2 | 3 | set VC_VERSION_LOWER=17 4 | set VC_VERSION_UPPER=18 5 | if "%VC_YEAR%" == "2019" ( set VC_VERSION_LOWER=16 set VC_VERSION_UPPER=17) 6 | if "%VC_YEAR%" == "2017" ( set VC_VERSION_LOWER=15 set VC_VERSION_UPPER=16) 7 | 8 | for /f "usebackq tokens=*" %%i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -legacy -products * -version [%VC_VERSION_LOWER%^,%VC_VERSION_UPPER%^) -property installationPath`) do ( 9 | if exist "%%i" if exist "%%i\VC\Auxiliary\Build\vcvarsall.bat" ( 10 | set "VS15INSTALLDIR=%%i" 11 | set "VS15VCVARSALL=%%i\VC\Auxiliary\Build\vcvarsall.bat" 12 | goto vswhere 13 | ) 14 | ) 15 | 16 | :vswhere 17 | if "%VSDEVCMD_ARGS%" == "" ( 18 | call "%VS15VCVARSALL%" x64 || exit /b 1 19 | ) else ( 20 | call "%VS15VCVARSALL%" x64 %VSDEVCMD_ARGS% || exit /b 1 21 | ) 22 | 23 | @echo on 24 | 25 | if "%CU_VERSION%" == "xpu" call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" 26 | 27 | set DISTUTILS_USE_SDK=1 28 | 29 | set args=%1 30 | shift 31 | :start 32 | if [%1] == [] goto done 33 | set args=%args% %1 34 | shift 35 | goto start 36 | 37 | :done 38 | if "%args%" == "" ( 39 | echo Usage: vc_env_helper.bat [command] [args] 40 | echo e.g. vc_env_helper.bat cl /c test.cpp 41 | ) 42 | 43 | %args% || exit /b 1 44 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | # NOTE: The AO repo is completely linted. 2 | # Add linting rules here 3 | lint.select = ["F", "I"] 4 | lint.ignore = ["E731"] 5 | 6 | 7 | # Exclude third-party modules 8 | exclude = [ 9 | "third_party/*", 10 | "torchao/prototype/paretoq/*", 11 | ] 12 | -------------------------------------------------------------------------------- /scripts/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # copied from https://github.com/pytorch-labs/gpt-fast/blob/main/scripts/download.py 8 | import os 9 | from typing import Optional 10 | 11 | from requests.exceptions import HTTPError 12 | 13 | 14 | def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None: 15 | from huggingface_hub import snapshot_download 16 | 17 | os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) 18 | try: 19 | snapshot_download( 20 | repo_id, 21 | local_dir=f"checkpoints/{repo_id}", 22 | local_dir_use_symlinks=False, 23 | token=hf_token, 24 | ) 25 | except HTTPError as e: 26 | if e.response.status_code == 401: 27 | print( 28 | "You need to pass a valid `--hf_token=...` to download private checkpoints." 29 | ) 30 | else: 31 | raise e 32 | 33 | 34 | if __name__ == "__main__": 35 | import argparse 36 | 37 | parser = argparse.ArgumentParser(description="Download data from HuggingFace Hub.") 38 | parser.add_argument( 39 | "--repo_id", 40 | type=str, 41 | default="checkpoints/meta-llama/llama-2-7b-chat-hf", 42 | help="Repository ID to download from.", 43 | ) 44 | parser.add_argument( 45 | "--hf_token", type=str, default=None, help="HuggingFace API token." 46 | ) 47 | 48 | args = parser.parse_args() 49 | hf_download(args.repo_id, args.hf_token) 50 | -------------------------------------------------------------------------------- /scripts/prepare.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf 7 | python scripts/download.py --repo_id meta-llama/Meta-Llama-3-8B 8 | python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B 9 | python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B-Instruct 10 | python scripts/download.py --repo_id meta-llama/Llama-3.2-3B 11 | python scripts/download.py --repo_id nm-testing/SparseLlama-3-8B-pruned_50.2of4 12 | python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf 13 | python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B 14 | python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B 15 | python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B-Instruct 16 | python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-3.2-3B 17 | # neuralmagic doesn't come with tokenizer, so we need to copy it over 18 | mkdir -p checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original && cp checkpoints/meta-llama/Meta-Llama-3-8B/original/tokenizer.model checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original/tokenizer.model 19 | python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4 20 | -------------------------------------------------------------------------------- /scripts/run_ruff_fix.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | ruff check . --fix 7 | # --isolated is used to skip the allowlist at all so this applies to all files 8 | # please be careful when using this large changes means everyone needs to rebase 9 | ruff check --isolated --select F821,F823,W191 --fix 10 | ruff check --select F,I --fix 11 | ruff format . 12 | -------------------------------------------------------------------------------- /test/dtypes/ddp/run_ddp_nf4_test.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | #!/bin/bash 7 | 8 | set -euo pipefail 9 | WORLD_SIZE=${1:-2} 10 | 11 | 12 | # Test params 13 | GLOBAL_BS=8 14 | DIM=128 15 | NUM_LINEARS=1 16 | NUM_STEPS=3 17 | 18 | PARAMS="--global_bs $GLOBAL_BS --dim $DIM --num_linears $NUM_LINEARS --num_steps $NUM_STEPS" 19 | SAVE_DIR="checkpoints" 20 | REF_DIR="${SAVE_DIR}/ref" 21 | TEST_DIR="${SAVE_DIR}/test" 22 | DDP_PROGRAM="ddp_nf4.py" 23 | CHECK_PROGRAM="check_ddp_nf4.py" 24 | REF_CMD="torchrun --nproc_per_node 1 $DDP_PROGRAM $PARAMS --save_dir $REF_DIR" 25 | TEST_CMD="torchrun --nproc_per_node $WORLD_SIZE $DDP_PROGRAM $PARAMS --save_dir $TEST_DIR" 26 | CHECK_CMD="python $CHECK_PROGRAM --ref_checkpoint_dir $REF_DIR --test_checkpoints_dir $TEST_DIR" 27 | CLEANUP_CMD="rm -rf $SAVE_DIR" 28 | 29 | echo "Step 1: Generating reference checkpoint..." 30 | echo $REF_CMD 31 | $REF_CMD 32 | echo -e "\n --- \n" 33 | sleep 2 34 | 35 | echo "Step 2: Generating test checkpoints..." 36 | echo $TEST_CMD 37 | $TEST_CMD 38 | echo -e "\n --- \n" 39 | sleep 2 40 | 41 | # Check params 42 | echo "Step 3: Checking params..." 43 | echo $CHECK_CMD 44 | $CHECK_CMD 45 | echo -e "\n --- \n" 46 | sleep 2 47 | 48 | # Cleanup 49 | echo "Step 4: Cleaning up..." 50 | echo $CLEANUP_CMD 51 | $CLEANUP_CMD 52 | echo -e "\n --- \n" 53 | echo "Done!" 54 | -------------------------------------------------------------------------------- /test/dtypes/test_fbgemm_quantized.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import unittest 8 | 9 | import torch 10 | from torch.testing._internal.common_utils import ( 11 | TestCase, 12 | run_tests, 13 | ) 14 | 15 | from torchao.quantization import ( 16 | FbgemmConfig, 17 | quantize_, 18 | ) 19 | from torchao.quantization.utils import compute_error 20 | from torchao.utils import is_sm_at_least_90 21 | 22 | 23 | class TestFbgemmInt4Tensor(TestCase): 24 | @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") 25 | @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") 26 | def test_linear(self): 27 | dtype = torch.bfloat16 28 | device = "cuda" 29 | input = torch.randn(1, 128, dtype=dtype, device=device) 30 | linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) 31 | original = linear(input) 32 | config = FbgemmConfig( 33 | input_dtype=torch.bfloat16, 34 | weight_dtype=torch.int4, 35 | output_dtype=torch.bfloat16, 36 | block_size=(1, 128), 37 | ) 38 | quantize_(linear, config) 39 | quantized = linear(input) 40 | self.assertTrue(compute_error(original, quantized) > 20) 41 | 42 | 43 | if __name__ == "__main__": 44 | run_tests() 45 | -------------------------------------------------------------------------------- /test/dtypes/test_fbgemm_quantized_tensor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import unittest 8 | 9 | import torch 10 | from torch.testing._internal.common_utils import ( 11 | TestCase, 12 | run_tests, 13 | ) 14 | 15 | from torchao.quantization import ( 16 | FbgemmConfig, 17 | quantize_, 18 | ) 19 | from torchao.quantization.utils import compute_error 20 | from torchao.utils import ( 21 | TORCH_VERSION_AT_LEAST_2_6, 22 | is_sm_at_least_90, 23 | ) 24 | 25 | 26 | class TestFbgemmInt4Tensor(TestCase): 27 | @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") 28 | @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") 29 | @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Need torch >= 2.6") 30 | def test_linear(self): 31 | dtype = torch.bfloat16 32 | device = "cuda" 33 | input = torch.randn(1, 128, dtype=dtype, device=device) 34 | linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) 35 | original = linear(input) 36 | config = FbgemmConfig( 37 | input_dtype=torch.bfloat16, 38 | weight_dtype=torch.int4, 39 | output_dtype=torch.bfloat16, 40 | block_size=[1, 128], 41 | ) 42 | quantize_(linear, config) 43 | quantized = linear(input) 44 | self.assertTrue(compute_error(original, quantized) > 20) 45 | 46 | 47 | if __name__ == "__main__": 48 | run_tests() 49 | -------------------------------------------------------------------------------- /test/float8/test_dtensor.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | #!/bin/bash 7 | 8 | # terminate script on first error 9 | set -e 10 | 11 | if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then 12 | echo "Skipping test_dtensor.sh because no CUDA devices are available." 13 | exit 14 | fi 15 | 16 | # integration tests for TP/SP 17 | NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/float8/test_dtensor.py 18 | 19 | # integration smoke tests for FSDP2 + TP 20 | NCCL_DEBUG=WARN torchrun --nproc_per_node 4 test/float8/test_fsdp2_tp.py 21 | -------------------------------------------------------------------------------- /test/float8/test_everything.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | #!/bin/bash 7 | 8 | # terminate script on first error 9 | set -e 10 | IS_ROCM=$(rocm-smi --version || true) 11 | 12 | pytest test/float8/test_base.py 13 | pytest test/float8/test_compile.py 14 | pytest test/float8/test_numerics_integration.py 15 | 16 | # These tests do not work on ROCm yet 17 | if [ -z "$IS_ROCM" ] 18 | then 19 | ./test/float8/test_fsdp.sh 20 | ./test/float8/test_fsdp_compile.sh 21 | ./test/float8/test_dtensor.sh 22 | python test/float8/test_fsdp2/test_fsdp2.py 23 | fi 24 | 25 | echo "all tests successful" 26 | -------------------------------------------------------------------------------- /test/float8/test_fsdp.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | #!/bin/bash 7 | 8 | # terminate script on first error 9 | set -e 10 | 11 | launch() { 12 | echo "launching compile_fsdp $COMPILE" 13 | 14 | # the NCCL_DEBUG setting is to avoid log spew 15 | # the CUDA_VISIBLE_DEVICES setting is for easy debugging 16 | NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/float8/test_fsdp.py \ 17 | --compile_fsdp $COMPILE 18 | 19 | echo "✅ All Tests Passed ✅" 20 | } 21 | 22 | if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then 23 | echo "Skipping test_fsdp.sh because no CUDA devices are available." 24 | exit 25 | fi 26 | 27 | COMPILE=False launch 28 | COMPILE=True launch 29 | -------------------------------------------------------------------------------- /test/float8/test_fsdp_compile.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | #!/bin/bash 7 | 8 | # terminate script on first error 9 | set -e 10 | if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then 11 | echo "Skipping test_fsdp_compile.sh because no CUDA devices are available." 12 | exit 13 | fi 14 | 15 | # Code to be executed if CUDA devices are available 16 | NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/float8/test_fsdp_compile.py 17 | -------------------------------------------------------------------------------- /test/prototype/module_swap_quantization/test_llm_ptq_data_getter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from typing import Tuple 3 | 4 | import torch 5 | from transformers.models.llama.modeling_llama import LlamaConfig, LlamaForCausalLM 6 | 7 | from torchao.prototype.quantization.module_swap.data_getters import LLMPTQDataGetter 8 | 9 | test_config = LlamaConfig( 10 | vocab_size=10, 11 | hidden_size=32, 12 | num_hidden_layers=2, 13 | num_attention_heads=2, 14 | intermediate_size=64, 15 | ) 16 | 17 | 18 | def get_test_llama_model_data() -> Tuple[LlamaForCausalLM, torch.Tensor]: 19 | model = LlamaForCausalLM(test_config) 20 | input_ids = torch.randint(0, test_config.vocab_size, (1, 10)) 21 | return model, input_ids 22 | 23 | 24 | class TestPTQDataGetter(unittest.TestCase): 25 | @unittest.skip("TypeError: cannot unpack non-iterable NoneType object") 26 | def test_data_getter(self) -> None: 27 | model, data = get_test_llama_model_data() 28 | data_getter = LLMPTQDataGetter(model, data, 1) 29 | for name, module in model.named_modules(): 30 | if isinstance(module, torch.nn.Linear): 31 | data = data_getter.pop(model, name) 32 | 33 | 34 | if __name__ == "__main__": 35 | unittest.main() 36 | -------------------------------------------------------------------------------- /test/prototype/module_swap_quantization/test_module_swap.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from torchao.prototype.quantization.module_swap import ( 7 | QuantizationRecipe, 8 | quantize_module_swap, 9 | ) 10 | 11 | 12 | class SimpleEmbeddingTestNetwork(nn.Module): 13 | def __init__(self) -> None: 14 | super().__init__() 15 | self.embedding = nn.Embedding(10, 64) 16 | 17 | def forward(self, x: torch.Tensor) -> torch.Tensor: 18 | return self.embedding(x) 19 | 20 | 21 | class TestEmbeddingSwap(unittest.TestCase): 22 | def test_embedding_swap(self) -> None: 23 | model = SimpleEmbeddingTestNetwork() 24 | recipe = QuantizationRecipe() 25 | recipe.embedding_bits = 4 26 | recipe.embedding_quantization = True 27 | model = quantize_module_swap(model, recipe) 28 | x = torch.randint(0, 10, (10, 64)) 29 | model(x) 30 | assert model.embedding.weight_quantizer.num_bits == 4 31 | assert model.embedding.weight_quantizer.group_size == 32 32 | 33 | 34 | if __name__ == "__main__": 35 | unittest.main() 36 | -------------------------------------------------------------------------------- /test/prototype/scaled_grouped_mm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/test/prototype/scaled_grouped_mm/__init__.py -------------------------------------------------------------------------------- /test/prototype/test_paretoq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import unittest 8 | 9 | import torch 10 | 11 | from torchao.prototype.paretoq.models.utils_quant import ( 12 | LsqBinaryTernaryExtension, 13 | QuantizeLinear, 14 | StretchedElasticQuant, 15 | ) 16 | 17 | 18 | class M(torch.nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | self.linear = torch.nn.Linear(256, 256, bias=False).to(torch.float32) 22 | 23 | def forward(self, x): 24 | return self.linear(x) 25 | 26 | 27 | class TestParetoQ(unittest.TestCase): 28 | def test_quantized_linear(self): 29 | m = M() 30 | example_inputs = torch.randn(1, 256).to(torch.float32) 31 | for w_bits in [0, 1, 2, 3, 4, 16]: 32 | m.linear = QuantizeLinear( 33 | m.linear.in_features, 34 | m.linear.out_features, 35 | bias=False, 36 | w_bits=w_bits, 37 | ) 38 | m(example_inputs) 39 | 40 | def test_quantize_functions(self): 41 | x = torch.randn(256, 256).to(torch.float32) 42 | alpha = torch.Tensor(256, 1) 43 | for layerwise in [True, False]: 44 | LsqBinaryTernaryExtension.apply(x, alpha, 1, layerwise) 45 | LsqBinaryTernaryExtension.apply(x, alpha, 3, layerwise) 46 | LsqBinaryTernaryExtension.apply(x, alpha, 4, layerwise) 47 | StretchedElasticQuant.apply(x, alpha, 0, layerwise) 48 | StretchedElasticQuant.apply(x, alpha, 2, layerwise) 49 | 50 | 51 | if __name__ == "__main__": 52 | unittest.main() 53 | -------------------------------------------------------------------------------- /test/prototype/test_spinquant.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import pytest 7 | import torch 8 | 9 | from torchao._models.llama.model import Transformer 10 | from torchao.prototype.spinquant import apply_spinquant 11 | 12 | 13 | def _init_model(name="7B", device="cpu", precision=torch.bfloat16): 14 | model = Transformer.from_name(name) 15 | model.to(device=device, dtype=precision) 16 | return model.eval() 17 | 18 | 19 | _AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) 20 | 21 | 22 | @pytest.mark.parametrize("device", _AVAILABLE_DEVICES) 23 | def test_spinquant_no_quantization(device): 24 | model = _init_model(device=device) 25 | seq_len = 16 26 | batch_size = 1 27 | is_training = False 28 | input_ids = torch.randint(0, 1024, (batch_size, seq_len)).to(device) 29 | input_pos = None if is_training else torch.arange(seq_len).to(device) 30 | with torch.device(device): 31 | model.setup_caches( 32 | max_batch_size=batch_size, max_seq_length=seq_len, training=is_training 33 | ) 34 | 35 | with torch.no_grad(): 36 | out = model(input_ids, input_pos) 37 | apply_spinquant(model) 38 | out_spinquant = model(input_ids, input_pos) 39 | 40 | # Output should be the same without quantization (the rotations cancel out) 41 | # TODO: not sure if these atol/rtol are excessively large (it fails for smaller values) 42 | torch.testing.assert_close(out, out_spinquant, atol=5e-2, rtol=1e-2) 43 | 44 | 45 | # TODO: test GPTQ compatability? 46 | -------------------------------------------------------------------------------- /test/smoke_tests/smoke_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Run smoke tests""" 7 | 8 | import torchao 9 | 10 | print("torchao version is ", torchao.__version__) 11 | -------------------------------------------------------------------------------- /test/test_ao_models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import pytest 7 | import torch 8 | 9 | from torchao._models.llama.model import Transformer 10 | 11 | _AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) 12 | 13 | 14 | def init_model(name="stories15M", device="cpu", precision=torch.bfloat16): 15 | model = Transformer.from_name(name) 16 | model.to(device=device, dtype=precision) 17 | return model.eval() 18 | 19 | 20 | @pytest.mark.parametrize("device", _AVAILABLE_DEVICES) 21 | @pytest.mark.parametrize("batch_size", [1, 4]) 22 | @pytest.mark.parametrize("is_training", [True, False]) 23 | def test_ao_llama_model_inference_mode(device, batch_size, is_training): 24 | random_model = init_model(device=device) 25 | seq_len = 16 26 | input_ids = torch.randint(0, 1024, (batch_size, seq_len)).to(device) 27 | input_pos = None if is_training else torch.arange(seq_len).to(device) 28 | with torch.device(device): 29 | random_model.setup_caches( 30 | max_batch_size=batch_size, max_seq_length=seq_len, training=is_training 31 | ) 32 | for i in range(3): 33 | out = random_model(input_ids, input_pos) 34 | assert out is not None, "model failed to run" 35 | -------------------------------------------------------------------------------- /torchao/_models/README.md: -------------------------------------------------------------------------------- 1 | ## SAM2 2 | sam2 is a fork of https://github.com/facebookresearch/sam2 at commit c2ec8e14a185632b0a5d8b161928ceb50197eddc 3 | 4 | It includes 5 | - modifications to enable fullgraph=True compile 6 | - `mask_to_rle_pytorch_2` 7 | - small performance changes and fixes 8 | - integration into torchao's packaging 9 | -------------------------------------------------------------------------------- /torchao/_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/_models/__init__.py -------------------------------------------------------------------------------- /torchao/_models/llama/.gitignore: -------------------------------------------------------------------------------- 1 | moby.txt 2 | -------------------------------------------------------------------------------- /torchao/_models/llama/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/_models/llama/__init__.py -------------------------------------------------------------------------------- /torchao/_models/llama/bsr_benchmarks.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # BSR benchmarks 8 | export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder 9 | export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B 10 | 11 | # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result bsr_bench_results.txt 12 | # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result bsr_bench_results.txt 13 | # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --sparsity semi-structured --precision float16 --write_result bsr_bench_results.txt 14 | # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result bsr_bench_results.txt --sparsity bsr-0.8-32 15 | # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result bsr_bench_results.txt --sparsity bsr-0.8-64 16 | # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result bsr_bench_results.txt --sparsity bsr-0.9-32 17 | python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result bsr_bench_results.txt --sparsity bsr-0.9-64 18 | -------------------------------------------------------------------------------- /torchao/_models/llama/demo_summarize.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # grab moby dick prompt 7 | wget -nc -O moby.txt https://gist.githubusercontent.com/jcaip/f319146bb543e92e23b2c76815b0f29f/raw/31a9cd12b0b59f323eb197c9534953bdac352986/gistfile1.txt 8 | 9 | export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B-Instruct 10 | 11 | python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq_prefill_wo_decode --prefill_size 8192 --max_new_tokens 256 --num_samples 1 --demo_summarize_prompt moby.txt 12 | python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8wo --prefill_size 8192 --max_new_tokens 256 --num_samples 1 --demo_summarize_prompt moby.txt 13 | python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --prefill_size 8192 --max_new_tokens 256 --num_samples 1 --demo_summarize_prompt moby.txt 14 | -------------------------------------------------------------------------------- /torchao/_models/mixtral-moe/README.md: -------------------------------------------------------------------------------- 1 | ## Mixtral-MoE 2 | 3 | This folder contains code and scripts for benchmarking the Mixtral-MoE model. 4 | Running 5 | 6 | `sh scripts/prepare.sh` 7 | 8 | should download the model and `sh run.sh` will run teh benchmarks. 9 | -------------------------------------------------------------------------------- /torchao/_models/mixtral-moe/scripts/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | from typing import Optional 8 | 9 | from requests.exceptions import HTTPError 10 | 11 | 12 | def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None: 13 | from huggingface_hub import snapshot_download 14 | 15 | os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) 16 | try: 17 | snapshot_download( 18 | repo_id, 19 | local_dir=f"checkpoints/{repo_id}", 20 | local_dir_use_symlinks=False, 21 | token=hf_token, 22 | ignore_patterns="*.safetensors", 23 | ) 24 | except HTTPError as e: 25 | if e.response.status_code == 401: 26 | print( 27 | "You need to pass a valid `--hf_token=...` to download private checkpoints." 28 | ) 29 | else: 30 | raise e 31 | 32 | 33 | if __name__ == "__main__": 34 | import argparse 35 | 36 | parser = argparse.ArgumentParser(description="Download data from HuggingFace Hub.") 37 | parser.add_argument( 38 | "--repo_id", 39 | type=str, 40 | default="mistralai/Mixtral-8x7B-Instruct-v0.1", 41 | help="Repository ID to download from.", 42 | ) 43 | parser.add_argument( 44 | "--hf_token", type=str, default=None, help="HuggingFace API token." 45 | ) 46 | 47 | args = parser.parse_args() 48 | hf_download(args.repo_id, args.hf_token) 49 | -------------------------------------------------------------------------------- /torchao/_models/mixtral-moe/scripts/prepare.sh: -------------------------------------------------------------------------------- 1 | python scripts/download.py --repo_id mistralai/Mixtral-8x7B-Instruct-v0.1 2 | python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/mistralai/Mixtral-8x7B-Instruct-v0.1 3 | -------------------------------------------------------------------------------- /torchao/_models/sam/.gitignore: -------------------------------------------------------------------------------- 1 | tmp 2 | checkpoints 3 | datasets 4 | -------------------------------------------------------------------------------- /torchao/_models/sam/README.md: -------------------------------------------------------------------------------- 1 | # benchmarking instructions: 2 | 3 | Setup your enviornment with: 4 | ``` 5 | conda env create -n "saf-ao" python=3.10 6 | conda activate saf-ao 7 | pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126 8 | pip3 install git+https://github.com/pytorch-labs/segment-anything-fast.git 9 | pip3 install tqdm fire pandas 10 | cd ../.. && python setup.py install 11 | ``` 12 | 13 | Then download data and models by running 14 | ``` 15 | sh setup.sh 16 | ``` 17 | 18 | Finally, you can run benchmarks with 19 | ``` 20 | sh benchmark.sh 21 | ``` 22 | 23 | You can check out the result in results.csv 24 | -------------------------------------------------------------------------------- /torchao/_models/sam/flash_4_configs.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/_models/sam/flash_4_configs.p -------------------------------------------------------------------------------- /torchao/_models/sam/results.csv: -------------------------------------------------------------------------------- 1 | device,sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,use_compile_decoder,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path 2 | cuda,vit_h,32,15172,18,22.533401716616083,44.37856354651513,0.5812715827356921,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None 3 | cuda,vit_h,32,15154,18,25.16516896830006,39.73746416166231,0.5818834536577897,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None 4 | cuda,vit_h,32,15632,19,24.824717871078573,40.282431614863405,0.5675837487618974,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None 5 | cuda,vit_h,32,13429,16,24.589577947798148,40.66763578142439,0.5306639662569573,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None 6 | cuda,vit_h,32,14869,18,26.597207143088742,37.597932543073384,0.5669944616184625,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None 7 | cuda,vit_h,32,17068,21,23.96093702681232,41.73459489004953,0.5485481164943489,max-autotune,torch.float16,int4_weight_only_sparse,False,True,True,32,154,4928,None,None 8 | -------------------------------------------------------------------------------- /torchao/_models/sam/setup.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | SETUP_HOME=$(pwd) 8 | 9 | 10 | mkdir -p checkpoints 11 | mkdir -p datasets 12 | 13 | mkdir -p tmp 14 | mkdir -p tmp/sam_coco_mask_center_cache 15 | mkdir -p tmp/sam_eval_masks_out 16 | 17 | wget -nc -P checkpoints https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 18 | wget -nc -P checkpoints https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth 19 | wget -nc -P checkpoints https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth 20 | 21 | mkdir -p datasets/coco2017 22 | wget -nc -P datasets/coco2017 http://images.cocodataset.org/zips/val2017.zip 23 | wget -nc -P datasets/coco2017 http://images.cocodataset.org/annotations/annotations_trainval2017.zip 24 | 25 | cd datasets/coco2017 && unzip -n val2017.zip && cd $SETUP_HOME 26 | cd datasets/coco2017 && unzip -n annotations_trainval2017.zip && cd $SETUP_HOME 27 | -------------------------------------------------------------------------------- /torchao/_models/sam2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from hydra import initialize_config_module 8 | from hydra.core.global_hydra import GlobalHydra 9 | 10 | if not GlobalHydra.instance().is_initialized(): 11 | initialize_config_module("torchao._models.sam2", version_base="1.2") 12 | -------------------------------------------------------------------------------- /torchao/_models/sam2/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchao/_models/sam2/modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchao/_models/sam2/modeling/sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchao/_models/sam2/sam2_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_b+.yaml -------------------------------------------------------------------------------- /torchao/_models/sam2/sam2_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_l.yaml -------------------------------------------------------------------------------- /torchao/_models/sam2/sam2_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_s.yaml -------------------------------------------------------------------------------- /torchao/_models/sam2/sam2_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_t.yaml -------------------------------------------------------------------------------- /torchao/_models/sam2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchao/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/core/__init__.py -------------------------------------------------------------------------------- /torchao/csrc/cuda/cutlass_extensions/common.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD 3-Clause license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | #pragma once 7 | 8 | #include 9 | #include 10 | 11 | #define CUTLASS_STATUS_CHECK(status, message_prefix) \ 12 | { \ 13 | TORCH_CHECK(status == cutlass::Status::kSuccess, message_prefix, \ 14 | " : Got CUTLASS error: ", cutlassGetStatusString(status)); \ 15 | } 16 | 17 | namespace torchao { 18 | 19 | template 20 | struct enable_2x_kernel_for_sm80_or_later : Kernel { 21 | template 22 | CUTLASS_DEVICE static void invoke(Args&&... args) { 23 | #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 24 | Kernel::invoke(std::forward(args)...); 25 | #endif 26 | } 27 | }; 28 | 29 | template 30 | struct enable_3x_kernel_for_sm90_or_later : Kernel { 31 | template 32 | CUTLASS_DEVICE void operator()(Args&&... args) { 33 | #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 34 | Kernel::operator()(std::forward(args)...); 35 | #endif 36 | } 37 | }; 38 | 39 | } // namespace torchao 40 | -------------------------------------------------------------------------------- /torchao/csrc/cuda/fp6_llm/README.md: -------------------------------------------------------------------------------- 1 | # FP6-LLM kernel 2 | 3 | This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 or BF16 and W is in FP6 (E3M2 without infinities and NaN). 4 | 5 | On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion. 6 | 7 | See https://github.com/pytorch/ao/pull/223 and and https://github.com/pytorch/ao/pull/1147 for some benchmark results. 8 | -------------------------------------------------------------------------------- /torchao/csrc/cuda/marlin_qqq/base.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD 3-Clause license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | /* 7 | * Modified by HandH1998 8 | * Modified by Neural Magic 9 | * Copyright (C) Marlin.2024 Elias Frantar 10 | * 11 | * Licensed under the Apache License, Version 2.0 (the "License"); 12 | * you may not use this file except in compliance with the License. 13 | * You may obtain a copy of the License at 14 | * 15 | * http://www.apache.org/licenses/LICENSE-2.0 16 | * 17 | * Unless required by applicable law or agreed to in writing, software 18 | * distributed under the License is distributed on an "AS IS" BASIS, 19 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | * See the License for the specific language governing permissions and 21 | * limitations under the License. 22 | */ 23 | 24 | #pragma once 25 | 26 | namespace torchao { 27 | 28 | constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } 29 | 30 | // Instances of `Vec` are used to organize groups of >>registers<<, as needed 31 | // for instance as inputs to tensor core operations. Consequently, all 32 | // corresponding index accesses must be compile-time constants, which is why we 33 | // extensively use `#pragma unroll` throughout the kernel code to guarantee 34 | // this. 35 | template 36 | struct Vec { 37 | T elems[n]; 38 | __device__ T& operator[](int i) { return elems[i]; } 39 | }; 40 | 41 | } // namespace torchao 42 | -------------------------------------------------------------------------------- /torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD 3-Clause license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | #include 7 | 8 | #include "rowwise_scaled_linear_cutlass.cuh" 9 | 10 | namespace torchao { 11 | 12 | at::Tensor 13 | rowwise_scaled_linear_cutlass_s4s4( 14 | const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, 15 | const at::Tensor& W_scale, 16 | const std::optional& bias_opt = std::nullopt, 17 | const std::optional out_dtype_opt = std::nullopt) { 18 | // Validate input datatypes. 19 | TORCH_CHECK(Xq.dtype() == at::kChar && Wq.dtype() == at::kChar, 20 | __func__, " : The input datatypes combination ", Xq.dtype(), 21 | " for Xq and ", Wq.dtype(), " for Wq is not supported"); 22 | 23 | #if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) 24 | // Dispatch to appropriate kernel template. 25 | using ElementA = cutlass::int4b_t; 26 | using ElementB = cutlass::int4b_t; 27 | return rowwise_scaled_linear_cutlass( 28 | Xq, X_scale, Wq, W_scale, bias_opt, out_dtype_opt); 29 | #else 30 | TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); 31 | return at::Tensor{}; 32 | #endif 33 | } 34 | 35 | TORCH_LIBRARY_IMPL(torchao, CUDA, m) { 36 | m.impl("torchao::rowwise_scaled_linear_cutlass_s4s4", 37 | &rowwise_scaled_linear_cutlass_s4s4); 38 | } 39 | 40 | } // namespace torchao 41 | -------------------------------------------------------------------------------- /torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD 3-Clause license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | #include 7 | 8 | #include "rowwise_scaled_linear_cutlass.cuh" 9 | 10 | namespace torchao { 11 | 12 | at::Tensor 13 | rowwise_scaled_linear_cutlass_s8s4( 14 | const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, 15 | const at::Tensor& W_scale, 16 | const std::optional& bias_opt = std::nullopt, 17 | const std::optional out_dtype_opt = std::nullopt) { 18 | // Validate input datatypes. 19 | TORCH_CHECK(Xq.dtype() == at::kChar && Wq.dtype() == at::kChar, 20 | __func__, " : The input datatypes combination ", Xq.dtype(), 21 | " for Xq and ", Wq.dtype(), " for Wq is not supported"); 22 | 23 | #if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) 24 | // Dispatch to appropriate kernel template. 25 | using ElementA = int8_t; 26 | using ElementB = cutlass::int4b_t; 27 | return rowwise_scaled_linear_cutlass( 28 | Xq, X_scale, Wq, W_scale, bias_opt, out_dtype_opt); 29 | #else 30 | TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); 31 | return at::Tensor{}; 32 | #endif 33 | } 34 | 35 | TORCH_LIBRARY_IMPL(torchao, CUDA, m) { 36 | m.impl("torchao::rowwise_scaled_linear_cutlass_s8s4", 37 | &rowwise_scaled_linear_cutlass_s8s4); 38 | } 39 | 40 | } // namespace torchao 41 | -------------------------------------------------------------------------------- /torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD 3-Clause license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | #include "rowwise_scaled_linear_sparse_cutlass.cuh" 7 | #include "rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.h" 8 | 9 | namespace torchao { 10 | 11 | at::Tensor 12 | rowwise_scaled_linear_sparse_cutlass_e4m3e4m3( 13 | const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, 14 | const at::Tensor& W_meta, const at::Tensor& W_scale, 15 | const std::optional& bias_opt, 16 | const std::optional out_dtype_opt) { 17 | // Validate input datatypes. 18 | TORCH_CHECK( 19 | Xq.dtype() == at::kFloat8_e4m3fn && Wq.dtype() == at::kFloat8_e4m3fn, 20 | __func__, " : The input datatypes combination ", Xq.dtype(), " for Xq and ", 21 | Wq.dtype(), " for Wq is not supported"); 22 | 23 | #if defined(BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS) 24 | using DtypeXq = cutlass::float_e4m3_t; 25 | using DtypeWq = cutlass::float_e4m3_t; 26 | return rowwise_scaled_linear_sparse_cutlass( 27 | Xq, X_scale, Wq, W_meta, W_scale, bias_opt, out_dtype_opt); 28 | #else 29 | TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); 30 | return at::Tensor{}; 31 | #endif 32 | } 33 | 34 | } // namespace torchao 35 | -------------------------------------------------------------------------------- /torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e4m3.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD 3-Clause license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | #pragma once 7 | 8 | #include 9 | #include 10 | 11 | namespace torchao { 12 | 13 | at::Tensor 14 | rowwise_scaled_linear_sparse_cutlass_e4m3e4m3( 15 | const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, 16 | const at::Tensor& W_meta, const at::Tensor& W_scale, 17 | const std::optional& bias_opt, 18 | const std::optional out_dtype_opt); 19 | 20 | } // namespace torchao 21 | -------------------------------------------------------------------------------- /torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD 3-Clause license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | #include "rowwise_scaled_linear_sparse_cutlass.cuh" 7 | #include "rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.h" 8 | 9 | namespace torchao { 10 | 11 | at::Tensor 12 | rowwise_scaled_linear_sparse_cutlass_e4m3e5m2( 13 | const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, 14 | const at::Tensor& W_meta, const at::Tensor& W_scale, 15 | const std::optional& bias_opt, 16 | const std::optional out_dtype_opt) { 17 | // Validate input datatypes. 18 | TORCH_CHECK( 19 | Xq.dtype() == at::kFloat8_e4m3fn && Wq.dtype() == at::kFloat8_e5m2, 20 | __func__, " : The input datatypes combination ", Xq.dtype(), " for Xq and ", 21 | Wq.dtype(), " for Wq is not supported"); 22 | 23 | #if defined(BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS) 24 | using DtypeXq = cutlass::float_e4m3_t; 25 | using DtypeWq = cutlass::float_e5m2_t; 26 | return rowwise_scaled_linear_sparse_cutlass( 27 | Xq, X_scale, Wq, W_meta, W_scale, bias_opt, out_dtype_opt); 28 | #else 29 | TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); 30 | return at::Tensor{}; 31 | #endif 32 | } 33 | 34 | } // namespace torchao 35 | -------------------------------------------------------------------------------- /torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e4m3e5m2.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD 3-Clause license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | #pragma once 7 | 8 | #include 9 | #include 10 | 11 | namespace torchao { 12 | 13 | at::Tensor 14 | rowwise_scaled_linear_sparse_cutlass_e4m3e5m2( 15 | const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, 16 | const at::Tensor& W_meta, const at::Tensor& W_scale, 17 | const std::optional& bias_opt, 18 | const std::optional out_dtype_opt); 19 | 20 | } // namespace torchao 21 | -------------------------------------------------------------------------------- /torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD 3-Clause license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | #include "rowwise_scaled_linear_sparse_cutlass.cuh" 7 | #include "rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.h" 8 | 9 | namespace torchao { 10 | 11 | at::Tensor 12 | rowwise_scaled_linear_sparse_cutlass_e5m2e4m3( 13 | const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, 14 | const at::Tensor& W_meta, const at::Tensor& W_scale, 15 | const std::optional& bias_opt, 16 | const std::optional out_dtype_opt) { 17 | // Validate input datatypes. 18 | TORCH_CHECK( 19 | Xq.dtype() == at::kFloat8_e5m2 && Wq.dtype() == at::kFloat8_e4m3fn, 20 | __func__, " : The input datatypes combination ", Xq.dtype(), " for Xq and ", 21 | Wq.dtype(), " for Wq is not supported"); 22 | 23 | #if defined(BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS) 24 | using DtypeXq = cutlass::float_e5m2_t; 25 | using DtypeWq = cutlass::float_e4m3_t; 26 | return rowwise_scaled_linear_sparse_cutlass( 27 | Xq, X_scale, Wq, W_meta, W_scale, bias_opt, out_dtype_opt); 28 | #else 29 | TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); 30 | return at::Tensor{}; 31 | #endif 32 | } 33 | 34 | } // namespace torchao 35 | -------------------------------------------------------------------------------- /torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e4m3.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD 3-Clause license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | #pragma once 7 | 8 | #include 9 | #include 10 | 11 | namespace torchao { 12 | 13 | at::Tensor 14 | rowwise_scaled_linear_sparse_cutlass_e5m2e4m3( 15 | const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, 16 | const at::Tensor& W_meta, const at::Tensor& W_scale, 17 | const std::optional& bias_opt, 18 | const std::optional out_dtype_opt); 19 | 20 | } // namespace torchao 21 | -------------------------------------------------------------------------------- /torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD 3-Clause license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | #include "rowwise_scaled_linear_sparse_cutlass.cuh" 7 | #include "rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.h" 8 | 9 | namespace torchao { 10 | 11 | at::Tensor 12 | rowwise_scaled_linear_sparse_cutlass_e5m2e5m2( 13 | const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, 14 | const at::Tensor& W_meta, const at::Tensor& W_scale, 15 | const std::optional& bias_opt, 16 | const std::optional out_dtype_opt) { 17 | // Validate input datatypes. 18 | TORCH_CHECK( 19 | Xq.dtype() == at::kFloat8_e5m2 && Wq.dtype() == at::kFloat8_e5m2, 20 | __func__, " : The input datatypes combination ", Xq.dtype(), " for Xq and ", 21 | Wq.dtype(), " for Wq is not supported"); 22 | 23 | #if defined(BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS) 24 | using DtypeXq = cutlass::float_e5m2_t; 25 | using DtypeWq = cutlass::float_e5m2_t; 26 | return rowwise_scaled_linear_sparse_cutlass( 27 | Xq, X_scale, Wq, W_meta, W_scale, bias_opt, out_dtype_opt); 28 | #else 29 | TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); 30 | return at::Tensor{}; 31 | #endif 32 | } 33 | 34 | } // namespace torchao 35 | -------------------------------------------------------------------------------- /torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_e5m2e5m2.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD 3-Clause license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | #pragma once 7 | 8 | #include 9 | #include 10 | 11 | namespace torchao { 12 | 13 | at::Tensor 14 | rowwise_scaled_linear_sparse_cutlass_e5m2e5m2( 15 | const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, 16 | const at::Tensor& W_meta, const at::Tensor& W_scale, 17 | const std::optional& bias_opt, 18 | const std::optional out_dtype_opt); 19 | 20 | } // namespace torchao 21 | -------------------------------------------------------------------------------- /torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x_f8.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD 3-Clause license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | #include 7 | 8 | #include "to_sparse_semi_structured_cutlass_sm9x.cuh" 9 | 10 | namespace torchao { 11 | 12 | std::tuple 13 | to_sparse_semi_structured_cutlass_sm9x_f8(const at::Tensor& W) { 14 | // Validate input datatypes. 15 | TORCH_CHECK(W.dtype() == at::kFloat8_e5m2 || W.dtype() == at::kFloat8_e4m3fn, 16 | __func__, " : The input datatype ", W.dtype(), 17 | " is not supported"); 18 | 19 | #if defined(BUILD_TO_SPARSE_SEMI_STRUCTURED_CUTLASS_SM9X) 20 | // Dispatch to appropriate kernel template. 21 | if (W.dtype() == at::kFloat8_e5m2) { 22 | using DtypeW = cutlass::float_e5m2_t; 23 | return to_sparse_semi_structured_cutlass_sm9x(W); 24 | } else if (W.dtype() == at::kFloat8_e4m3fn) { 25 | using DtypeW = cutlass::float_e4m3_t; 26 | return to_sparse_semi_structured_cutlass_sm9x(W); 27 | } 28 | return std::tuple(at::Tensor{}, at::Tensor{}); 29 | #else 30 | TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); 31 | return std::tuple(at::Tensor{}, at::Tensor{}); 32 | #endif 33 | } 34 | 35 | TORCH_LIBRARY_IMPL(torchao, CUDA, m) { 36 | m.impl("torchao::to_sparse_semi_structured_cutlass_sm9x_f8", 37 | &to_sparse_semi_structured_cutlass_sm9x_f8); 38 | } 39 | 40 | } // namespace torchao 41 | -------------------------------------------------------------------------------- /torchao/dtypes/README.md: -------------------------------------------------------------------------------- 1 | # README 2 | 3 | ## File Structure of the `dtypes` Folder 4 | 5 | The `dtypes` folder contains several important files and subfolders that are organized as follows: 6 | 7 | - **affine_quantized_tensor.py**: This is the main file, from which the subfolders `uintx` and `floatx` inherit. It contains the base tensor subclass `AffineQuantizedTensor` and code for layout and tensorImpl registration. 8 | 9 | - **affine_quantized_tensor_ops.py**: This file defines all the overriden aten ops and different dispatch kernels related to affine quantized tensors. 10 | 11 | - **utils.py**: A utility file that provides helper functions and common utilities used across different files in the `dtypes` folder. 12 | 13 | - **nf4tensor.py**: This file is specific to the NF4 tensor implementation, and layouts. 14 | 15 | ### Subfolders 16 | 17 | - **uintx**: A subfolder that contains layouts and tensor subclasses inheriting from `affine_quantized_tensor.py`. It is specialized for handling unsigned integer quantized tensors. 18 | 19 | - **floatx**: Similar to `uintx`, this subfolder contains layouts and tensor subclasses that inherit from `affine_quantized_tensor.py`, but it is focused on floating-point quantized tensors. 20 | -------------------------------------------------------------------------------- /torchao/dtypes/_nf4tensor_api.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | 8 | from torchao.core.config import AOBaseConfig 9 | from torchao.dtypes.nf4tensor import NF4Tensor 10 | from torchao.quantization.transform_module import ( 11 | register_quantize_module_handler, 12 | ) 13 | 14 | 15 | class NF4WeightOnlyConfig(AOBaseConfig): 16 | """ 17 | Note: the file location of this workflow is temporary. 18 | TODO(future PR): integrate this properly into torchao's directory structure 19 | """ 20 | 21 | block_size: int = 64 22 | scaler_block_size: int = 256 23 | 24 | 25 | # for bc 26 | nf4_weight_only = NF4WeightOnlyConfig 27 | 28 | 29 | @register_quantize_module_handler(NF4WeightOnlyConfig) 30 | def _nf4_weight_only_transform( 31 | module: torch.nn.Module, 32 | config: NF4WeightOnlyConfig, 33 | ) -> torch.nn.Module: 34 | block_size = config.block_size 35 | scaler_block_size = config.scaler_block_size 36 | 37 | new_weight = NF4Tensor.from_tensor(module.weight, block_size, scaler_block_size) 38 | module.weight = torch.nn.Parameter(new_weight, requires_grad=False) 39 | return module 40 | -------------------------------------------------------------------------------- /torchao/dtypes/floatx/__init__.py: -------------------------------------------------------------------------------- 1 | from .cutlass_semi_sparse_layout import ( 2 | CutlassSemiSparseLayout, 3 | ) 4 | from .float8_layout import Float8Layout 5 | from .floatx_tensor_core_layout import ( 6 | FloatxTensorCoreLayout, 7 | from_scaled_tc_floatx, 8 | to_scaled_tc_floatx, 9 | ) 10 | 11 | __all__ = [ 12 | "FloatxTensorCoreLayout", 13 | "to_scaled_tc_floatx", 14 | "from_scaled_tc_floatx", 15 | "Float8Layout", 16 | "CutlassSemiSparseLayout", 17 | ] 18 | -------------------------------------------------------------------------------- /torchao/dtypes/uintx/__init__.py: -------------------------------------------------------------------------------- 1 | from .block_sparse_layout import ( 2 | BlockSparseLayout, 3 | ) 4 | from .cutlass_int4_packed_layout import ( 5 | CutlassInt4PackedLayout, 6 | ) 7 | from .int4_cpu_layout import ( 8 | Int4CPULayout, 9 | ) 10 | from .int4_xpu_layout import ( 11 | Int4XPULayout, 12 | ) 13 | from .marlin_qqq_tensor import ( 14 | MarlinQQQLayout, 15 | MarlinQQQTensor, 16 | to_marlinqqq_quantized_intx, 17 | ) 18 | from .marlin_sparse_layout import ( 19 | MarlinSparseLayout, 20 | ) 21 | from .packed_linear_int8_dynamic_activation_intx_weight_layout import ( 22 | PackedLinearInt8DynamicActivationIntxWeightLayout, 23 | ) 24 | from .q_dq_layout import ( 25 | QDQLayout, 26 | ) 27 | from .semi_sparse_layout import ( 28 | SemiSparseLayout, 29 | ) 30 | from .tensor_core_tiled_layout import ( 31 | TensorCoreTiledLayout, 32 | ) 33 | from .uintx_layout import ( 34 | UintxLayout, 35 | ) 36 | 37 | __all__ = [ 38 | "UintxLayout", 39 | "BlockSparseLayout", 40 | "MarlinSparseLayout", 41 | "SemiSparseLayout", 42 | "TensorCoreTiledLayout", 43 | "Int4CPULayout", 44 | "MarlinQQQLayout", 45 | "MarlinQQQTensor", 46 | "to_marlinqqq_quantized_intx", 47 | "CutlassInt4PackedLayout", 48 | "PackedLinearInt8DynamicActivationIntxWeightLayout", 49 | "QDQLayout", 50 | "Int4XPULayout", 51 | ] 52 | -------------------------------------------------------------------------------- /torchao/experimental/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/experimental/__init__.py -------------------------------------------------------------------------------- /torchao/experimental/benchmark_infra/ios/Entitlements-Dev.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | get-task-allow 6 | 7 | keychain-access-groups 8 | 9 | T84QZS65DQ.platformFamily 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /torchao/experimental/benchmark_infra/ios/TorchAOBenchmark-Info.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | CFBundleDevelopmentRegion 6 | en 7 | CFBundleDisplayName 8 | ${PRODUCT_NAME} 9 | CFBundleExecutable 10 | ${EXECUTABLE_NAME} 11 | CFBundleIdentifier 12 | ${FB_BUNDLE_ID} 13 | CFBundleInfoDictionaryVersion 14 | 6.0 15 | CFBundleName 16 | ${PRODUCT_NAME} 17 | CFBundlePackageType 18 | APPL 19 | CFBundleShortVersionString 20 | 1.0 21 | CFBundleSignature 22 | ???? 23 | CFBundleVersion 24 | 1.0 25 | LSRequiresIPhoneOS 26 | 27 | UILaunchStoryboardName 28 | LaunchScreen 29 | UISupportedInterfaceOrientations 30 | 31 | UIInterfaceOrientationPortrait 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /torchao/experimental/benchmark_infra/ios/main_empty.mm: -------------------------------------------------------------------------------- 1 | // (c) Meta Platforms, Inc. and affiliates. 2 | 3 | #include 4 | 5 | int main(int argc, char** argv) { 6 | std::cout << "Default main when no benchmarking deps are specified\n"; 7 | return 0; 8 | } 9 | -------------------------------------------------------------------------------- /torchao/experimental/benchmark_infra/ios/output_redirect.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD 3-Clause license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchao/experimental/benchmark_infra/ios/output_redirect.mm: -------------------------------------------------------------------------------- 1 | // (c) Meta Platforms, Inc. and affiliates. 2 | 3 | #import "output_redirect.h" 4 | 5 | #import 6 | #import 7 | #import 8 | #import 9 | 10 | #import 11 | 12 | class STDIORedirector { 13 | public: 14 | STDIORedirector() { 15 | if (@available(iOS 17, *)) { 16 | /* duplicate stdout */ 17 | std::string file_name = 18 | std::string(std::getenv("HOME")) + "/tmp/BENCH_LOG"; 19 | redirect_out_ = fopen(file_name.c_str(), "w"); 20 | stdout_dupfd_ = dup(STDOUT_FILENO); 21 | stderr_dupfd_ = dup(STDERR_FILENO); 22 | /* replace stdout with our output fd */ 23 | dup2(fileno(redirect_out_), STDOUT_FILENO); 24 | dup2(fileno(redirect_out_), STDERR_FILENO); 25 | fflush(stdout); 26 | fflush(stderr); 27 | setvbuf(stdout, nil, _IONBF, 0); 28 | setvbuf(stderr, nil, _IONBF, 0); 29 | setvbuf(redirect_out_, nil, _IONBF, 0); 30 | } 31 | } 32 | 33 | ~STDIORedirector() { 34 | if (@available(iOS 17, *)) { 35 | fflush(stdout); 36 | fflush(stderr); 37 | /* restore stdout */ 38 | dup2(stdout_dupfd_, STDOUT_FILENO); 39 | dup2(stderr_dupfd_, STDERR_FILENO); 40 | close(stdout_dupfd_); 41 | close(stderr_dupfd_); 42 | fclose(redirect_out_); 43 | } 44 | } 45 | 46 | private: 47 | FILE* redirect_out_; 48 | int stdout_dupfd_; 49 | int stderr_dupfd_; 50 | }; 51 | 52 | static STDIORedirector stdio_redirector_; 53 | -------------------------------------------------------------------------------- /torchao/experimental/benchmark_infra/test/test_bench.cpp: -------------------------------------------------------------------------------- 1 | // (c) Meta Platforms, Inc. and affiliates. 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace { 11 | std::random_device rd; 12 | std::mt19937 generator(rd()); 13 | 14 | // This test is to validate that the benchmarking binary 15 | // can be run on any device. Right now, it is focusing 16 | // on benchmarking on laptop (x86 oe mac) and iOS 17 | static void TestBenchmark(benchmark::State& state) { 18 | const int32_t K = state.range(0); 19 | auto a = std::make_unique(K); 20 | auto b = std::make_unique(K); 21 | auto c = std::make_unique(K); 22 | static std::uniform_real_distribution<> real_distrib(-1.0, 1.0); 23 | for (int ii = 0; ii < K; ++ii) { 24 | a[ii] = real_distrib(generator); 25 | b[ii] = real_distrib(generator); 26 | c[ii] = 0; 27 | } 28 | for (auto _ : state) { 29 | for (int ii = 0; ii < K; ++ii) { 30 | c[ii] = a[ii] + b[ii]; 31 | } 32 | } 33 | } 34 | 35 | BENCHMARK(TestBenchmark)->Args({4096 * 4})->UseRealTime(); 36 | } // namespace 37 | 38 | int main(int argc, char** argv) { 39 | char arg0_default[] = "benchmark"; 40 | char* args_default = arg0_default; 41 | if (!argv) { 42 | argc = 1; 43 | argv = &args_default; 44 | } 45 | ::benchmark::Initialize(&argc, argv); 46 | ::benchmark::RunSpecifiedBenchmarks(); 47 | ::benchmark::Shutdown(); 48 | return 0; 49 | } 50 | -------------------------------------------------------------------------------- /torchao/experimental/build_torchao_ops.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | if [[ $# -ne 1 ]]; then 9 | echo "Usage: $0 "; 10 | exit 1; 11 | fi 12 | TARGET="${1}" 13 | export CMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())') 14 | echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" 15 | if [[ $TARGET == "executorch" ]]; then 16 | TORCHAO_BUILD_EXECUTORCH_OPS=ON 17 | else 18 | TORCHAO_BUILD_EXECUTORCH_OPS=OFF 19 | fi 20 | export CMAKE_OUT=cmake-out 21 | cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ 22 | -DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \ 23 | -DTORCHAO_BUILD_EXECUTORCH_OPS="${TORCHAO_BUILD_EXECUTORCH_OPS}" \ 24 | -DTORCHAO_BUILD_CPU_AARCH64=ON \ 25 | -DTORCHAO_ENABLE_ARM_NEON_DOT=ON \ 26 | -S . \ 27 | -B ${CMAKE_OUT} 28 | cmake --build ${CMAKE_OUT} -j 16 --target install --config Release 29 | -------------------------------------------------------------------------------- /torchao/experimental/install_requirements.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # Install requirements for experimental torchao ops. 9 | if [[ -z $PIP ]]; 10 | then 11 | PIP=pip 12 | fi 13 | 14 | NIGHTLY_VERSION="dev20241011" 15 | $PIP install "executorch==0.5.0.$NIGHTLY_VERSION" --extra-index-url https://download.pytorch.org/whl/nightly/cpu 16 | -------------------------------------------------------------------------------- /torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | if (TORCHAO_BUILD_CPU_AARCH64) 8 | add_library( 9 | torchao_kernels_aarch64 10 | ${CMAKE_CURRENT_SOURCE_DIR}/reduction/find_min_and_max.cpp 11 | ${CMAKE_CURRENT_SOURCE_DIR}/reduction/compute_sum.cpp 12 | ${CMAKE_CURRENT_SOURCE_DIR}/quantization/quantize.cpp 13 | ${CMAKE_CURRENT_SOURCE_DIR}/valpacking/interleave.cpp 14 | ) 15 | if (TORCHAO_BUILD_KLEIDIAI) 16 | include(FetchContent) 17 | # KleidiAI is an open-source library that provides optimized 18 | # performance-critical routines, also known as micro-kernels, for artificial 19 | # intelligence (AI) workloads tailored for Arm® CPUs. 20 | FetchContent_Declare(kleidiai 21 | GIT_REPOSITORY https://git.gitlab.arm.com/kleidi/kleidiai.git 22 | GIT_TAG v1.5.0) 23 | FetchContent_MakeAvailable(kleidiai) 24 | 25 | target_link_libraries(torchao_kernels_aarch64 PUBLIC kleidiai) 26 | endif() 27 | 28 | install( 29 | TARGETS torchao_kernels_aarch64 30 | DESTINATION lib 31 | ) 32 | endif() 33 | -------------------------------------------------------------------------------- /torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_quantization.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #if defined(__aarch64__) || defined(__ARM_NEON) 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | static void benchmark_quantize(benchmark::State& state) { 15 | int nbit = state.range(0); 16 | int size = state.range(1); 17 | auto vals = torchao::get_random_vector(size, -10, 10); 18 | auto qvals = std::vector(size, 0); 19 | 20 | int qmin, qmax, zero; 21 | float vmin, vmax, scale; 22 | 23 | for (auto _ : state) { 24 | torchao::kernels::cpu::aarch64::reduction::find_min_and_max( 25 | vmin, vmax, vals.data(), vals.size()); 26 | 27 | torchao::quantization::get_qvals_range( 28 | qmin, qmax, nbit, /*is_symmetric=*/false); 29 | 30 | torchao::quantization::get_scale_and_zero( 31 | scale, zero, vmin, vmax, qmin, qmax); 32 | 33 | torchao::kernels::cpu::aarch64::quantization::quantize( 34 | qvals.data(), vals.data(), vals.size(), scale, zero, qmin, qmax); 35 | } 36 | } 37 | 38 | BENCHMARK(benchmark_quantize) 39 | ->ArgsProduct( 40 | {{3, 4, 8}, benchmark::CreateRange(1024, 131072, /*multi=*/4)}); 41 | 42 | // Run the benchmark 43 | BENCHMARK_MAIN(); 44 | 45 | #endif // defined(__aarch64__) || defined(__ARM_NEON) 46 | -------------------------------------------------------------------------------- /torchao/experimental/kernels/cpu/aarch64/benchmarks/build_and_run_benchmarks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -eu 9 | 10 | if [[ $# -ne 1 ]]; then 11 | echo "Usage: $0 "; 12 | exit 1; 13 | fi 14 | 15 | BENCHMARK_TYPE="${1}" 16 | SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) 17 | 18 | export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. 19 | export CMAKE_OUT=/tmp/cmake-out/torch_ao/benchmarks 20 | 21 | # Build 22 | cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ 23 | -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/benchmarks \ 24 | -B ${CMAKE_OUT} 25 | 26 | cmake --build ${CMAKE_OUT} 27 | 28 | # Run 29 | case "${BENCHMARK_TYPE}" in 30 | quantization) ${CMAKE_OUT}/benchmark_quantization; ;; 31 | bitpacking) ${CMAKE_OUT}/benchmark_bitpacking; ;; 32 | linear) ${CMAKE_OUT}/benchmark_linear; ;; 33 | *) echo "Unknown benchmark: $1. Please specify quantization, bitpacking, or linear."; exit 1; ;; 34 | esac 35 | -------------------------------------------------------------------------------- /torchao/experimental/kernels/cpu/aarch64/macro.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #define TORCHAO_ALWAYS_INLINE __attribute__((always_inline)) 10 | -------------------------------------------------------------------------------- /torchao/experimental/kernels/cpu/aarch64/quantization/quantize.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #if defined(__aarch64__) || defined(__ARM_NEON) 10 | #include 11 | 12 | // These methods are here temporarily 13 | // Eventually they will be moved to a non-arch specific location 14 | // or replaced by existing PyTorch functions 15 | // The quantize method in aarch64 namespace will remain here; 16 | // it is used for dynamic activation quantization 17 | namespace torchao { 18 | namespace quantization { 19 | 20 | void get_qvals_range(int& qmin, int& qmax, int nbit, bool is_symmetric); 21 | 22 | // val = scale * qval 23 | float get_scale(float vmin, float vmax, int qmin, int qmax); 24 | 25 | // val = scale * (qval - zero) 26 | void get_scale_and_zero( 27 | float& scale, 28 | int& zero, 29 | float vmin, 30 | float vmax, 31 | int qmin, 32 | int qmax); 33 | 34 | } // namespace quantization 35 | } // namespace torchao 36 | 37 | namespace torchao { 38 | namespace kernels { 39 | namespace cpu { 40 | namespace aarch64 { 41 | namespace quantization { 42 | void quantize( 43 | // Output 44 | int8_t* qvals, 45 | // Inputs 46 | const float32_t* vals, 47 | int size, 48 | float32_t scale, 49 | int8_t zero, 50 | int8_t qmin, 51 | int8_t qmax); 52 | 53 | } // namespace quantization 54 | } // namespace aarch64 55 | } // namespace cpu 56 | } // namespace kernels 57 | } // namespace torchao 58 | 59 | #endif // defined(__aarch64__) || defined(__ARM_NEON) 60 | -------------------------------------------------------------------------------- /torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #if defined(__aarch64__) || defined(__ARM_NEON) 8 | 9 | #include 10 | #include 11 | 12 | int32_t torchao::kernels::cpu::aarch64::reduction::compute_sum( 13 | const int8_t* vals, 14 | int size) { 15 | assert(size >= 1); 16 | 17 | int32_t res = 0; 18 | int i = 0; 19 | 20 | #pragma unroll(4) 21 | for (; i + 15 < size; i += 16) { 22 | int8x16_t vec_vals = vld1q_s8(vals + i); 23 | res += (int)(vaddlvq_s8(vec_vals)); 24 | } 25 | for (; i < size; i += 1) { 26 | res += vals[i]; 27 | } 28 | return res; 29 | } 30 | 31 | #endif // defined(__aarch64__) || defined(__ARM_NEON) 32 | -------------------------------------------------------------------------------- /torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #if defined(__aarch64__) || defined(__ARM_NEON) 8 | 9 | #include 10 | #include 11 | 12 | void torchao::kernels::cpu::aarch64::reduction::find_min_and_max( 13 | float32_t& min, 14 | float32_t& max, 15 | const float32_t* vals, 16 | int size) { 17 | assert(size > 0); 18 | 19 | // Needed in case size < 4 so we don't compare to 20 | // uninitialized min/max values 21 | min = vals[0]; 22 | max = min; 23 | 24 | int i = 0; 25 | if (i + 3 < size) { 26 | float32x4_t mins = vld1q_f32(vals + i); 27 | float32x4_t maxes = mins; 28 | i += 4; 29 | for (; i + 3 < size; i += 4) { 30 | float32x4_t v = vld1q_f32(vals + i); 31 | mins = vminq_f32(mins, v); 32 | maxes = vmaxq_f32(maxes, v); 33 | } 34 | min = vminvq_f32(mins); 35 | max = vmaxvq_f32(maxes); 36 | } 37 | 38 | // Remainder 39 | while (i < size) { 40 | if (vals[i] < min) { 41 | min = vals[i]; 42 | } 43 | if (vals[i] > max) { 44 | max = vals[i]; 45 | } 46 | i += 1; 47 | } 48 | } 49 | 50 | #endif // defined(__aarch64__) || defined(__ARM_NEON) 51 | -------------------------------------------------------------------------------- /torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #if defined(__aarch64__) || defined(__ARM_NEON) 10 | #include 11 | #include 12 | 13 | namespace torchao { 14 | namespace kernels { 15 | namespace cpu { 16 | namespace aarch64 { 17 | namespace reduction { 18 | void find_min_and_max( 19 | float32_t& min, 20 | float32_t& max, 21 | const float32_t* vals, 22 | int size); 23 | 24 | int32_t compute_sum(const int8_t* vals, int size); 25 | 26 | } // namespace reduction 27 | } // namespace aarch64 28 | } // namespace cpu 29 | } // namespace kernels 30 | } // namespace torchao 31 | 32 | #endif // defined(__aarch64__) || defined(__ARM_NEON) 33 | -------------------------------------------------------------------------------- /torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | namespace torchao { 10 | namespace kernels { 11 | namespace cpu { 12 | namespace valpacking { 13 | 14 | // TODO: should this be relocated out of aarch64? 15 | void interleave_data( 16 | void* data_interleaved, 17 | const void* data, 18 | int bytes_per_val, 19 | int vals_per_channel, 20 | int vals_per_group, 21 | int vals_per_chunk, 22 | int channels, 23 | int channel_stride_in_vals); 24 | 25 | } // namespace valpacking 26 | } // namespace cpu 27 | } // namespace kernels 28 | } // namespace torchao 29 | -------------------------------------------------------------------------------- /torchao/experimental/kernels/mps/metal.yaml: -------------------------------------------------------------------------------- 1 | - func: Vec4Type 2 | file: common.metal 3 | 4 | - func: int1mm 5 | file: int1mm.metal 6 | 7 | - func: int2mm 8 | file: int2mm_opt.metal 9 | 10 | - func: int3mm 11 | file: int3mm_opt.metal 12 | 13 | - func: int4mm 14 | file: int4mm_opt.metal 15 | 16 | - func: int5mm 17 | file: int5mm.metal 18 | 19 | - func: int6mm 20 | file: int6mm.metal 21 | 22 | - func: int7mm 23 | file: int7mm.metal 24 | 25 | - func: qmv_fast 26 | file: qmv_fast.metal 27 | -------------------------------------------------------------------------------- /torchao/experimental/kernels/mps/metal/common.metal: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD 3-Clause license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | template struct Vec4Type {}; 8 | 9 | template <> struct Vec4Type { 10 | using type = float4; 11 | }; 12 | 13 | template <> struct Vec4Type { 14 | using type = half4; 15 | }; 16 | 17 | #if __METAL_VERSION__ >= 310 18 | template <> struct Vec4Type { 19 | using type = bfloat4; 20 | }; 21 | #endif 22 | -------------------------------------------------------------------------------- /torchao/experimental/kernels/mps/src/OperationUtils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | id getMetalDevice(); 10 | 11 | class MPSStream { 12 | public: 13 | MPSStream() { 14 | _commandQueue = [getMetalDevice() newCommandQueue]; 15 | } 16 | 17 | ~MPSStream() { 18 | [_commandQueue release]; 19 | _commandQueue = nil; 20 | 21 | assert(_commandBuffer == nil); 22 | } 23 | 24 | id queue() const { 25 | return _commandQueue; 26 | } 27 | 28 | id commandBuffer() { 29 | if (!_commandBuffer) { 30 | auto desc = [MTLCommandBufferDescriptor new]; 31 | desc.errorOptions = MTLCommandBufferErrorOptionEncoderExecutionStatus; 32 | _commandBuffer = [_commandQueue commandBufferWithDescriptor:desc]; 33 | } 34 | return _commandBuffer; 35 | } 36 | 37 | id commandEncoder() { 38 | if (!_commandEncoder) { 39 | _commandEncoder = [commandBuffer() computeCommandEncoder]; 40 | } 41 | return _commandEncoder; 42 | } 43 | 44 | private: 45 | id _commandQueue = nil; 46 | id _commandBuffer = nil; 47 | id _commandEncoder = nil; 48 | }; 49 | 50 | inline MPSStream* getCurrentMPSStream() { 51 | return new MPSStream(); 52 | } 53 | -------------------------------------------------------------------------------- /torchao/experimental/kernels/mps/src/OperationUtils.mm: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | id getMetalDevice() { 12 | @autoreleasepool { 13 | NSArray* devices = [MTLCopyAllDevices() autorelease]; 14 | if (devices.count == 0) { 15 | throw std::runtime_error("Metal is not supported"); 16 | } 17 | static id MTL_DEVICE = devices[0]; 18 | return MTL_DEVICE; 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /torchao/experimental/kernels/mps/src/common.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #ifdef USE_ATEN 10 | #include 11 | using namespace at::native::mps; 12 | #elif defined(USE_EXECUTORCH) 13 | #include 14 | using namespace executorch::backends::mps::delegate; 15 | #else 16 | #include 17 | #endif 18 | 19 | inline void dispatch_block( 20 | MPSStream* mpsStream, 21 | void (^block)()) { 22 | #if defined(USE_ATEN) 23 | dispatch_sync_with_rethrow(mpsStream->queue(), block); 24 | #elif defined(USE_EXECUTORCH) 25 | dispatch_sync(mpsStream->queue(), block); 26 | #else 27 | (void)mpsStream; 28 | block(); 29 | #endif 30 | } 31 | 32 | inline void optionally_wait_for_command_completion(MPSStream* mpsStream) { 33 | #if defined(USE_ATEN) 34 | #elif defined(USE_EXECUTORCH) 35 | ET_CHECK(mpsStream->synchronize(SyncType::COMMIT_AND_WAIT) == executorch::runtime::Error::Ok); 36 | #else 37 | id encoder = mpsStream->commandEncoder(); 38 | id cmdBuffer = mpsStream->commandBuffer(); 39 | [encoder endEncoding]; 40 | [cmdBuffer commit]; 41 | [cmdBuffer waitUntilCompleted]; 42 | #endif 43 | } 44 | 45 | inline id get_metal_device() { 46 | #if defined(USE_ATEN) || defined(USE_EXECUTORCH) 47 | return MPSDevice::getInstance()->device(); 48 | #else 49 | return getMetalDevice(); 50 | #endif 51 | } 52 | -------------------------------------------------------------------------------- /torchao/experimental/kernels/mps/src/dispatch.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | 11 | namespace torchao::kernels::mps::lowbit::dispatch { 12 | 13 | inline void dispatch_mm( 14 | id encoder, 15 | int32_t maxThreadsPerGroup, 16 | int32_t M, 17 | int32_t N, 18 | [[maybe_unused]] int32_t K) { 19 | [encoder dispatchThreads:MTLSizeMake(N, M, 1) 20 | threadsPerThreadgroup:MTLSizeMake(std::min(maxThreadsPerGroup, M), 1, 1)]; 21 | } 22 | 23 | inline void dispatch_mm_Mr1xNr4_per_TG( 24 | id encoder, 25 | int32_t maxThreadsPerGroup, 26 | int32_t M, 27 | int32_t N, 28 | int32_t K) { 29 | (void)K; 30 | if (maxThreadsPerGroup < 32) { 31 | throw std::runtime_error("Can't dispatch!"); 32 | } 33 | [encoder dispatchThreads:MTLSizeMake(N / 4 * 32, 1, M) 34 | threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; 35 | } 36 | 37 | inline void dispatch_qmv_fast( 38 | id encoder, 39 | int32_t maxThreadsPerGroup, 40 | int32_t M, 41 | int32_t N, 42 | int32_t K) { 43 | (void)K; 44 | if (maxThreadsPerGroup < 64) { 45 | throw std::runtime_error("Can't dispatch!"); 46 | } 47 | [encoder dispatchThreadgroups:MTLSizeMake(M, (N + 7) / 8, 1) 48 | threadsPerThreadgroup:MTLSizeMake(32, 2, 1)]; 49 | } 50 | 51 | } // namespace torchao::kernels::mps::lowbit::dispatch 52 | -------------------------------------------------------------------------------- /torchao/experimental/kernels/mps/test/Makefile: -------------------------------------------------------------------------------- 1 | all: test_lowbit 2 | 3 | test_lowbit: test_lowbit.mm ../src/OperationUtils.mm 4 | clang++ -I${TORCHAO_ROOT} -O3 -std=c++17 -Wall -Wextra -o $@ $^ -framework Metal -framework Foundation 5 | 6 | run: test_lowbit 7 | ./test_lowbit 8 | -------------------------------------------------------------------------------- /torchao/experimental/kernels/mps/test/bfloat16.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | /** 12 | * This implementation is copied from 13 | * executorch/runtime/core/portable_type/bfloat16.h 14 | */ 15 | 16 | inline float f32_from_bits(uint16_t src) { 17 | float res = 0; 18 | uint32_t tmp = src; 19 | tmp <<= 16; 20 | std::memcpy(&res, &tmp, sizeof(tmp)); 21 | return res; 22 | } 23 | 24 | inline uint16_t bits_from_f32(float src) { 25 | uint32_t res = 0; 26 | std::memcpy(&res, &src, sizeof(res)); 27 | return res >> 16; 28 | } 29 | 30 | inline uint16_t round_to_nearest_even(float src) { 31 | if (std::isnan(src)) { 32 | return UINT16_C(0x7FC0); 33 | } 34 | uint32_t U32 = 0; 35 | std::memcpy(&U32, &src, sizeof(U32)); 36 | uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); 37 | return static_cast((U32 + rounding_bias) >> 16); 38 | } 39 | 40 | /** 41 | * The "brain floating-point" type, compatible with c10/util/BFloat16.h from 42 | * pytorch core. 43 | * 44 | * This representation uses 1 bit for the sign, 8 bits for the exponent and 7 45 | * bits for the mantissa. 46 | */ 47 | struct alignas(2) BFloat16 { 48 | uint16_t x; 49 | 50 | BFloat16() = default; 51 | struct from_bits_t {}; 52 | static constexpr from_bits_t from_bits() { 53 | return from_bits_t(); 54 | } 55 | 56 | constexpr BFloat16(unsigned short bits, from_bits_t) : x(bits) {} 57 | /* implicit */ BFloat16(float value) : x(round_to_nearest_even(value)) {} 58 | operator float() const { 59 | return f32_from_bits(x); 60 | } 61 | }; 62 | -------------------------------------------------------------------------------- /torchao/experimental/op_lib_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | 10 | def _check_torchao_ops_loaded(): 11 | # Check kernels are installed/loaded 12 | try: 13 | torch.ops.torchao._pack_8bit_act_4bit_weight 14 | except AttributeError: 15 | raise Exception( 16 | "TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU." 17 | + " You can also set target to 'aten' if you are using ARM CPU." 18 | ) 19 | -------------------------------------------------------------------------------- /torchao/experimental/ops/benchmarks/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | cmake_minimum_required(VERSION 3.19) 8 | project(benchmarks) 9 | 10 | set(CMAKE_CXX_STANDARD 17) 11 | set(CMAKE_BUILD_TYPE Release) 12 | add_compile_options("-Wall" "-Werror") 13 | 14 | set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) 15 | set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) 16 | 17 | include(FetchContent) 18 | FetchContent_Declare(googlebenchmark 19 | GIT_REPOSITORY https://github.com/google/benchmark.git 20 | GIT_TAG main) # need main for benchmark::benchmark 21 | 22 | set(BENCHMARK_ENABLE_TESTING OFF) 23 | FetchContent_MakeAvailable( 24 | googlebenchmark) 25 | 26 | include_directories(${TORCHAO_INCLUDE_DIRS}) 27 | 28 | set(TORCHAO_PARALLEL_BACKEND "openmp") 29 | 30 | include(${TORCHAO_ROOT}/Utils.cmake) 31 | 32 | add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) 33 | 34 | add_executable(benchmark_linear_8bit_act_xbit_weight 35 | benchmark_linear_8bit_act_xbit_weight.cpp 36 | ${TORCHAO_ROOT}/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp 37 | ) 38 | target_link_torchao_parallel_backend(benchmark_linear_8bit_act_xbit_weight "${TORCHAO_PARALLEL_BACKEND}") 39 | target_link_libraries( 40 | benchmark_linear_8bit_act_xbit_weight 41 | PRIVATE 42 | benchmark::benchmark 43 | torchao_kernels_aarch64 44 | ) 45 | -------------------------------------------------------------------------------- /torchao/experimental/ops/benchmarks/build_and_run_benchmarks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # Call script with sh build_and_run_benchmarks.sh {BENCHAMRK} 9 | 10 | export CMAKE_OUT=/tmp/cmake-out/torchao/benchmarks 11 | cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ 12 | -S . \ 13 | -B ${CMAKE_OUT} \ 14 | -DOpenMP_ROOT=$(brew --prefix libomp) \ 15 | -DTORCHAO_PARALLEL_OMP=ON 16 | 17 | cmake --build ${CMAKE_OUT} 18 | 19 | # Run 20 | ${CMAKE_OUT}/benchmark_linear_8bit_act_xbit_weight 21 | -------------------------------------------------------------------------------- /torchao/experimental/ops/embedding_xbit/packed_weights_header.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | #include 9 | #include 10 | 11 | namespace torchao::ops::embedding_xbit { 12 | 13 | inline torchao::ops::PackedWeightsHeader get_packed_weights_header_universal( 14 | int weight_nbit, 15 | int min_value_chunk_size, 16 | int max_value_chunk_size, 17 | int version = 1) { 18 | return torchao::ops::PackedWeightsHeader( 19 | torchao::ops::PackedWeightsType::embedding_xbit_universal, 20 | {version, 21 | weight_nbit, 22 | min_value_chunk_size, 23 | max_value_chunk_size, 24 | 0, 25 | 0, 26 | 0, 27 | 0, 28 | 0, 29 | 0, 30 | 0, 31 | 0, 32 | 0, 33 | 0}); 34 | } 35 | 36 | } // namespace torchao::ops::embedding_xbit 37 | -------------------------------------------------------------------------------- /torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | project(examples) 8 | 9 | cmake_minimum_required(VERSION 3.19) 10 | set(CMAKE_CXX_STANDARD 17) 11 | set(CMAKE_BUILD_TYPE Release) 12 | 13 | include(CMakePrintHelpers) 14 | 15 | set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) 16 | set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..) 17 | 18 | include_directories(${TORCHAO_INCLUDE_DIRS}) 19 | 20 | set(TORCHAO_PARALLEL_BACKEND "openmp") 21 | add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) 22 | 23 | include(${TORCHAO_ROOT}/Utils.cmake) 24 | 25 | add_executable(separate_function_wrappers 26 | separate_function_wrappers.cpp 27 | ${TORCHAO_ROOT}/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp 28 | ) 29 | target_link_libraries( 30 | separate_function_wrappers 31 | PRIVATE 32 | torchao_kernels_aarch64 33 | ) 34 | target_link_torchao_parallel_backend(separate_function_wrappers "${TORCHAO_PARALLEL_BACKEND}") 35 | 36 | add_executable(stateful_class_wrapper 37 | stateful_class_wrapper.cpp 38 | ${TORCHAO_ROOT}/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp 39 | ) 40 | target_link_libraries( 41 | stateful_class_wrapper 42 | PRIVATE 43 | torchao_kernels_aarch64 44 | ) 45 | target_link_torchao_parallel_backend(stateful_class_wrapper "${TORCHAO_PARALLEL_BACKEND}") 46 | -------------------------------------------------------------------------------- /torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/build_and_run_examples.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" 9 | echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" 10 | export CMAKE_OUT=/tmp/cmake-out/torchao/examples 11 | cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ 12 | -S . \ 13 | -B ${CMAKE_OUT} \ 14 | -DOpenMP_ROOT=$(brew --prefix libomp) 15 | cmake --build ${CMAKE_OUT} 16 | 17 | # Run 18 | case "$1" in 19 | separate_function_wrappers) ${CMAKE_OUT}/separate_function_wrappers; ;; 20 | stateful_class_wrapper) ${CMAKE_OUT}/stateful_class_wrapper; ;; 21 | *) echo "Unknown example: $1. Please specify one of: separate_function_wrappers, stateful_class_wrapper."; exit 1; ;; 22 | esac 23 | -------------------------------------------------------------------------------- /torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define DEFINE_OP(weight_nbit) \ 4 | Tensor _op_out_##weight_nbit( \ 5 | RuntimeContext& ctx, \ 6 | const Tensor& activations, \ 7 | const Tensor& packed_weights, \ 8 | const int64_t& group_size, \ 9 | const int64_t& n, \ 10 | const int64_t& k, \ 11 | Tensor& out) { \ 12 | (void)ctx; \ 13 | linear_out_cpu( \ 14 | activations, packed_weights, group_size, n, k, out); \ 15 | return out; \ 16 | } \ 17 | EXECUTORCH_LIBRARY( \ 18 | torchao, \ 19 | "_linear_8bit_act_" #weight_nbit "bit_weight.out", \ 20 | _op_out_##weight_nbit) 21 | 22 | DEFINE_OP(1); 23 | DEFINE_OP(2); 24 | DEFINE_OP(3); 25 | DEFINE_OP(4); 26 | DEFINE_OP(5); 27 | DEFINE_OP(6); 28 | DEFINE_OP(7); 29 | DEFINE_OP(8); 30 | -------------------------------------------------------------------------------- /torchao/experimental/ops/memory.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | namespace torchao { 15 | 16 | using aligned_byte_ptr = std::unique_ptr; 17 | 18 | inline aligned_byte_ptr make_aligned_byte_ptr(size_t alignment, size_t size) { 19 | // Adjust size to next multiple of alignment >= size 20 | size_t adjusted_size = ((size + alignment - 1) / alignment) * alignment; 21 | 22 | char* ptr = static_cast(std::aligned_alloc(alignment, adjusted_size)); 23 | if (!ptr) { 24 | throw std::runtime_error( 25 | "Failed to allocate memory. Requested size: " + std::to_string(size) + 26 | ". Requested alignment: " + std::to_string(alignment) + 27 | ". Adjusted size: " + std::to_string(adjusted_size) + "."); 28 | } 29 | return std::unique_ptr(ptr, std::free); 30 | } 31 | } // namespace torchao 32 | -------------------------------------------------------------------------------- /torchao/experimental/ops/mps/.gitignore: -------------------------------------------------------------------------------- 1 | cmake-out/ 2 | -------------------------------------------------------------------------------- /torchao/experimental/ops/mps/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | cd "$(dirname "$BASH_SOURCE")" 9 | 10 | export CMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())') 11 | echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" 12 | export CMAKE_OUT=${PWD}/cmake-out 13 | echo "CMAKE_OUT: ${CMAKE_OUT}" 14 | 15 | cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ 16 | -DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \ 17 | -S . \ 18 | -B ${CMAKE_OUT} 19 | cmake --build ${CMAKE_OUT} -j 16 --target install --config Release 20 | -------------------------------------------------------------------------------- /torchao/experimental/ops/mps/mps_op_lib.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor 9 | from torch.library import impl 10 | 11 | torchao_lib = torch.library.Library("torchao", "IMPL") 12 | for nbit in range(1, 8): 13 | 14 | @impl(torchao_lib, f"_linear_fp_act_{nbit}bit_weight", "Meta") 15 | def _( 16 | activations: Tensor, 17 | packed_weights: Tensor, 18 | group_size: int, 19 | scales: int, 20 | zeros: int, 21 | ): 22 | assert activations.dtype in [torch.float32, torch.float16, torch.bfloat16] 23 | assert activations.is_contiguous() 24 | assert activations.dim() == 2 25 | 26 | assert packed_weights.dtype == torch.uint8 27 | assert packed_weights.is_contiguous() 28 | 29 | m = activations.size(0) 30 | k = activations.size(1) 31 | n = packed_weights.size(0) 32 | 33 | assert k % 8 == 0 34 | assert n % 4 == 0 35 | 36 | assert group_size in [32, 64, 128, 256] 37 | 38 | assert scales.is_contiguous() 39 | assert scales.dim() == 2 40 | assert scales.size(0) == n 41 | 42 | assert zeros.is_contiguous() 43 | assert zeros.dim() == 2 44 | assert zeros.size(0) == n 45 | 46 | return torch.empty(m, n, dtype=activations.dtype, device="meta") 47 | -------------------------------------------------------------------------------- /torchao/experimental/ops/parallel-aten-impl.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | #include 9 | #include 10 | #include 11 | 12 | // F has signature [&](int64_t idx) 13 | template 14 | void torchao::parallel_1d(const int64_t begin, const int64_t end, const F& f) { 15 | at::parallel_for(begin, end, 1, [&](int64_t begin, int64_t end) { 16 | for (int64_t idx = begin; idx < end; idx++) { 17 | f(idx); 18 | } 19 | }); 20 | } 21 | 22 | inline void torchao::set_num_threads(int num_threads) { 23 | torch::set_num_threads(num_threads); 24 | } 25 | 26 | inline int torchao::get_num_threads() { 27 | return torch::get_num_threads(); 28 | } 29 | -------------------------------------------------------------------------------- /torchao/experimental/ops/parallel-executorch-impl.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | 11 | template 12 | void torchao::parallel_1d(const int64_t begin, const int64_t end, const F& f) { 13 | torch::executorch::threadpool::get_threadpool()->run( 14 | [&](size_t i) { 15 | int64_t idx = begin + i; 16 | f(idx); 17 | }, 18 | end - begin); 19 | } 20 | 21 | inline void torchao::set_num_threads(int num_threads) { 22 | torch::executorch::threadpool::get_threadpool()->_unsafe_reset_threadpool( 23 | num_threads); 24 | } 25 | 26 | inline int torchao::get_num_threads() { 27 | return torch::executorch::threadpool::get_threadpool()->get_thread_count(); 28 | } 29 | -------------------------------------------------------------------------------- /torchao/experimental/ops/parallel-openmp-impl.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | #include 9 | 10 | template 11 | void torchao::parallel_1d(const int64_t begin, const int64_t end, const F& f) { 12 | #pragma omp parallel 13 | { 14 | #pragma omp for 15 | for (int i = begin; i < end; i += 1) { 16 | f(i); 17 | } 18 | } 19 | } 20 | 21 | inline void torchao::set_num_threads(int num_threads) { 22 | omp_set_num_threads(num_threads); 23 | } 24 | inline int torchao::get_num_threads() { 25 | // omp_get_num_threads returns the number of threads 26 | // in the current code section, which will be 1 in the routines 27 | // that select tiling params 28 | return omp_get_max_threads(); 29 | } 30 | -------------------------------------------------------------------------------- /torchao/experimental/ops/parallel-single_threaded-impl.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | template 10 | void torchao::parallel_1d(const int64_t begin, const int64_t end, const F& f) { 11 | for (int i = begin; i < end; i += 1) { 12 | f(i); 13 | } 14 | } 15 | 16 | inline void torchao::set_num_threads(int num_threads) {} 17 | inline int torchao::get_num_threads() { 18 | return 1; 19 | } 20 | -------------------------------------------------------------------------------- /torchao/experimental/ops/parallel-test_dummy-impl.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | namespace torchao::parallel::internal { 8 | static int num_threads_test_dummy_{1}; 9 | } 10 | 11 | template 12 | void torchao::parallel_1d(const int64_t begin, const int64_t end, const F& f) { 13 | for (int i = begin; i < end; i += 1) { 14 | f(i); 15 | } 16 | } 17 | 18 | inline void torchao::set_num_threads(int num_threads) { 19 | torchao::parallel::internal::num_threads_test_dummy_ = num_threads; 20 | } 21 | inline int torchao::get_num_threads() { 22 | return torchao::parallel::internal::num_threads_test_dummy_; 23 | } 24 | -------------------------------------------------------------------------------- /torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # TODO: delete this file. 8 | # File is kept in torchao/experimental to avoid breaking existing code 9 | import logging 10 | 11 | logging.warning( 12 | "torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout.py is deprecated and will be removed. Please use torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout.py instead." 13 | ) 14 | from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import ( 15 | PackedLinearInt8DynamicActivationIntxWeightLayout, 16 | Target, 17 | ) 18 | 19 | __all__ = [ 20 | "PackedLinearInt8DynamicActivationIntxWeightLayout", 21 | "Target", 22 | ] 23 | -------------------------------------------------------------------------------- /torchao/experimental/q_dq_layout.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # TODO: delete this file. 8 | # File is kept in torchao/experimental to avoid breaking existing code 9 | import logging 10 | 11 | logging.warning( 12 | "torchao.experimental.q_dq_layout.py is deprecated and will be removed. Please use torchao.dtypes.uintx.q_dq_layout.py instead." 13 | ) 14 | from torchao.dtypes import QDQLayout 15 | 16 | __all__ = [ 17 | "QDQLayout", 18 | ] 19 | -------------------------------------------------------------------------------- /torchao/experimental/temp_build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import glob 8 | import subprocess 9 | import tempfile 10 | 11 | import torch 12 | 13 | 14 | def cmake_build_torchao_ops(cmake_lists_path, temp_build_dir): 15 | from distutils.sysconfig import get_python_lib 16 | 17 | print("Building torchao ops for ATen target") 18 | cmake_prefix_path = get_python_lib() 19 | subprocess.run( 20 | [ 21 | "cmake", 22 | "-DCMAKE_PREFIX_PATH=" + cmake_prefix_path, 23 | "-DCMAKE_INSTALL_PREFIX=" + temp_build_dir.name, 24 | "-S " + cmake_lists_path, 25 | "-B " + temp_build_dir.name, 26 | ] 27 | ) 28 | subprocess.run( 29 | [ 30 | "cmake", 31 | "--build", 32 | temp_build_dir.name, 33 | "-j 16", 34 | "--target install", 35 | "--config Release", 36 | ] 37 | ) 38 | 39 | 40 | def temp_build_and_load_torchao_ops(cmake_lists_path): 41 | temp_build_dir = tempfile.TemporaryDirectory() 42 | cmake_build_torchao_ops(cmake_lists_path, temp_build_dir) 43 | libs = glob.glob(f"{temp_build_dir.name}/lib/libtorchao_ops_aten.*") 44 | libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) 45 | assert len(libs) == 1 46 | torch.ops.load_library(libs[0]) 47 | print(f"TorchAO ops are loaded from {libs[0]}") 48 | -------------------------------------------------------------------------------- /torchao/float8/__init__.py: -------------------------------------------------------------------------------- 1 | # Lets define a few top level things here 2 | from torchao.float8.config import ( 3 | CastConfig, 4 | Float8GemmConfig, 5 | Float8LinearConfig, 6 | ScalingType, 7 | ) 8 | from torchao.float8.float8_linear_utils import ( 9 | convert_to_float8_training, 10 | ) 11 | from torchao.float8.float8_tensor import ( 12 | Float8Tensor, 13 | GemmInputRole, 14 | LinearMMConfig, 15 | ScaledMMConfig, 16 | ) 17 | from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp 18 | from torchao.float8.inference import Float8MMConfig 19 | from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 20 | 21 | if TORCH_VERSION_AT_LEAST_2_5: 22 | # Needed to load Float8Tensor with weights_only = True 23 | from torch.serialization import add_safe_globals 24 | 25 | add_safe_globals( 26 | [ 27 | Float8Tensor, 28 | ScaledMMConfig, 29 | GemmInputRole, 30 | LinearMMConfig, 31 | Float8MMConfig, 32 | ] 33 | ) 34 | 35 | __all__ = [ 36 | # configuration 37 | "ScalingType", 38 | "Float8GemmConfig", 39 | "Float8LinearConfig", 40 | "CastConfig", 41 | # top level UX 42 | "convert_to_float8_training", 43 | "precompute_float8_dynamic_scale_for_fsdp", 44 | # note: Float8Tensor and Float8Linear are not public APIs 45 | ] 46 | -------------------------------------------------------------------------------- /torchao/float8/distributed_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.distributed._functional_collectives as funcol 9 | from torch.distributed._tensor import DTensor 10 | 11 | from torchao.float8.float8_tensor import Float8Tensor 12 | 13 | 14 | def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: 15 | """ 16 | Check if the tensor is already casted to fp8, works if the local 17 | tensor is wrapped in DTensor. 18 | """ 19 | if isinstance(tensor, Float8Tensor): 20 | return True 21 | elif isinstance(tensor, DTensor): 22 | # TODO: shall we stick to public API and directly use tensor.to_local() here? 23 | return tensor_already_casted_to_fp8(tensor._local_tensor) 24 | elif isinstance(tensor, funcol.AsyncCollectiveTensor): 25 | return tensor_already_casted_to_fp8(tensor.elem) 26 | 27 | return False 28 | -------------------------------------------------------------------------------- /torchao/kernel/README.md: -------------------------------------------------------------------------------- 1 | ## Autotuner and custom Triton kernels 2 | 3 | ### Environment variables 4 | 5 | `TORCHAO_AUTOTUNER_ENABLE=0` 6 | 7 | Set this to a nonzero value to enable the kernels generated by the autotuner. This is turned off by default, because it is still an experimental feature and also can take a long time to run. 8 | 9 | `TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE=EXHAUSTIVE` 10 | Use this to enable exhaustive search for both int8mm and scaled_mm kernels. 11 | 12 | Searching a new config can take a long time and we'll save the updated data in `data.pkl`. If you'd like to contributed updated configs for your hardware or shapes, please open a pull request. 13 | 14 | `TORCHAO_AUTOTUNER_DATA_PATH=torchao/kernel/configs/data_a100.pkl` 15 | 16 | By default we load precomputed configs for A100. If we're not on an A100, we search set the path to `data.pkl`. 17 | 18 | Updated configs are always stored in the current working directory as `data.pkl` to avoid accidentally overwriting the supplied configs. 19 | -------------------------------------------------------------------------------- /torchao/kernel/__init__.py: -------------------------------------------------------------------------------- 1 | from torchao.kernel.bsr_triton_ops import bsr_dense_addmm 2 | from torchao.kernel.intmm import int_scaled_matmul, safe_int_mm 3 | 4 | __all__ = [ 5 | "bsr_dense_addmm", 6 | "safe_int_mm", 7 | "int_scaled_matmul", 8 | ] 9 | -------------------------------------------------------------------------------- /torchao/kernel/configs/data_a100.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/kernel/configs/data_a100.pkl -------------------------------------------------------------------------------- /torchao/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .adam import Adam4bit, Adam8bit, AdamFp8, AdamW4bit, AdamW8bit, AdamWFp8, _AdamW 2 | from .cpu_offload import CPUOffloadOptimizer 3 | 4 | __all__ = [ 5 | "Adam4bit", 6 | "Adam8bit", 7 | "AdamFp8", 8 | "AdamW4bit", 9 | "AdamW8bit", 10 | "AdamWFp8", 11 | "_AdamW", 12 | "CPUOffloadOptimizer", 13 | ] 14 | -------------------------------------------------------------------------------- /torchao/prototype/README.md: -------------------------------------------------------------------------------- 1 | # Prototype 2 | 3 | ### Experimental kernels and utilities for efficient inference and training 4 | 5 | > The goal isn't to reproduce all emerging methodologies but to extract common components across prevalent, proven paradigms that can be modularized and composed with the `torch` stack as well as other OSS ML frameworks. 6 | 7 | #### Code structure 8 | 9 | - `galore` - fused kernels for memory-efficient pre-training / fine-tuning per the [GaLore algorithm](https://arxiv.org/abs/2403.03507) 10 | - `galore/kernels` - `triton` kernels that fuse various steps of the `GaLore` algorithm 11 | - `galore/docs` - implementation notes and discussion of issues faced in kernel design. 12 | - [`quant_llm`](quant_llm) - FP16 x Floatx mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112) 13 | - ~~`low_bit_optim`~~ - re-implementation of 8-bit optimizers from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and 4-bit optimizers from [lpmm](https://github.com/thu-ml/low-bit-optimizers). **Promoted to `torchao.optim`.** 14 | - [`spinquant`](spinquant) - re-implementation of [SpinQuant](https://arxiv.org/abs/2405.16406) 15 | 16 | #### Roadmap 17 | 18 | - `hqq`, `awq`, `marlin`,`QuaRot`, and other well-researched methodologies for quantized fine-tuning and inference. 19 | - ideally, techniques that are both **theoretically sound** and have **practical hardware-aware implementations** 20 | - AWQ and GPTQ are good examples. 21 | - `cutlass` / `triton` utilities for common quantization ops (numeric conversion, quant / dequant, mixed type gemm, etc.) 22 | - goal is to create a set of kernels and components that can expedite the implementation & optimization across the spectrum of quantization, fine-tuning, and inference patterns. 23 | -------------------------------------------------------------------------------- /torchao/prototype/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/__init__.py -------------------------------------------------------------------------------- /torchao/prototype/autoround/__init__.py: -------------------------------------------------------------------------------- 1 | from torchao.prototype.autoround.core import ( 2 | apply_auto_round, 3 | prepare_model_for_applying_auto_round_, 4 | ) 5 | from torchao.prototype.autoround.multi_tensor import MultiTensor 6 | 7 | __all__ = [ 8 | "apply_auto_round", 9 | "prepare_model_for_applying_auto_round_", 10 | "MultiTensor", 11 | ] 12 | -------------------------------------------------------------------------------- /torchao/prototype/autoround/requirements.txt: -------------------------------------------------------------------------------- 1 | auto_round @ git+https://github.com/intel/auto-round.git@patch-for-ao-2 2 | numpy < 2.0 # dataset requires numpy < 2.0, can be removed after dataset is updated 3 | datasets # for loading the calibration dataset 4 | transformers # for loading the model -------------------------------------------------------------------------------- /torchao/prototype/autoround/run_example.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Run examples 7 | python autoround_llm.py -m meta-llama/Llama-2-7b-chat-hf 8 | python autoround_llm.py -m meta-llama/Llama-2-7b-chat-hf --quant_lm_head 9 | python autoround_llm.py -m meta-llama/Meta-Llama-3-8B-Instruct --model_device cpu 10 | python autoround_llm.py -m meta-llama/Meta-Llama-3.1-8B-Instruct --model_device cpu 11 | 12 | # Evaluate with lm-eval 13 | # Auto-round 14 | python eval_autoround.py -m meta-llama/Llama-2-7b-chat-hf --tasks wikitext lambada_openai hellaswag winogrande piqa mmlu 15 | python eval_autoround.py -m meta-llama/Meta-Llama-3-8B-Instruct --model_device cpu --tasks wikitext lambada_openai hellaswag winogrande piqa mmlu 16 | python eval_autoround.py -m meta-llama/Meta-Llama-3.1-8B-Instruct --model_device cpu --tasks wikitext lambada_openai hellaswag winogrande piqa mmlu 17 | # wo_int4 18 | python eval_autoround.py -m meta-llama/Llama-2-7b-chat-hf --woq_int4 --tasks wikitext lambada_openai hellaswag winogrande piqa mmlu 19 | python eval_autoround.py -m meta-llama/Meta-Llama-3-8B-Instruct --woq_int4 --tasks wikitext lambada_openai hellaswag winogrande piqa mmlu 20 | python eval_autoround.py -m meta-llama/Meta-Llama-3.1-8B-Instruct --woq_int4 --tasks wikitext lambada_openai hellaswag winogrande piqa mmlu 21 | # uintx 22 | python eval_autoround.py -m /models/Meta-Llama-3.1-8B-Instruct/ --uintx --bits 2 --tasks wikitext lambada_openai hellaswag winogrande piqa mmlu 23 | -------------------------------------------------------------------------------- /torchao/prototype/awq/__init__.py: -------------------------------------------------------------------------------- 1 | from .api import awq_uintx, insert_awq_observer_ 2 | from .core import AWQObservedLinear 3 | 4 | __all__ = [ 5 | "awq_uintx", 6 | "insert_awq_observer_", 7 | "AWQObservedLinear", 8 | ] 9 | -------------------------------------------------------------------------------- /torchao/prototype/blockwise_fp8/__init__.py: -------------------------------------------------------------------------------- 1 | from .blockwise_linear import BlockwiseQuantLinear 2 | from .blockwise_quantization import ( 3 | blockwise_fp8_gemm, 4 | fp8_blockwise_act_quant, 5 | fp8_blockwise_weight_dequant, 6 | fp8_blockwise_weight_quant, 7 | ) 8 | 9 | __all__ = [ 10 | "blockwise_fp8_gemm", 11 | "BlockwiseQuantLinear", 12 | "fp8_blockwise_act_quant", 13 | "fp8_blockwise_weight_quant", 14 | "fp8_blockwise_weight_dequant", 15 | ] 16 | -------------------------------------------------------------------------------- /torchao/prototype/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/common/__init__.py -------------------------------------------------------------------------------- /torchao/prototype/common/triton/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/common/triton/__init__.py -------------------------------------------------------------------------------- /torchao/prototype/float8nocompile/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/float8nocompile/__init__.py -------------------------------------------------------------------------------- /torchao/prototype/float8nocompile/examples/example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | import torch.nn as nn 8 | 9 | from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( 10 | convert_to_float8_nocompile_training, 11 | ) 12 | from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 13 | 14 | if not TORCH_VERSION_AT_LEAST_2_5: 15 | raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") 16 | 17 | # create model and sample input 18 | m = ( 19 | nn.Sequential( 20 | nn.Linear(32, 32), 21 | ) 22 | .bfloat16() 23 | .cuda() 24 | ) 25 | x = torch.randn(32, 32, device="cuda", dtype=torch.bfloat16) 26 | optimizer = torch.optim.SGD(m.parameters(), lr=0.1) 27 | 28 | # convert specified `torch.nn.Linear` modules to `Float8Linear` 29 | print("calling convert_to_float8_nocompile_training") 30 | convert_to_float8_nocompile_training(m) 31 | print("finished convert_to_float8_nocompile_training") 32 | 33 | for i in range(10): 34 | print(f"step {i}") 35 | optimizer.zero_grad() 36 | y = m(x) 37 | y.sum().backward() 38 | optimizer.step() 39 | -------------------------------------------------------------------------------- /torchao/prototype/float8nocompile/float8nocompile_loss_curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/float8nocompile/float8nocompile_loss_curves.png -------------------------------------------------------------------------------- /torchao/prototype/float8nocompile/kernels/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/float8nocompile/kernels/__init__.py -------------------------------------------------------------------------------- /torchao/prototype/galore/README.md: -------------------------------------------------------------------------------- 1 | ## Fused GaLore 2 | 3 | ### Experimental kernels for fusing various parts of the GaLore algorithm 4 | 5 | #### AdamW 6 | 7 | See `docs/galore_adam.md` for implementation notes. 8 | 9 | #### AdamW8bit 10 | 11 | See `docs/galore_adam8bit.md` for implementation notes. 12 | -------------------------------------------------------------------------------- /torchao/prototype/galore/__init__.py: -------------------------------------------------------------------------------- 1 | from .kernels import * # noqa: F403 2 | -------------------------------------------------------------------------------- /torchao/prototype/galore/kernels/__init__.py: -------------------------------------------------------------------------------- 1 | from .adam_downproj_fused import fused_adam_mm_launcher 2 | from .adam_step import triton_adam_launcher 3 | from .matmul import triton_mm_launcher 4 | from .quant import triton_dequant_blockwise, triton_quantize_blockwise 5 | 6 | __all__ = [ 7 | "fused_adam_mm_launcher", 8 | "triton_adam_launcher", 9 | "triton_mm_launcher", 10 | "triton_dequant_blockwise", 11 | "triton_quantize_blockwise", 12 | ] 13 | -------------------------------------------------------------------------------- /torchao/prototype/galore/optim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/galore/optim/__init__.py -------------------------------------------------------------------------------- /torchao/prototype/hqq/__init__.py: -------------------------------------------------------------------------------- 1 | from .mixed_mm import pack_2xint4, triton_mixed_mm 2 | 3 | __all__ = [ 4 | "pack_2xint4", 5 | "triton_mixed_mm", 6 | ] 7 | -------------------------------------------------------------------------------- /torchao/prototype/inductor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/inductor/__init__.py -------------------------------------------------------------------------------- /torchao/prototype/inductor/codegen/__init__.py: -------------------------------------------------------------------------------- 1 | from .cpp_int8_sdpa_template import CppInt8SdpaTemplate 2 | 3 | __all__ = [ 4 | "CppInt8SdpaTemplate", 5 | ] 6 | -------------------------------------------------------------------------------- /torchao/prototype/inductor/codegen/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | 3 | from torch._inductor import lowering as L 4 | from torch._inductor.codegen.cpp_template_kernel import ( 5 | parse_expr_with_index_symbols, 6 | wrap_with_tensorbox, 7 | ) 8 | 9 | 10 | def expand(node, sizes: List[Any]): 11 | node = wrap_with_tensorbox(node) 12 | sizes = parse_expr_with_index_symbols(sizes) 13 | return L.expand(node, sizes).data 14 | -------------------------------------------------------------------------------- /torchao/prototype/inductor/fx_passes/README.md: -------------------------------------------------------------------------------- 1 | # Inductor FX Passes 2 | 3 | This directory contains the FX passes of Inductor. FX passes are transformations applied to the FX graph to optimize and modify it for better performance and functionality. 4 | 5 | In TorchAO, you can replace the following customized graph passes of Inductor: 6 | - `pre_grad_custom_pass` 7 | - `joint_custom_pre_pass` 8 | - `joint_custom_post_pass` 9 | - `post_grad_custom_post_pass` 10 | - `post_grad_custom_pre_pass` 11 | 12 | ## Directory Structure 13 | 14 | - `int8_sdpa_fusion`: Pattern match for int8 sdpa fusion. 15 | 16 | ## Getting Started 17 | 18 | To get started with using the FX passes in TorchAO, you can register and apply them to your FX graph as follows: 19 | 20 | ```python 21 | from torch._inductor import config 22 | from torch._inductor.pattern_matcher import PatternMatcherPass 23 | 24 | # Example usage 25 | class _CustomPass(...): # create a custom pass class 26 | custom_pass = _CustomPass() # create an instance of custom pass 27 | with config.patch(config.custom_pass=custom_pass): 28 | _register_patterns(config.custom_pass) # register your own passes 29 | 30 | ``` 31 | 32 | ## Limitations 33 | 34 | For now, we can only register one pass as the custom pass. 35 | In the future, it is better to extend it to a list. 36 | -------------------------------------------------------------------------------- /torchao/prototype/inductor/fx_passes/__init__.py: -------------------------------------------------------------------------------- 1 | from .int8_sdpa_fusion import _int8_sdpa_init 2 | 3 | __all__ = [ 4 | "_int8_sdpa_init", 5 | ] 6 | -------------------------------------------------------------------------------- /torchao/prototype/moe_quant/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/moe_quant/__init__.py -------------------------------------------------------------------------------- /torchao/prototype/mx_formats/__init__.py: -------------------------------------------------------------------------------- 1 | from torchao.prototype.mx_formats.config import ( 2 | MXGemmKernelChoice, 3 | MXInferenceLinearConfig, 4 | MXLinearConfig, 5 | MXLinearRecipeName, 6 | ) 7 | 8 | # Note: Prototype and subject to change 9 | from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig 10 | 11 | # import mx_linear here to register the quantize_ transform logic 12 | # ruff: noqa: I001 13 | import torchao.prototype.mx_formats.mx_linear # noqa: F401 14 | 15 | __all__ = [ 16 | "MXGemmKernelChoice", 17 | "MXInferenceLinearConfig", 18 | "MXLinearConfig", 19 | "MXLinearRecipeName", 20 | "MXFPInferenceConfig", 21 | ] 22 | -------------------------------------------------------------------------------- /torchao/prototype/mx_formats/mx_funcs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | This file defines the top level torch ops that are extended by MXTensor 9 | See: https://docs.pytorch.org/docs/stable/notes/extending.html#extending-torch-with-a-tensor-wrapper-type 10 | for more details. 11 | """ 12 | 13 | from typing import Any, Dict 14 | 15 | import torch 16 | 17 | from torchao.prototype.mx_formats.mx_ops import _addmm_mx_dispatch 18 | from torchao.prototype.mx_formats.mx_tensor import ( # noqa: E501 19 | MXTensor, 20 | ) 21 | 22 | aten = torch.ops.aten 23 | 24 | MX_FUNC_TABLE: Dict[Any, Any] = {} 25 | 26 | 27 | def implements_func(torch_ops): 28 | """Register torch ops to the mx op table for torch function""" 29 | 30 | def decorator(func): 31 | for op in torch_ops: 32 | MX_FUNC_TABLE[op] = func 33 | return func 34 | 35 | return decorator 36 | 37 | 38 | @implements_func([aten.linear.default]) 39 | def mx_linear(func, types, args, kwargs): 40 | a, b = args[0], args[1] 41 | assert isinstance(a, MXTensor) and isinstance(b, MXTensor) 42 | bias = args[2] if len(args) == 3 else None 43 | return _addmm_mx_dispatch(a, b.t(), func, bias=bias) 44 | -------------------------------------------------------------------------------- /torchao/prototype/paretoq/1_run_train.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | torchrun --nnodes=1 --nproc_per_node=1 train.py \ 7 | --local_dir "/tmp/llama/" \ 8 | --input_model_filename "meta-llama/Llama-3.2-1B" \ 9 | --output_model_filename "1B-finetuned" \ 10 | --train_data_local_path "/tmp/train.jsonl" \ 11 | --do_train True \ 12 | --do_eval False \ 13 | --model_max_length 2048 \ 14 | --fp16 False \ 15 | --bf16 True \ 16 | --log_on_each_node False \ 17 | --logging_dir /tmp/output/runs/current \ 18 | --num_train_epochs 1 \ 19 | --per_device_train_batch_size 2 \ 20 | --per_device_eval_batch_size 1 \ 21 | --gradient_accumulation_steps 1 \ 22 | --evaluation_strategy "no" \ 23 | --save_strategy "steps" \ 24 | --save_steps 2000 \ 25 | --report_to "tensorboard" \ 26 | --save_total_limit 1 \ 27 | --learning_rate 2e-5 \ 28 | --weight_decay 0. \ 29 | --warmup_ratio 0. \ 30 | --lr_scheduler_type "cosine" \ 31 | --logging_steps 1 \ 32 | --tf32 False \ 33 | --gradient_checkpointing False \ 34 | --qat True \ 35 | --w_bits 4 \ 36 | -------------------------------------------------------------------------------- /torchao/prototype/paretoq/2_run_eval.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | torchrun --nnodes=1 --nproc_per_node=1 train.py \ 7 | CUDA_VISIBLE_DEVICES=0 torchrun --nnodes=1 --nproc_per_node=1 train.py \ 8 | --local_dir "/tmp/llama/" \ 9 | --input_model_filename "/tmp/llama_1B/llama_1B_bit1" \ 10 | --output_model_filename "1B-finetuned" \ 11 | --train_data_local_path "/tmp/train.jsonl" \ 12 | --eval_data_local_path "/tmp/wikitext-2/test.jsonl" \ 13 | --do_train False \ 14 | --do_eval True \ 15 | --model_max_length 2048 \ 16 | --fp16 False \ 17 | --bf16 True \ 18 | --log_on_each_node False \ 19 | --logging_dir /tmp/output/runs/current \ 20 | --num_train_epochs 1 \ 21 | --per_device_train_batch_size 2 \ 22 | --per_device_eval_batch_size 4 \ 23 | --gradient_accumulation_steps 1 \ 24 | --evaluation_strategy "no" \ 25 | --save_strategy "steps" \ 26 | --save_steps 2000 \ 27 | --report_to "tensorboard" \ 28 | --save_total_limit 1 \ 29 | --learning_rate 2e-5 \ 30 | --weight_decay 0. \ 31 | --warmup_ratio 0. \ 32 | --lr_scheduler_type "cosine" \ 33 | --logging_steps 1 \ 34 | --tf32 False \ 35 | --gradient_checkpointing False \ 36 | --qat True \ 37 | --w_bits 1 \ 38 | --contain_weight_clip_val True \ 39 | -------------------------------------------------------------------------------- /torchao/prototype/paretoq/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/paretoq/__init__.py -------------------------------------------------------------------------------- /torchao/prototype/paretoq/main_result_234bit.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/paretoq/main_result_234bit.jpg -------------------------------------------------------------------------------- /torchao/prototype/paretoq/main_result_scaling_law.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/paretoq/main_result_scaling_law.jpg -------------------------------------------------------------------------------- /torchao/prototype/paretoq/main_result_ternary.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/paretoq/main_result_ternary.jpg -------------------------------------------------------------------------------- /torchao/prototype/paretoq/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/paretoq/models/__init__.py -------------------------------------------------------------------------------- /torchao/prototype/paretoq/requirement.txt: -------------------------------------------------------------------------------- 1 | transformers==4.48.3 2 | accelerate>=0.26.0 3 | datasets==2.20.0 4 | sentencepiece 5 | tensorboardX 6 | -------------------------------------------------------------------------------- /torchao/prototype/parq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .optim import ( # noqa: F401 8 | ProxBinaryRelax, 9 | ProxHardQuant, 10 | ProxMap, 11 | ProxPARQ, 12 | QuantOptimizer, 13 | ) 14 | from .quant import ( # noqa: F401 15 | Int4UnifTorchaoQuantizer, 16 | LSBQuantizer, 17 | MaxUnifQuantizer, 18 | Quantizer, 19 | TernaryUnifQuantizer, 20 | UnifQuantizer, 21 | UnifTorchaoQuantizer, 22 | ) 23 | -------------------------------------------------------------------------------- /torchao/prototype/parq/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .binarelax import ProxBinaryRelax # noqa: F401 8 | from .parq import ProxPARQ # noqa: F401 9 | from .proxmap import ProxHardQuant, ProxMap # noqa: F401 10 | from .quantopt import QuantOptimizer # noqa: F401 11 | -------------------------------------------------------------------------------- /torchao/prototype/parq/optim/binarelax.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | from torch import Tensor 11 | 12 | from ..utils import channel_bucketize 13 | from .proxmap import ProxMap 14 | 15 | 16 | class ProxBinaryRelax(ProxMap): 17 | """Prox-map of Binary Relax, Q may not be evenly spaced.""" 18 | 19 | def __init__(self, anneal_start: int, anneal_end: int) -> None: 20 | self.anneal_start = anneal_start 21 | self.anneal_end = anneal_end 22 | 23 | @torch.no_grad() 24 | def apply_( 25 | self, 26 | p: Tensor, 27 | q: Tensor, 28 | Q: Tensor, 29 | step_count: int, 30 | dim: Optional[int] = None, 31 | ) -> None: 32 | if step_count < self.anneal_start: 33 | return 34 | 35 | if q is None: 36 | # hard quantization to the nearest point in Q 37 | Q_mid = (Q[..., :-1] + Q[..., 1:]) / 2 38 | if dim is None: 39 | q = Q[torch.bucketize(p, Q_mid)] 40 | else: 41 | q = Q.gather(1, channel_bucketize(p, Q_mid)) 42 | 43 | if step_count >= self.anneal_end: 44 | p.copy_(q) 45 | else: 46 | # linear annealing of relaxation coefficient 47 | theta = (step_count - self.anneal_start) / ( 48 | self.anneal_end - self.anneal_start 49 | ) 50 | p.mul_(1 - theta).add_(q, alpha=theta) 51 | -------------------------------------------------------------------------------- /torchao/prototype/parq/optim/proxmap.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from abc import ABC, abstractmethod 8 | from typing import Optional 9 | 10 | import torch 11 | from torch import Tensor 12 | 13 | from ..utils import channel_bucketize 14 | 15 | 16 | # Create an abstract class to provide proximal-mapping interface 17 | class ProxMap(ABC): 18 | @abstractmethod 19 | def apply_(self, p: Tensor, q: Tensor, Q: Tensor, step_count: int) -> None: 20 | """Provide interface for proximal mapping (modify p in-place): 21 | prox_map.apply_(p, q, Q, step_count) 22 | Inputs: 23 | p (Tensor): tensor to be quantized 24 | q (Tensor): None or hard quantized tensor of same size as p 25 | Q (Tensor): set of target quantization values 26 | step_count: trigger iteration-dependent mapping if needed 27 | """ 28 | 29 | 30 | class ProxHardQuant(ProxMap): 31 | """Prox-map of hard quantization, Q may not be evenly spaced.""" 32 | 33 | @torch.no_grad() 34 | def apply_( 35 | self, 36 | p: Tensor, 37 | q: Tensor, 38 | Q: Tensor, 39 | step_count: int, 40 | dim: Optional[int] = None, 41 | ) -> None: 42 | if q is None: 43 | # quantize to the nearest point in Q 44 | Q_mid = (Q[..., :-1] + Q[..., 1:]) / 2 45 | if dim is None: 46 | q = Q[torch.bucketize(p, Q_mid)] 47 | else: 48 | q = Q.gather(1, channel_bucketize(p, Q_mid)) 49 | p.copy_(q) 50 | -------------------------------------------------------------------------------- /torchao/prototype/parq/quant/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .lsbq import LSBQuantizer # noqa: F401 8 | from .quantizer import Quantizer # noqa: F401 9 | from .uniform import ( # noqa: F401 10 | MaxUnifQuantizer, 11 | TernaryUnifQuantizer, 12 | UnifQuantizer, 13 | ) 14 | from .uniform_torchao import ( # noqa: F401 15 | Int4UnifTorchaoQuantizer, 16 | UnifTorchaoQuantizer, 17 | ) 18 | -------------------------------------------------------------------------------- /torchao/prototype/parq/quant/quantizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from abc import ABC, abstractmethod 8 | from typing import Optional 9 | 10 | from torch import Tensor 11 | 12 | 13 | class Quantizer(ABC): 14 | """Abstract base class that defines the quantization interface""" 15 | 16 | def __init__(self, center: bool = False) -> None: 17 | self.center = center 18 | 19 | @abstractmethod 20 | def get_quant_size(self, b: int) -> int: 21 | """Given number of bits b, return total number of quantization values""" 22 | 23 | @abstractmethod 24 | def quantize(self, p: Tensor, b: int) -> tuple[Tensor, Tensor]: 25 | """Provide interface for quantization: 26 | q, Q = Quantizer.quantize(p) 27 | Inputs: 28 | p (Tensor): tensor to be quantized 29 | Outputs: 30 | q (Tensor): quantized tensor of same size as p 31 | Q (Tensor): set of 2^b quantization values 32 | Instantiation should not modify p, leaving update to ProxMap. 33 | """ 34 | 35 | @staticmethod 36 | def remove_mean(p: Tensor, dim: Optional[int] = None) -> tuple[Tensor, Tensor]: 37 | """Center parameters in a Tensor, called if self.center == True. 38 | Note that this is different from direct asymmetric quantization, 39 | and may lead to (hopefully only) slightly different performance. 40 | """ 41 | mean = p.mean(dim=dim, keepdim=dim is not None) 42 | q = p - mean 43 | return q, mean 44 | -------------------------------------------------------------------------------- /torchao/prototype/parq/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor 9 | 10 | try: 11 | from torch.distributed.tensor import DTensor 12 | 13 | HAS_DTENSOR = True 14 | except ImportError: 15 | HAS_DTENSOR = False 16 | 17 | 18 | def is_dtensor(x): 19 | return HAS_DTENSOR and isinstance(x, DTensor) 20 | 21 | 22 | def channel_bucketize(input: Tensor, boundaries: Tensor, right: bool = False) -> Tensor: 23 | """Generalizes torch.bucketize to run on 2-D boundaries.""" 24 | inf_pad = torch.full_like(boundaries[:, :1], torch.inf) 25 | boundaries = ( 26 | torch.cat((-inf_pad, boundaries), dim=1) 27 | if right 28 | else torch.cat((boundaries, inf_pad), dim=1) 29 | ) 30 | boundaries = boundaries.unsqueeze(1) 31 | input = input.unsqueeze(-1) 32 | mask = input.ge(boundaries) if right else input.le(boundaries) 33 | return mask.to(torch.uint8).argmax(dim=-1) 34 | -------------------------------------------------------------------------------- /torchao/prototype/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | from .gguf import GGUFWeightOnlyConfig 2 | 3 | __all__ = [ 4 | "GGUFWeightOnlyConfig", 5 | ] 6 | -------------------------------------------------------------------------------- /torchao/prototype/quantization/codebook/__init__.py: -------------------------------------------------------------------------------- 1 | from .codebook_ops import ( 2 | choose_qparams_codebook, 3 | dequantize_codebook, 4 | quantize_codebook, 5 | ) 6 | from .codebook_quantized_tensor import CodebookQuantizedTensor, codebook_weight_only 7 | 8 | __all__ = [ 9 | "CodebookQuantizedTensor", 10 | "codebook_weight_only", 11 | "quantize_codebook", 12 | "dequantize_codebook", 13 | "choose_qparams_codebook", 14 | ] 15 | -------------------------------------------------------------------------------- /torchao/prototype/quantization/gguf/__init__.py: -------------------------------------------------------------------------------- 1 | from .api import GGUFWeightOnlyConfig 2 | from .gguf_quantized_tensor import ( 3 | GGUFQuantizedTensor, 4 | ) 5 | 6 | __all__ = [ 7 | "GGUFQuantizedTensor", 8 | "GGUFWeightOnlyConfig", 9 | ] 10 | -------------------------------------------------------------------------------- /torchao/prototype/quantization/gguf/api.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass 8 | 9 | import torch 10 | 11 | from torchao.core.config import AOBaseConfig 12 | from torchao.quantization.transform_module import register_quantize_module_handler 13 | 14 | from .gguf_quantized_tensor import GGUFQuantizedTensor 15 | 16 | __all__ = [ 17 | "GGUFWeightOnlyConfig", 18 | ] 19 | 20 | 21 | @dataclass 22 | class GGUFWeightOnlyConfig(AOBaseConfig): 23 | dtype: torch.dtype = torch.uint4 24 | n_blocks_per_superblock: int = 8 25 | 26 | 27 | @register_quantize_module_handler(GGUFWeightOnlyConfig) 28 | def _gguf_weight_only_transform( 29 | module: torch.nn.Module, 30 | config: GGUFWeightOnlyConfig, 31 | ): 32 | """ 33 | Applies gguf weight-only quantization to linear layers. 34 | 35 | Args: 36 | dtype: torch.uint1 to torch.uint8, torch.int32 supported. 37 | n_blocks_per_superblock: the number of super blocks in a 256 element block for gguf, e.g. when it is 8 38 | it means we have blocks of 32 and 8 blocks in a superblock of 256 elements. 39 | Returns: 40 | Callable for quantization transformation. 41 | """ 42 | weight = module.weight 43 | if (weight.ndim != 2) or (weight.shape[-1] % 256 != 0): 44 | return module 45 | 46 | quantized_weight = GGUFQuantizedTensor.from_float( 47 | weight, 48 | n_blocks_per_superblock=config.n_blocks_per_superblock, 49 | target_dtype=config.dtype, 50 | ) 51 | module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) 52 | return module 53 | -------------------------------------------------------------------------------- /torchao/prototype/quantization/mixed_precision/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/quantization/mixed_precision/__init__.py -------------------------------------------------------------------------------- /torchao/prototype/quantization/mixed_precision/scripts/Llama3-8B_parameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "parameters": [ 3 | { 4 | "name": "bitwidth", 5 | "name_format": "bitwidth.{i}.", 6 | "layers": [ 7 | {"range": [0, 3], "type": "fixed", "value": 5}, 8 | {"range": [3, 30], "type": "choice", "values": [2, 3, 4, 5, 6, 8]}, 9 | {"range": [30, 32], "type": "fixed", "value": 5} 10 | ] 11 | }, 12 | { 13 | "name": "groupsize", 14 | "name_format": "groupsize.{i}.", 15 | "layers": [ 16 | {"range": [0, 3], "type": "fixed", "value": 32}, 17 | {"range": [3, 30], "type": "choice", "values": [32, 64, 128, 256]}, 18 | {"range": [30, 32], "type": "fixed", "value": 32} 19 | ] 20 | } 21 | ] 22 | } 23 | -------------------------------------------------------------------------------- /torchao/prototype/quantization/mixed_precision/scripts/Mistral-7B_parameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "parameters": [ 3 | { 4 | "name": "bitwidth", 5 | "name_format": "bitwidth.{i}.", 6 | "layers": [ 7 | {"range": [0, 4], "type": "fixed", "value": 5}, 8 | {"range": [4, 32], "type": "choice", "values": [2, 3, 4, 5, 6, 8]} 9 | ] 10 | }, 11 | { 12 | "name": "groupsize", 13 | "name_format": "groupsize.{i}.", 14 | "layers": [ 15 | {"range": [0, 4], "type": "fixed", "value": 32}, 16 | {"range": [4, 32], "type": "choice", "values": [32, 64, 128, 256]} 17 | ] 18 | } 19 | ] 20 | } 21 | -------------------------------------------------------------------------------- /torchao/prototype/quantization/mixed_precision/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | from .naive_intNwo import intN_weight_only 2 | 3 | __all__ = [ 4 | "intN_weight_only", 5 | ] 6 | -------------------------------------------------------------------------------- /torchao/prototype/quantization/module_swap/__init__.py: -------------------------------------------------------------------------------- 1 | from .module_swap import ( 2 | QuantizationRecipe, 3 | quantize_module_swap, 4 | ) 5 | from .quantized_modules import ( 6 | QuantizedEmbedding, 7 | QuantizedLinear, 8 | ) 9 | from .quantizers import ( 10 | CodeBookQuantizer, 11 | IntQuantizer, 12 | ) 13 | 14 | __all__ = [ 15 | "CodeBookQuantizer", 16 | "IntQuantizer", 17 | "QuantizedEmbedding", 18 | "QuantizedLinear", 19 | "QuantizationRecipe", 20 | "quantize_module_swap", 21 | ] 22 | -------------------------------------------------------------------------------- /torchao/prototype/quantization/module_swap/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | from .kmeans_codebook import kmeans_codebook 2 | 3 | __all__ = [ 4 | "kmeans_codebook", 5 | ] 6 | -------------------------------------------------------------------------------- /torchao/prototype/quantization/module_swap/algorithms/kmeans_codebook.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torchao.prototype.quantization.module_swap.quantized_modules import QuantizedLinear 5 | from torchao.prototype.quantization.module_swap.quantizers import CodeBookQuantizer 6 | 7 | 8 | def kmeans_codebook( 9 | model: nn.Module, 10 | niter: int = 30, 11 | nredo: int = 1, 12 | dtype: torch.dtype = torch.float32, 13 | ) -> None: 14 | import faiss 15 | 16 | with torch.no_grad(): 17 | for layer in model.modules(): 18 | if isinstance(layer, QuantizedLinear): 19 | if isinstance(layer.weight_quantizer, CodeBookQuantizer): 20 | weight = layer.weight 21 | codebook_dim = layer.weight_quantizer.codebook_dim 22 | weight = weight.reshape( 23 | weight.shape[0] * (weight.shape[1] // codebook_dim), 24 | codebook_dim, 25 | ) 26 | num_centroids = layer.weight_quantizer.codebook.shape[0] 27 | kmeans = faiss.Kmeans( 28 | weight.shape[1], 29 | num_centroids, 30 | niter=niter, 31 | nredo=nredo, 32 | verbose=True, 33 | gpu=True if torch.cuda.is_available() else False, 34 | ) 35 | kmeans.train(weight.to(device="cpu", dtype=dtype)) 36 | C = kmeans.centroids 37 | 38 | layer.weight_quantizer.codebook.data = torch.FloatTensor(C).to( 39 | weight.dtype 40 | ) 41 | -------------------------------------------------------------------------------- /torchao/prototype/quantization/module_swap/data_getters/__init__.py: -------------------------------------------------------------------------------- 1 | from .llm_ptq_data_getter import ( 2 | LLMPTQDataGetter, 3 | ) 4 | from .ptq_data_getter import ( 5 | DataGetter, 6 | get_module_input_data, 7 | ) 8 | 9 | __all__ = [ 10 | "DataGetter", 11 | "get_module_input_data", 12 | "LLMPTQDataGetter", 13 | ] 14 | -------------------------------------------------------------------------------- /torchao/prototype/quantization/subgraph_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/quantization/subgraph_utils/__init__.py -------------------------------------------------------------------------------- /torchao/prototype/quantized_training/__init__.py: -------------------------------------------------------------------------------- 1 | from .bitnet import ( 2 | BitNetTrainingLinearWeight, 3 | bitnet_training, 4 | precompute_bitnet_scale_for_fsdp, 5 | ) 6 | from .int8 import ( 7 | Int8QuantizedTrainingLinearWeight, 8 | int8_weight_only_quantized_training, 9 | quantize_int8_rowwise, 10 | ) 11 | from .int8_mixed_precision import ( 12 | Int8MixedPrecisionTrainingConfig, 13 | Int8MixedPrecisionTrainingLinear, 14 | Int8MixedPrecisionTrainingLinearWeight, 15 | int8_mixed_precision_training, 16 | ) 17 | 18 | __all__ = [ 19 | "BitNetTrainingLinearWeight", 20 | "bitnet_training", 21 | "precompute_bitnet_scale_for_fsdp", 22 | "Int8MixedPrecisionTrainingConfig", 23 | "Int8MixedPrecisionTrainingLinear", 24 | "Int8MixedPrecisionTrainingLinearWeight", 25 | "int8_mixed_precision_training", 26 | "Int8QuantizedTrainingLinearWeight", 27 | "int8_weight_only_quantized_training", 28 | "quantize_int8_rowwise", 29 | ] 30 | -------------------------------------------------------------------------------- /torchao/prototype/scaled_grouped_mm/__init__.py: -------------------------------------------------------------------------------- 1 | from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import ( 2 | _scaled_grouped_mm as _scaled_grouped_mm, 3 | ) 4 | -------------------------------------------------------------------------------- /torchao/prototype/scaled_grouped_mm/kernels/__init__.py: -------------------------------------------------------------------------------- 1 | from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import ( 2 | triton_fp8_col_major_jagged_colwise_scales as triton_fp8_col_major_jagged_colwise_scales, 3 | ) 4 | from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import ( 5 | triton_fp8_row_major_jagged_rowwise_scales as triton_fp8_row_major_jagged_rowwise_scales, 6 | ) 7 | -------------------------------------------------------------------------------- /torchao/prototype/smoothquant/__init__.py: -------------------------------------------------------------------------------- 1 | from .api import ( 2 | SmoothQuantConfig, 3 | insert_smooth_quant_observer_, 4 | load_smooth_quant_recipe, 5 | save_smooth_quant_recipe, 6 | ) 7 | from .core import SmoothQuantObservedLinear 8 | 9 | __all__ = [ 10 | "insert_smooth_quant_observer_", 11 | "load_smooth_quant_recipe", 12 | "save_smooth_quant_recipe", 13 | "SmoothQuantConfig", 14 | "SmoothQuantObservedLinear", 15 | ] 16 | -------------------------------------------------------------------------------- /torchao/prototype/sparsity/__init__.py: -------------------------------------------------------------------------------- 1 | # Sparsifier 2 | # Scheduler 3 | from torchao.prototype.sparsity.scheduler.base_scheduler import BaseScheduler 4 | from torchao.prototype.sparsity.scheduler.cubic_scheduler import CubicSL 5 | from torchao.prototype.sparsity.scheduler.lambda_scheduler import LambdaSL 6 | from torchao.prototype.sparsity.sparsifier.base_sparsifier import BaseSparsifier 7 | from torchao.prototype.sparsity.sparsifier.nearly_diagonal_sparsifier import ( 8 | NearlyDiagonalSparsifier, 9 | ) 10 | 11 | # Parametrizations 12 | from torchao.prototype.sparsity.sparsifier.utils import ( 13 | FakeSparsity, 14 | fqn_to_module, 15 | get_arg_info_from_tensor_fqn, 16 | module_to_fqn, 17 | ) 18 | from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import ( 19 | WeightNormSparsifier, 20 | ) 21 | 22 | __all__ = [ 23 | "BaseScheduler", 24 | "CubicSL", 25 | "LambdaSL", 26 | "BaseSparsifier", 27 | "NearlyDiagonalSparsifier", 28 | "FakeSparsity", 29 | "fqn_to_module", 30 | "get_arg_info_from_tensor_fqn", 31 | "module_to_fqn", 32 | "WeightNormSparsifier", 33 | ] 34 | -------------------------------------------------------------------------------- /torchao/prototype/sparsity/activation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/sparsity/activation/__init__.py -------------------------------------------------------------------------------- /torchao/prototype/sparsity/pruner/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_structured_sparsifier import BaseStructuredSparsifier 2 | from .FPGM_pruner import FPGMPruner 3 | from .lstm_saliency_pruner import LSTMSaliencyPruner 4 | from .parametrization import ( 5 | BiasHook, 6 | FakeStructuredSparsity, 7 | ) 8 | from .saliency_pruner import SaliencyPruner 9 | 10 | __all__ = [ 11 | "BaseStructuredSparsifier", 12 | "FPGMPruner", 13 | "LSTMSaliencyPruner", 14 | "BiasHook", 15 | "FakeStructuredSparsity", 16 | "SaliencyPruner", 17 | ] 18 | -------------------------------------------------------------------------------- /torchao/prototype/sparsity/pruner/images/prune_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/sparsity/pruner/images/prune_1.png -------------------------------------------------------------------------------- /torchao/prototype/sparsity/pruner/images/prune_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/sparsity/pruner/images/prune_2.png -------------------------------------------------------------------------------- /torchao/prototype/sparsity/pruner/images/prune_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/sparsity/pruner/images/prune_3.png -------------------------------------------------------------------------------- /torchao/prototype/sparsity/pruner/images/prune_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/sparsity/pruner/images/prune_4.png -------------------------------------------------------------------------------- /torchao/prototype/sparsity/pruner/images/prune_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/sparsity/pruner/images/prune_5.png -------------------------------------------------------------------------------- /torchao/prototype/sparsity/pruner/images/prune_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/sparsity/pruner/images/prune_6.png -------------------------------------------------------------------------------- /torchao/prototype/sparsity/pruner/saliency_pruner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from .base_structured_sparsifier import BaseStructuredSparsifier 7 | 8 | 9 | class SaliencyPruner(BaseStructuredSparsifier): 10 | """ 11 | Prune rows based on the saliency (L1 norm) of each row. 12 | 13 | This pruner works on N-Dimensional weight tensors. 14 | For each row, we will calculate the saliency, whic is the sum the L1 norm of all weights in that row. 15 | We expect that the resulting saliency vector has the same shape as our mask. 16 | We then pick elements to remove until we reach the target sparsity_level. 17 | """ 18 | 19 | def update_mask(self, module, tensor_name, **kwargs): 20 | # tensor_name will give you the FQN, all other entries in sparse config is present in kwargs 21 | weights = getattr(module, tensor_name) 22 | mask = getattr(module.parametrizations, tensor_name)[0].mask 23 | 24 | # use negative weights so we can use topk (we prune out the smallest) 25 | if weights.dim() <= 1: 26 | raise Exception( 27 | "Structured pruning can only be applied to a 2+dim weight tensor!" 28 | ) 29 | saliency = -weights.norm(dim=tuple(range(1, weights.dim())), p=1) 30 | assert saliency.shape == mask.shape 31 | 32 | num_to_pick = int(len(mask) * kwargs["sparsity_level"]) 33 | prune = saliency.topk(num_to_pick).indices 34 | 35 | # Set the mask to be false for the rows we want to prune 36 | mask.data[prune] = False 37 | -------------------------------------------------------------------------------- /torchao/prototype/sparsity/scheduler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/sparsity/scheduler/__init__.py -------------------------------------------------------------------------------- /torchao/prototype/sparsity/sparsifier/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/sparsity/sparsifier/__init__.py -------------------------------------------------------------------------------- /torchao/prototype/sparsity/superblock/.gitignore: -------------------------------------------------------------------------------- 1 | */*.pyc 2 | 3 | # Model checkpoints 4 | *.pth 5 | 6 | # Editor temporaries 7 | *.swa 8 | *.swb 9 | *.swc 10 | *.swd 11 | *.swe 12 | *.swf 13 | *.swg 14 | *.swh 15 | *.swi 16 | *.swj 17 | *.swk 18 | *.swl 19 | *.swm 20 | *.swn 21 | *.swo 22 | *.swp 23 | *~ 24 | .~lock.* 25 | 26 | # macOS dir files 27 | .DS_Store 28 | -------------------------------------------------------------------------------- /torchao/prototype/sparsity/superblock/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/prototype/sparsity/superblock/__init__.py -------------------------------------------------------------------------------- /torchao/prototype/sparsity/superblock/evaluation_results.txt: -------------------------------------------------------------------------------- 1 | model,batch_size,dtype,sparsity,bsr,sparsity_level,quantization,top-1_acc,encoder img/s,max_mem (MB) 2 | vit_b_16,256,bfloat16,None,None,0.0,False,81.97716346153847,734.904399886552,247.97265625 3 | vit_b_16,256,bfloat16,None,None,0.0,True,81.89503205128206,230.83627917226997,196.841796875 4 | vit_b_16,256,bfloat16,semi_structured,None,0.0,False,77.05729166666667,1386.7278781133518,316.40234375 5 | vit_b_16,256,bfloat16,semi_structured,None,0.0,True,76.74078525641026,150.53603093207843,249.25390625 6 | vit_b_16,256,bfloat16,bsr,64,0.8,False,77.13541666666667,1469.2705176409308,179.55322265625 7 | vit_b_16,256,bfloat16,bsr,64,0.8,True,77.13341346153847,87.8480561274922,158.70361328125 8 | vit_b_16,256,bfloat16,bsr,64,0.84,False,76.14983974358974,1752.835540513905,174.01953125 9 | vit_b_16,256,bfloat16,bsr,64,0.84,True,76.0556891025641,1013.7495284783578,156.630859375 10 | vit_b_16,256,bfloat16,bsr,64,0.9,False,62.99879807692308,1702.289195236525,164.2822265625 11 | vit_b_16,256,bfloat16,bsr,64,0.9,True,62.946714743589745,987.5488468441617,152.5732421875 12 | 13 | model,batch_size,dtype,sparsity,bsr,sparsity_level,quantization,top-1_acc,encoder img/s,max_mem (MB) 14 | vit_h_14,128,bfloat16,None,None,0.0,False,89.29286858974359,81.02922135697278,1430.05615234375 15 | vit_h_14,128,bfloat16,None,None,0.0,True,89.3349358974359,56.076129157634355,1025.00927734375 16 | vit_h_14,128,bfloat16,semi_structured,None,0.0,False,82.03725961538461,75.83586253901329,1900.36279296875 17 | vit_h_14,128,bfloat16,semi_structured,None,0.0,True,82.06330128205128,36.36097831133589,1390.98779296875 18 | vit_h_14,128,bfloat16,bsr,64,0.9,False,78.21113782051282,350.91330496491446,599.6201171875 19 | vit_h_14,128,bfloat16,bsr,64,0.9,True,78.2051282051282,108.84048044884008,531.5810546875 20 | -------------------------------------------------------------------------------- /torchao/prototype/spinquant/__init__.py: -------------------------------------------------------------------------------- 1 | from .spinquant import apply_spinquant 2 | 3 | __all__ = [ 4 | "apply_spinquant", 5 | ] 6 | -------------------------------------------------------------------------------- /torchao/quantization/marlin_qqq/README.md: -------------------------------------------------------------------------------- 1 | # Marlin QQQ 2 | 3 | Marlin QQQ kernel is now compatible with GPUs for sm80 and above. 4 | Marlin QQQ kernel and Marlin kernel mainly have the following differences: 5 | 1. Marlin QQQ kernel supports W4A8 mixed precision GEMM using INT8 Tensor Core, while the original Marlin kernel supports W4A16 mixed precision GEMM using FP16 Tensor Core. 6 | 2. Because the mma instruction requires that the data types of weight and activation be consistent, type conversion is required. Marlin QQQ needs to convert INT4 weight to INT8, while Marlin needs to convert INT4 weight to FP16. 7 | 3. Similar to W8A8, Marlin QQQ needs to dequant to FP16 before writing the final result because the calculation result is accumulated in INT32, while Marlin does not need this processing. 8 | 9 | For more details about Marlin QQQ, please refer to [paper](https://arxiv.org/pdf/2406.09904). 10 | 11 | Marlin QQQ implementation adapted from the two below sources: 12 | 13 | * [QQQ](https://github.com/HandH1998/QQQ/tree/main) 14 | * [vllm](https://github.com/vllm-project/vllm/tree/main) 15 | -------------------------------------------------------------------------------- /torchao/quantization/prototype/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/quantization/prototype/__init__.py -------------------------------------------------------------------------------- /torchao/quantization/prototype/qat/README.md: -------------------------------------------------------------------------------- 1 | Note: QAT has been moved to torchao/quantization/qat. 2 | This is a legacy folder only for backward compatibility 3 | and will be removed in the near future. 4 | -------------------------------------------------------------------------------- /torchao/quantization/prototype/qat/__init__.py: -------------------------------------------------------------------------------- 1 | from torchao.quantization.qat import ( 2 | ComposableQATQuantizer, 3 | Int4WeightOnlyEmbeddingQATQuantizer, 4 | Int4WeightOnlyQATQuantizer, 5 | Int8DynActInt4WeightQATQuantizer, 6 | ) 7 | from torchao.quantization.qat.linear import ( 8 | Int8DynActInt4WeightQATLinear, 9 | disable_4w_fake_quant, 10 | disable_8da4w_fake_quant, 11 | enable_4w_fake_quant, 12 | enable_8da4w_fake_quant, 13 | ) 14 | 15 | __all__ = [ 16 | "disable_4w_fake_quant", 17 | "disable_8da4w_fake_quant", 18 | "enable_4w_fake_quant", 19 | "enable_8da4w_fake_quant", 20 | "ComposableQATQuantizer", 21 | "Int4WeightOnlyQATQuantizer", 22 | "Int4WeightOnlyEmbeddingQATQuantizer", 23 | "Int8DynActInt4WeightQATQuantizer", 24 | "Int8DynActInt4WeightQATLinear", 25 | ] 26 | -------------------------------------------------------------------------------- /torchao/quantization/prototype/qat/_module_swap_api.py: -------------------------------------------------------------------------------- 1 | # For backward compatibility only 2 | # These will be removed in the future 3 | 4 | from torchao.quantization.qat.linear import ( 5 | Int4WeightOnlyQATQuantizer as Int4WeightOnlyQATQuantizerModuleSwap, 6 | ) 7 | from torchao.quantization.qat.linear import ( 8 | Int8DynActInt4WeightQATQuantizer as Int8DynActInt4WeightQATQuantizerModuleSwap, 9 | ) 10 | from torchao.quantization.qat.linear import ( 11 | disable_4w_fake_quant as disable_4w_fake_quant_module_swap, 12 | ) 13 | from torchao.quantization.qat.linear import ( 14 | disable_8da4w_fake_quant as disable_8da4w_fake_quant_module_swap, 15 | ) 16 | from torchao.quantization.qat.linear import ( 17 | enable_4w_fake_quant as enable_4w_fake_quant_module_swap, 18 | ) 19 | from torchao.quantization.qat.linear import ( 20 | enable_8da4w_fake_quant as enable_8da4w_fake_quant_module_swap, 21 | ) 22 | 23 | __all__ = [ 24 | "Int8DynActInt4WeightQATQuantizerModuleSwap", 25 | "Int4WeightOnlyQATQuantizerModuleSwap", 26 | "enable_8da4w_fake_quant_module_swap", 27 | "disable_8da4w_fake_quant_module_swap", 28 | "enable_4w_fake_quant_module_swap", 29 | "disable_4w_fake_quant_module_swap", 30 | ] 31 | -------------------------------------------------------------------------------- /torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py: -------------------------------------------------------------------------------- 1 | from torchao.quantization.qat.affine_fake_quantized_tensor import ( 2 | AffineFakeQuantizedTensor, 3 | to_affine_fake_quantized, 4 | ) 5 | 6 | __all__ = [ 7 | "AffineFakeQuantizedTensor", 8 | "to_affine_fake_quantized", 9 | ] 10 | -------------------------------------------------------------------------------- /torchao/quantization/prototype/qat/api.py: -------------------------------------------------------------------------------- 1 | from torchao.quantization.qat.api import ( 2 | ComposableQATQuantizer, 3 | FakeQuantizeConfig, 4 | ) 5 | 6 | __all__ = [ 7 | "ComposableQATQuantizer", 8 | "FakeQuantizeConfig", 9 | ] 10 | -------------------------------------------------------------------------------- /torchao/quantization/prototype/qat/embedding.py: -------------------------------------------------------------------------------- 1 | from torchao.quantization.qat.embedding import ( 2 | FakeQuantizedEmbedding, 3 | Int4WeightOnlyEmbedding, 4 | Int4WeightOnlyEmbeddingQATQuantizer, 5 | Int4WeightOnlyQATEmbedding, 6 | ) 7 | 8 | __all__ = [ 9 | "FakeQuantizedEmbedding", 10 | "Int4WeightOnlyEmbeddingQATQuantizer", 11 | "Int4WeightOnlyEmbedding", 12 | "Int4WeightOnlyQATEmbedding", 13 | ] 14 | -------------------------------------------------------------------------------- /torchao/quantization/prototype/qat/fake_quantizer.py: -------------------------------------------------------------------------------- 1 | from torchao.quantization.qat.fake_quantizer import ( 2 | FakeQuantizer, 3 | ) 4 | 5 | __all__ = [ 6 | "FakeQuantizer", 7 | ] 8 | -------------------------------------------------------------------------------- /torchao/quantization/prototype/qat/linear.py: -------------------------------------------------------------------------------- 1 | from torchao.quantization.qat.linear import ( 2 | FakeQuantizedLinear, 3 | Int4WeightOnlyQATLinear, 4 | Int4WeightOnlyQATQuantizer, 5 | Int8DynActInt4WeightQATLinear, 6 | Int8DynActInt4WeightQATQuantizer, 7 | disable_4w_fake_quant, 8 | disable_8da4w_fake_quant, 9 | enable_4w_fake_quant, 10 | enable_8da4w_fake_quant, 11 | ) 12 | 13 | __all__ = [ 14 | "disable_4w_fake_quant", 15 | "disable_8da4w_fake_quant", 16 | "enable_4w_fake_quant", 17 | "enable_8da4w_fake_quant", 18 | "FakeQuantizedLinear", 19 | "Int4WeightOnlyQATLinear", 20 | "Int4WeightOnlyQATQuantizer", 21 | "Int8DynActInt4WeightQATLinear", 22 | "Int8DynActInt4WeightQATQuantizer", 23 | ] 24 | -------------------------------------------------------------------------------- /torchao/quantization/pt2e/inductor_passes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/quantization/pt2e/inductor_passes/__init__.py -------------------------------------------------------------------------------- /torchao/quantization/pt2e/quantizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .composable_quantizer import ComposableQuantizer 2 | from .duplicate_dq_pass import DuplicateDQPass 3 | from .port_metadata_pass import PortNodeMetaForQDQ 4 | from .quantizer import ( 5 | DerivedQuantizationSpec, 6 | EdgeOrNode, 7 | FixedQParamsQuantizationSpec, 8 | QuantizationAnnotation, 9 | QuantizationSpec, 10 | QuantizationSpecBase, 11 | Quantizer, 12 | SharedQuantizationSpec, 13 | ) 14 | from .utils import ( 15 | OperatorConfig, 16 | OperatorPatternType, 17 | QuantizationConfig, 18 | annotate_input_qspec_map, 19 | annotate_output_qspec, 20 | get_bias_qspec, 21 | get_input_act_qspec, 22 | get_module_name_filter, 23 | get_output_act_qspec, 24 | get_weight_qspec, 25 | is_valid_annotation, 26 | ) 27 | 28 | __all__ = [ 29 | # basic classes for quantizer and annotations 30 | "Quantizer", 31 | "ComposableQuantizer", 32 | "EdgeOrNode", 33 | "QuantizationSpec", 34 | "QuantizationSpecBase", 35 | "DerivedQuantizationSpec", 36 | "FixedQParamsQuantizationSpec", 37 | "SharedQuantizationSpec", 38 | "QuantizationAnnotation", 39 | # utils 40 | "annotate_input_qspec_map", 41 | "annotate_output_qspec", 42 | "get_module_name_filter", 43 | "is_valid_annotation", 44 | "QuantizationConfig", 45 | "OperatorPatternType", 46 | "OperatorConfig", 47 | "get_input_act_qspec", 48 | "get_output_act_qspec", 49 | "get_weight_qspec", 50 | "get_bias_qspec", 51 | "DuplicateDQPass", 52 | "PortNodeMetaForQDQ", 53 | ] 54 | -------------------------------------------------------------------------------- /torchao/quantization/qat/__init__.py: -------------------------------------------------------------------------------- 1 | from .api import ( 2 | ComposableQATQuantizer, 3 | FakeQuantizeConfig, 4 | FromIntXQuantizationAwareTrainingConfig, 5 | IntXQuantizationAwareTrainingConfig, 6 | from_intx_quantization_aware_training, 7 | initialize_fake_quantizers, 8 | intx_quantization_aware_training, 9 | ) 10 | from .embedding import ( 11 | Int4WeightOnlyEmbeddingQATQuantizer, 12 | ) 13 | from .linear import ( 14 | Int4WeightOnlyQATQuantizer, 15 | Int8DynActInt4WeightQATQuantizer, 16 | ) 17 | 18 | __all__ = [ 19 | "ComposableQATQuantizer", 20 | "FakeQuantizeConfig", 21 | "FromIntXQuantizationAwareTrainingConfig", 22 | "Int4WeightOnlyEmbeddingQATQuantizer", 23 | "Int4WeightOnlyQATQuantizer", 24 | "Int8DynActInt4WeightQATQuantizer", 25 | "IntXQuantizationAwareTrainingConfig", 26 | "initialize_fake_quantizers", 27 | "intx_quantization_aware_training", 28 | "from_intx_quantization_aware_training", 29 | ] 30 | -------------------------------------------------------------------------------- /torchao/quantization/qat/images/qat_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/quantization/qat/images/qat_diagram.png -------------------------------------------------------------------------------- /torchao/quantization/unified.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from abc import ABC, abstractmethod 7 | from typing import Any 8 | 9 | import torch 10 | 11 | """ 12 | The vast majority of quantization algorithms follow one of two patterns 13 | 1. Single quantize call to create a quantized model with quantized state_dict 14 | 2. Flow that needs calibration or training 15 | 16 | This file defines the API for both patterns 17 | """ 18 | 19 | 20 | # API 1, single quantize call to create a quantized model with quantized state_dict 21 | class Quantizer(ABC): 22 | @abstractmethod 23 | def quantize( 24 | self, model: torch.nn.Module, *args: Any, **kwargs: Any 25 | ) -> torch.nn.Module: 26 | pass 27 | 28 | 29 | # API 2, flow that needs calibration or training 30 | class TwoStepQuantizer: 31 | @abstractmethod 32 | def prepare( 33 | self, model: torch.nn.Module, *args: Any, **kwargs: Any 34 | ) -> torch.nn.Module: 35 | pass 36 | 37 | @abstractmethod 38 | def convert( 39 | self, model: torch.nn.Module, *args: Any, **kwargs: Any 40 | ) -> torch.nn.Module: 41 | pass 42 | -------------------------------------------------------------------------------- /torchao/sparsity/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torchao.quantization.quant_api import ( 8 | int8_dynamic_activation_int8_semi_sparse_weight, 9 | ) 10 | 11 | from .sparse_api import ( 12 | apply_fake_sparsity, 13 | block_sparse_weight, 14 | semi_sparse_weight, 15 | sparsify_, 16 | ) 17 | from .supermask import SupermaskLinear 18 | from .utils import PerChannelNormObserver # noqa: F403 19 | from .wanda import WandaSparsifier # noqa: F403 20 | 21 | __all__ = [ 22 | "WandaSparsifier", 23 | "SupermaskLinear", 24 | "PerChannelNormObserver", 25 | "apply_fake_sparsity", 26 | "sparsify_", 27 | "semi_sparse_weight", 28 | "block_sparse_weight", 29 | "int8_dynamic_activation_int8_semi_sparse_weight", 30 | ] 31 | -------------------------------------------------------------------------------- /torchao/sparsity/marlin/README.md: -------------------------------------------------------------------------------- 1 | # Sparse Marlin 2 | 3 | Sparse Marlin implementation adapted from the two below sources: 4 | 5 | * [Sparse-Marlin](https://github.com/IST-DASLab/Sparse-Marlin/tree/main) 6 | * [nm-vllm](https://github.com/neuralmagic/nm-vllm/tree/main) 7 | -------------------------------------------------------------------------------- /torchao/swizzle/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .swizzle_tensor import SwizzleTensor 8 | 9 | __all__ = ["SwizzleTensor"] 10 | -------------------------------------------------------------------------------- /torchao/testing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/testing/__init__.py -------------------------------------------------------------------------------- /torchao/testing/float8/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/testing/float8/__init__.py -------------------------------------------------------------------------------- /torchao/testing/float8/dtensor_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class FeedForward(nn.Module): 12 | """MLP based model""" 13 | 14 | def __init__(self): 15 | super(FeedForward, self).__init__() 16 | self.w1 = nn.Linear(16, 32, bias=False) 17 | self.w2 = nn.Linear(16, 32, bias=False) 18 | self.out_proj = nn.Linear(32, 16, bias=False) 19 | 20 | def forward(self, x): 21 | return self.out_proj(F.silu(self.w1(x)) * self.w2(x)) 22 | 23 | 24 | class ToyModel(nn.Module): 25 | def __init__(self): 26 | super(ToyModel, self).__init__() 27 | self.ffn = FeedForward() 28 | 29 | def forward(self, x): 30 | return self.ffn(x) 31 | -------------------------------------------------------------------------------- /torchao/testing/float8/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from torchao.float8.config import ( 7 | CastConfig, 8 | Float8LinearConfig, 9 | ) 10 | 11 | 12 | def get_test_float8_linear_config( 13 | scaling_type_input, 14 | scaling_type_weight, 15 | scaling_type_grad_output, 16 | emulate: bool, 17 | ): 18 | cast_config_input = CastConfig( 19 | scaling_type=scaling_type_input, 20 | ) 21 | cast_config_weight = CastConfig( 22 | scaling_type=scaling_type_weight, 23 | ) 24 | cast_config_grad_output = CastConfig( 25 | scaling_type=scaling_type_grad_output, 26 | ) 27 | 28 | config = Float8LinearConfig( 29 | cast_config_input=cast_config_input, 30 | cast_config_weight=cast_config_weight, 31 | cast_config_grad_output=cast_config_grad_output, 32 | emulate=emulate, 33 | ) 34 | return config 35 | -------------------------------------------------------------------------------- /torchao/testing/pt2e/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/torchao/testing/pt2e/__init__.py -------------------------------------------------------------------------------- /tutorials/developer_api_guide/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/tutorials/developer_api_guide/__init__.py -------------------------------------------------------------------------------- /tutorials/quantize_vit/bfloat16.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/tutorials/quantize_vit/bfloat16.json.gz -------------------------------------------------------------------------------- /tutorials/quantize_vit/quant.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/ao/8366465ebd8b017c89ae4a2fcddad744cbb9c405/tutorials/quantize_vit/quant.json.gz -------------------------------------------------------------------------------- /tutorials/quantize_vit/run.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | #!/bin/bash 7 | 8 | # Run bfloat16 version 9 | TORCH_LOGS='graph_breaks,recompiles' python run_vit_b.py 10 | 11 | # Run dynamic quantized version 12 | TORCH_LOGS='graph_breaks,recompiles' python run_vit_b_quant.py 13 | 14 | # Store the output code for further inspection 15 | echo "bfloat16 generated code lives in:" 16 | TORCH_LOGS='output_code' python run_vit_b.py 2>&1 | grep "Output code written to: " | awk -F" " '{print $NF}' 17 | echo "quantization generated code lives in:" 18 | TORCH_LOGS='output_code' python run_vit_b_quant.py 2>&1 | grep "Output code written to: " | awk -F" " '{print $NF}' 19 | -------------------------------------------------------------------------------- /tutorials/quantize_vit/run_vit_b.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | from torchvision import models 8 | 9 | from torchao.utils import benchmark_model, profiler_runner 10 | 11 | torch.set_float32_matmul_precision("high") 12 | # Load Vision Transformer model 13 | model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1) 14 | 15 | # Set the model to evaluation mode 16 | model.eval().cuda().to(torch.bfloat16) 17 | 18 | # Input tensor (batch_size, channels, height, width) 19 | inputs = (torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device="cuda"),) 20 | 21 | model = torch.compile(model, mode="max-autotune") 22 | 23 | # Must run with no_grad when optimizing for inference 24 | with torch.no_grad(): 25 | # warmup 26 | benchmark_model(model, 5, inputs) 27 | # benchmark 28 | print("elapsed_time: ", benchmark_model(model, 100, inputs), " milliseconds") 29 | # Create a trace 30 | profiler_runner("bfloat16.json.gz", benchmark_model, model, 5, inputs) 31 | -------------------------------------------------------------------------------- /tutorials/run_all.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | #!/bin/bash 7 | FAILED=0 8 | for dir in $(find . -type d); do 9 | if [ -f "$dir/run.sh" ]; then 10 | echo "Running: $dir/run.sh" 11 | CURRENT_DIR=$(pwd) 12 | cd "$dir" 13 | bash run.sh 14 | cd "$CURRENT_DIR" 15 | else 16 | for file in $(find "$dir" -maxdepth 1 -name "*.py"); do 17 | filename=$(basename "$file") 18 | if echo "$filename" | grep -q "tensor_parallel"; then 19 | echo "Running: torchrun --standalone --nnodes=1 --nproc-per-node=1 $file" 20 | torchrun --standalone --nnodes=1 --nproc-per-node=4 "$file" 21 | STATUS=$? 22 | else 23 | echo "Running: python $file" 24 | python "$file" 25 | STATUS=$? 26 | fi 27 | 28 | if [ $STATUS -ne 0 ]; then 29 | FAILED=1 30 | echo "Test failed: $file" 31 | fi 32 | done 33 | fi 34 | done 35 | 36 | if [ "$FAILED" -eq 1 ]; then 37 | echo "One or more tests failed" 38 | exit 1 39 | else 40 | echo "All tests passed" 41 | exit 0 42 | fi 43 | -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 0.12.0 2 | --------------------------------------------------------------------------------