├── .circleci └── config.yml ├── .clang-format ├── .flake8 ├── .github └── workflows │ ├── docs.yaml │ ├── pages.yaml │ ├── pylint.yaml │ └── rocm_ci.yml ├── .gitignore ├── .gitmodules ├── CITATION.cff ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── default.nix ├── docker ├── Dockerfile.cuda ├── Dockerfile.rocm ├── README.md ├── build.sh ├── install │ ├── install_ait.sh │ ├── install_basic_dep.sh │ ├── install_detection_deps.sh │ ├── install_doc_dep.sh │ ├── install_test_dep.sh │ └── rocm_dev-requirements.txt └── rocm_fix │ └── fix_10736.py ├── docs ├── Makefile ├── README.md ├── image │ ├── gpu_grid_block.png │ ├── pack_size_1.png │ ├── pack_size_2.png │ ├── pack_size_4.png │ ├── pack_size_8.png │ ├── softmax.png │ └── vs_oneflow.png ├── make.bat ├── source │ ├── arch │ │ ├── index.rst │ │ └── philosophy.rst │ ├── conf.py │ ├── debughints.rst │ ├── genindex.rst │ ├── index.rst │ ├── install │ │ └── index.rst │ ├── reference │ │ ├── backend.rst │ │ ├── compiler.rst │ │ ├── cuda.rst │ │ ├── env.rst │ │ ├── frontend.rst │ │ ├── index.rst │ │ ├── ops.rst │ │ ├── rocm.rst │ │ ├── testing.rst │ │ ├── transform.rst │ │ └── utils.rst │ ├── runtime │ │ ├── cxx_design.rst │ │ ├── index.rst │ │ └── py_design.rst │ └── tutorial │ │ ├── how_to_add_op.rst │ │ ├── how_to_infer_pt.rst │ │ ├── how_to_visualize.rst │ │ └── index.rst └── static │ └── ait_model.html ├── examples ├── 01_resnet-50 │ ├── README.md │ ├── benchmark_ait.py │ ├── benchmark_mi250.sh │ ├── benchmark_pt.py │ ├── infer_with_torch.py │ ├── modeling │ │ ├── __init__.py │ │ └── resnet.py │ ├── test_correctness.py │ └── weight_utils.py ├── 02_detectron2 │ ├── README.md │ ├── compile_model.py │ ├── configs │ │ ├── __init__.py │ │ ├── config.py │ │ ├── defaults.py │ │ ├── faster_rcnn_R_101_FPN.yaml │ │ ├── faster_rcnn_R_50_FPN.yaml │ │ ├── mask_rcnn_R_101_FPN.yaml │ │ └── mask_rcnn_R_50_FPN.yaml │ ├── demo.py │ ├── modeling │ │ ├── backbone │ │ │ ├── __init__.py │ │ │ ├── fpn.py │ │ │ ├── resnet.py │ │ │ └── utils.py │ │ ├── meta_arch │ │ │ ├── __init__.py │ │ │ └── rcnn.py │ │ ├── proposal_generator │ │ │ ├── __init__.py │ │ │ └── rpn.py │ │ └── roi_heads │ │ │ ├── __init__.py │ │ │ ├── box_head.py │ │ │ ├── fast_rcnn.py │ │ │ ├── mask_head.py │ │ │ └── roi_heads.py │ ├── predictor │ │ ├── __init__.py │ │ ├── builtin_meta.py │ │ └── predictor.py │ ├── prepare_and_run_rcnn.sh │ ├── test_correctness.py │ └── tools │ │ └── convert_pt2ait.py ├── 03_bert │ ├── README.md │ ├── benchmark_ait.py │ ├── benchmark_mi250.sh │ ├── benchmark_pt.py │ ├── demo.py │ ├── modeling │ │ ├── __init__.py │ │ ├── bert.py │ │ └── torch_model.py │ └── test_correctness.py ├── 04_vit │ ├── README.md │ ├── benchmark_ait.py │ ├── benchmark_mi250.sh │ ├── benchmark_pt.py │ ├── modeling │ │ └── vision_transformer.py │ ├── test_correctness.py │ ├── verification.py │ └── weight_utils.py ├── 05_stable_diffusion │ ├── .gitignore │ ├── README.md │ ├── scripts │ │ ├── compile.py │ │ ├── compile_alt.py │ │ ├── compile_controlnet.py │ │ ├── demo.py │ │ ├── demo_alt.py │ │ ├── demo_controlnet.py │ │ ├── demo_img2img.py │ │ └── download_pipeline.py │ └── src │ │ ├── __init__.py │ │ ├── benchmark.py │ │ ├── benchmark_pt.py │ │ ├── compile_lib │ │ ├── __init__.py │ │ ├── compile_clip.py │ │ ├── compile_clip_alt.py │ │ ├── compile_controlnet.py │ │ ├── compile_unet.py │ │ ├── compile_unet_alt.py │ │ ├── compile_vae.py │ │ ├── compile_vae_alt.py │ │ └── util.py │ │ ├── modeling │ │ ├── attention.py │ │ ├── clip.py │ │ ├── controlnet_unet_2d_condition.py │ │ ├── embeddings.py │ │ ├── resnet.py │ │ ├── unet_2d_condition.py │ │ ├── unet_blocks.py │ │ └── vae.py │ │ ├── pipeline_stable_diffusion_ait.py │ │ ├── pipeline_stable_diffusion_ait_alt.py │ │ ├── pipeline_stable_diffusion_controlnet_ait.py │ │ ├── pipeline_stable_diffusion_img2img_ait.py │ │ └── test_correctness.py ├── 06_how_to_add_an_op │ └── how_to_add_an_op.py ├── 07_how_to_run_pt_model │ └── how_to_run_pt_model.py └── 08_esrgan │ ├── README.md │ ├── compile.py │ ├── demo.py │ └── modeling │ └── rrdbnet.py ├── fx2ait ├── CMakeLists.txt ├── README.md ├── fx2ait │ ├── __init__.py │ ├── acc_tracer │ │ ├── __init__.py │ │ ├── acc_normalizer.py │ │ ├── acc_op_properties.py │ │ ├── acc_ops.py │ │ ├── acc_shape_prop.py │ │ ├── acc_tracer.py │ │ ├── acc_utils.py │ │ ├── ait_acc_normalizer.py │ │ ├── ait_acc_ops.py │ │ └── ait_acc_ops_registry.py │ ├── ait_module.py │ ├── ait_splitter.py │ ├── cache.py │ ├── converters │ │ ├── __init__.py │ │ ├── ait_converters.py │ │ ├── ait_module_converters.py │ │ ├── aten2ait_converters.py │ │ ├── converter_registry.py │ │ └── utils.py │ ├── csrc │ │ ├── AITModel.cpp │ │ ├── AITModel.h │ │ ├── AITModelImpl.cpp │ │ └── AITModelImpl.h │ ├── example │ │ ├── 01_transformer_model │ │ │ ├── README.md │ │ │ └── test_transformer_encoder.py │ │ ├── 02_vision_model │ │ │ ├── README.md │ │ │ └── test_vision_model.py │ │ ├── 03_lowering_split │ │ │ ├── README.md │ │ │ └── test_lower.py │ │ ├── __init__.py │ │ └── benchmark_utils.py │ ├── extension.py │ ├── find_batch_size_dim.py │ ├── fx2ait.py │ ├── lower │ │ ├── __init__.py │ │ ├── lower.py │ │ └── lower_settings.py │ ├── passes │ │ ├── __init__.py │ │ └── lower_basic_pass_aten.py │ ├── tensor_spec.py │ ├── test │ │ ├── __init__.py │ │ ├── converters │ │ │ ├── converters_model │ │ │ │ ├── test_ait_transformer_model.py │ │ │ │ └── test_ait_vision_model.py │ │ │ ├── converters_module │ │ │ │ └── test_ait_multihead_attention.py │ │ │ ├── test_ait_adaptive_avg_pool2d.py │ │ │ ├── test_ait_avg_pool2d.py │ │ │ ├── test_ait_avg_pool3d.py │ │ │ ├── test_ait_batch_norm.py │ │ │ ├── test_ait_binary_op.py │ │ │ ├── test_ait_cast.py │ │ │ ├── test_ait_cat.py │ │ │ ├── test_ait_chunk.py │ │ │ ├── test_ait_clamp.py │ │ │ ├── test_ait_contiguous.py │ │ │ ├── test_ait_conv2d.py │ │ │ ├── test_ait_conv3d.py │ │ │ ├── test_ait_conv3d_depthwise.py │ │ │ ├── test_ait_convtranspose2d.py │ │ │ ├── test_ait_elu.py │ │ │ ├── test_ait_expand.py │ │ │ ├── test_ait_flatten.py │ │ │ ├── test_ait_full.py │ │ │ ├── test_ait_gelu.py │ │ │ ├── test_ait_group_norm.py │ │ │ ├── test_ait_index_select.py │ │ │ ├── test_ait_layer_norm.py │ │ │ ├── test_ait_leaky_relu.py │ │ │ ├── test_ait_linalg_norm.py │ │ │ ├── test_ait_linear.py │ │ │ ├── test_ait_masked_select.py │ │ │ ├── test_ait_matmul.py │ │ │ ├── test_ait_max_pool2d.py │ │ │ ├── test_ait_max_pool3d.py │ │ │ ├── test_ait_nan2num.py │ │ │ ├── test_ait_permute.py │ │ │ ├── test_ait_pooling_ops.py │ │ │ ├── test_ait_pow.py │ │ │ ├── test_ait_reduce.py │ │ │ ├── test_ait_reshape.py │ │ │ ├── test_ait_sigmoid.py │ │ │ ├── test_ait_slice_tensor.py │ │ │ ├── test_ait_softmax.py │ │ │ ├── test_ait_split.py │ │ │ ├── test_ait_square.py │ │ │ ├── test_ait_squeeze.py │ │ │ ├── test_ait_tile.py │ │ │ ├── test_ait_topk.py │ │ │ ├── test_ait_unary_ops.py │ │ │ ├── test_ait_unbind.py │ │ │ ├── test_ait_unsqueeze.py │ │ │ ├── test_ait_upsampling2d.py │ │ │ └── test_ait_var.py │ │ ├── test_ait_lower.py │ │ ├── test_ait_splitter.py │ │ ├── test_fx2ait.py │ │ └── test_tensor_spec.py │ ├── tools │ │ ├── __init__.py │ │ ├── ait_minimizer.py │ │ ├── ait_subgraph_rewriter.py │ │ ├── common_aten2ait.py │ │ └── common_fx2ait.py │ └── utils.py └── setup.py ├── licenses ├── LICENSE.composable_kernel.txt ├── LICENSE.cub.txt ├── LICENSE.cutlass.txt ├── LICENSE.dmlc.txt ├── LICENSE.flash_attention.txt ├── LICENSE.hipcub.txt ├── LICENSE.markdown_table.txt ├── LICENSE.oneflow.txt ├── LICENSE.pydot.txt ├── LICENSE.pytorch.txt ├── LICENSE.tensorrt.txt └── license.header.txt ├── python ├── aitemplate │ ├── __init__.py │ ├── _libinfo.py │ ├── backend │ │ ├── __init__.py │ │ ├── backend_spec.py │ │ ├── build_cache.py │ │ ├── build_cache_base.py │ │ ├── builder.py │ │ ├── codegen.py │ │ ├── common │ │ │ ├── concatenate_common.py │ │ │ ├── elementwise_common.py │ │ │ ├── gemm_common.py │ │ │ ├── split_common.py │ │ │ ├── tensor │ │ │ │ ├── argmax_common.py │ │ │ │ ├── batch_gather_common.py │ │ │ │ ├── identity_common.py │ │ │ │ ├── permute0213_common.py │ │ │ │ ├── permute021_common.py │ │ │ │ ├── permute102_common.py │ │ │ │ ├── permute210_common.py │ │ │ │ ├── slice_common.py │ │ │ │ ├── slice_reshape_scatter_common.py │ │ │ │ └── topk_common.py │ │ │ ├── tensor_accessor.cuh │ │ │ ├── tensor_accessor_codegen.py │ │ │ ├── upsampling2d_common.py │ │ │ └── vision_ops │ │ │ │ ├── efficient_nms_common.py │ │ │ │ ├── efficient_nms_kernel.py │ │ │ │ ├── multi_level_roi_align_common.py │ │ │ │ ├── nms_common.py │ │ │ │ ├── nms_kernel.py │ │ │ │ └── roi_align_common.py │ │ ├── cuda │ │ │ ├── __init__.py │ │ │ ├── attention │ │ │ │ ├── __init__.py │ │ │ │ ├── flash_attention.py │ │ │ │ ├── mem_eff_attention.py │ │ │ │ └── src │ │ │ │ │ ├── fmha.h │ │ │ │ │ ├── fmha │ │ │ │ │ ├── gemm.h │ │ │ │ │ ├── gmem_tile.h │ │ │ │ │ ├── kernel_traits.h │ │ │ │ │ ├── mask.h │ │ │ │ │ ├── smem_tile.h │ │ │ │ │ ├── softmax.h │ │ │ │ │ └── utils.h │ │ │ │ │ ├── fmha_block_fprop_fp16_kernel.sm80.cu │ │ │ │ │ ├── fmha_block_fprop_kernel_1xN.h │ │ │ │ │ ├── fmha_blockmask.h │ │ │ │ │ ├── fmha_fprop_fp16_kernel.sm80.cu │ │ │ │ │ ├── fmha_fprop_kernel_1xN.h │ │ │ │ │ ├── fmha_kernel.h │ │ │ │ │ ├── fmha_utils.h │ │ │ │ │ ├── licenses │ │ │ │ │ └── LICENSE │ │ │ │ │ └── philox.cuh │ │ │ ├── b2b_bmm │ │ │ │ ├── __init__.py │ │ │ │ ├── classic_b2b_bmm.py │ │ │ │ ├── fmha_style_b2b_bmm.py │ │ │ │ ├── grouped_classic_b2b_bmm.py │ │ │ │ └── grouped_fmha_style_b2b_bmm.py │ │ │ ├── builder_cmake.py │ │ │ ├── common │ │ │ │ ├── __init__.py │ │ │ │ └── dummy_op.py │ │ │ ├── conv2d │ │ │ │ ├── __init__.py │ │ │ │ ├── common.py │ │ │ │ ├── common_conv2d_bias_activation.py │ │ │ │ ├── common_conv2d_bias_add_activation.py │ │ │ │ ├── common_conv2d_few_channels.py │ │ │ │ ├── common_transposed_conv2d.py │ │ │ │ ├── conv2d.py │ │ │ │ ├── conv2d_bias.py │ │ │ │ ├── conv2d_bias_add.py │ │ │ │ ├── conv2d_bias_add_hardswish.py │ │ │ │ ├── conv2d_bias_add_relu.py │ │ │ │ ├── conv2d_bias_few_channels.py │ │ │ │ ├── conv2d_bias_hardswish.py │ │ │ │ ├── conv2d_bias_hardswish_few_channels.py │ │ │ │ ├── conv2d_bias_relu.py │ │ │ │ ├── conv2d_bias_relu_few_channels.py │ │ │ │ ├── conv2d_bias_sigmoid.py │ │ │ │ ├── conv2d_depthwise.py │ │ │ │ ├── conv2d_depthwise_bias.py │ │ │ │ ├── transposed_conv2d.py │ │ │ │ └── transposed_conv2d_bias.py │ │ │ ├── conv3d │ │ │ │ ├── __init__.py │ │ │ │ ├── common.py │ │ │ │ ├── common_bias.py │ │ │ │ ├── conv3d.py │ │ │ │ ├── conv3d_bias.py │ │ │ │ ├── depthwise_conv3d.py │ │ │ │ └── depthwise_conv3d_bias.py │ │ │ ├── cuda_common.py │ │ │ ├── elementwise │ │ │ │ ├── __init__.py │ │ │ │ ├── custom_math.cuh │ │ │ │ ├── fused_elementwise.py │ │ │ │ └── int_elementwise.py │ │ │ ├── embedding │ │ │ │ ├── __init__.py │ │ │ │ └── bert_embeddings.py │ │ │ ├── gemm_epilogue_vistor │ │ │ │ ├── __init__.py │ │ │ │ ├── bmm_rcr_softmax.py │ │ │ │ ├── common_dual_gemm.py │ │ │ │ ├── common_softmax.py │ │ │ │ ├── dual_bmm_rrr_div.py │ │ │ │ ├── dual_gemm_rcr_fast_gelu.py │ │ │ │ ├── dual_gemm_rcr_silu.py │ │ │ │ ├── gemm_rcr_bias_softmax.py │ │ │ │ ├── gemm_rcr_softmax.py │ │ │ │ └── include │ │ │ │ │ └── gemm_with_softmax.h │ │ │ ├── gemm_special │ │ │ │ ├── __init__.py │ │ │ │ ├── batched_dense_vec_jagged_2d_mul.py │ │ │ │ ├── bmm_rcr_n1.py │ │ │ │ ├── bmm_rrr_k1_tanh.py │ │ │ │ └── gemm_rrr_small_nk.py │ │ │ ├── gemm_universal │ │ │ │ ├── __init__.py │ │ │ │ ├── bmm_common.py │ │ │ │ ├── bmm_permute_common.py │ │ │ │ ├── bmm_rcr_permute.py │ │ │ │ ├── bmm_rrr_permute.py │ │ │ │ ├── bmm_softmax_bmm_permute.py │ │ │ │ ├── bmm_xxx.py │ │ │ │ ├── bmm_xxx_add.py │ │ │ │ ├── common.py │ │ │ │ ├── common_bias.py │ │ │ │ ├── common_bias_activation.py │ │ │ │ ├── common_bias_broadcast.py │ │ │ │ ├── common_no_bias.py │ │ │ │ ├── common_permute.py │ │ │ │ ├── gemm_rcr.py │ │ │ │ ├── gemm_rcr_bias.py │ │ │ │ ├── gemm_rcr_bias_elementwise.py │ │ │ │ ├── gemm_rcr_bias_fast_gelu.py │ │ │ │ ├── gemm_rcr_bias_gelu.py │ │ │ │ ├── gemm_rcr_bias_hardswish.py │ │ │ │ ├── gemm_rcr_bias_permute.py │ │ │ │ ├── gemm_rcr_bias_relu.py │ │ │ │ ├── gemm_rcr_bias_sigmoid.py │ │ │ │ ├── gemm_rcr_bias_swish.py │ │ │ │ ├── gemm_rcr_bias_tanh.py │ │ │ │ ├── gemm_rcr_fast_gelu.py │ │ │ │ ├── gemm_rcr_permute.py │ │ │ │ ├── gemm_rcr_permute_elup1.py │ │ │ │ ├── gemm_rrr.py │ │ │ │ ├── gemm_rrr_bias.py │ │ │ │ ├── gemm_rrr_permute.py │ │ │ │ ├── group_common.py │ │ │ │ ├── group_common_bias.py │ │ │ │ ├── group_gemm_rcr.py │ │ │ │ ├── group_gemm_rcr_bias.py │ │ │ │ ├── group_gemm_rcr_bias_relu.py │ │ │ │ ├── group_gemm_rcr_bias_sigmoid.py │ │ │ │ ├── layout.py │ │ │ │ ├── perm021fc_ccr.py │ │ │ │ ├── perm021fc_ccr_bias.py │ │ │ │ ├── perm021fc_ccr_bias_permute.py │ │ │ │ ├── perm021fc_crc.py │ │ │ │ ├── perm021fc_crc_bias.py │ │ │ │ ├── perm102_bmm_rcr.py │ │ │ │ ├── perm102_bmm_rcr_bias.py │ │ │ │ ├── perm102_bmm_rrr.py │ │ │ │ └── perm102_bmm_rrr_bias.py │ │ │ ├── groupnorm │ │ │ │ ├── __init__.py │ │ │ │ ├── groupnorm.py │ │ │ │ ├── groupnorm_common.py │ │ │ │ ├── groupnorm_kernel.cuh │ │ │ │ ├── groupnorm_swish.py │ │ │ │ └── layer_norm.cuh │ │ │ ├── jagged │ │ │ │ ├── __init__.py │ │ │ │ ├── jagged_lengths_to_offsets.py │ │ │ │ └── jagged_lengths_to_presences.py │ │ │ ├── layernorm_sigmoid_mul │ │ │ │ ├── __init__.py │ │ │ │ ├── batch_layernorm_sigmoid_mul.py │ │ │ │ ├── group_layernorm_sigmoid_mul.py │ │ │ │ ├── layer_norm.cuh │ │ │ │ ├── layernorm_common.py │ │ │ │ ├── layernorm_sigmoid_mul.py │ │ │ │ ├── layernorm_sigmoid_mul_kernel.cuh │ │ │ │ └── layernorm_welford.cuh │ │ │ ├── lib_template.py │ │ │ ├── padding │ │ │ │ ├── __init__.py │ │ │ │ ├── ndhwc3to8.py │ │ │ │ ├── nhwc3to4.py │ │ │ │ ├── nhwc3to8.py │ │ │ │ └── pad_last_dim.py │ │ │ ├── pool2d │ │ │ │ ├── __init__.py │ │ │ │ ├── avg_pool2d.py │ │ │ │ ├── max_pool2d.py │ │ │ │ └── pool2d.py │ │ │ ├── reduce │ │ │ │ ├── __init__.py │ │ │ │ ├── reduce_3d.py │ │ │ │ ├── reduce_common.py │ │ │ │ ├── reduce_common_slim_tensor.py │ │ │ │ ├── reduce_max.py │ │ │ │ ├── reduce_mean.py │ │ │ │ ├── reduce_min.py │ │ │ │ ├── reduce_small_axis.py │ │ │ │ ├── reduce_sum.py │ │ │ │ ├── var.py │ │ │ │ └── vector_norm.py │ │ │ ├── softmax │ │ │ │ ├── __init__.py │ │ │ │ ├── softmax.cuh │ │ │ │ └── softmax.py │ │ │ ├── target_def.py │ │ │ ├── tensor │ │ │ │ ├── __init__.py │ │ │ │ ├── argmax.py │ │ │ │ ├── batch_gather.py │ │ │ │ ├── cast.py │ │ │ │ ├── concatenate.py │ │ │ │ ├── concatenate_fast.cuh │ │ │ │ ├── concatenate_fast.py │ │ │ │ ├── concatenate_tanh.py │ │ │ │ ├── dynamic_slice.py │ │ │ │ ├── expand.py │ │ │ │ ├── expand_static_shape.py │ │ │ │ ├── full.py │ │ │ │ ├── gather.py │ │ │ │ ├── identity.py │ │ │ │ ├── index_select.py │ │ │ │ ├── jagged_to_padded_dense.py │ │ │ │ ├── masked_select.py │ │ │ │ ├── padded_dense_to_jagged.py │ │ │ │ ├── permute.cuh │ │ │ │ ├── permute.py │ │ │ │ ├── permute021.py │ │ │ │ ├── permute0213.py │ │ │ │ ├── permute102.py │ │ │ │ ├── permute210.py │ │ │ │ ├── relational.py │ │ │ │ ├── repeat.cuh │ │ │ │ ├── slice_reshape_scatter.py │ │ │ │ ├── slice_scatter.py │ │ │ │ ├── split.py │ │ │ │ ├── topk.py │ │ │ │ └── where.py │ │ │ ├── upsample │ │ │ │ ├── __init__.py │ │ │ │ ├── upsampling2d.py │ │ │ │ └── upsampling2d_add.py │ │ │ ├── utils.py │ │ │ ├── view_ops │ │ │ │ ├── __init__.py │ │ │ │ ├── make_jagged.py │ │ │ │ └── view_ops.py │ │ │ └── vision_ops │ │ │ │ ├── __init__.py │ │ │ │ ├── nms │ │ │ │ ├── __init__.py │ │ │ │ ├── batched_nms.py │ │ │ │ ├── batched_nms_kernel.cuh │ │ │ │ ├── efficient_nms.py │ │ │ │ └── nms.py │ │ │ │ └── roi_ops │ │ │ │ ├── __init__.py │ │ │ │ ├── multi_level_roi_align.py │ │ │ │ ├── roi_align.py │ │ │ │ └── roi_ops.py │ │ ├── main_templates.py │ │ ├── profiler_cache.py │ │ ├── profiler_runner.py │ │ ├── registry.py │ │ ├── rocm │ │ │ ├── __init__.py │ │ │ ├── attention │ │ │ │ ├── __init__.py │ │ │ │ └── mem_eff_attention.py │ │ │ ├── common │ │ │ │ ├── __init__.py │ │ │ │ └── dummy_op.py │ │ │ ├── conv2d │ │ │ │ ├── __init__.py │ │ │ │ ├── common.py │ │ │ │ ├── conv2d.py │ │ │ │ ├── conv2d_bias.py │ │ │ │ ├── conv2d_bias_add.py │ │ │ │ ├── conv2d_bias_add_relu.py │ │ │ │ ├── conv2d_bias_relu.py │ │ │ │ ├── conv2d_bias_sigmoid.py │ │ │ │ ├── transposed_conv2d.py │ │ │ │ └── transposed_conv2d_bias_relu.py │ │ │ ├── elementwise │ │ │ │ ├── __init__.py │ │ │ │ ├── custom_math.h │ │ │ │ └── fused_elementwise.py │ │ │ ├── embedding │ │ │ │ ├── __init__.py │ │ │ │ └── bert_embeddings.py │ │ │ ├── gemm │ │ │ │ ├── __init__.py │ │ │ │ ├── bmm_ccr.py │ │ │ │ ├── bmm_ccr_add.py │ │ │ │ ├── bmm_common.py │ │ │ │ ├── bmm_crr.py │ │ │ │ ├── bmm_crr_add.py │ │ │ │ ├── bmm_permute_common.py │ │ │ │ ├── bmm_rcr.py │ │ │ │ ├── bmm_rcr_permute.py │ │ │ │ ├── bmm_rrr.py │ │ │ │ ├── bmm_rrr_add.py │ │ │ │ ├── bmm_rrr_permute.py │ │ │ │ ├── bmm_softmax_bmm.py │ │ │ │ ├── bmm_softmax_bmm_permute.py │ │ │ │ ├── common.py │ │ │ │ ├── gemm_epilogue.py │ │ │ │ ├── gemm_rcr.py │ │ │ │ ├── gemm_rcr_bias.py │ │ │ │ ├── gemm_rcr_bias_add.py │ │ │ │ ├── gemm_rcr_bias_add_add.py │ │ │ │ ├── gemm_rcr_bias_add_add_relu.py │ │ │ │ ├── gemm_rcr_bias_add_relu.py │ │ │ │ ├── gemm_rcr_bias_fast_gelu.py │ │ │ │ ├── gemm_rcr_bias_hardswish.py │ │ │ │ ├── gemm_rcr_bias_mul.py │ │ │ │ ├── gemm_rcr_bias_mul_add.py │ │ │ │ ├── gemm_rcr_bias_mul_tanh.py │ │ │ │ ├── gemm_rcr_bias_permute.py │ │ │ │ ├── gemm_rcr_bias_permute_m2n3.py │ │ │ │ ├── gemm_rcr_bias_permute_m3n2.py │ │ │ │ ├── gemm_rcr_bias_relu.py │ │ │ │ ├── gemm_rcr_bias_sigmoid.py │ │ │ │ ├── gemm_rcr_bias_sigmoid_mul.py │ │ │ │ ├── gemm_rcr_bias_sigmoid_mul_tanh.py │ │ │ │ ├── gemm_rcr_bias_swish.py │ │ │ │ ├── gemm_rcr_bias_tanh.py │ │ │ │ ├── gemm_rcr_permute_m2n3.py │ │ │ │ ├── gemm_rrr.py │ │ │ │ ├── gemm_rrr_bias_permute.py │ │ │ │ ├── layout.py │ │ │ │ └── permute_common.py │ │ │ ├── lib_template.py │ │ │ ├── normalization │ │ │ │ ├── __init__.py │ │ │ │ ├── groupnorm.py │ │ │ │ ├── groupnorm_swish.py │ │ │ │ ├── layernorm.py │ │ │ │ ├── norm_common.py │ │ │ │ └── softmax.py │ │ │ ├── padding │ │ │ │ ├── __init__.py │ │ │ │ ├── nhwc3to4.py │ │ │ │ ├── nhwc3to8.py │ │ │ │ └── pad_last_dim.py │ │ │ ├── pool2d │ │ │ │ ├── __init__.py │ │ │ │ ├── avg_pool2d.py │ │ │ │ ├── max_pool2d.py │ │ │ │ └── pool2d.py │ │ │ ├── target_def.py │ │ │ ├── tensor │ │ │ │ ├── __init__.py │ │ │ │ ├── argmax.py │ │ │ │ ├── batch_gather.py │ │ │ │ ├── concatenate.py │ │ │ │ ├── concatenate_tanh.py │ │ │ │ ├── dynamic_slice.py │ │ │ │ ├── expand.py │ │ │ │ ├── expand_static_shape.py │ │ │ │ ├── full.py │ │ │ │ ├── identity.py │ │ │ │ ├── permute021.py │ │ │ │ ├── permute0213.py │ │ │ │ ├── permute102.py │ │ │ │ ├── permute210.py │ │ │ │ ├── repeat.h │ │ │ │ ├── slice_reshape_scatter.py │ │ │ │ ├── slice_scatter.py │ │ │ │ ├── split.py │ │ │ │ └── topk.py │ │ │ ├── upsample │ │ │ │ ├── __init__.py │ │ │ │ ├── upsampling2d.py │ │ │ │ └── upsampling2d_add.py │ │ │ ├── utils.py │ │ │ ├── view_ops │ │ │ │ ├── __init__.py │ │ │ │ └── view_ops.py │ │ │ └── vision_ops │ │ │ │ ├── __init__.py │ │ │ │ ├── efficient_nms.py │ │ │ │ ├── nms.py │ │ │ │ └── roi_ops │ │ │ │ ├── __init__.py │ │ │ │ ├── multi_level_roi_align.py │ │ │ │ └── roi_align.py │ │ ├── target.py │ │ └── task_runner.py │ ├── compiler │ │ ├── __init__.py │ │ ├── base.py │ │ ├── compiler.py │ │ ├── dtype.py │ │ ├── model.py │ │ ├── op_registry.py │ │ ├── ops │ │ │ ├── __init__.py │ │ │ ├── attention │ │ │ │ ├── __init__.py │ │ │ │ ├── flash_attention.py │ │ │ │ └── mem_eff_attention.py │ │ │ ├── b2b_bmm │ │ │ │ ├── __init__.py │ │ │ │ ├── b2b_bmm_base.py │ │ │ │ ├── classic_b2b_bmm.py │ │ │ │ ├── fmha_style_b2b_bmm.py │ │ │ │ ├── grouped_classic_b2b_bmm.py │ │ │ │ └── grouped_fmha_style_b2b_bmm.py │ │ │ ├── common │ │ │ │ ├── __init__.py │ │ │ │ ├── elementwise.py │ │ │ │ ├── epilogue.py │ │ │ │ ├── fused_elementwise.py │ │ │ │ ├── int_elementwise.py │ │ │ │ ├── math.py │ │ │ │ ├── python_ops.py │ │ │ │ └── view_ops.py │ │ │ ├── conv │ │ │ │ ├── __init__.py │ │ │ │ ├── cache_entry.py │ │ │ │ ├── common_conv2d_bias_activation.py │ │ │ │ ├── common_conv2d_bias_add_activation.py │ │ │ │ ├── conv2d.py │ │ │ │ ├── conv2d_bias.py │ │ │ │ ├── conv2d_bias_add.py │ │ │ │ ├── conv2d_bias_add_hardswish.py │ │ │ │ ├── conv2d_bias_add_relu.py │ │ │ │ ├── conv2d_bias_few_channels.py │ │ │ │ ├── conv2d_bias_hardswish.py │ │ │ │ ├── conv2d_bias_hardswish_few_channels.py │ │ │ │ ├── conv2d_bias_relu.py │ │ │ │ ├── conv2d_bias_relu_few_channels.py │ │ │ │ ├── conv2d_bias_sigmoid.py │ │ │ │ ├── conv2d_depthwise.py │ │ │ │ ├── conv2d_depthwise_bias.py │ │ │ │ ├── conv3d.py │ │ │ │ ├── conv3d_bias.py │ │ │ │ ├── conv_common.py │ │ │ │ ├── depthwise_conv3d.py │ │ │ │ ├── special_conv2d_bias_activation.py │ │ │ │ ├── transposed_conv2d.py │ │ │ │ ├── transposed_conv2d_bias.py │ │ │ │ └── transposed_conv2d_bias_relu.py │ │ │ ├── embedding │ │ │ │ ├── __init__.py │ │ │ │ └── bert_embeddings.py │ │ │ ├── gemm_epilogue_vistor │ │ │ │ ├── __init__.py │ │ │ │ ├── bmm_rcr_softmax.py │ │ │ │ ├── dual_bmm_rrr_div.py │ │ │ │ ├── dual_gemm_rcr_fast_gelu.py │ │ │ │ ├── dual_gemm_rcr_silu.py │ │ │ │ ├── gemm_rcr_bias_softmax.py │ │ │ │ └── gemm_rcr_softmax.py │ │ │ ├── gemm_special │ │ │ │ ├── __init__.py │ │ │ │ ├── batched_dense_vec_jagged_2d_mul.py │ │ │ │ ├── bmm_rcr_n1.py │ │ │ │ ├── bmm_rrr_k1_tanh.py │ │ │ │ └── gemm_rrr_small_nk.py │ │ │ ├── gemm_universal │ │ │ │ ├── __init__.py │ │ │ │ ├── bmm.py │ │ │ │ ├── bmm_rcr_permute.py │ │ │ │ ├── bmm_rrr_permute.py │ │ │ │ ├── bmm_softmax_bmm.py │ │ │ │ ├── bmm_softmax_bmm_permute.py │ │ │ │ ├── bmm_xxx.py │ │ │ │ ├── bmm_xxx_add.py │ │ │ │ ├── cache_entry.py │ │ │ │ ├── gemm_common.py │ │ │ │ ├── gemm_rcr.py │ │ │ │ ├── gemm_rcr_bias.py │ │ │ │ ├── gemm_rcr_bias_add.py │ │ │ │ ├── gemm_rcr_bias_add_add.py │ │ │ │ ├── gemm_rcr_bias_add_add_relu.py │ │ │ │ ├── gemm_rcr_bias_add_relu.py │ │ │ │ ├── gemm_rcr_bias_broadcast.py │ │ │ │ ├── gemm_rcr_bias_fast_gelu.py │ │ │ │ ├── gemm_rcr_bias_gelu.py │ │ │ │ ├── gemm_rcr_bias_hardswish.py │ │ │ │ ├── gemm_rcr_bias_mul.py │ │ │ │ ├── gemm_rcr_bias_mul_add.py │ │ │ │ ├── gemm_rcr_bias_mul_tanh.py │ │ │ │ ├── gemm_rcr_bias_permute.py │ │ │ │ ├── gemm_rcr_bias_relu.py │ │ │ │ ├── gemm_rcr_bias_sigmoid.py │ │ │ │ ├── gemm_rcr_bias_sigmoid_mul.py │ │ │ │ ├── gemm_rcr_bias_sigmoid_mul_tanh.py │ │ │ │ ├── gemm_rcr_bias_swish.py │ │ │ │ ├── gemm_rcr_bias_tanh.py │ │ │ │ ├── gemm_rcr_fast_gelu.py │ │ │ │ ├── gemm_rcr_permute.py │ │ │ │ ├── gemm_rcr_permute_elup1.py │ │ │ │ ├── gemm_rrr.py │ │ │ │ ├── gemm_rrr_bias.py │ │ │ │ ├── gemm_rrr_bias_permute.py │ │ │ │ ├── gemm_rrr_permute.py │ │ │ │ ├── group_gemm_rcr.py │ │ │ │ ├── group_gemm_rcr_bias.py │ │ │ │ ├── group_gemm_rcr_bias_relu.py │ │ │ │ ├── group_gemm_rcr_bias_sigmoid.py │ │ │ │ ├── perm021fc_ccr.py │ │ │ │ ├── perm021fc_ccr_bias.py │ │ │ │ ├── perm021fc_ccr_bias_permute.py │ │ │ │ ├── perm021fc_crc.py │ │ │ │ ├── perm021fc_crc_bias.py │ │ │ │ ├── perm102_bmm_rcr.py │ │ │ │ ├── perm102_bmm_rcr_bias.py │ │ │ │ ├── perm102_bmm_rrr.py │ │ │ │ └── perm102_bmm_rrr_bias.py │ │ │ ├── groupnorm │ │ │ │ ├── __init__.py │ │ │ │ ├── groupnorm.py │ │ │ │ └── groupnorm_swish.py │ │ │ ├── jagged │ │ │ │ ├── __init__.py │ │ │ │ ├── jagged_lengths_to_offsets.py │ │ │ │ └── jagged_lengths_to_presences.py │ │ │ ├── layernorm │ │ │ │ ├── __init__.py │ │ │ │ ├── batch_layernorm_sigmoid_mul.py │ │ │ │ ├── group_layernorm.py │ │ │ │ ├── group_layernorm_sigmoid_mul.py │ │ │ │ ├── layernorm.py │ │ │ │ └── layernorm_sigmoid_mul.py │ │ │ ├── padding │ │ │ │ ├── __init__.py │ │ │ │ ├── ndhwc3to8.py │ │ │ │ ├── nhwc3to4.py │ │ │ │ ├── nhwc3to8.py │ │ │ │ ├── nhwc_pad_common.py │ │ │ │ └── pad_last_dim.py │ │ │ ├── pool │ │ │ │ ├── __init__.py │ │ │ │ ├── avg_pool2d.py │ │ │ │ ├── max_pool2d.py │ │ │ │ └── pool2d.py │ │ │ ├── reduce │ │ │ │ ├── __init__.py │ │ │ │ ├── reduce_common.py │ │ │ │ ├── reduce_max.py │ │ │ │ ├── reduce_mean.py │ │ │ │ ├── reduce_min.py │ │ │ │ ├── reduce_sum.py │ │ │ │ ├── var.py │ │ │ │ └── vector_norm.py │ │ │ ├── softmax │ │ │ │ ├── __init__.py │ │ │ │ ├── cache_entry.py │ │ │ │ └── softmax.py │ │ │ ├── tensor │ │ │ │ ├── __init__.py │ │ │ │ ├── argmax.py │ │ │ │ ├── batch_gather.py │ │ │ │ ├── cast.py │ │ │ │ ├── chunk.py │ │ │ │ ├── concatenate.py │ │ │ │ ├── concatenate_tanh.py │ │ │ │ ├── dynamic_slice.py │ │ │ │ ├── expand.py │ │ │ │ ├── full.py │ │ │ │ ├── gather.py │ │ │ │ ├── identity.py │ │ │ │ ├── index_select.py │ │ │ │ ├── jagged_to_padded_dense.py │ │ │ │ ├── masked_select.py │ │ │ │ ├── padded_dense_to_jagged.py │ │ │ │ ├── permute.py │ │ │ │ ├── permute021.py │ │ │ │ ├── permute0213.py │ │ │ │ ├── permute102.py │ │ │ │ ├── permute210.py │ │ │ │ ├── relational.py │ │ │ │ ├── size.py │ │ │ │ ├── slice_reshape_scatter.py │ │ │ │ ├── slice_scatter.py │ │ │ │ ├── split.py │ │ │ │ ├── topk.py │ │ │ │ ├── transpose.py │ │ │ │ └── where.py │ │ │ ├── upsample │ │ │ │ ├── __init__.py │ │ │ │ ├── upsampling2d.py │ │ │ │ ├── upsampling2d_add.py │ │ │ │ └── upsampling_common.py │ │ │ └── vision_ops │ │ │ │ ├── __init__.py │ │ │ │ ├── nms │ │ │ │ ├── __init__.py │ │ │ │ ├── batched_nms.py │ │ │ │ ├── efficient_nms.py │ │ │ │ └── nms.py │ │ │ │ └── roi_ops │ │ │ │ ├── __init__.py │ │ │ │ ├── multi_level_roi_align.py │ │ │ │ ├── roi_align.py │ │ │ │ └── roi_ops.py │ │ ├── public │ │ │ └── __init__.py │ │ ├── stable_set.py │ │ ├── symbolic.py │ │ ├── tensor_accessor.py │ │ └── transform │ │ │ ├── __init__.py │ │ │ ├── apply_padding.py │ │ │ ├── bind_constants.py │ │ │ ├── constant_folding.py │ │ │ ├── dedup_make_jagged_ops.py │ │ │ ├── fuse_bmm_permute.py │ │ │ ├── fuse_conv_elementwise.py │ │ │ ├── fuse_conv_patterns.py │ │ │ ├── fuse_duplicate_fused_elementwise.py │ │ │ ├── fuse_expand_bmm.py │ │ │ ├── fuse_group_ops.py │ │ │ ├── fuse_mm_elementwise.py │ │ │ ├── fuse_mm_elementwise_patterns.py │ │ │ ├── fuse_mm_reshape_permute.py │ │ │ ├── fuse_ops.py │ │ │ ├── fuse_parallel_gemms.py │ │ │ ├── fuse_permute_bmm_and_gemm.py │ │ │ ├── fuse_split.py │ │ │ ├── fuse_utils.py │ │ │ ├── mark_param_tensor.py │ │ │ ├── memory_planning.py │ │ │ ├── move_view_ops.py │ │ │ ├── name_graph.py │ │ │ ├── optimize_graph.py │ │ │ ├── profile.py │ │ │ ├── profile_dynamic_dim.py │ │ │ ├── refine_graph.py │ │ │ ├── remove_elementwise_no_ops.py │ │ │ ├── remove_no_ops.py │ │ │ ├── remove_unused_ops.py │ │ │ ├── split_large_concat_ops.py │ │ │ ├── split_large_slice_scatter_ops.py │ │ │ ├── split_large_split_ops.py │ │ │ ├── toposort.py │ │ │ ├── transform_memory_ops.py │ │ │ ├── transform_merge_slice_ops.py │ │ │ ├── transform_merge_view_ops.py │ │ │ ├── transform_odd_alignment.py │ │ │ ├── transform_permutations.py │ │ │ ├── transform_permute_to_reshape.py │ │ │ ├── transform_special_ops.py │ │ │ ├── transform_strided_op_and_view_op.py │ │ │ ├── transform_strided_ops.py │ │ │ ├── transform_strided_ops_utils.py │ │ │ ├── transform_strided_slice.py │ │ │ └── transform_utils.py │ ├── frontend │ │ ├── __init__.py │ │ ├── nn │ │ │ ├── __init__.py │ │ │ ├── activation.py │ │ │ ├── attention.py │ │ │ ├── batch_norm.py │ │ │ ├── container.py │ │ │ ├── conv1d.py │ │ │ ├── conv2d │ │ │ │ ├── __init__.py │ │ │ │ ├── common_conv2d_bias_act.py │ │ │ │ ├── common_conv2d_bias_add_act.py │ │ │ │ ├── conv2d.py │ │ │ │ ├── conv2d_bias.py │ │ │ │ ├── conv2d_bias_add_hardswish.py │ │ │ │ ├── conv2d_bias_add_relu.py │ │ │ │ ├── conv2d_bias_few_channels.py │ │ │ │ ├── conv2d_bias_hardswish.py │ │ │ │ ├── conv2d_bias_hardswish_few_channels.py │ │ │ │ ├── conv2d_bias_relu.py │ │ │ │ ├── conv2d_bias_relu_few_channels.py │ │ │ │ ├── conv2d_bias_sigmoid.py │ │ │ │ ├── conv2d_depthwise.py │ │ │ │ ├── conv2d_depthwise_bias.py │ │ │ │ ├── special_conv2d_bias_act.py │ │ │ │ ├── transposed_conv2d_bias.py │ │ │ │ ├── transposed_conv2d_bias_act.py │ │ │ │ └── transposed_conv2d_bias_relu.py │ │ │ ├── conv3d.py │ │ │ ├── dropout.py │ │ │ ├── dual_gemm.py │ │ │ ├── embedding.py │ │ │ ├── fpn_proposal.py │ │ │ ├── group_norm.py │ │ │ ├── head.py │ │ │ ├── identity.py │ │ │ ├── layer_norm.py │ │ │ ├── linear.py │ │ │ ├── module.py │ │ │ ├── multiscale_attention.py │ │ │ ├── padding.py │ │ │ ├── parameter.py │ │ │ ├── patch_embed.py │ │ │ ├── pool2d.py │ │ │ ├── pool3d.py │ │ │ ├── positional_encoding.py │ │ │ ├── proposal.py │ │ │ ├── roi_ops.py │ │ │ ├── softmax.py │ │ │ ├── upsample.py │ │ │ ├── vanilla_attention.py │ │ │ ├── view_ops.py │ │ │ └── vision_transformers.py │ │ └── parameter.py │ ├── testing │ │ ├── __init__.py │ │ ├── benchmark_ait.py │ │ ├── benchmark_pt.py │ │ ├── benchmark_trt.py │ │ ├── detect_target.py │ │ ├── jagged_utils.py │ │ ├── profile.py │ │ └── test_utils.py │ └── utils │ │ ├── __init__.py │ │ ├── alignment.py │ │ ├── debug_settings.py │ │ ├── environ.py │ │ ├── graph_utils.py │ │ ├── import_path.py │ │ ├── io.py │ │ ├── json_utils.py │ │ ├── markdown_table.py │ │ ├── misc.py │ │ ├── mk_ck_lib │ │ ├── __init__.py │ │ ├── conv2d_operation.py │ │ ├── gemm_operation.py │ │ ├── generator.py │ │ ├── groupnorm_operation.py │ │ ├── layernorm_operation.py │ │ ├── library.py │ │ ├── manifest.py │ │ └── softmax_operation.py │ │ ├── mk_cutlass_lib │ │ ├── extra_conv_emit.py │ │ ├── extra_cutlass_generator.py │ │ ├── extra_enum.py │ │ ├── extra_gemm_emit.py │ │ └── mk_cutlass_lib.py │ │ ├── serialization │ │ ├── ait_program.py │ │ └── serdes_code.py │ │ ├── shape_utils.py │ │ ├── tensor_utils.py │ │ ├── torch_utils.py │ │ └── visualization │ │ ├── __init__.py │ │ ├── op_attr_factory.py │ │ ├── plot.py │ │ ├── pydot.py │ │ └── web_template.py └── setup.py ├── static ├── README.md ├── csrc │ ├── debug_utility.cpp │ ├── model_container.cpp │ ├── model_interface.cpp │ ├── rocm_hack.cpp │ ├── standalone.cpp │ ├── utility.cpp │ └── windll.cpp └── include │ ├── cuda_device_functions.h │ ├── debug_utility.h │ ├── jagged.h │ ├── kernels │ ├── classic_b2b_bmm │ │ ├── device │ │ │ └── b2b_batched_gemm.h │ │ ├── kernel │ │ │ ├── b2b_batched_gemm.h │ │ │ └── default_b2b_batched_gemm.h │ │ ├── thread │ │ │ └── linear_combination_triu.h │ │ ├── threadblock │ │ │ ├── b2b_mma_base.h │ │ │ ├── b2b_mma_multistage.h │ │ │ ├── b2b_mma_pipelined.h │ │ │ ├── custom_epilogue_tensor_op.h │ │ │ ├── default_b2b_mma.h │ │ │ ├── default_gmem_to_accum_loader_tensor_op.h │ │ │ ├── gmem_to_accum_loader.h │ │ │ └── gmem_to_accum_loader_shared_load_iterator.h │ │ └── warp │ │ │ ├── gmem_to_accum_loader_fragment_iterator_tensor_op.h │ │ │ └── triu_mma_tensor_op_fragment_iterator.h │ ├── debug_string.h │ ├── fmha_style_b2b_bmm │ │ ├── attention_scaling_coefs_updater.h │ │ ├── debug_utils.h │ │ ├── epilogue_pipelined.h │ │ ├── epilogue_rescale_output.h │ │ ├── find_default_mma.h │ │ ├── gemm_kernel_utils.h │ │ ├── iterators │ │ │ ├── epilogue_predicated_tile_iterator.h │ │ │ ├── make_residual_last.h │ │ │ ├── predicated_tile_access_iterator_residual_last.h │ │ │ ├── predicated_tile_iterator_residual_last.h │ │ │ ├── transpose_warp_iterator.h │ │ │ └── warp_iterator_from_smem.h │ │ ├── kernel_forward.h │ │ ├── mma_from_smem.h │ │ └── transform │ │ │ └── tile_smem_loader.h │ ├── grouped_classic_b2b_bmm │ │ ├── device │ │ │ └── b2b_batched_gemm.h │ │ ├── kernel │ │ │ ├── b2b_batched_gemm.h │ │ │ └── default_b2b_batched_gemm.h │ │ ├── thread │ │ │ └── linear_combination_triu.h │ │ ├── threadblock │ │ │ ├── b2b_mma_base.h │ │ │ ├── b2b_mma_multistage.h │ │ │ ├── b2b_mma_pipelined.h │ │ │ ├── custom_epilogue_tensor_op.h │ │ │ ├── default_b2b_mma.h │ │ │ ├── default_gmem_to_accum_loader_tensor_op.h │ │ │ ├── gmem_to_accum_loader.h │ │ │ ├── gmem_to_accum_loader_shared_load_iterator.h │ │ │ └── non_predicated_tile_access_iterator.h │ │ └── warp │ │ │ ├── gmem_to_accum_loader_fragment_iterator_tensor_op.h │ │ │ └── triu_mma_tensor_op_fragment_iterator.h │ ├── kat_printf.h │ └── mem_eff_attention │ │ ├── attention_scaling_coefs_updater.h │ │ ├── debug_utils.h │ │ ├── default_fmha_grouped.h │ │ ├── epilogue_pipelined.h │ │ ├── epilogue_rescale_output.h │ │ ├── epilogue_thread_apply_logsumexp.h │ │ ├── find_default_mma.h │ │ ├── fmha_grouped.h │ │ ├── fmha_grouped_problem_visitor.h │ │ ├── gemm │ │ ├── custom_mma.h │ │ ├── custom_mma_base.h │ │ ├── custom_mma_multistage.h │ │ └── custom_mma_pipelined.h │ │ ├── gemm_kernel_utils.h │ │ ├── iterators │ │ ├── epilogue_predicated_tile_iterator.h │ │ ├── make_residual_last.h │ │ ├── predicated_tile_access_iterator_residual_last.h │ │ └── predicated_tile_iterator_residual_last.h │ │ ├── kernel_forward.h │ │ └── mma_from_smem.h │ ├── logging.h │ ├── macros.h │ ├── model.h │ ├── model_container.h │ ├── model_interface.h │ ├── owned_constants.h │ ├── raii_wrapper.h │ ├── rocm_device_functions.h │ ├── utility.h │ └── windll.h └── tests ├── ci_profile_cache ├── README.md └── update_cache.py ├── lint ├── check_meta_header.py └── flake8_problem_matcher.json └── unittest ├── backend ├── test_build_cache.py ├── test_codegen_output_aliases.py ├── test_codegen_output_tensor.py ├── test_cuda_graph.py ├── test_fused_elementwise_backend.py ├── test_gen_standalone.py ├── test_model_api.py └── test_profiler.py ├── benchmark ├── test_gemm_benchmark.py ├── test_group_gemm_benchmark.py └── test_strided_layernorm_benchmark.py ├── compiler ├── test_compilation_failure.py ├── test_constant_folding.py ├── test_eliminate_permutations.py ├── test_fuse_bmm_permute.py ├── test_fuse_cat_view_cat.py ├── test_fuse_conv_elementwise.py ├── test_fuse_duplicate_fused_elementwise.py ├── test_fuse_expand.py ├── test_fuse_expand_bmm.py ├── test_fuse_mm_elementwise.py ├── test_fuse_mm_reshape_permute.py ├── test_fuse_ops.py ├── test_fuse_permute_bmm.py ├── test_fuse_permute_gemm.py ├── test_fuse_split_cat.py ├── test_fused_elementwise_complex_dependency.py ├── test_fused_elementwise_out_of_order.py ├── test_fused_elementwise_singleton.py ├── test_group_fusions.py ├── test_memory_planning.py ├── test_merge_slice_ops.py ├── test_merge_view_ops.py ├── test_move_view_ops.py ├── test_op_common_elementwise.py ├── test_pad_bmm_rrr_bias_with_cat.py ├── test_pad_gemm_rrr_with_cat.py ├── test_pad_gemm_with_cat.py ├── test_pad_gemm_with_elementwise.py ├── test_parallel_gemm_fusions.py ├── test_permute_bmm_special_op.py ├── test_public_import.py ├── test_refine_graph.py ├── test_remove_elementwise_no_ops.py ├── test_remove_id_ops.py ├── test_remove_no_op_concats.py ├── test_remove_no_op_dynamic_slices.py ├── test_remove_no_op_splits.py ├── test_remove_unused_ops.py ├── test_slice_bmm_fusion.py ├── test_slice_elemwise_fusion.py ├── test_slice_gemm_fusion.py ├── test_slice_permute021_fusion.py ├── test_slice_reshape_scatter.py ├── test_slice_scatter_pattern.py ├── test_slice_view_strided.py ├── test_split_bmm_fusion.py ├── test_split_bmm_softmax_bmm.py ├── test_split_full_idx.py ├── test_split_large_concat.py ├── test_split_large_slice_reshape_scatter.py ├── test_split_large_slice_scatter.py ├── test_split_large_split.py ├── test_split_view_strided.py ├── test_strided_group_gemm.py ├── test_strided_group_layernorm.py ├── test_strided_layernorm.py ├── test_strided_layernorm_reshape.py ├── test_strided_op_cat_pattern.py ├── test_strided_reshape_cat.py ├── test_strided_scatter.py ├── test_strided_split_group_gemm.py ├── test_strided_view_cat.py ├── test_strided_view_op.py ├── test_symbolic.py ├── test_tensor.py ├── test_tensor_accessor.py ├── test_transform_memory_ops.py ├── test_transform_odd_alignment.py ├── test_transform_permute_to_reshape.py ├── test_transform_special_op.py ├── test_transform_toposort.py ├── test_transform_utils.py └── test_view_strided_op.py ├── frontend └── test_module.py ├── ops ├── test_activation.py ├── test_argmax.py ├── test_argmax_sm80.py ├── test_attention.py ├── test_avg_pool2d.py ├── test_b2b_bmm.py ├── test_batch_gather.py ├── test_batch_norm.py ├── test_batched_dense_vec_jagged_2d_mul.py ├── test_bert_embeddings.py ├── test_bmm.py ├── test_bmm_add.py ├── test_bmm_alpha.py ├── test_bmm_permute.py ├── test_bmm_rcr_n1.py ├── test_bmm_rrr_k1_tanh.py ├── test_bmm_softmax.py ├── test_bmm_softmax_bmm.py ├── test_cast.py ├── test_chunk.py ├── test_clamp_nan_to_num.py ├── test_concatenate.py ├── test_concatenate_tanh.py ├── test_conv.py ├── test_conv2d_bias_add.py ├── test_conv3d.py ├── test_conv3d_profiler_cache.py ├── test_conv_bias.py ├── test_conv_bias_act_few_channels.py ├── test_conv_bias_add_hardswish.py ├── test_conv_bias_add_relu.py ├── test_conv_bias_hardswish.py ├── test_conv_bias_relu.py ├── test_conv_bias_sigmoid.py ├── test_conv_depthwise.py ├── test_conv_depthwise_bias.py ├── test_conv_profiler_cache.py ├── test_cross_attention.py ├── test_depthwise_conv3d.py ├── test_dual_bmm.py ├── test_dual_gemm.py ├── test_dynamic_conv.py ├── test_efficient_nms.py ├── test_expand.py ├── test_flatten.py ├── test_fpn_roi_align.py ├── test_full.py ├── test_fused_elementwise.py ├── test_fused_elementwise_broadcast.py ├── test_fused_elementwise_with_strided_outputs.py ├── test_gather.py ├── test_gemm.py ├── test_gemm_bias.py ├── test_gemm_bias_broadcast.py ├── test_gemm_bias_hardswish.py ├── test_gemm_bias_permute.py ├── test_gemm_bias_relu.py ├── test_gemm_bias_sigmoid.py ├── test_gemm_bias_softmax.py ├── test_gemm_bias_swish.py ├── test_gemm_bias_tanh.py ├── test_gemm_no_tf32.py ├── test_gemm_permute.py ├── test_gemm_profiler_cache.py ├── test_gemm_rcr_bias_fast_gelu.py ├── test_gemm_rcr_fast_gelu.py ├── test_gemm_rrr_small_nk.py ├── test_gemm_softmax.py ├── test_group_gemm_rcr.py ├── test_group_gemm_rcr_bias.py ├── test_group_gemm_rcr_bias_activation.py ├── test_group_gemm_rcr_bias_cat.py ├── test_group_gemm_rcr_cat.py ├── test_grouped_b2b_bmm.py ├── test_grouped_classic_b2b_bmm.py ├── test_groupnorm.py ├── test_identity.py ├── test_index_select.py ├── test_int_elementwise_dynamic_reshape.py ├── test_jagged_elementwise.py ├── test_jagged_lengths_to_offsets.py ├── test_jagged_lengths_to_presences.py ├── test_jagged_to_padded_dense.py ├── test_layernorm.py ├── test_layernorm_sigmoid_mul.py ├── test_make_jagged.py ├── test_masked_select.py ├── test_max_pool2d.py ├── test_max_pool3d.py ├── test_ndhwc3to8.py ├── test_nhwc3to4.py ├── test_nhwc3to8.py ├── test_nms.py ├── test_nn_gelu.py ├── test_norm.py ├── test_pad_last_dim.py ├── test_padded_dense_to_jagged.py ├── test_perm021fc_ccr.py ├── test_perm021fc_ccr_bias.py ├── test_perm021fc_ccr_bias_perm021.py ├── test_perm021fc_crc.py ├── test_perm021fc_crc_bias.py ├── test_perm102_bmm_rcr.py ├── test_perm102_bmm_rrr.py ├── test_permute.py ├── test_permute021.py ├── test_permute0213.py ├── test_permute102.py ├── test_permute210.py ├── test_proposal.py ├── test_reduce.py ├── test_relational.py ├── test_reshape.py ├── test_roi_align.py ├── test_size_getitem_ops.py ├── test_slice.py ├── test_softmax.py ├── test_split.py ├── test_split_getitem.py ├── test_squeeze.py ├── test_topk.py ├── test_transpose.py ├── test_transpose_conv2d.py ├── test_transpose_conv2d_bias.py ├── test_transpose_conv2d_bias_relu.py ├── test_tuple_list_construct.py ├── test_upsampling2d.py ├── test_upsampling2d_add.py ├── test_vanilla_attention.py ├── test_var.py └── test_where.py ├── test_stable_set.py └── util ├── test_debug_utils.py └── test_serdes.py /.github/workflows/docs.yaml: -------------------------------------------------------------------------------- 1 | name: Docs 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | pull_request: 9 | branches: 10 | - main 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | python-version: ["3.9"] 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: Install dependencies 24 | run: | 25 | python3.9 -m pip install --upgrade pip 26 | python3.9 -m pip install numpy autodocsumm 'sphinx<6' sphinx_rtd_theme sphinx_gallery sphinxcontrib-inlinesyntaxhighlight sphinx_toolbox 27 | cd python 28 | python setup.py develop 29 | cd .. 30 | - name: Build documents with Sphinx 31 | run: | 32 | cd docs 33 | make html 34 | cd .. 35 | -------------------------------------------------------------------------------- /.github/workflows/pylint.yaml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | pull_request: 9 | branches: 10 | - main 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | python-version: ["3.8"] 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install ufmt==2.0.1 click==8.1.3 black==24.4.0 flake8==5.0.4 27 | - name: Analyzing the code with flake8 28 | run: | 29 | echo "::add-matcher::tests/lint/flake8_problem_matcher.json" 30 | flake8 . 31 | - name: Analyzing the code with ufmt 32 | run: | 33 | ufmt diff python 34 | ufmt diff tests 35 | ufmt diff docs 36 | - name: Check Meta copyright header 37 | run: | 38 | python tests/lint/check_meta_header.py --path=./tests --fixit=False 39 | python tests/lint/check_meta_header.py --path=./python --fixit=False 40 | python tests/lint/check_meta_header.py --path=./fx2ait --fixit=False 41 | -------------------------------------------------------------------------------- /.github/workflows/rocm_ci.yml: -------------------------------------------------------------------------------- 1 | name: ROCM_CI 2 | 3 | on: 4 | pull_request: 5 | types: [labeled] 6 | 7 | jobs: 8 | build: 9 | if: contains(github.event.label.name, 'rocm') 10 | runs-on: rocm 11 | 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Get CPU info on Ubuntu 15 | if: contains(runner.os, 'linux') 16 | run: | 17 | cat /proc/cpuinfo 18 | - name: Get env vars 19 | run: | 20 | echo GITHUB_WORKFLOW = $GITHUB_WORKFLOW 21 | echo HOME = $HOME 22 | echo GITHUB_ACTION = $GITHUB_ACTION 23 | echo GITHUB_ACTIONS = $GITHUB_ACTIONS 24 | echo GITHUB_REPOSITORY = $GITHUB_REPOSITORY 25 | echo GITHUB_EVENT_NAME = $GITHUB_EVENT_NAME 26 | echo GITHUB_EVENT_PATH = $GITHUB_EVENT_PATH 27 | echo GITHUB_WORKSPACE = $GITHUB_WORKSPACE 28 | echo GITHUB_SHA = $GITHUB_SHA 29 | echo GITHUB_REF = $GITHUB_REF 30 | export GIT_BRANCH=${GITHUB_BASE_REF:-${GITHUB_REF#refs/heads/}} 31 | echo GIT_BRANCH = $GIT_BRANCH 32 | c++ --verbose 33 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rdparty/cutlass"] 2 | path = 3rdparty/cutlass 3 | url = https://github.com/facebookincubator/cutlass-fork.git 4 | [submodule "3rdparty/cub"] 5 | path = 3rdparty/cub 6 | url = https://github.com/NVIDIA/cub.git 7 | [submodule "3rdparty/composable_kernel"] 8 | path = 3rdparty/composable_kernel 9 | url = https://github.com/ROCmSoftwarePlatform/composable_kernel.git 10 | branch = develop 11 | [submodule "3rdparty/picojson"] 12 | path = 3rdparty/picojson 13 | url = https://github.com/kazuho/picojson.git 14 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to AITemplate 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | 1. For major change, submit RFC to discuss the change. 7 | 2. For feature extension, submit PR with tests and documentation. 8 | 3. For bug fix, submit PR with tests and documentation. 9 | 10 | ## Pull Requests 11 | We actively welcome your pull requests. 12 | 13 | 1. Fork the repo and create your branch from `main`. 14 | 2. If you've added code that should be tested, add tests. 15 | 3. If you've changed APIs, update the documentation. 16 | 4. Ensure the test suite passes. 17 | 5. Make sure your code lints. 18 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 19 | 20 | ## Contributor License Agreement ("CLA") 21 | In order to accept your pull request, we need you to submit a CLA. You only need 22 | to do this once to work on any of Meta's open source projects. 23 | 24 | Complete your CLA here: 25 | 26 | ## Issues 27 | We use GitHub issues to track public bugs. Please ensure your description is 28 | clear and has sufficient instructions to be able to reproduce the issue. 29 | 30 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 31 | disclosure of security bugs. In those cases, please go through the process 32 | outlined on that page and do not file a public issue. 33 | 34 | 35 | ## License 36 | By contributing to AITemplate, you agree that your contributions will be licensed 37 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /default.nix: -------------------------------------------------------------------------------- 1 | { pkgs ? import { 2 | config = { 3 | allowUnfree = true; 4 | cudaSupport = true; 5 | }; 6 | }}: 7 | 8 | let 9 | ait-deps = ps: with ps; [ 10 | pytorch-bin 11 | pip 12 | wheel 13 | click 14 | unidecode 15 | inflect 16 | librosa 17 | jinja2 18 | sympy 19 | einops 20 | parameterized 21 | transformers 22 | # ( 23 | # buildPythonPackage rec { 24 | # pname = "cuda_python"; 25 | # version = "12.1.0"; 26 | # format = "wheel"; 27 | # src = fetchPypi { 28 | # inherit pname version format; 29 | # sha256 = "94506d730baade1744767e2c05d5ddd84d7fbe4c9b6f694a54a3f376f7ffa525"; 30 | # abi = "cp39"; 31 | # python = "cp39"; 32 | # platform = "manylinux_2_17_x86_64.manylinux2014_x86_64"; 33 | # }; 34 | # doCheck = false; 35 | # } 36 | # ) 37 | ]; 38 | in 39 | pkgs.mkShell { 40 | buildInputs = [ 41 | pkgs.cmake 42 | pkgs.cudatoolkit 43 | (pkgs.python310.withPackages ait-deps) 44 | ]; 45 | 46 | shellHook = '' 47 | export CUDA_PATH=${pkgs.cudatoolkit} 48 | echo "You are now using a NIX environment" 49 | ''; 50 | } 51 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # Docker + AITemplate 2 | 3 | AITemplate provides a Docker image with all test, benchmark, and documentation dependencies installed. 4 | 5 | ## Build CUDA Docker Image 6 | 7 | ```bash docker/build.sh cuda``` 8 | This will build a CUDA 11 docker image with tag: `ait:latest` 9 | 10 | ## Build ROCM Docker Image 11 | 12 | ```DOCKER_BUILDKIT=1 bash docker/build.sh rocm``` 13 | This will build a RCOM 5 docker image with tag: `ait:latest` 14 | 15 | ## Running Unit Tests in Docker 16 | 17 | ```docker run --gpus all ait:latest bash /AITemplate/tests/nightly/unittest.sh``` 18 | 19 | ## Launching CUDA Docker 20 | ```docker run --gpus all -it ait:latest``` 21 | 22 | ## Launching ROCM Docker 23 | 24 | ```docker run -it --network=host --device=/dev/kfd --device=/dev/dri --group-add=video --ipc=host --cap-add=SYS_PTRACE --security-opt seccomp=unconfined ait:latest``` 25 | 26 | 27 | ## Common questions: 28 | - Q: When building ROCm Docker, I hit this error ` => ERROR [internal] load metadata for docker.io/library/ubuntu:20.04`, what shall I do? 29 | 30 | A: Run `docker pull docker.io/library/ubuntu:20.04` to pull base image manually, then re-run `./docker/build.sh rocm` 31 | -------------------------------------------------------------------------------- /docker/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | TARGET=$1 4 | COMMIT=$(git show --format="%H" --no-patch) 5 | COMMIT_AUTHOR=$(git show --format="%an" --no-patch) 6 | COMMIT_TIME=$(git show --format="%cI" --no-patch) 7 | echo "$COMMIT" > COMMIT_INFO 8 | echo "$COMMIT_AUTHOR" >> COMMIT_INFO 9 | echo "$COMMIT_TIME" >> COMMIT_INFO 10 | 11 | if [ "$TARGET" = "cuda" ]; then 12 | if [ "$2" = "debug" ]; then 13 | echo "Build in DEBUG mode with git files" 14 | echo "RUN apt install -y vim git" >> ./docker/Dockerfile.cuda 15 | echo "ADD .git /AITemplate/.git" >> ./docker/Dockerfile.cuda 16 | fi 17 | echo "Building CUDA Docker Image with tag ait:latest" 18 | docker build -f ./docker/Dockerfile.cuda -t ait . 19 | elif [ "$TARGET" = "rocm" ]; then 20 | echo "Building ROCM Docker Image with tag ait:latest" 21 | docker build -f ./docker/Dockerfile.rocm -t ait . 22 | else 23 | echo "Unknown target" 24 | fi 25 | -------------------------------------------------------------------------------- /docker/install/install_ait.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd /AITemplate/python 4 | python3 setup.py bdist_wheel 5 | pip3 install --no-input /AITemplate/python/dist/*.whl 6 | -------------------------------------------------------------------------------- /docker/install/install_basic_dep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | apt install -y time 4 | pip3 install numpy 5 | pip3 install jinja2 6 | -------------------------------------------------------------------------------- /docker/install/install_detection_deps.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | apt install -y ffmpeg libsm6 libxext6 wget 4 | pip3 install yacs 5 | pip3 install opencv-python 6 | pip3 install tqdm 7 | pip3 install timm 8 | pip3 install transformers 9 | pip3 install diffusers 10 | -------------------------------------------------------------------------------- /docker/install/install_doc_dep.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | pip3 install autodocsumm 3 | pip3 install sphinx_rtd_theme 4 | pip3 install sphinx_gallery 5 | pip3 install sphinxcontrib-inlinesyntaxhighlight 6 | pip3 install sphinx_toolbox 7 | -------------------------------------------------------------------------------- /docker/install/install_test_dep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | pip3 install click 4 | pip3 install pytest 5 | pip3 install parameterized 6 | pip3 install pylint==2.13.9 7 | pip3 install ufmt 8 | pip3 install pyGithub 9 | pip3 install gitpython 10 | pip3 install xmltodict 11 | pip3 install einops 12 | -------------------------------------------------------------------------------- /docker/install/rocm_dev-requirements.txt: -------------------------------------------------------------------------------- 1 | ROCmSoftwarePlatform/rocm-recipes 2 | # 1.90+ 3 | danmar/cppcheck@dd05839a7e63ef04afd34711cb3e1e0ef742882f 4 | -------------------------------------------------------------------------------- /docker/rocm_fix/fix_10736.py: -------------------------------------------------------------------------------- 1 | src = "" 2 | with open("/opt/rocm/hip/bin/hipcc.pl", "r") as fi: 3 | src = fi.read() 4 | 5 | src = src.replace( 6 | "$HIP_CLANG_TARGET = chomp($HIP_CLANG_TARGET);", "chomp($HIP_CLANG_TARGET);" 7 | ) 8 | with open("/opt/rocm/hip/bin/hipcc.pl", "w") as fo: 9 | fo.write(src) 10 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | cp static/ait_model.html build/html/tutorial/ait_model.html 22 | 23 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # AITemplate Documentation 2 | 3 | 4 | ## Build locally 5 | 6 | 1. Install AITemplate 7 | 8 | 2. Install Sphinx 9 | ``` 10 | pip install autodocsumm 11 | pip install sphinx_rtd_theme 12 | pip install sphinx_gallery 13 | pip install sphinxcontrib-inlinesyntaxhighlight 14 | pip install sphinx_toolbox 15 | ``` 16 | 17 | 3. Build HTML 18 | ``` 19 | make html 20 | ``` 21 | -------------------------------------------------------------------------------- /docs/image/gpu_grid_block.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookincubator/AITemplate/5d909ca884250589256e8ccc11ea78ff1454da6b/docs/image/gpu_grid_block.png -------------------------------------------------------------------------------- /docs/image/pack_size_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookincubator/AITemplate/5d909ca884250589256e8ccc11ea78ff1454da6b/docs/image/pack_size_1.png -------------------------------------------------------------------------------- /docs/image/pack_size_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookincubator/AITemplate/5d909ca884250589256e8ccc11ea78ff1454da6b/docs/image/pack_size_2.png -------------------------------------------------------------------------------- /docs/image/pack_size_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookincubator/AITemplate/5d909ca884250589256e8ccc11ea78ff1454da6b/docs/image/pack_size_4.png -------------------------------------------------------------------------------- /docs/image/pack_size_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookincubator/AITemplate/5d909ca884250589256e8ccc11ea78ff1454da6b/docs/image/pack_size_8.png -------------------------------------------------------------------------------- /docs/image/softmax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookincubator/AITemplate/5d909ca884250589256e8ccc11ea78ff1454da6b/docs/image/softmax.png -------------------------------------------------------------------------------- /docs/image/vs_oneflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookincubator/AITemplate/5d909ca884250589256e8ccc11ea78ff1454da6b/docs/image/vs_oneflow.png -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/arch/index.rst: -------------------------------------------------------------------------------- 1 | Design and Architecture 2 | ======================= 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 1 7 | 8 | philosophy 9 | 10 | 11 | 12 | Stay tuned for more... 13 | -------------------------------------------------------------------------------- /docs/source/arch/philosophy.rst: -------------------------------------------------------------------------------- 1 | Design Philosophy 2 | ================== 3 | 4 | 5 | KISS (Keep it simple and stupid) 6 | -------------------------------- 7 | 8 | AITemplate avoids deep IR lowering stacks to reduce the system's complexity. 9 | A highly modularized, multiple backend codegen system written in pure Python directly attacks the pain point in high-performance GPU inference. 10 | 11 | Pragmatism 12 | ---------- 13 | 14 | AITemplate provides a PyTorch-style frontend to enable engineers to manually match the PyTorch model & weights to AITemplate for optimization. 15 | Using it is less painful than debugging different lowering IR stacks, especially for complex models such as MaskRCNN. 16 | 17 | We believe most of the neural network workload can be decoupled. 18 | For example, most of the network can be decoupled into Encoder, Decoder, and Decoder logics. 19 | For encoder and decoder, it is a computation-bounded problem. 20 | For decoder logic, it may involve more control flows. 21 | By using divide and conquer, we left the decoder logic part to C++ or Python rather than build a unified language / IR stack as a silver bullet. 22 | -------------------------------------------------------------------------------- /docs/source/debughints.rst: -------------------------------------------------------------------------------- 1 | Debug Hints 2 | =========== 3 | 4 | AITemplate is a new project under active development. 5 | We have a rich test set to avoid bugs but don't be surprised if there is anything unexpected. 6 | 7 | Here are some helpful tips we learned during the development of AITemplate: 8 | 9 | 1. Once the codegen for op which requires profiling is changed, remember to delete old profilers (usually located at workdir), and flush the cache by either deleting `~/.aitemplate` or setting the environment variable `FLUSH_PROFILE_CACHE=1`. 10 | 11 | 2. Check the pseudo code/visualization generated by each optimization pass if some optimization behaves in unexpected way. 12 | 13 | 3. Always do the numerical test, from small to large, to make sure the entire model is correct. 14 | 15 | 4. Try to make the new fusion subgraph work in a manual way, then try to add an automatic pass to rewrite the graph with the fused subgraph. 16 | -------------------------------------------------------------------------------- /docs/source/genindex.rst: -------------------------------------------------------------------------------- 1 | Index 2 | ===== -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | 2 | AITemplate Documentation 3 | ======================== 4 | 5 | AITemplate (AIT) is a Python framework that transforms deep neural networks into CUDA (NVIDIA GPU) / HIP (AMD GPU) C++ code for lightning-fast inference serving. AITemplate highlights include: 6 | 7 | * High performance: close to roofline fp16 TensorCore (NVIDIA GPU) / MatrixCore (AMD GPU) performance on major models, including ResNet, MaskRCNN, BERT, VisionTransformer, Stable Diffusion, etc. 8 | * Unified, open, and flexible. Seamless fp16 deep neural network models for NVIDIA GPU or AMD GPU. Fully open source, Lego-style easily extendable high-performance primitives for new model support. Supports a significantly more comprehensive range of fusions than existing solutions for both GPU platforms. 9 | 10 | 11 | .. toctree:: 12 | :maxdepth: 1 13 | :caption: Getting Started 14 | 15 | install/index 16 | 17 | 18 | .. toctree:: 19 | :maxdepth: 1 20 | :caption: User Guide 21 | 22 | tutorial/index 23 | debughints 24 | 25 | .. toctree:: 26 | :maxdepth: 1 27 | :caption: Runtime Design 28 | 29 | runtime/index 30 | 31 | .. toctree:: 32 | :maxdepth: 1 33 | :caption: Architecture Guide 34 | 35 | arch/index 36 | 37 | 38 | .. toctree:: 39 | :maxdepth: 1 40 | :caption: Reference Guide 41 | 42 | reference/index 43 | reference/env 44 | genindex 45 | -------------------------------------------------------------------------------- /docs/source/reference/compiler.rst: -------------------------------------------------------------------------------- 1 | aitemplate.compiler 2 | ============================== 3 | 4 | 5 | base 6 | ------------------------ 7 | .. automodule:: aitemplate.compiler.base 8 | :members: 9 | :imported-members: 10 | :exclude-members: ABC, Enum, abstructmethod, dataclass, pformat, reduce 11 | :autosummary: 12 | 13 | 14 | tensor_accessor 15 | ----------------------------------- 16 | .. automodule:: aitemplate.compiler.tensor_accessor 17 | :members: 18 | :imported-members: 19 | :exclude-members: IntImm, IntVar, Tensor, pformat 20 | :autosummary: 21 | 22 | compiler 23 | ---------------------------- 24 | 25 | .. automodule:: aitemplate.compiler.compiler 26 | :members: 27 | :imported-members: 28 | :exclude-members: IntImm, IntVar, Tensor, pformat, DynamicProfileStrategy 29 | :autosummary: 30 | 31 | model 32 | ---------------------------- 33 | .. automodule:: aitemplate.compiler.model 34 | :members: 35 | :imported-members: 36 | :exclude-members: NamedTuple, TypeVar 37 | :autosummary: -------------------------------------------------------------------------------- /docs/source/reference/cuda.rst: -------------------------------------------------------------------------------- 1 | aitemplate.backend.cuda 2 | =========================== 3 | 4 | target_def 5 | ---------- 6 | .. automodule:: aitemplate.backend.cuda.target_def 7 | :members: 8 | :imported-members: 9 | :exclude-members: Path, ProfileCacheDB, TargetType 10 | :autosummary: 11 | 12 | 13 | -------------------------------------------------------------------------------- /docs/source/reference/frontend.rst: -------------------------------------------------------------------------------- 1 | aitemplate.frontend 2 | ==================== 3 | 4 | .. automodule:: aitemplate.frontend.nn 5 | :members: 6 | :imported-members: 7 | :exclude-members: 8 | :autosummary: 9 | 10 | .. automodule:: aitemplate.frontend.tensor 11 | :members: 12 | :imported-members: 13 | :exclude-members: 14 | :autosummary: 15 | -------------------------------------------------------------------------------- /docs/source/reference/index.rst: -------------------------------------------------------------------------------- 1 | Python API 2 | ========== 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | 8 | compiler 9 | ops 10 | transform 11 | backend 12 | cuda 13 | rocm 14 | frontend 15 | testing 16 | utils 17 | -------------------------------------------------------------------------------- /docs/source/reference/ops.rst: -------------------------------------------------------------------------------- 1 | aitemplate.compiler.ops 2 | ======================== 3 | 4 | .. automodule:: aitemplate.compiler.ops 5 | :members: 6 | :imported-members: 7 | :exclude-members: Tensor, TensorAccessor, Enum, Operator, IntImm, IntVar, IntVarTensor, wrap_dim 8 | :autosummary: 9 | -------------------------------------------------------------------------------- /docs/source/reference/rocm.rst: -------------------------------------------------------------------------------- 1 | aitemplate.backend.rocm 2 | =========================== 3 | 4 | target_def 5 | ---------- 6 | .. automodule:: aitemplate.backend.rocm.target_def 7 | :members: 8 | :imported-members: 9 | :exclude-members: 10 | :autosummary: 11 | 12 | -------------------------------------------------------------------------------- /docs/source/reference/testing.rst: -------------------------------------------------------------------------------- 1 | aitemplate.testing 2 | ================== 3 | 4 | detect_target 5 | ------------- 6 | .. automodule:: aitemplate.testing.detect_target 7 | :members: 8 | :imported-members: 9 | :exclude-members: CUDA, ROCM, Popen 10 | :autosummary: 11 | 12 | 13 | benchmark_pt 14 | ------------ 15 | .. automodule:: aitemplate.testing.benchmark_pt 16 | :members: 17 | :imported-members: 18 | :exclude-members: CUDA, ROCM, Popen 19 | :autosummary: 20 | 21 | benchmark_ait 22 | ------------- 23 | .. automodule:: aitemplate.testing.benchmark_ait 24 | :members: 25 | :imported-members: 26 | :exclude-members: CUDA, ROCM, Popen 27 | :autosummary: -------------------------------------------------------------------------------- /docs/source/reference/utils.rst: -------------------------------------------------------------------------------- 1 | aitemplate.utils 2 | ================== 3 | 4 | 5 | visualization.plot 6 | ------------------ 7 | .. automodule:: aitemplate.utils.visualization.plot 8 | :members: 9 | :imported-members: 10 | :exclude-members: Tensor, Operator 11 | :autosummary: 12 | 13 | -------------------------------------------------------------------------------- /docs/source/runtime/cxx_design.rst: -------------------------------------------------------------------------------- 1 | ================ 2 | C++ Runtime Note 3 | ================ 4 | 5 | `Model` v.s. `ModelContainer` 6 | ============================= 7 | 8 | These are the two main classes involved in the C++ runtime implementation: 9 | 10 | * The bulk of the runtime implementation is in the `Model` class. 11 | * The `ModelContainer` class stores a set of shared constants and a collection of `Model` instances. Almost all functions in `model_interface.h` forward to a method in `ModelContainer`. When `Run` is invoked, `ModelContainer` looks for an available `Model`, or blocks until one becomes available (see the section on asynchronous predictions). It then forwards the run request to the runtime. 12 | 13 | Code Structure 14 | ============== 15 | 16 | Some important files: 17 | 18 | 1. `include/model_interface.h`: The interface that we expose in the compiled `.so`. 19 | 2. `include/model_container.h`: The bulk of the `ModelContainer` implementation. 20 | 21 | Some files are generated at compile time. These include: 22 | 23 | * `model-generated.h`: The implementation of the `Model`. 24 | * `model_container_base.cu`: A small part of the implementation for `ModelContainer` that needs to be generated. `ModelContainer` inherits from `ModelContainerBase`, and `ModelContainerBase`'s implementation lives in this file. See `model_container.h` for more details. 25 | 26 | All codegen templates can be found in `backend/main_templates.py`. 27 | The codegen implementation is in `backend/codegen.py`. 28 | 29 | Note that many of the headers in this directory rely on generated code and thus cannot be `#include` -d in external projects. 30 | `model_interface.h` is an exception. 31 | -------------------------------------------------------------------------------- /docs/source/runtime/index.rst: -------------------------------------------------------------------------------- 1 | Runtime Note 2 | ================== 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 1 7 | 8 | cxx_design 9 | py_design 10 | -------------------------------------------------------------------------------- /docs/source/tutorial/index.rst: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ========= 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | how_to_infer_pt 8 | how_to_add_op 9 | how_to_visualize 10 | -------------------------------------------------------------------------------- /examples/01_resnet-50/benchmark_mi250.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HIP_VISIBLE_DEVICES=0 python3 benchmark_ait.py --batch-size "$1" & 4 | HIP_VISIBLE_DEVICES=1 python3 benchmark_ait.py --batch-size "$1" && fg 5 | -------------------------------------------------------------------------------- /examples/01_resnet-50/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | -------------------------------------------------------------------------------- /examples/02_detectron2/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from .config import get_cfg_defaults 16 | 17 | __all__ = ["get_cfg_defaults"] 18 | -------------------------------------------------------------------------------- /examples/02_detectron2/configs/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from yacs.config import CfgNode 16 | 17 | 18 | def get_cfg_defaults() -> CfgNode: 19 | """ 20 | Get a copy of the default config. 21 | Returns: 22 | a detectron2 CfgNode instance. 23 | """ 24 | from .defaults import _C 25 | 26 | return _C.clone() 27 | -------------------------------------------------------------------------------- /examples/02_detectron2/configs/faster_rcnn_R_101_FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: "faster_rcnn_R_101_FPN" 3 | META_ARCHITECTURE: "GeneralizedRCNN" 4 | BACKBONE: 5 | NAME: "build_resnet_fpn_backbone" 6 | RESNETS: 7 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 8 | DEPTH: 101 9 | STAGES: [3, 4, 23, 3] 10 | FPN: 11 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 12 | ANCHOR_GENERATOR: 13 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 14 | ASPECT_RATIOS: [0.5, 1.0, 2.0] # Three aspect ratios (same for all in feature maps) 15 | RPN: 16 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 17 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 18 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 19 | POST_NMS_TOPK_TRAIN: 1000 20 | POST_NMS_TOPK_TEST: 1000 21 | ROI_HEADS: 22 | NAME: "StandardROIHeads" 23 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 24 | ROI_BOX_HEAD: 25 | NAME: "FastRCNNConvFCHead" 26 | NUM_FC: 2 27 | POOLER_RESOLUTION: 7 28 | ROI_MASK_HEAD: 29 | NAME: "MaskRCNNConvUpsampleHead" 30 | NUM_CONV: 4 31 | POOLER_RESOLUTION: 14 32 | DATASETS: 33 | TRAIN: ("coco_2017_train",) 34 | TEST: ("coco_2017_val",) 35 | SOLVER: 36 | IMS_PER_BATCH: 1 37 | BASE_LR: 0.02 38 | STEPS: (60000, 80000) 39 | MAX_ITER: 90000 40 | INPUT: 41 | MIN_SIZE_TEST: 800 42 | MAX_SIZE_TEST: 1344 43 | POSTPROCESS: 44 | POST_ON: True 45 | USE_TOPK: True 46 | TOPK: 100 47 | VERSION: 2 48 | -------------------------------------------------------------------------------- /examples/02_detectron2/configs/faster_rcnn_R_50_FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: "faster_rcnn_R_50_FPN" 3 | META_ARCHITECTURE: "GeneralizedRCNN" 4 | BACKBONE: 5 | NAME: "build_resnet_fpn_backbone" 6 | RESNETS: 7 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 8 | FPN: 9 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 10 | ANCHOR_GENERATOR: 11 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 12 | ASPECT_RATIOS: [0.5, 1.0, 2.0] # Three aspect ratios (same for all in feature maps) 13 | RPN: 14 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 15 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 16 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 17 | POST_NMS_TOPK_TRAIN: 1000 18 | POST_NMS_TOPK_TEST: 1000 19 | ROI_HEADS: 20 | NAME: "StandardROIHeads" 21 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 22 | ROI_BOX_HEAD: 23 | NAME: "FastRCNNConvFCHead" 24 | NUM_FC: 2 25 | POOLER_RESOLUTION: 7 26 | ROI_MASK_HEAD: 27 | NAME: "MaskRCNNConvUpsampleHead" 28 | NUM_CONV: 4 29 | POOLER_RESOLUTION: 14 30 | DATASETS: 31 | TRAIN: ("coco_2017_train",) 32 | TEST: ("coco_2017_val",) 33 | SOLVER: 34 | IMS_PER_BATCH: 1 35 | BASE_LR: 0.02 36 | STEPS: (60000, 80000) 37 | MAX_ITER: 90000 38 | INPUT: 39 | MIN_SIZE_TEST: 800 40 | MAX_SIZE_TEST: 1344 41 | POSTPROCESS: 42 | POST_ON: True 43 | USE_TOPK: True 44 | TOPK: 100 45 | VERSION: 2 46 | -------------------------------------------------------------------------------- /examples/02_detectron2/configs/mask_rcnn_R_101_FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: "mask_rcnn_R_101_FPN" 3 | MASK_ON: True 4 | META_ARCHITECTURE: "GeneralizedRCNN" 5 | BACKBONE: 6 | NAME: "build_resnet_fpn_backbone" 7 | RESNETS: 8 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 9 | DEPTH: 101 10 | STAGES: [3, 4, 23, 3] 11 | FPN: 12 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 13 | ANCHOR_GENERATOR: 14 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 15 | ASPECT_RATIOS: [0.5, 1.0, 2.0] # Three aspect ratios (same for all in feature maps) 16 | RPN: 17 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 18 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 19 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 20 | POST_NMS_TOPK_TRAIN: 1000 21 | POST_NMS_TOPK_TEST: 1000 22 | ROI_HEADS: 23 | NAME: "StandardROIHeads" 24 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 25 | ROI_BOX_HEAD: 26 | NAME: "FastRCNNConvFCHead" 27 | NUM_FC: 2 28 | POOLER_RESOLUTION: 7 29 | ROI_MASK_HEAD: 30 | NAME: "MaskRCNNConvUpsampleHead" 31 | NUM_CONV: 4 32 | POOLER_RESOLUTION: 14 33 | DATASETS: 34 | TRAIN: ("coco_2017_train",) 35 | TEST: ("coco_2017_val",) 36 | SOLVER: 37 | IMS_PER_BATCH: 1 38 | BASE_LR: 0.02 39 | STEPS: (60000, 80000) 40 | MAX_ITER: 90000 41 | INPUT: 42 | MIN_SIZE_TEST: 800 43 | MAX_SIZE_TEST: 1344 44 | POSTPROCESS: 45 | POST_ON: True 46 | USE_TOPK: False 47 | TOPK: 100 48 | VERSION: 2 49 | -------------------------------------------------------------------------------- /examples/02_detectron2/configs/mask_rcnn_R_50_FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: "mask_rcnn_R_50_FPN" 3 | MASK_ON: True 4 | META_ARCHITECTURE: "GeneralizedRCNN" 5 | BACKBONE: 6 | NAME: "build_resnet_fpn_backbone" 7 | RESNETS: 8 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 9 | FPN: 10 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 11 | ANCHOR_GENERATOR: 12 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 13 | ASPECT_RATIOS: [0.5, 1.0, 2.0] # Three aspect ratios (same for all in feature maps) 14 | RPN: 15 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 16 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 17 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 18 | POST_NMS_TOPK_TRAIN: 1000 19 | POST_NMS_TOPK_TEST: 1000 20 | ROI_HEADS: 21 | NAME: "StandardROIHeads" 22 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 23 | ROI_BOX_HEAD: 24 | NAME: "FastRCNNConvFCHead" 25 | NUM_FC: 2 26 | POOLER_RESOLUTION: 7 27 | ROI_MASK_HEAD: 28 | NAME: "MaskRCNNConvUpsampleHead" 29 | NUM_CONV: 4 30 | POOLER_RESOLUTION: 14 31 | DATASETS: 32 | TRAIN: ("coco_2017_train",) 33 | TEST: ("coco_2017_val",) 34 | SOLVER: 35 | IMS_PER_BATCH: 1 36 | BASE_LR: 0.02 37 | STEPS: (60000, 80000) 38 | MAX_ITER: 90000 39 | INPUT: 40 | MIN_SIZE_TEST: 800 41 | MAX_SIZE_TEST: 1344 42 | POSTPROCESS: 43 | POST_ON: True 44 | USE_TOPK: False 45 | TOPK: 100 46 | VERSION: 2 47 | -------------------------------------------------------------------------------- /examples/02_detectron2/modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # flake8: noqa 16 | from .fpn import build_resnet_fpn_backbone, FPN 17 | from .resnet import ( 18 | BasicStem, 19 | BottleneckBlock, 20 | build_resnet_backbone, 21 | make_stage, 22 | ResNet, 23 | ) 24 | 25 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 26 | -------------------------------------------------------------------------------- /examples/02_detectron2/modeling/backbone/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from dataclasses import dataclass 16 | from typing import Optional 17 | 18 | 19 | @dataclass 20 | class ShapeSpec: 21 | """ 22 | A simple structure that contains basic shape specification about a tensor. 23 | It is often used as the auxiliary inputs/outputs of models, 24 | to complement the lack of shape inference ability among pytorch modules. 25 | """ 26 | 27 | channels: Optional[int] = None 28 | height: Optional[int] = None 29 | width: Optional[int] = None 30 | stride: Optional[int] = None 31 | -------------------------------------------------------------------------------- /examples/02_detectron2/modeling/meta_arch/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # flake8: noqa 16 | from .rcnn import GeneralizedRCNN 17 | 18 | __all__ = list(globals().keys()) 19 | -------------------------------------------------------------------------------- /examples/02_detectron2/modeling/proposal_generator/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # flake8: noqa 16 | from .rpn import build_rpn_head, RPN, StandardRPNHead 17 | 18 | __all__ = list(globals().keys()) 19 | -------------------------------------------------------------------------------- /examples/02_detectron2/modeling/roi_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # flake8: noqa 16 | from .box_head import build_box_head, FastRCNNConvFCHead 17 | from .mask_head import MaskRCNNConvUpsampleHead 18 | from .roi_heads import build_roi_heads, StandardROIHeads 19 | 20 | __all__ = list(globals().keys()) 21 | -------------------------------------------------------------------------------- /examples/02_detectron2/predictor/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from .builtin_meta import _get_coco_instances_meta 16 | from .predictor import Predictor 17 | 18 | __all__ = ["Predictor", "_get_coco_instances_meta"] 19 | -------------------------------------------------------------------------------- /examples/03_bert/benchmark_mi250.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #profile 4 | HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 benchmark_ait.py 5 | 6 | #1GCD 7 | HIP_VISIBLE_DEVICES=0 python3 benchmark_ait.py --batch-size "$1" 8 | 9 | #2GCD 10 | HIP_VISIBLE_DEVICES=0 python3 benchmark_ait.py --batch-size "$1" & 11 | HIP_VISIBLE_DEVICES=1 python3 benchmark_ait.py --batch-size "$1" && fg 12 | -------------------------------------------------------------------------------- /examples/03_bert/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | -------------------------------------------------------------------------------- /examples/04_vit/benchmark_mi250.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #profile 4 | HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 benchmark_ait.py 5 | 6 | #1GCD 7 | HIP_VISIBLE_DEVICES=0 python3 benchmark_ait.py --batch-size "$1" 8 | 9 | #2GCD 10 | HIP_VISIBLE_DEVICES=0 python3 benchmark_ait.py --batch-size "$1" & 11 | HIP_VISIBLE_DEVICES=1 python3 benchmark_ait.py --batch-size "$1" && fg 12 | -------------------------------------------------------------------------------- /examples/05_stable_diffusion/.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | -------------------------------------------------------------------------------- /examples/05_stable_diffusion/src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | -------------------------------------------------------------------------------- /examples/05_stable_diffusion/src/compile_lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookincubator/AITemplate/5d909ca884250589256e8ccc11ea78ff1454da6b/examples/05_stable_diffusion/src/compile_lib/__init__.py -------------------------------------------------------------------------------- /examples/05_stable_diffusion/src/compile_lib/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | def mark_output(y): 16 | if type(y) is not tuple: 17 | y = (y,) 18 | for i in range(len(y)): 19 | y[i]._attrs["is_output"] = True 20 | y[i]._attrs["name"] = "output_%d" % (i) 21 | y_shape = [d._attrs["values"] for d in y[i]._attrs["shape"]] 22 | print("AIT output_{} shape: {}".format(i, y_shape)) 23 | -------------------------------------------------------------------------------- /fx2ait/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.1 FATAL_ERROR) 2 | 3 | project(ait_model) 4 | find_package(Torch REQUIRED) 5 | 6 | include_directories( 7 | ${CMAKE_CURRENT_SOURCE_DIR}/../3rdparty/picojson 8 | ) 9 | 10 | # Define our library target 11 | set(CMAKE_CXX_STANDARD 17) 12 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../static/include) 13 | add_library(ait_model SHARED 14 | ${CMAKE_CURRENT_SOURCE_DIR}/fx2ait/csrc/AITModel.cpp 15 | ${CMAKE_CURRENT_SOURCE_DIR}/fx2ait/csrc/AITModelImpl.cpp 16 | ) 17 | 18 | # Link against LibTorch 19 | target_link_libraries(ait_model "${TORCH_LIBRARIES}") 20 | -------------------------------------------------------------------------------- /fx2ait/fx2ait/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import sys 16 | 17 | from . import acc_tracer, converters, extension # noqa 18 | 19 | if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 7): 20 | PY3STATEMENT = "The minimal Python requirement is Python 3.7" 21 | raise Exception(PY3STATEMENT) 22 | 23 | __all__ = [ 24 | "acc_tracer", 25 | "converters", 26 | "core", 27 | "extension", 28 | "lower", 29 | "test", 30 | ] 31 | -------------------------------------------------------------------------------- /fx2ait/fx2ait/acc_tracer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import sys 16 | 17 | from . import ( # noqa 18 | acc_normalizer, 19 | acc_op_properties, 20 | acc_ops, 21 | acc_shape_prop, 22 | acc_tracer, 23 | acc_utils, 24 | ait_acc_normalizer, 25 | ait_acc_ops, 26 | ait_acc_ops_registry, 27 | ) 28 | 29 | if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 7): 30 | PY3STATEMENT = "The minimal Python requirement is Python 3.7" 31 | raise Exception(PY3STATEMENT) 32 | 33 | __all__ = [ 34 | "acc_normalizer", 35 | "acc_op_properties", 36 | "acc_ops", 37 | "acc_shape_prop", 38 | "acc_tracer", 39 | "acc_utils", 40 | "ait_acc_normalizer", 41 | "ait_acc_ops_registry", 42 | "ait_acc_ops", 43 | ] 44 | -------------------------------------------------------------------------------- /fx2ait/fx2ait/acc_tracer/ait_acc_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import torch 16 | 17 | from fx2ait.acc_tracer.acc_normalizer import register_acc_op 18 | 19 | from fx2ait.acc_tracer.ait_acc_ops_registry import ait_register_acc_op_mapping 20 | 21 | this_arg_is_optional: bool = True 22 | 23 | 24 | @ait_register_acc_op_mapping( 25 | op_and_target=("call_method", "split"), 26 | arg_replacement_tuples=[ 27 | ("tensor", "input"), 28 | ("split_size_or_sections", "split_size_or_sections"), 29 | ("dim", "dim", this_arg_is_optional), 30 | ], 31 | ) 32 | @ait_register_acc_op_mapping( 33 | op_and_target=("call_function", torch.split), 34 | arg_replacement_tuples=[ 35 | ("tensor", "input"), 36 | ("split_size_or_sections", "split_size_or_sections"), 37 | ("dim", "dim", this_arg_is_optional), 38 | ], 39 | ) 40 | @register_acc_op 41 | def split(*, input, split_size_or_sections, dim=0): 42 | return torch.split(input, split_size_or_sections, dim) 43 | -------------------------------------------------------------------------------- /fx2ait/fx2ait/cache.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import os.path as path 16 | 17 | 18 | def save_profile_cache(remote_cache_file_path, cache_path): 19 | with open(cache_path, "rb") as f: 20 | with open(remote_cache_file_path, "wb") as target: 21 | target.write(f.read()) 22 | 23 | 24 | def load_profile_cache(remote_cache_file_path, cache_bytes): 25 | if path.isfile(remote_cache_file_path): 26 | with open(remote_cache_file_path, "rb") as cache_content: 27 | cache_bytes.write(cache_content.read()) 28 | -------------------------------------------------------------------------------- /fx2ait/fx2ait/converters/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from .ait_converters import * # noqa: F401 F403 16 | from .aten2ait_converters import * # noqa: F401 F403 17 | from .ait_module_converters import * # noqa: F401 F403 18 | from .utils import set_tensor_layout_policy # noqa: F401 19 | -------------------------------------------------------------------------------- /fx2ait/fx2ait/converters/converter_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from typing import Any, Callable, Dict 16 | 17 | from torch.fx.node import Target 18 | 19 | AIT_CONVERTERS: Dict[Target, Any] = {} 20 | 21 | 22 | def ait_converter(key: Target, enabled: bool = True) -> Callable[[Any], Any]: 23 | def register_converter(converter): 24 | AIT_CONVERTERS[key] = converter 25 | return converter 26 | 27 | def disable_converter(converter): 28 | return converter 29 | 30 | if enabled: 31 | return register_converter 32 | else: 33 | return disable_converter 34 | -------------------------------------------------------------------------------- /fx2ait/fx2ait/example/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | -------------------------------------------------------------------------------- /fx2ait/fx2ait/lower/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | -------------------------------------------------------------------------------- /fx2ait/fx2ait/passes/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | -------------------------------------------------------------------------------- /fx2ait/fx2ait/test/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import sys 16 | 17 | from . import test_ait_lower, test_fx2ait # noqa 18 | 19 | if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 7): 20 | PY3STATEMENT = "The minimal Python requirement is Python 3.7" 21 | raise Exception(PY3STATEMENT) 22 | -------------------------------------------------------------------------------- /fx2ait/fx2ait/test/converters/converters_model/test_ait_vision_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import torch 16 | import torchvision 17 | from fx2ait.tools.common_fx2ait import AITTestCase 18 | 19 | 20 | class TestVisionModelConverter(AITTestCase): 21 | def test_resnet50(self): 22 | torch.manual_seed(0) 23 | 24 | class TestModule(torch.nn.Module): 25 | def __init__(self): 26 | super().__init__() 27 | self.mod = torchvision.models.resnet18() 28 | 29 | def forward(self, x): 30 | return self.mod(x) 31 | 32 | model = TestModule().cuda().half() 33 | inputs = [torch.randn(32, 3, 224, 224).half().cuda()] 34 | self.run_test( 35 | model, 36 | inputs, 37 | expected_ops={}, 38 | permute_outputs=None, 39 | ) 40 | -------------------------------------------------------------------------------- /fx2ait/fx2ait/test/converters/test_ait_contiguous.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import torch 16 | from fx2ait.acc_tracer import acc_ops 17 | from fx2ait.tools.common_fx2ait import AITTestCase 18 | 19 | 20 | class TestContiguousConverter(AITTestCase): 21 | def test_contigupus(self): 22 | class TestModule(torch.nn.Module): 23 | def forward(self, x) -> torch.Tensor: 24 | x = x.contiguous() 25 | return x + x 26 | 27 | model = TestModule().cuda().half() 28 | inputs = [ 29 | torch.randn(1, 2, 3).half().cuda(), 30 | ] 31 | 32 | self.run_test(model, inputs, expected_ops={acc_ops.contiguous}) 33 | -------------------------------------------------------------------------------- /fx2ait/fx2ait/test/converters/test_ait_flatten.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import torch 16 | from fx2ait.acc_tracer import acc_ops 17 | from fx2ait.tools.common_fx2ait import AITTestCase 18 | from parameterized import param, parameterized 19 | 20 | 21 | class TestFlattenConverter(AITTestCase): 22 | @parameterized.expand( 23 | [param("default"), param("start", start_dim=1), param("end", end_dim=2)] 24 | ) 25 | def test_clamp(self, name, start_dim=0, end_dim=-1): 26 | class TestModule(torch.nn.Module): 27 | def forward(self, x: torch.Tensor) -> torch.Tensor: 28 | return torch.flatten(x, start_dim=start_dim, end_dim=end_dim) 29 | 30 | model = TestModule().cuda().half() 31 | inputs = [ 32 | torch.randn(1, 2, 3).half().cuda(), 33 | ] 34 | 35 | self.run_test(model, inputs, expected_ops={acc_ops.flatten}) 36 | -------------------------------------------------------------------------------- /fx2ait/fx2ait/test/converters/test_ait_group_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import torch 16 | from fx2ait.tools.common_fx2ait import AITTestCase 17 | from parameterized import parameterized 18 | from torch import nn 19 | 20 | 21 | class TestGroupNormTensor(AITTestCase): 22 | @parameterized.expand( 23 | [ 24 | [True], 25 | [False], 26 | ] 27 | ) 28 | def test_group_norm(self, affine): 29 | class GN(nn.Module): 30 | def __init__(self): 31 | super().__init__() 32 | self.gn = torch.nn.GroupNorm(3, 6, affine=affine) 33 | 34 | def forward(self, x): 35 | return self.gn(x) 36 | 37 | mod = GN().half().cuda() 38 | inputs = [torch.randn(2, 6, 4, 5).half().cuda()] 39 | self.run_test( 40 | mod, 41 | inputs, 42 | expected_ops={}, 43 | ) 44 | -------------------------------------------------------------------------------- /fx2ait/fx2ait/test/converters/test_ait_leaky_relu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 16 | import torch 17 | from fx2ait.acc_tracer import acc_ops 18 | from fx2ait.tools.common_fx2ait import AITTestCase 19 | 20 | 21 | class TestLeakyReluConverter(AITTestCase): 22 | def test_leaky_relu(self): 23 | class TestModule(torch.nn.Module): 24 | def forward(self, x: torch.Tensor) -> torch.Tensor: 25 | return torch.nn.functional.leaky_relu(x, negative_slope=0.05) 26 | 27 | model = TestModule().cuda().half() 28 | inputs = [ 29 | torch.randn(2, 3).half().cuda(), 30 | ] 31 | 32 | self.run_test(model, inputs, expected_ops={acc_ops.leaky_relu}) 33 | -------------------------------------------------------------------------------- /fx2ait/fx2ait/test/converters/test_ait_pow.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import torch 16 | from fx2ait.acc_tracer import acc_ops 17 | from fx2ait.tools.common_fx2ait import AITTestCase 18 | 19 | from parameterized import parameterized 20 | 21 | 22 | class TestPowConverter(AITTestCase): 23 | @parameterized.expand([("int", 3), ("float", 0.25)]) 24 | def test_pow(self, _, exp): 25 | class Pow(torch.nn.Module): 26 | def forward(self, x: torch.Tensor): 27 | return torch.pow(x, exp) 28 | 29 | model = Pow().half().cuda() 30 | input = [torch.randn(3, 3).half().cuda()] 31 | self.run_test(model, input, expected_ops={acc_ops.pow}) 32 | -------------------------------------------------------------------------------- /fx2ait/fx2ait/test/converters/test_ait_sigmoid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import torch 16 | from fx2ait.acc_tracer import acc_ops 17 | from fx2ait.tools.common_fx2ait import AITTestCase 18 | from torch import nn 19 | 20 | 21 | class TestSigmoidConverter(AITTestCase): 22 | def test_sigmoid(self): 23 | class Sigmoid(nn.Module): 24 | def forward(self, x): 25 | return torch.sigmoid(x) 26 | 27 | model = Sigmoid().cuda() 28 | inputs = [torch.randn(1, 2, 3).half().cuda()] 29 | self.run_test(model, inputs, expected_ops={acc_ops.sigmoid}) 30 | -------------------------------------------------------------------------------- /fx2ait/fx2ait/test/converters/test_ait_square.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import torch 16 | from fx2ait.acc_tracer import acc_ops 17 | from fx2ait.tools.common_fx2ait import AITTestCase 18 | 19 | 20 | class TestSquareConverter(AITTestCase): 21 | def test_square(self): 22 | class TestModule(torch.nn.Module): 23 | def forward(self, x: torch.Tensor) -> torch.Tensor: 24 | return torch.square(x) 25 | 26 | inputs = [torch.randn(3, 10, 20).cuda().half()] 27 | model = TestModule().cuda().half() 28 | 29 | self.run_test(model, inputs, expected_ops={acc_ops.square}) 30 | -------------------------------------------------------------------------------- /fx2ait/fx2ait/test/converters/test_ait_unbind.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import torch 16 | from fx2ait.tools.common_fx2ait import AITTestCase 17 | from parameterized import parameterized 18 | from torch import nn 19 | 20 | 21 | class TestUnbindTensor(AITTestCase): 22 | @parameterized.expand( 23 | [ 24 | ("positive_dim", 2), 25 | ("negative_dim", -1), 26 | ] 27 | ) 28 | def test_unbind(self, name, dim): 29 | class GetItem(nn.Module): 30 | def __init__(self): 31 | super().__init__() 32 | 33 | def forward(self, x): 34 | y = torch.unbind(x, dim=dim) 35 | z = y[0] 36 | return z 37 | 38 | mod = GetItem().half().cuda() 39 | inputs = [torch.randn(2, 3, 4).half().cuda()] 40 | self.run_test( 41 | mod, 42 | inputs, 43 | expected_ops={}, 44 | ) 45 | -------------------------------------------------------------------------------- /fx2ait/fx2ait/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | -------------------------------------------------------------------------------- /fx2ait/fx2ait/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from aitemplate.utils.torch_utils import torch_dtype_to_string 16 | 17 | 18 | def dtype_to_str(dtype): 19 | if dtype is None: 20 | return "float16" 21 | return torch_dtype_to_string(dtype) 22 | 23 | 24 | def make_str_ait_friendly(s: str) -> str: 25 | if s.isalnum(): 26 | ret = s 27 | else: 28 | ret = "".join(c if c.isalnum() else "_" for c in s) 29 | if ret[0].isdigit(): 30 | ret = "_" + ret 31 | return ret 32 | -------------------------------------------------------------------------------- /licenses/LICENSE.cub.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2010-2011, Duane Merrill. All rights reserved. 2 | Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | * Redistributions of source code must retain the above copyright 7 | notice, this list of conditions and the following disclaimer. 8 | * Redistributions in binary form must reproduce the above copyright 9 | notice, this list of conditions and the following disclaimer in the 10 | documentation and/or other materials provided with the distribution. 11 | * Neither the name of the NVIDIA CORPORATION nor the 12 | names of its contributors may be used to endorse or promote products 13 | derived from this software without specific prior written permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /licenses/LICENSE.cutlass.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | SPDX-License-Identifier: BSD-3-Clause 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /licenses/LICENSE.markdown_table.txt: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2020 hvalev 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | -------------------------------------------------------------------------------- /licenses/LICENSE.pydot.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014 Carlos Jenkins 2 | Copyright (c) 2014 Lance Hepler 3 | Copyright (c) 2004 Ero Carrera 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /licenses/license.header.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /python/aitemplate/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import sys 16 | 17 | from aitemplate import backend, compiler, frontend, testing, utils 18 | from aitemplate._libinfo import __version__ # noqa 19 | from aitemplate.utils.misc import setup_logger 20 | 21 | if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 7): 22 | PY3STATEMENT = "The minimal Python requirement is Python 3.7" 23 | raise Exception(PY3STATEMENT) 24 | 25 | __all__ = ["backend", "compiler", "frontend", "testing", "utils"] 26 | 27 | root_logger = setup_logger(__name__) 28 | -------------------------------------------------------------------------------- /python/aitemplate/_libinfo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # current version 16 | # We use the version of the incoming release for code 17 | __version__ = "0.3.dev0" 18 | -------------------------------------------------------------------------------- /python/aitemplate/backend/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Backend for AITemplate. 17 | """ 18 | 19 | from aitemplate.backend import ( # noqa 20 | backend_spec, 21 | builder, 22 | codegen, 23 | cuda, 24 | profiler_runner, 25 | registry, 26 | rocm, 27 | target, 28 | ) 29 | 30 | __all__ = [ 31 | "builder", 32 | "codegen", 33 | "cuda", 34 | "profiler_runner", 35 | "registry", 36 | "rocm", 37 | "target", 38 | ] 39 | -------------------------------------------------------------------------------- /python/aitemplate/backend/build_cache.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | """ 17 | The build_cache functionality is split into 18 | this file and build_cache_base.py 19 | 20 | This file is part of the AITemplate OSS distribution. 21 | For Meta-internal use, there can be an alternative 22 | to this file which allows to instantiate build caches 23 | with Meta-internal backing infrastructure. 24 | """ 25 | 26 | from aitemplate.backend.build_cache_base import ( 27 | BuildCache, 28 | FileBasedBuildCache, 29 | NoBuildCache, 30 | ) 31 | from aitemplate.utils import environ as aitemplate_env 32 | 33 | __all__ = ["BUILD_CACHE", "BuildCache"] 34 | 35 | 36 | def create_build_cache() -> BuildCache: 37 | build_cache_dir = aitemplate_env.ait_build_cache_dir() 38 | if build_cache_dir is None or build_cache_dir == "": 39 | return NoBuildCache() 40 | else: 41 | return FileBasedBuildCache(build_cache_dir) 42 | 43 | 44 | BUILD_CACHE: BuildCache = create_build_cache() 45 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/attention/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | cuda flash_attention module init 17 | """ 18 | 19 | from aitemplate.backend.cuda.attention import flash_attention, mem_eff_attention 20 | 21 | __all__ = ["flash_attention", "mem_eff_attention"] 22 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/b2b_bmm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # flake8: noqa 16 | 17 | """ 18 | b2b bmm module init 19 | """ 20 | 21 | from aitemplate.backend.cuda.b2b_bmm import ( 22 | classic_b2b_bmm, 23 | fmha_style_b2b_bmm, 24 | grouped_classic_b2b_bmm, 25 | grouped_fmha_style_b2b_bmm, 26 | ) 27 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # flake8: noqa 16 | """ 17 | CUDA Common module init 18 | """ 19 | 20 | from aitemplate.backend.cuda.common.dummy_op import * 21 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/common/dummy_op.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Dummy op codegen for CUDA. 17 | """ 18 | 19 | from typing import Any, Dict 20 | 21 | from aitemplate.backend import registry 22 | 23 | 24 | @registry.reg("cuda.size.gen_function") 25 | def dummy_gen_function(func_attrs: Dict[str, Any]) -> str: 26 | return "" 27 | 28 | 29 | @registry.reg("cuda.size.func_decl") 30 | def dummy_gen_function_decl(func_attrs): 31 | return "" 32 | 33 | 34 | @registry.reg("cuda.size.func_call") 35 | def dummy_gen_function_call(func_attrs, indent): 36 | return "" 37 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/conv2d/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # flake8: noqa 16 | """ 17 | cuda conv2d module init 18 | """ 19 | 20 | from aitemplate.backend.cuda.conv2d import ( 21 | conv2d, 22 | conv2d_bias, 23 | conv2d_bias_add, 24 | conv2d_bias_add_hardswish, 25 | conv2d_bias_add_relu, 26 | conv2d_bias_few_channels, 27 | conv2d_bias_hardswish, 28 | conv2d_bias_hardswish_few_channels, 29 | conv2d_bias_relu, 30 | conv2d_bias_relu_few_channels, 31 | conv2d_bias_sigmoid, 32 | conv2d_depthwise, 33 | conv2d_depthwise_bias, 34 | transposed_conv2d, 35 | transposed_conv2d_bias, 36 | ) 37 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/conv3d/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | CUDA conv3d module init 17 | """ 18 | 19 | from aitemplate.backend.cuda.conv3d import ( 20 | conv3d, 21 | conv3d_bias, 22 | depthwise_conv3d, 23 | depthwise_conv3d_bias, 24 | ) 25 | 26 | __all__ = ["conv3d", "conv3d_bias", "depthwise_conv3d", "depthwise_conv3d_bias"] 27 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/elementwise/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 17 | """ 18 | 19 | from aitemplate.backend.cuda.elementwise import fused_elementwise, int_elementwise 20 | 21 | __all__ = ["fused_elementwise", "int_elementwise"] 22 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # flake8: noqa 16 | from aitemplate.backend.cuda.embedding.bert_embeddings import * 17 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/gemm_epilogue_vistor/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | from aitemplate.backend.cuda.gemm_epilogue_vistor import ( 17 | bmm_rcr_softmax, 18 | dual_bmm_rrr_div, 19 | dual_gemm_rcr_fast_gelu, 20 | dual_gemm_rcr_silu, 21 | gemm_rcr_bias_softmax, 22 | gemm_rcr_softmax, 23 | ) 24 | 25 | __all__ = [ 26 | "bmm_rcr_softmax", 27 | "dual_bmm_rrr_div", 28 | "dual_gemm_rcr_fast_gelu", 29 | "dual_gemm_rcr_silu", 30 | "gemm_rcr_bias_softmax", 31 | "gemm_rcr_softmax", 32 | ] 33 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/gemm_special/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | special gemm ops 17 | """ 18 | 19 | from aitemplate.backend.cuda.gemm_special import ( 20 | batched_dense_vec_jagged_2d_mul, 21 | bmm_rcr_n1, 22 | bmm_rrr_k1_tanh, 23 | gemm_rrr_small_nk, 24 | ) 25 | 26 | 27 | __all__ = [ 28 | "batched_dense_vec_jagged_2d_mul", 29 | "bmm_rcr_n1", 30 | "bmm_rrr_k1_tanh", 31 | "gemm_rrr_small_nk", 32 | ] 33 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/gemm_universal/bmm_softmax_bmm_permute.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | from aitemplate.backend import registry 17 | 18 | 19 | @registry.reg("cuda.bmm_softmax_bmm_permute.func_decl") 20 | def gen_function_decl(func_attrs): 21 | raise NotImplementedError("bmm_softmax_bmm_permute kernel is not implemented.") 22 | 23 | 24 | @registry.reg("cuda.bmm_softmax_bmm_permute.gen_function") 25 | def gen_function(func_attrs): 26 | raise NotImplementedError("bmm_softmax_bmm_permute kernel is not implemented.") 27 | 28 | 29 | @registry.reg("cuda.bmm_softmax_bmm_permute.func_call") 30 | def gen_function_call(func_attrs, indent=" "): 31 | raise NotImplementedError("bmm_softmax_bmm_permute kernel is not implemented.") 32 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/groupnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from aitemplate.backend.cuda.groupnorm import groupnorm, groupnorm_swish 16 | 17 | __all__ = ["groupnorm", "groupnorm_swish"] 18 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/groupnorm/groupnorm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from typing import Any, Dict 16 | 17 | from aitemplate.backend import registry 18 | 19 | from aitemplate.backend.cuda.groupnorm.groupnorm_common import ( 20 | groupnorm_gen_func_call, 21 | groupnorm_gen_func_decl, 22 | groupnorm_gen_function, 23 | ) 24 | 25 | 26 | @registry.reg("cuda.groupnorm.gen_function") 27 | def gen_function(func_attrs: Dict[str, Any]) -> str: 28 | return groupnorm_gen_function(func_attrs) 29 | 30 | 31 | @registry.reg("cuda.groupnorm.func_decl") 32 | def func_decl(func_attrs: Dict[str, Any]) -> str: 33 | return groupnorm_gen_func_decl(func_attrs) 34 | 35 | 36 | @registry.reg("cuda.groupnorm.func_call") 37 | def gen_func_call(func_attrs: Dict[str, Any], indent=" ") -> str: 38 | return groupnorm_gen_func_call(func_attrs, indent) 39 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/groupnorm/groupnorm_swish.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from typing import Any, Dict 16 | 17 | from aitemplate.backend import registry 18 | 19 | from aitemplate.backend.cuda.groupnorm.groupnorm_common import ( 20 | groupnorm_gen_func_call, 21 | groupnorm_gen_func_decl, 22 | groupnorm_gen_function, 23 | ) 24 | 25 | 26 | @registry.reg("cuda.groupnorm_swish.gen_function") 27 | def gen_function(func_attrs: Dict[str, Any]) -> str: 28 | return groupnorm_gen_function(func_attrs) 29 | 30 | 31 | @registry.reg("cuda.groupnorm_swish.func_decl") 32 | def func_decl(func_attrs: Dict[str, Any]) -> str: 33 | return groupnorm_gen_func_decl(func_attrs) 34 | 35 | 36 | @registry.reg("cuda.groupnorm_swish.func_call") 37 | def gen_func_call(func_attrs: Dict[str, Any], indent=" ") -> str: 38 | return groupnorm_gen_func_call(func_attrs, indent) 39 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/jagged/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | CUDA jagged tensor-specific ops module init 17 | """ 18 | 19 | from aitemplate.backend.cuda.jagged import ( 20 | jagged_lengths_to_offsets, 21 | jagged_lengths_to_presences, 22 | ) 23 | 24 | __all__ = [ 25 | "jagged_lengths_to_offsets", 26 | "jagged_lengths_to_presences", 27 | ] 28 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/layernorm_sigmoid_mul/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 17 | """ 18 | 19 | from aitemplate.backend.cuda.layernorm_sigmoid_mul import ( 20 | batch_layernorm_sigmoid_mul, 21 | group_layernorm_sigmoid_mul, 22 | layernorm_sigmoid_mul, 23 | ) 24 | 25 | __all__ = [ 26 | "batch_layernorm_sigmoid_mul", 27 | "group_layernorm_sigmoid_mul", 28 | "layernorm_sigmoid_mul", 29 | ] 30 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/padding/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | CUDA padding init 17 | """ 18 | 19 | from aitemplate.backend.cuda.padding import ndhwc3to8, nhwc3to4, nhwc3to8, pad_last_dim 20 | 21 | __all__ = ["ndhwc3to8", "nhwc3to8", "pad_last_dim", "nhwc3to4"] 22 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/pool2d/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | CUDA pool2d module init 17 | """ 18 | 19 | from aitemplate.backend.cuda.pool2d import avg_pool2d, max_pool2d 20 | 21 | __all__ = ["avg_pool2d", "max_pool2d"] 22 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/reduce/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | CUDA reduce module init 17 | """ 18 | 19 | from aitemplate.backend.cuda.reduce import ( 20 | reduce_3d, 21 | reduce_common, 22 | reduce_max, 23 | reduce_mean, 24 | reduce_min, 25 | reduce_sum, 26 | var, 27 | vector_norm, 28 | ) 29 | 30 | __all__ = [ 31 | "reduce_3d", 32 | "reduce_common", 33 | "reduce_max", 34 | "reduce_mean", 35 | "reduce_min", 36 | "reduce_sum", 37 | "var", 38 | "vector_norm", 39 | ] 40 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/softmax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 17 | """ 18 | 19 | from aitemplate.backend.cuda.softmax import softmax 20 | 21 | __all__ = ["softmax"] 22 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/upsample/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | CUDA upsampling module init 17 | """ 18 | 19 | from aitemplate.backend.cuda.upsample import upsampling2d, upsampling2d_add 20 | 21 | __all__ = ["upsampling2d", "upsampling2d_add"] 22 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/view_ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | CUDA view_ops module init 17 | """ 18 | 19 | from aitemplate.backend.cuda.view_ops import make_jagged, view_ops 20 | 21 | __all__ = [ 22 | "view_ops", 23 | "make_jagged", 24 | ] 25 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/vision_ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | CUDA vision ops 17 | """ 18 | # flake8: noqa 19 | 20 | from aitemplate.backend.cuda.vision_ops.nms import * 21 | from aitemplate.backend.cuda.vision_ops.roi_ops import * 22 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/vision_ops/nms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 17 | """ 18 | 19 | from aitemplate.backend.cuda.vision_ops.nms import ( # noqa 20 | batched_nms, 21 | efficient_nms, 22 | nms, 23 | ) 24 | -------------------------------------------------------------------------------- /python/aitemplate/backend/cuda/vision_ops/roi_ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | CUDA roi_align module init 17 | """ 18 | 19 | from aitemplate.backend.cuda.vision_ops.roi_ops import multi_level_roi_align, roi_align 20 | 21 | __all__ = ["roi_align", "multi_level_roi_align"] 22 | -------------------------------------------------------------------------------- /python/aitemplate/backend/rocm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # flake8: noqa 16 | """ 17 | Rocm backend init. 18 | """ 19 | 20 | from aitemplate.backend.rocm import lib_template, target_def, utils 21 | from aitemplate.backend.rocm.attention import * 22 | from aitemplate.backend.rocm.common import * 23 | from aitemplate.backend.rocm.conv2d import * 24 | from aitemplate.backend.rocm.embedding import * 25 | from aitemplate.backend.rocm.gemm import * 26 | from aitemplate.backend.rocm.pool2d import * 27 | from aitemplate.backend.rocm.view_ops import * 28 | from aitemplate.backend.rocm.elementwise import * 29 | from aitemplate.backend.rocm.tensor import * 30 | from aitemplate.backend.rocm.normalization import softmax 31 | from aitemplate.backend.rocm.upsample import * 32 | from aitemplate.backend.rocm.vision_ops import * 33 | from aitemplate.backend.rocm.padding import * 34 | from aitemplate.backend.rocm.normalization import groupnorm, groupnorm_swish, layernorm 35 | -------------------------------------------------------------------------------- /python/aitemplate/backend/rocm/attention/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | from aitemplate.backend.rocm.attention import mem_eff_attention 17 | 18 | __all__ = ["mem_eff_attention"] 19 | -------------------------------------------------------------------------------- /python/aitemplate/backend/rocm/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # flake8: noqa 16 | """ 17 | ROCM Common module init 18 | """ 19 | 20 | from aitemplate.backend.rocm.common.dummy_op import * 21 | -------------------------------------------------------------------------------- /python/aitemplate/backend/rocm/common/dummy_op.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Dummy op codegen for ROCM. 17 | """ 18 | 19 | from typing import Any, Dict 20 | 21 | from aitemplate.backend import registry 22 | 23 | 24 | @registry.reg("rocm.size.gen_function") 25 | def dummy_gen_function(func_attrs: Dict[str, Any]) -> str: 26 | return "" 27 | 28 | 29 | @registry.reg("rocm.size.func_decl") 30 | def dummy_gen_function_decl(func_attrs): 31 | return "" 32 | 33 | 34 | @registry.reg("rocm.size.func_call") 35 | def dummy_gen_function_call(func_attrs, indent): 36 | return "" 37 | -------------------------------------------------------------------------------- /python/aitemplate/backend/rocm/conv2d/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | ROCM conv2d init. 17 | """ 18 | 19 | from aitemplate.backend.rocm.conv2d import ( 20 | conv2d, 21 | conv2d_bias, 22 | conv2d_bias_add, 23 | conv2d_bias_add_relu, 24 | conv2d_bias_relu, 25 | conv2d_bias_sigmoid, 26 | transposed_conv2d, 27 | transposed_conv2d_bias_relu, 28 | ) 29 | 30 | __all__ = [ 31 | "conv2d", 32 | "conv2d_bias", 33 | "conv2d_bias_add", 34 | "conv2d_bias_add_relu", 35 | "conv2d_bias_relu", 36 | "conv2d_bias_sigmoid", 37 | "transposed_conv2d", 38 | "transposed_conv2d_bias_relu", 39 | ] 40 | -------------------------------------------------------------------------------- /python/aitemplate/backend/rocm/elementwise/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 17 | """ 18 | 19 | from aitemplate.backend.rocm.elementwise import fused_elementwise 20 | 21 | __all__ = ["fused_elementwise"] 22 | -------------------------------------------------------------------------------- /python/aitemplate/backend/rocm/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # flake8: noqa 16 | from .bert_embeddings import * 17 | -------------------------------------------------------------------------------- /python/aitemplate/backend/rocm/normalization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Common modules for backends 17 | """ 18 | 19 | from aitemplate.backend.rocm.normalization import norm_common, softmax # noqa 20 | -------------------------------------------------------------------------------- /python/aitemplate/backend/rocm/padding/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | CUDA padding init 17 | """ 18 | 19 | from . import nhwc3to4, nhwc3to8, pad_last_dim 20 | 21 | __all__ = ["nhwc3to8", "pad_last_dim", "nhwc3to4"] 22 | -------------------------------------------------------------------------------- /python/aitemplate/backend/rocm/pool2d/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | ROCM pool2d init 17 | """ 18 | 19 | from aitemplate.backend.rocm.pool2d import avg_pool2d, max_pool2d 20 | 21 | __all__ = ["avg_pool2d", "max_pool2d"] 22 | -------------------------------------------------------------------------------- /python/aitemplate/backend/rocm/pool2d/avg_pool2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | ROCM avg_pool2d funcs 17 | """ 18 | 19 | from aitemplate.backend import registry 20 | from aitemplate.backend.rocm.pool2d import pool2d 21 | 22 | 23 | @registry.reg("rocm.avg_pool2d.gen_function") 24 | def max_pool2d_gen_function( 25 | func_attrs, 26 | exec_cond_template, 27 | shape_eval_template, 28 | shape_save_template, 29 | ): 30 | return pool2d.gen_function( 31 | func_attrs, 32 | exec_cond_template, 33 | shape_eval_template, 34 | shape_save_template, 35 | ) 36 | 37 | 38 | @registry.reg("rocm.avg_pool2d.func_decl") 39 | def avg_pool2d_gen_function_decl(func_attrs): 40 | func_name = func_attrs["name"] 41 | return pool2d.gen_function_decl(func_name) 42 | 43 | 44 | @registry.reg("rocm.avg_pool2d.func_call") 45 | def avg_pool2d_gen_function_call(func_attrs, indent=" "): 46 | return pool2d.gen_function_call(func_attrs, indent) 47 | -------------------------------------------------------------------------------- /python/aitemplate/backend/rocm/pool2d/max_pool2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | ROCM max_pool2d funcs 17 | """ 18 | 19 | from aitemplate.backend import registry 20 | from aitemplate.backend.rocm.pool2d import pool2d 21 | 22 | 23 | @registry.reg("rocm.max_pool2d.gen_function") 24 | def max_pool2d_gen_function( 25 | func_attrs, 26 | exec_cond_template, 27 | shape_eval_template, 28 | shape_save_template, 29 | ): 30 | return pool2d.gen_function( 31 | func_attrs, 32 | exec_cond_template, 33 | shape_eval_template, 34 | shape_save_template, 35 | ) 36 | 37 | 38 | @registry.reg("rocm.max_pool2d.func_decl") 39 | def avg_pool2d_gen_function_decl(func_attrs): 40 | func_name = func_attrs["name"] 41 | return pool2d.gen_function_decl(func_name) 42 | 43 | 44 | @registry.reg("rocm.max_pool2d.func_call") 45 | def avg_pool2d_gen_function_call(func_attrs, indent=" "): 46 | return pool2d.gen_function_call(func_attrs, indent) 47 | -------------------------------------------------------------------------------- /python/aitemplate/backend/rocm/tensor/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | ROCM tensor ops module init 17 | """ 18 | 19 | from aitemplate.backend.rocm.tensor import ( # noqa 20 | argmax, 21 | batch_gather, 22 | concatenate, 23 | concatenate_tanh, 24 | dynamic_slice, 25 | expand, 26 | expand_static_shape, 27 | full, 28 | identity, 29 | permute021, 30 | permute0213, 31 | permute102, 32 | permute210, 33 | slice_reshape_scatter, 34 | slice_scatter, 35 | split, 36 | topk, 37 | ) 38 | -------------------------------------------------------------------------------- /python/aitemplate/backend/rocm/upsample/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | ROCM upsampling module init 17 | """ 18 | 19 | from aitemplate.backend.rocm.upsample import upsampling2d, upsampling2d_add 20 | 21 | __all__ = ["upsampling2d", "upsampling2d_add"] 22 | -------------------------------------------------------------------------------- /python/aitemplate/backend/rocm/view_ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | ROCM view_ops module init 17 | """ 18 | 19 | from aitemplate.backend.rocm.view_ops import view_ops 20 | 21 | __all__ = ["view_ops"] 22 | -------------------------------------------------------------------------------- /python/aitemplate/backend/rocm/vision_ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 17 | """ 18 | 19 | from aitemplate.backend.rocm.vision_ops import efficient_nms, nms # noqa 20 | from aitemplate.backend.rocm.vision_ops.roi_ops import ( # noqa # noqa 21 | multi_level_roi_align, 22 | roi_align, 23 | ) 24 | -------------------------------------------------------------------------------- /python/aitemplate/backend/rocm/vision_ops/roi_ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | ROCM roi_align module init 17 | """ 18 | 19 | from aitemplate.backend.rocm.vision_ops.roi_ops import multi_level_roi_align, roi_align 20 | 21 | __all__ = ["roi_align", "multi_level_roi_align"] 22 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from aitemplate.compiler import base, dtype, ops, tensor_accessor, transform 16 | from aitemplate.compiler.compiler import compile_model 17 | from aitemplate.compiler.model import AIT_DEFAULT_NUM_RUNTIMES, AITData, Model 18 | 19 | __all__ = [ 20 | "base", 21 | "dtype", 22 | "op_registry", 23 | "ops", 24 | "symbolic", 25 | "tensor_accessor", 26 | "transform", 27 | "compile_model", 28 | "Model", 29 | "AITData", 30 | "AIT_DEFAULT_NUM_RUNTIMES", 31 | ] 32 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/op_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | """ 17 | Registry for basic operators and math functions. 18 | """ 19 | 20 | from typing import Callable, Dict 21 | 22 | # OP_REGISTRY defines a mapping from a FuncEnum name to a function to create this elementwise operator. 23 | # This object is initialized in elementwise.py, and referenced in base.py and math.py. 24 | OP_REGISTRY: Dict[str, Callable] = {} 25 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # flake8: noqa 16 | """ 17 | AIT operators. 18 | """ 19 | 20 | from aitemplate.compiler.ops.common import * 21 | from aitemplate.compiler.ops.conv import * 22 | from aitemplate.compiler.ops.embedding import * 23 | from aitemplate.compiler.ops.gemm_special import * 24 | from aitemplate.compiler.ops.gemm_universal import * 25 | from aitemplate.compiler.ops.gemm_epilogue_vistor import * 26 | from aitemplate.compiler.ops.jagged import * 27 | from aitemplate.compiler.ops.layernorm import * 28 | from aitemplate.compiler.ops.padding import * 29 | from aitemplate.compiler.ops.pool import * 30 | from aitemplate.compiler.ops.reduce import * 31 | from aitemplate.compiler.ops.softmax import * 32 | from aitemplate.compiler.ops.tensor import * 33 | from aitemplate.compiler.ops.upsample import * 34 | from aitemplate.compiler.ops.vision_ops import * 35 | from aitemplate.compiler.ops.attention import * 36 | from aitemplate.compiler.ops.groupnorm import * 37 | from aitemplate.compiler.ops.b2b_bmm import * 38 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/attention/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | flash attention module init 17 | """ 18 | 19 | from aitemplate.compiler.ops.attention.flash_attention import flash_attention 20 | from aitemplate.compiler.ops.attention.mem_eff_attention import mem_eff_attention 21 | 22 | 23 | __all__ = ["flash_attention", "mem_eff_attention"] 24 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/b2b_bmm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # flake8: noqa 16 | """ 17 | B2B Bmm ops. 18 | """ 19 | 20 | from aitemplate.compiler.ops.b2b_bmm.classic_b2b_bmm import classic_b2b_bmm 21 | from aitemplate.compiler.ops.b2b_bmm.fmha_style_b2b_bmm import fmha_style_b2b_bmm 22 | from aitemplate.compiler.ops.b2b_bmm.grouped_classic_b2b_bmm import ( 23 | grouped_classic_b2b_bmm, 24 | ) 25 | from aitemplate.compiler.ops.b2b_bmm.grouped_fmha_style_b2b_bmm import ( 26 | grouped_fmha_style_b2b_bmm, 27 | ) 28 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # flake8: noqa 16 | """ 17 | Common ops. 18 | """ 19 | 20 | from aitemplate.compiler.ops.common.elementwise import * 21 | from aitemplate.compiler.ops.common.int_elementwise import * 22 | from aitemplate.compiler.ops.common.epilogue import * 23 | from aitemplate.compiler.ops.common.fused_elementwise import * 24 | from aitemplate.compiler.ops.common.math import * 25 | from aitemplate.compiler.ops.common.python_ops import * 26 | from aitemplate.compiler.ops.common.view_ops import * 27 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/conv/conv2d_bias_few_channels.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Fused conv2d_bias_few_channels op. 17 | """ 18 | 19 | from aitemplate.compiler.ops.conv.special_conv2d_bias_activation import ( 20 | special_conv2d_bias_activation, 21 | ) 22 | 23 | 24 | # pylint: disable=C0103 25 | class conv2d_bias_few_channels(special_conv2d_bias_activation): 26 | """conv2d_bias_few_channels. 27 | 28 | This operator equals to conv2d_bias but has improved performance for in_channels < 8. 29 | """ 30 | 31 | def __init__(self, stride, pad, dilate=1, auto_padding=True) -> None: 32 | """Initializes conv2d_bias_few_channels""" 33 | super().__init__("identity", stride, pad, dilate, auto_padding) 34 | self._attrs["op"] = "conv2d_bias_few_channels" 35 | self._attrs["epilogue"] = "LinearCombination" 36 | 37 | def _get_op_attributes(self): 38 | attr = super()._get_op_attributes() 39 | del attr["activation"] 40 | 41 | return attr 42 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/conv/conv2d_bias_hardswish_few_channels.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Fused conv2d_bias_hardswish_few_channels op. 17 | """ 18 | 19 | from aitemplate.compiler.ops.conv.special_conv2d_bias_activation import ( 20 | special_conv2d_bias_activation, 21 | ) 22 | 23 | 24 | # pylint: disable=C0103 25 | class conv2d_bias_hardswish_few_channels(special_conv2d_bias_activation): 26 | """conv2d_bias_hardswish_few_channels. 27 | 28 | This operator equals to conv2d_bias_hardswish but has imporved performance for in_channels < 8. 29 | """ 30 | 31 | def __init__(self, stride, pad, dilate=1, auto_padding=True) -> None: 32 | """Initializes conv2d_bias_relu_few_channels""" 33 | super().__init__("hardswish", stride, pad, dilate, auto_padding) 34 | 35 | def _get_op_attributes(self): 36 | attr = super()._get_op_attributes() 37 | del attr["activation"] 38 | 39 | return attr 40 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/conv/conv2d_bias_relu_few_channels.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Fused conv2d_bias_relu_few_channels op. 17 | """ 18 | 19 | from aitemplate.compiler.ops.conv.special_conv2d_bias_activation import ( 20 | special_conv2d_bias_activation, 21 | ) 22 | 23 | 24 | # pylint: disable=C0103 25 | class conv2d_bias_relu_few_channels(special_conv2d_bias_activation): 26 | """conv2d_bias_relu_few_channels. 27 | 28 | This operator equals to conv2d_bias_relu but has imporved performance for in_channels < 8. 29 | """ 30 | 31 | def __init__(self, stride, pad, dilate=1, auto_padding=True) -> None: 32 | """Initializes conv2d_bias_relu_few_channels""" 33 | super().__init__("relu", stride, pad, dilate, auto_padding) 34 | 35 | def _get_op_attributes(self): 36 | attr = super()._get_op_attributes() 37 | del attr["activation"] 38 | 39 | return attr 40 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # flake8: noqa 16 | from aitemplate.compiler.ops.embedding.bert_embeddings import bert_embeddings 17 | 18 | __all__ = [ 19 | "bert_embeddings", 20 | ] 21 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/gemm_epilogue_vistor/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from aitemplate.compiler.ops.gemm_epilogue_vistor.bmm_rcr_softmax import bmm_rcr_softmax 16 | from aitemplate.compiler.ops.gemm_epilogue_vistor.dual_bmm_rrr_div import ( 17 | dual_bmm_rrr_div, 18 | ) 19 | from aitemplate.compiler.ops.gemm_epilogue_vistor.dual_gemm_rcr_fast_gelu import ( 20 | dual_gemm_rcr_fast_gelu, 21 | ) 22 | from aitemplate.compiler.ops.gemm_epilogue_vistor.dual_gemm_rcr_silu import ( 23 | dual_gemm_rcr_silu, 24 | ) 25 | from aitemplate.compiler.ops.gemm_epilogue_vistor.gemm_rcr_bias_softmax import ( 26 | gemm_rcr_bias_softmax, 27 | ) 28 | from aitemplate.compiler.ops.gemm_epilogue_vistor.gemm_rcr_softmax import ( 29 | gemm_rcr_softmax, 30 | ) 31 | 32 | 33 | __all__ = [ 34 | "bmm_rcr_softmax", 35 | "dual_bmm_rrr_div", 36 | "dual_gemm_rcr_fast_gelu", 37 | "dual_gemm_rcr_silu", 38 | "gemm_rcr_bias_softmax", 39 | "gemm_rcr_softmax", 40 | ] 41 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/gemm_special/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | special gemm ops 17 | """ 18 | 19 | from aitemplate.compiler.ops.gemm_special.batched_dense_vec_jagged_2d_mul import ( 20 | batched_dense_vec_jagged_2d_mul, 21 | ) 22 | from aitemplate.compiler.ops.gemm_special.bmm_rcr_n1 import bmm_rcr_n1 23 | from aitemplate.compiler.ops.gemm_special.bmm_rrr_k1_tanh import bmm_rrr_k1_tanh 24 | from aitemplate.compiler.ops.gemm_special.gemm_rrr_small_nk import gemm_rrr_small_nk 25 | 26 | 27 | __all__ = [ 28 | "batched_dense_vec_jagged_2d_mul", 29 | "bmm_rcr_n1", 30 | "bmm_rrr_k1_tanh", 31 | "gemm_rrr_small_nk", 32 | ] 33 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/gemm_universal/cache_entry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | GEMM profiling cache entries 17 | """ 18 | 19 | from dataclasses import dataclass 20 | 21 | 22 | @dataclass 23 | class GemmQueryEntry: 24 | """GEMM query entry""" 25 | 26 | dtype_a: int 27 | dtype_b: int 28 | dtype_c: int 29 | dtype_acc: int 30 | major_a: int 31 | major_b: int 32 | major_c: int 33 | op_type: str 34 | device: str 35 | epilogue: int 36 | exec_entry_sha1: str 37 | pshape: str 38 | 39 | 40 | @dataclass 41 | class GemmRecordEntry: 42 | """Profile result record entry""" 43 | 44 | exec_entry: str 45 | exec_entry_sha1: str 46 | dtype_a: int 47 | dtype_b: int 48 | dtype_c: int 49 | dtype_acc: int 50 | major_a: int 51 | major_b: int 52 | major_c: int 53 | op_type: str 54 | epilogue: int 55 | pshape: str 56 | device: str 57 | algo: str 58 | workspace: int 59 | split_k: int 60 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_gelu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | GEMM Specialization: GELU(GEMM_RCR(A, B) + Bias) 17 | """ 18 | 19 | from aitemplate.compiler.ops.gemm_universal import gemm_rcr_bias 20 | 21 | # pylint: disable=C0103,W0223,W0221 22 | 23 | 24 | class gemm_rcr_bias_gelu(gemm_rcr_bias): 25 | """GEMM Specialization: GELU(GEMM_RCR(A, B) + Bias) 26 | 27 | This operator is equivalent to the following pytorch code: 28 | 29 | .. highlight:: python 30 | .. code-block:: python 31 | A = torch.randn(M, K).cuda().half() 32 | B = torch.randn(N, K).cuda().half() 33 | Bias = torch.randn(N).cuda().half() 34 | 35 | linear = torch.nn.functional.linear(A, B, bias=Bias) 36 | y = torch.nn.GELU(linear) 37 | """ 38 | 39 | def __init__(self): 40 | """Constructor for gemm_rcr_bias_gelu""" 41 | super().__init__() 42 | self._attrs["op"] = "gemm_rcr_bias_gelu" 43 | self._attrs["epilogue"] = "LinearCombinationGELU" 44 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_hardswish.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | GEMM Specialization: HardSwish(GEMM_RCR(A, B) + Bias) 17 | """ 18 | 19 | from aitemplate.compiler.ops.gemm_universal import gemm_rcr_bias 20 | 21 | # pylint: disable=C0103,W0223,W0221 22 | 23 | 24 | class gemm_rcr_bias_hardswish(gemm_rcr_bias): 25 | """GEMM Specialization: HardSwish(GEMM_RCR(A, B) + Bias) 26 | 27 | This operator is equivalent to the following pytorch code: 28 | 29 | .. highlight:: python 30 | .. code-block:: python 31 | A = torch.randn(M, K).cuda().half() 32 | B = torch.randn(N, K).cuda().half() 33 | Bias = torch.randn(N).cuda().half() 34 | 35 | linear = torch.nn.functional.linear(A, B, bias=Bias) 36 | y = torch.nn.HardSwish(linear) 37 | """ 38 | 39 | def __init__(self): 40 | super().__init__() 41 | self._attrs["op"] = "gemm_rcr_bias_hardswish" 42 | self._attrs["epilogue"] = "LinearCombinationHardSwish" 43 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_relu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | GEMM Specialization: ReLU(GEMM_RCR(A, B) + Bias) 17 | """ 18 | 19 | from aitemplate.compiler.ops.gemm_universal import gemm_rcr_bias 20 | 21 | # pylint: disable=C0103,W0223,W0221 22 | 23 | 24 | class gemm_rcr_bias_relu(gemm_rcr_bias): 25 | """GEMM Specialization: ReLU(GEMM_RCR(A, B) + Bias) 26 | 27 | This operator is equivalent to the following pytorch code: 28 | 29 | .. highlight:: python 30 | .. code-block:: python 31 | A = torch.randn(M, K).cuda().half() 32 | B = torch.randn(N, K).cuda().half() 33 | Bias = torch.randn(N).cuda().half() 34 | 35 | linear = torch.nn.functional.linear(A, B, bias=Bias) 36 | y = torch.nn.ReLU(linear) 37 | """ 38 | 39 | def __init__(self): 40 | """Constructor for gemm_rcr_bias_relu""" 41 | super().__init__() 42 | self._attrs["op"] = "gemm_rcr_bias_relu" 43 | self._attrs["epilogue"] = "LinearCombinationRelu" 44 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_sigmoid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Sigmoid(GEMM_RCR(A, B) + Bias) 17 | """ 18 | 19 | from aitemplate.compiler.ops.gemm_universal import gemm_rcr_bias 20 | 21 | # pylint: disable=C0103,W0223,W0221 22 | 23 | 24 | class gemm_rcr_bias_sigmoid(gemm_rcr_bias): 25 | """GEMM Specialization: Sigmoid(GEMM_RCR(A, B) + Bias) 26 | 27 | This operator is equivalent to the following pytorch code: 28 | 29 | .. highlight:: python 30 | .. code-block:: python 31 | A = torch.randn(M, K).cuda().half() 32 | B = torch.randn(N, K).cuda().half() 33 | Bias = torch.randn(N).cuda().half() 34 | 35 | linear = torch.nn.functional.linear(A, B, bias=Bias) 36 | y = torch.sigmoid(linear) 37 | """ 38 | 39 | def __init__(self): 40 | """Constructor for gemm_rcr_bias_sigmoid""" 41 | super().__init__() 42 | self._attrs["op"] = "gemm_rcr_bias_sigmoid" 43 | self._attrs["epilogue"] = "LinearCombinationSigmoid" 44 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_swish.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | GEMM Specialization: SiLU(GEMM_RCR(A, B) + Bias) 17 | """ 18 | 19 | from aitemplate.compiler.ops.gemm_universal import gemm_rcr_bias 20 | 21 | # pylint: disable=C0103,W0223,W0221 22 | 23 | 24 | class gemm_rcr_bias_swish(gemm_rcr_bias): 25 | """GEMM Specialization: SiLU(GEMM_RCR(A, B) + Bias) 26 | 27 | This operator is equivalent to the following pytorch code: 28 | 29 | .. highlight:: python 30 | .. code-block:: python 31 | A = torch.randn(M, K).cuda().half() 32 | B = torch.randn(N, K).cuda().half() 33 | Bias = torch.randn(N).cuda().half() 34 | 35 | linear = torch.nn.functional.linear(A, B, bias=Bias) 36 | y = torch.nn.SiLU(linear) 37 | """ 38 | 39 | def __init__(self): 40 | """Constructor for gemm_rcr_bias_swish""" 41 | super().__init__() 42 | self._attrs["op"] = "gemm_rcr_bias_swish" 43 | self._attrs["epilogue"] = "LinearCombinationSilu" 44 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_tanh.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | GEMM Specialization: Tanh(GEMM_RCR(A, B) + Bias) 17 | """ 18 | 19 | from aitemplate.compiler.ops.gemm_universal import gemm_rcr_bias 20 | 21 | # pylint: disable=C0103,W0223,W0221 22 | 23 | 24 | class gemm_rcr_bias_tanh(gemm_rcr_bias): 25 | """GEMM Specialization: Tanh(GEMM_RCR(A, B) + Bias) 26 | 27 | This operator is equivalent to the following pytorch code: 28 | 29 | .. highlight:: python 30 | .. code-block:: python 31 | A = torch.randn(M, K).cuda().half() 32 | B = torch.randn(N, K).cuda().half() 33 | Bias = torch.randn(N).cuda().half() 34 | 35 | linear = torch.nn.functional.linear(A, B, bias=Bias) 36 | y = torch.tanh(linear) 37 | """ 38 | 39 | def __init__(self): 40 | """Constructor for gemm_rcr_bias_tanh""" 41 | super().__init__() 42 | self._attrs["op"] = "gemm_rcr_bias_tanh" 43 | self._attrs["epilogue"] = "LinearCombinationTanh" 44 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_fast_gelu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | GEMM Specialization: FastGELU(GEMM_RCR(A, B)) 17 | """ 18 | 19 | from aitemplate.compiler.ops.gemm_universal import gemm_rcr 20 | 21 | # pylint: disable=C0103,W0223,W0221 22 | 23 | 24 | class gemm_rcr_fast_gelu(gemm_rcr): 25 | """GEMM Specialization: FastGELU(GEMM_RCR(A, B)) 26 | 27 | This operator is equivalent to the following pytorch code: 28 | 29 | .. highlight:: python 30 | .. code-block:: python 31 | A = torch.randn(M, K).cuda().half() 32 | B = torch.randn(N, K).cuda().half() 33 | 34 | linear = torch.nn.functional.linear(A, B) 35 | y = torch.nn.GELU(linear) 36 | """ 37 | 38 | def __init__(self): 39 | """Constructor for gemm_rcr_fast_gelu""" 40 | super().__init__() 41 | self._attrs["op"] = "gemm_rcr_fast_gelu" 42 | self._attrs["epilogue"] = "LinearCombinationFastGELU" 43 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_permute_elup1.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | A specialization of gemm_rcr_permute applying ELU + 1 as epilogue. 17 | """ 18 | 19 | from aitemplate.compiler.ops.gemm_universal import gemm_rcr_permute 20 | 21 | # pylint: disable=C0103,W0223,W0221,W0613 22 | 23 | 24 | class gemm_rcr_permute_elup1(gemm_rcr_permute): 25 | def __init__(self, *args, **kwargs): 26 | super().__init__(*args, **kwargs) 27 | self._attrs["op"] = "gemm_rcr_permute_elup1" 28 | self._attrs["epilogue"] = "LinearCombinationELUp1" 29 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/groupnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | from aitemplate.compiler.ops.groupnorm.groupnorm import group_norm 17 | from aitemplate.compiler.ops.groupnorm.groupnorm_swish import group_norm_swish 18 | 19 | __all__ = ["group_norm", "group_norm_swish"] 20 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/groupnorm/groupnorm_swish.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | from aitemplate.compiler.ops.groupnorm.groupnorm import group_norm 17 | 18 | 19 | class group_norm_swish(group_norm): 20 | """Standalone group norm op. 21 | The grouped dim must be the last dim of the input tensor. 22 | """ 23 | 24 | def __init__(self, num_groups: int, num_channels: int) -> None: 25 | super().__init__(num_groups, num_channels) 26 | self._attrs["op"] = "groupnorm_swish" 27 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/jagged/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from aitemplate.compiler.ops.jagged.jagged_lengths_to_offsets import ( 16 | jagged_lengths_to_offsets, 17 | ) 18 | from aitemplate.compiler.ops.jagged.jagged_lengths_to_presences import ( 19 | jagged_lengths_to_presences, 20 | ) 21 | 22 | __all__ = [ 23 | "jagged_lengths_to_offsets", 24 | "jagged_lengths_to_presences", 25 | ] 26 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/layernorm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from aitemplate.compiler.ops.layernorm.batch_layernorm_sigmoid_mul import ( 16 | batch_layernorm_sigmoid_mul, 17 | ) 18 | from aitemplate.compiler.ops.layernorm.group_layernorm import group_layernorm 19 | from aitemplate.compiler.ops.layernorm.group_layernorm_sigmoid_mul import ( 20 | group_layernorm_sigmoid_mul, 21 | ) 22 | from aitemplate.compiler.ops.layernorm.layernorm import layernorm 23 | from aitemplate.compiler.ops.layernorm.layernorm_sigmoid_mul import ( 24 | layernorm_sigmoid_mul, 25 | ) 26 | 27 | 28 | __all__ = [ 29 | "batch_layernorm_sigmoid_mul", 30 | "group_layernorm", 31 | "group_layernorm_sigmoid_mul", 32 | "layernorm", 33 | "layernorm_sigmoid_mul", 34 | ] 35 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/padding/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Padding ops module init. 17 | """ 18 | 19 | from aitemplate.compiler.ops.padding.ndhwc3to8 import ndhwc3to8 20 | from aitemplate.compiler.ops.padding.nhwc3to4 import nhwc3to4 21 | from aitemplate.compiler.ops.padding.nhwc3to8 import nhwc3to8 22 | from aitemplate.compiler.ops.padding.pad_last_dim import pad_last_dim 23 | 24 | 25 | __all__ = ["ndhwc3to8", "nhwc3to8", "nhwc3to4", "pad_last_dim"] 26 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/padding/nhwc3to4.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Nhwc 3 channel to 4 channel padding. 17 | """ 18 | 19 | import jinja2 20 | 21 | from aitemplate.compiler.ops.padding.nhwc_pad_common import nhwc_pad_common 22 | 23 | 24 | SHAPE_FUNC_TEMPLATE = jinja2.Template( 25 | """ 26 | {{indent}}{{dtype}}NI = {{x_dim0}}; 27 | {{indent}}{{dtype}}HI = {{x_dim1}}; 28 | {{indent}}{{dtype}}WI = {{x_dim2}}; 29 | {{indent}}{{dtype}}NO = NI; 30 | {{indent}}{{dtype}}HO = HI; 31 | {{indent}}{{dtype}}WO = WI; 32 | {{indent}}{{dtype}}CO = 4; 33 | """ 34 | ) 35 | 36 | 37 | class nhwc3to4(nhwc_pad_common): 38 | def __init__(self): 39 | super().__init__(SHAPE_FUNC_TEMPLATE, 4) 40 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/padding/nhwc3to8.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Nhwc 3 channel to 8 channel padding. 17 | """ 18 | 19 | import jinja2 20 | 21 | from aitemplate.compiler.ops.padding.nhwc_pad_common import nhwc_pad_common 22 | 23 | SHAPE_FUNC_TEMPLATE = jinja2.Template( 24 | """ 25 | {{indent}}{{dtype}}NI = {{x_dim0}}; 26 | {{indent}}{{dtype}}HI = {{x_dim1}}; 27 | {{indent}}{{dtype}}WI = {{x_dim2}}; 28 | {{indent}}{{dtype}}NO = NI; 29 | {{indent}}{{dtype}}HO = HI; 30 | {{indent}}{{dtype}}WO = WI; 31 | {{indent}}{{dtype}}CO = 8; 32 | """ 33 | ) 34 | 35 | 36 | class nhwc3to8(nhwc_pad_common): 37 | def __init__(self): 38 | super().__init__(SHAPE_FUNC_TEMPLATE, 8) 39 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/pool/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Pool module init. 17 | """ 18 | 19 | from aitemplate.compiler.ops.pool.avg_pool2d import avg_pool2d 20 | from aitemplate.compiler.ops.pool.max_pool2d import max_pool2d 21 | 22 | 23 | __all__ = ["avg_pool2d", "max_pool2d"] 24 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/reduce/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Reduce module init. 17 | """ 18 | 19 | from aitemplate.compiler.ops.reduce.reduce_max import reduce_max 20 | from aitemplate.compiler.ops.reduce.reduce_mean import reduce_mean 21 | from aitemplate.compiler.ops.reduce.reduce_min import reduce_min 22 | from aitemplate.compiler.ops.reduce.reduce_sum import reduce_sum 23 | from aitemplate.compiler.ops.reduce.var import var 24 | from aitemplate.compiler.ops.reduce.vector_norm import vector_norm 25 | 26 | 27 | __all__ = [ 28 | "reduce_max", 29 | "reduce_mean", 30 | "reduce_min", 31 | "reduce_sum", 32 | "var", 33 | "vector_norm", 34 | ] 35 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/reduce/reduce_max.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | reduce_max op 17 | """ 18 | 19 | from aitemplate.compiler.ops.reduce.reduce_common import reduce_base 20 | 21 | # pylint: disable=C0103 22 | 23 | 24 | class reduce_max(reduce_base): 25 | """ 26 | Implements the reduce_max op. 27 | 28 | * .attr.:`dim` : int or tuple of python:ints 29 | the dimension or dimensions to reduce 30 | 31 | * .attr.:`keepdim` : bool, optional 32 | keep the reduced dimensions if True, default is False 33 | 34 | * .attr.:`dtype` : str, optional 35 | the type of the return tensor. If it is not None, 36 | the input tensor is cast to dtype before reduction. 37 | 38 | Args: 39 | input (Tensor): the input tensor. 40 | 41 | Return: 42 | Tensor that contains the max of all elements in the input tensor. 43 | """ 44 | 45 | def __init__(self, dim, keepdim=False, dtype=None) -> None: 46 | super().__init__(dim, keepdim, dtype) 47 | self._attrs["op"] = "reduce_max" 48 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/reduce/reduce_min.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | reduce_min op 17 | """ 18 | 19 | from aitemplate.compiler.ops.reduce.reduce_common import reduce_base 20 | 21 | # pylint: disable=C0103 22 | 23 | 24 | class reduce_min(reduce_base): 25 | """ 26 | Implements the reduce_min op. 27 | 28 | * .attr.:`dim` : int or tuple of python:ints 29 | the dimension or dimensions to reduce 30 | 31 | * .attr.:`keepdim` : bool, optional 32 | keep the reduced dimensions if True, default is False 33 | 34 | * .attr.:`dtype` : str, optional 35 | the type of the return tensor. If it is not None, 36 | the input tensor is cast to dtype before reduction. 37 | 38 | Args: 39 | input (Tensor): the input tensor. 40 | 41 | Return: 42 | Tensor that contains the min of all elements in the input tensor. 43 | """ 44 | 45 | def __init__(self, dim, keepdim=False, dtype=None) -> None: 46 | super().__init__(dim, keepdim, dtype) 47 | self._attrs["op"] = "reduce_min" 48 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/reduce/reduce_sum.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | reduce_sum op 17 | """ 18 | 19 | from aitemplate.compiler.ops.reduce.reduce_common import reduce_base 20 | 21 | # pylint: disable=C0103 22 | 23 | 24 | class reduce_sum(reduce_base): 25 | """ 26 | Implements the reduce_sum op. 27 | 28 | * .attr.:`dim` : int or tuple of python:ints 29 | the dimension or dimensions to reduce 30 | 31 | * .attr.:`keepdim` : bool, optional 32 | keep the reduced dimensions if True, default is False 33 | 34 | * .attr.:`dtype` : str, optional 35 | the type of the return tensor. If it is not None, 36 | the input tensor is cast to dtype before reduction. 37 | 38 | Args: 39 | input (Tensor): the input tensor. 40 | 41 | Return: 42 | Tensor that contains the sum of all elements in the input tensor. 43 | """ 44 | 45 | def __init__(self, dim, keepdim=False, dtype=None) -> None: 46 | super().__init__(dim, keepdim, dtype) 47 | self._attrs["op"] = "reduce_sum" 48 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/softmax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | softmax module init 17 | """ 18 | 19 | from aitemplate.compiler.ops.softmax.softmax import softmax 20 | 21 | 22 | __all__ = ["softmax"] 23 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/softmax/cache_entry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Softmax cache entry. 17 | """ 18 | 19 | from dataclasses import dataclass 20 | 21 | # pylint: disable=C0103 22 | 23 | 24 | @dataclass 25 | class NormQueryEntry: 26 | """Query Entry 27 | 28 | Attributes 29 | ---------- 30 | """ 31 | 32 | dtype_in: int 33 | dtype_acc: int 34 | dtype_out: int 35 | rank: int 36 | op_type: str 37 | device: str 38 | exec_entry_sha1: str 39 | 40 | 41 | @dataclass 42 | class NormRecordEntry: 43 | """Record Entry 44 | 45 | Attributes 46 | ---------- 47 | """ 48 | 49 | exec_entry: str 50 | exec_entry_sha1: str 51 | dtype_in: int 52 | dtype_acc: int 53 | dtype_out: int 54 | rank: int 55 | op_type: str 56 | device: str 57 | algo: str 58 | workspace: int 59 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/tensor/concatenate_tanh.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Concatenate_tanh 17 | """ 18 | 19 | from aitemplate.compiler.ops.tensor import concatenate 20 | 21 | # pylint: disable=C0103 22 | 23 | 24 | class concatenate_tanh(concatenate): 25 | """The fusion of concatenate and tanh.""" 26 | 27 | def __init__(self): 28 | super().__init__() 29 | self._attrs["op"] = "concatenate_tanh" 30 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/tensor/transpose.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | transpose op 17 | """ 18 | 19 | from aitemplate.compiler.base import Tensor 20 | from aitemplate.compiler.ops.tensor.permute import permute 21 | 22 | 23 | class transpose(permute): 24 | """ 25 | Returns a tensor with its two dimensions transposed. 26 | This returned tensor is not a view. Dims can be negative. 27 | """ 28 | 29 | def __call__(self, x: Tensor, dim0: int, dim1: int) -> Tensor: 30 | dims = list(range(x._rank())) 31 | dims[dim0] = dim1 32 | dims[dim1] = dim0 33 | 34 | return super().__call__(x, dims) 35 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/upsample/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Upsampling module init. 17 | """ 18 | 19 | from aitemplate.compiler.ops.upsample.upsampling2d import upsampling2d 20 | from aitemplate.compiler.ops.upsample.upsampling2d_add import upsampling2d_add 21 | 22 | 23 | __all__ = ["upsampling2d", "upsampling2d_add"] 24 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/upsample/upsampling2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Upsampling2d op. 17 | """ 18 | 19 | from aitemplate.compiler.ops.upsample.upsampling_common import upsampling2d_base 20 | 21 | 22 | # pylint: disable=C0103 23 | class upsampling2d(upsampling2d_base): 24 | """ 25 | Applies a 2D bilinear upsampling to an input signal composed of several input 26 | channels. 27 | 28 | To specify the scale, it takes the :attr:`scale_factor` as it's constructor argument. 29 | 30 | * :attr:`scale_factor` (float): multiplier for spatial size. 31 | 32 | Args: 33 | input (Tensor [N, H, W, C]): the input data. 34 | 35 | Return: 36 | Tensor [N, H_out, W_out, C]. 37 | """ 38 | 39 | def __init__(self, scale_factor, mode) -> None: 40 | super().__init__(scale_factor, mode) 41 | self._attrs["op"] = "upsampling2d" 42 | self._attrs["mode"] = mode 43 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/vision_ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Vision ops module init. 17 | """ 18 | 19 | from aitemplate.compiler.ops.vision_ops.nms import * # noqa 20 | from aitemplate.compiler.ops.vision_ops.roi_ops import * # noqa 21 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/vision_ops/nms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Nms family ops. 17 | """ 18 | 19 | from aitemplate.compiler.ops.vision_ops.nms.batched_nms import batched_nms 20 | from aitemplate.compiler.ops.vision_ops.nms.efficient_nms import efficient_nms 21 | from aitemplate.compiler.ops.vision_ops.nms.nms import nms 22 | 23 | 24 | __all__ = ["batched_nms", "nms", "efficient_nms"] 25 | -------------------------------------------------------------------------------- /python/aitemplate/compiler/ops/vision_ops/roi_ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Roi-align module init. 17 | """ 18 | 19 | from aitemplate.compiler.ops.vision_ops.roi_ops.multi_level_roi_align import ( 20 | multi_level_roi_align, 21 | ) 22 | from aitemplate.compiler.ops.vision_ops.roi_ops.roi_align import roi_align 23 | 24 | __all__ = ["roi_align", "multi_level_roi_align"] 25 | -------------------------------------------------------------------------------- /python/aitemplate/frontend/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from aitemplate.compiler.base import ( # noqa: F401 16 | DynamicProfileStrategy, 17 | IntImm, 18 | IntVar, 19 | Tensor, 20 | ) 21 | from aitemplate.frontend import nn 22 | from aitemplate.frontend.nn.parameter import Parameter 23 | 24 | __all__ = ["nn", "Parameter"] 25 | -------------------------------------------------------------------------------- /python/aitemplate/frontend/nn/conv2d/conv2d_bias_few_channels.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | conv2d bias for few channels 17 | """ 18 | 19 | from aitemplate.frontend.nn.conv2d.special_conv2d_bias_act import SpecialConv2dBiasAct 20 | 21 | 22 | class Conv2dBiasFewChannels(SpecialConv2dBiasAct): 23 | r"""Applies 2D convolution with bias for few channels. 24 | 25 | This layer equals to Conv2dBias but has improved performance for in_channels < 8. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | in_channels, 31 | out_channels, 32 | kernel_size, 33 | stride, 34 | padding=0, 35 | dilation=1, 36 | auto_padding=True, 37 | dtype="float16", 38 | ): 39 | super().__init__( 40 | "conv2d_bias_few_channels", 41 | in_channels, 42 | out_channels, 43 | kernel_size, 44 | stride, 45 | padding, 46 | dilation, 47 | auto_padding, 48 | dtype, 49 | ) 50 | -------------------------------------------------------------------------------- /python/aitemplate/frontend/nn/conv2d/conv2d_bias_hardswish.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | conv bias hardswish module 17 | """ 18 | 19 | from aitemplate.frontend.nn.conv2d.common_conv2d_bias_act import Conv2dBiasAct 20 | 21 | 22 | class Conv2dBiasHardswish(Conv2dBiasAct): 23 | r"""Applies 2D convolution with bias + hardswish.""" 24 | 25 | def __init__( 26 | self, 27 | in_channels, 28 | out_channels, 29 | kernel_size, 30 | stride, 31 | padding=0, 32 | dilation=1, 33 | groups=1, 34 | dtype="float16", 35 | ): 36 | super().__init__( 37 | "conv2d_bias_hardswish", 38 | in_channels, 39 | out_channels, 40 | kernel_size, 41 | stride, 42 | padding, 43 | dilation, 44 | groups, 45 | dtype, 46 | ) 47 | -------------------------------------------------------------------------------- /python/aitemplate/frontend/nn/conv2d/conv2d_bias_relu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | conv2d bias relu module 17 | """ 18 | 19 | from aitemplate.frontend.nn.conv2d.common_conv2d_bias_act import Conv2dBiasAct 20 | 21 | 22 | class Conv2dBiasRelu(Conv2dBiasAct): 23 | r"""Applies 2D convolution with bias + relu.""" 24 | 25 | def __init__( 26 | self, 27 | in_channels, 28 | out_channels, 29 | kernel_size, 30 | stride, 31 | padding=0, 32 | dilation=1, 33 | groups=1, 34 | dtype="float16", 35 | ): 36 | super().__init__( 37 | "conv2d_bias_relu", 38 | in_channels, 39 | out_channels, 40 | kernel_size, 41 | stride, 42 | padding, 43 | dilation, 44 | groups, 45 | dtype, 46 | ) 47 | -------------------------------------------------------------------------------- /python/aitemplate/frontend/nn/conv2d/conv2d_bias_sigmoid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | conv2d bias sigmoid module 17 | """ 18 | 19 | from aitemplate.frontend.nn.conv2d.common_conv2d_bias_act import Conv2dBiasAct 20 | 21 | 22 | class Conv2dBiasSigmoid(Conv2dBiasAct): 23 | r"""Applies 2D convolution with bias + sigmoid.""" 24 | 25 | def __init__( 26 | self, 27 | in_channels, 28 | out_channels, 29 | kernel_size, 30 | stride, 31 | padding=0, 32 | dilation=1, 33 | groups=1, 34 | dtype="float16", 35 | ): 36 | super().__init__( 37 | "conv2d_bias_sigmoid", 38 | in_channels, 39 | out_channels, 40 | kernel_size, 41 | stride, 42 | padding, 43 | dilation, 44 | groups, 45 | dtype, 46 | ) 47 | -------------------------------------------------------------------------------- /python/aitemplate/frontend/nn/conv2d/conv2d_depthwise.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | conv2d depthwise module 17 | """ 18 | 19 | from aitemplate.compiler.ops import conv2d_depthwise 20 | from aitemplate.frontend.nn.conv2d.conv2d import Conv2d 21 | 22 | 23 | class Conv2dDepthwise(Conv2d): 24 | def __init__( 25 | self, 26 | in_channels, 27 | out_channels, 28 | kernel_size, 29 | stride, 30 | padding=0, 31 | dilation=1, 32 | groups=1, 33 | dtype="float16", 34 | ): 35 | super().__init__( 36 | in_channels, 37 | out_channels, 38 | kernel_size, 39 | stride, 40 | padding, 41 | dilation, 42 | groups, 43 | dtype, 44 | ) 45 | self.op = conv2d_depthwise( 46 | stride=stride, pad=padding, dilate=dilation, group=groups 47 | ) 48 | -------------------------------------------------------------------------------- /python/aitemplate/frontend/nn/conv2d/conv2d_depthwise_bias.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | conv2d depthwise bias module 17 | """ 18 | 19 | from aitemplate.frontend.nn.conv2d.common_conv2d_bias_act import Conv2dBiasAct 20 | 21 | 22 | class Conv2dDepthwiseBias(Conv2dBiasAct): 23 | def __init__( 24 | self, 25 | in_channels, 26 | out_channels, 27 | kernel_size, 28 | stride, 29 | padding=0, 30 | dilation=1, 31 | groups=1, 32 | dtype="float16", 33 | ): 34 | super().__init__( 35 | "conv2d_depthwise_bias", 36 | in_channels, 37 | out_channels, 38 | kernel_size, 39 | stride, 40 | padding, 41 | dilation, 42 | groups, 43 | dtype, 44 | ) 45 | -------------------------------------------------------------------------------- /python/aitemplate/frontend/nn/conv2d/transposed_conv2d_bias_relu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | conv2d bias relu module 17 | """ 18 | 19 | from aitemplate.frontend.nn.conv2d.transposed_conv2d_bias_act import ( 20 | ConvTranspose2dBiasAct, 21 | ) 22 | 23 | 24 | class ConvTranspose2dBiasRelu(ConvTranspose2dBiasAct): 25 | r"""Applies a 2D transposed convolution with bias + relu.""" 26 | 27 | def __init__( 28 | self, 29 | in_channels, 30 | out_channels, 31 | kernel_size, 32 | stride, 33 | padding=0, 34 | dilation=1, 35 | groups=1, 36 | dtype="float16", 37 | ): 38 | super().__init__( 39 | "transposed_conv2d_bias_relu", 40 | in_channels, 41 | out_channels, 42 | kernel_size, 43 | stride, 44 | padding, 45 | dilation, 46 | groups, 47 | dtype, 48 | ) 49 | -------------------------------------------------------------------------------- /python/aitemplate/frontend/nn/dropout.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """Dropout/DropPath placeholder""" 16 | 17 | from aitemplate.frontend.nn.module import Module 18 | 19 | # pylint: disable=C0103 20 | 21 | 22 | class Dropout(Module): 23 | r"""Dropout placeholder""" 24 | 25 | def __init__( 26 | self, 27 | p=0, 28 | dtype="float16", 29 | ): 30 | super().__init__() 31 | 32 | def forward(self, *args): 33 | r"""Not implemented.""" 34 | assert len(args) == 1 35 | data = args[0] 36 | return data 37 | 38 | 39 | class DropPath(Dropout): 40 | r"""DropPath placeholder""" 41 | 42 | def __init__( 43 | self, 44 | dtype="float16", 45 | ): 46 | super().__init__() 47 | -------------------------------------------------------------------------------- /python/aitemplate/frontend/nn/identity.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Identity module. 17 | """ 18 | 19 | from aitemplate.frontend.nn.module import Module 20 | 21 | # pylint: disable=C0103 22 | 23 | 24 | class Identity(Module): 25 | """The identity of the input.""" 26 | 27 | def __init__( 28 | self, 29 | dtype="float16", 30 | ): 31 | super().__init__() 32 | 33 | def forward(self, *args): 34 | assert len(args) == 1 35 | data = args[0] 36 | return data 37 | -------------------------------------------------------------------------------- /python/aitemplate/frontend/nn/padding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Padding related modules. 17 | """ 18 | 19 | from aitemplate.compiler.ops import ndhwc3to8, nhwc3to8 20 | from aitemplate.frontend.nn.module import Module 21 | 22 | 23 | class Nhwc3to8(Module): 24 | r"""Pads the input data with nhwc dimensions from 3 channels to 8 channels""" 25 | 26 | def __init__(self): 27 | super().__init__() 28 | self.op = nhwc3to8() 29 | 30 | def forward(self, *args): 31 | assert len(args) == 1 32 | x = args[0] 33 | return self.op(x) 34 | 35 | 36 | class Ndhwc3to8(Module): 37 | r"""Pads the input data with ndhwc dimensions from 3 channels to 8 channels""" 38 | 39 | def __init__(self): 40 | super().__init__() 41 | self.op = ndhwc3to8() 42 | 43 | def forward(self, *args): 44 | assert len(args) == 1 45 | x = args[0] 46 | return self.op(x) 47 | -------------------------------------------------------------------------------- /python/aitemplate/frontend/nn/parameter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Parameter definition. 17 | """ 18 | 19 | from aitemplate.compiler.base import Tensor 20 | 21 | 22 | class Parameter: 23 | def __init__(self, shape, dtype, name=None, value=None): 24 | self._tensor = Tensor(shape=shape, dtype=dtype, name=name) 25 | self._value = value 26 | 27 | def tensor(self): 28 | return self._tensor 29 | 30 | def value(self): 31 | return self._value 32 | -------------------------------------------------------------------------------- /python/aitemplate/frontend/parameter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Parameter definition. 17 | """ 18 | 19 | from aitemplate.compiler.base import Tensor 20 | 21 | 22 | class Parameter: 23 | def __init__(self, shape, dtype, name=None, value=None): 24 | self._tensor = Tensor(shape=shape, dtype=dtype, name=name) 25 | self._value = value 26 | 27 | def tensor(self): 28 | return self._tensor 29 | 30 | def value(self): 31 | return self._value 32 | -------------------------------------------------------------------------------- /python/aitemplate/testing/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | testing module 17 | """ 18 | 19 | from aitemplate.testing import benchmark_ait, benchmark_pt 20 | from aitemplate.testing.detect_target import detect_target 21 | from aitemplate.testing.profile import profile_callable 22 | 23 | __all__ = ["benchmark_pt", "benchmark_ait", "detect_target", "profile_callable"] 24 | -------------------------------------------------------------------------------- /python/aitemplate/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | # Let's keep this file empty to resolve circular import issues 17 | -------------------------------------------------------------------------------- /python/aitemplate/utils/import_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import sys 16 | from pathlib import Path 17 | 18 | 19 | def import_parent(filepath: str, level: int) -> None: 20 | r_filepath = Path(filepath).resolve() 21 | parent, top = r_filepath.parent, r_filepath.parents[level] 22 | 23 | sys.path.append(str(top)) 24 | try: 25 | sys.path.remove(str(parent)) 26 | except ValueError: # Already removed 27 | pass 28 | -------------------------------------------------------------------------------- /python/aitemplate/utils/mk_ck_lib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | # flake8: noqa 17 | 18 | from aitemplate.utils.mk_ck_lib import ( 19 | conv2d_operation, 20 | gemm_operation, 21 | generator, 22 | library, 23 | manifest, 24 | ) 25 | -------------------------------------------------------------------------------- /python/aitemplate/utils/tensor_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | """ 16 | Util functions to handle tensor shapes. 17 | """ 18 | 19 | 20 | def wrap_dim(idx, rank): 21 | """ 22 | Wrap tensor index, idx, if it's negative. 23 | """ 24 | assert isinstance(idx, int), "idx must be int, but got {}".format(type(idx)) 25 | if idx < 0: 26 | idx = idx + rank 27 | assert idx < rank, "idx {} out of range; rank {}".format(idx, rank) 28 | return idx 29 | -------------------------------------------------------------------------------- /python/aitemplate/utils/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | from aitemplate.utils.visualization.plot import plot_graph 17 | 18 | __all__ = ["plot_graph"] 19 | -------------------------------------------------------------------------------- /python/aitemplate/utils/visualization/op_attr_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | KEYS = [ 17 | "op", 18 | "depth", 19 | "nop", 20 | "has_profiler", 21 | "epilogue", 22 | "epilogue_alignment", 23 | "split_k", 24 | "permute_shape", 25 | ] 26 | 27 | 28 | def op_to_content(op): 29 | # TODO (XXX): Add op specialized attrs here, like gemm/conv 30 | content = {} 31 | for k in KEYS: 32 | v = op._attrs.get(k) 33 | if v is not None and v != "": 34 | content[k] = v 35 | 36 | if op._attrs["op"] == "fused_elementwise": 37 | content["func"] = ", ".join( 38 | [str(x._attrs["func"]) for x in op._attrs["elementwise_ops"]] 39 | ) 40 | return content 41 | -------------------------------------------------------------------------------- /static/include/debug_utility.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #pragma once 16 | #include "device_functions-generated.h" 17 | 18 | namespace { 19 | template 20 | __global__ void outputs_checker(const T* tensor, int64_t elem_cnt) { 21 | for (int64_t i = 0; i < elem_cnt; i++) { 22 | float v = (float)(*(tensor + i)); 23 | if (i != 0) { 24 | printf(", "); 25 | } 26 | printf("%f", v); 27 | } 28 | printf("\n"); 29 | } 30 | 31 | } // namespace 32 | 33 | namespace ait { 34 | void InvokeInfAndNanChecker( 35 | const half* tensor, 36 | const char* tensor_name, 37 | int64_t elem_cnt, 38 | ait::StreamType stream); 39 | 40 | template 41 | void InvokeOutputsChecker( 42 | const T* tensor, 43 | const char* tensor_name, 44 | int64_t elem_cnt, 45 | ait::StreamType stream) { 46 | printf("Tensor (%s) output:\n", tensor_name); 47 | outputs_checker<<<1, 1, 0, stream>>>(tensor, elem_cnt); 48 | ait::StreamSynchronize(stream); 49 | } 50 | } // namespace ait 51 | -------------------------------------------------------------------------------- /static/include/jagged.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | // 15 | #pragma once 16 | 17 | namespace ait { 18 | 19 | // This structure is used to pack the offset metadata related to a 20 | // jagged Tensor's first dimension: JaggedIntVar. The offsets are not 21 | // available in compile time, as they are coming in a rank-1 Tensor. 22 | // In runtime, the members of the structure are set by the make_jagged 23 | // op's back-end, from the corresponding rank-1 offset Tensors' length 24 | // and data. The OFFSET_TYPE can be either int32 or int64. The number 25 | // of offset arrays is known in compile time, hence specified as the 26 | // NUM_OFFSET_ARRAYS template argument here. 27 | template 28 | struct JaggedOffsets { 29 | // the lengths the individual offset arrays 30 | int64_t lengths[NUM_OFFSET_ARRAYS]{0}; 31 | // the data in each of the offset arrays 32 | // (i.e., the offsets of the JaggedIntVar) 33 | const OFFSET_TYPE* data[NUM_OFFSET_ARRAYS]{nullptr}; 34 | }; 35 | 36 | } // namespace ait 37 | -------------------------------------------------------------------------------- /static/include/windll.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | // 15 | 16 | #pragma once 17 | 18 | #include 19 | 20 | namespace ait { 21 | 22 | // throws std::runtime_error in case of problems 23 | void GetConstantsBin(void** address, size_t* size); 24 | 25 | } // namespace ait 26 | -------------------------------------------------------------------------------- /tests/ci_profile_cache/README.md: -------------------------------------------------------------------------------- 1 | # Profiling Database for CI (Deprecated) 2 | 3 | Profile Cache DB for CI is deprecated. Now CI will select the algorithm with the smallest tiling size and smallest alignments for CI. 4 | 5 | The selection function is defined at: `backend/target.py: Target:select_minimal_algo` and specialized in each backend target implementation. 6 | -------------------------------------------------------------------------------- /tests/lint/flake8_problem_matcher.json: -------------------------------------------------------------------------------- 1 | { 2 | "problemMatcher": [ 3 | { 4 | "owner": "flake8", 5 | "severity": "error", 6 | "pattern": [ 7 | { 8 | "regexp": "^([^:]+):(\\d+):(\\d+):\\s+(.*)$", 9 | "file": 1, 10 | "line": 2, 11 | "column": 3, 12 | "message": 4 13 | } 14 | ] 15 | } 16 | ] 17 | } 18 | --------------------------------------------------------------------------------