├── .flake8 ├── .gitignore ├── CMakeLists.txt ├── DESCRIPTION.md ├── LICENCE ├── README.md ├── THIRD-PARTY-NOTICES ├── benchmark └── recsys │ └── dlrmv2-mlperf │ ├── README.md │ ├── prepare_env.sh │ ├── python │ ├── backend.py │ ├── backend_pytorch_native.py │ ├── consumer.py │ ├── dataset.py │ ├── dlrm_model.py │ ├── items.py │ ├── multihot_criteo.py │ └── runner.py │ ├── run_common.sh │ ├── run_local.sh │ ├── run_main.sh │ ├── setup_env_offline.sh │ ├── setup_env_server.sh │ ├── tools │ ├── accuracy-dlrm.py │ └── dist_quantile.txt │ └── user.conf ├── cmake └── modules │ ├── FindCPUkernels.cmake │ ├── FindZENDNN.cmake │ └── FindZentorch.cmake ├── examples ├── README.md ├── bert_example.py ├── dlrm_example.py ├── dlrm_model.py ├── llama_bf16_example.py ├── llama_woq_example.py ├── requirements.txt └── resnet_example.py ├── linter ├── py_cpp_linter.sh └── requirements.txt ├── pyproject.toml ├── requirements.txt ├── scripts ├── README.md ├── dlrm_optimal_env_setup.sh └── zentorch_env_setup.sh ├── setup.py ├── src └── cpu │ ├── cpp │ ├── Bindings.cpp │ ├── CausalAttentionMask.cpp │ ├── Config.cpp │ ├── Config.hpp.in │ ├── Conv.cpp │ ├── ConvUtils.hpp │ ├── Embed.cpp │ ├── EmbedBag.cpp │ ├── EmbedUtils.hpp │ ├── Fused_EB_MLP.cpp │ ├── MaskedMultiHeadAttention.cpp │ ├── Matmul.cpp │ ├── MatmulUtils.hpp │ ├── Memory.hpp │ ├── MoE.cpp │ ├── Ops.hpp │ ├── QLinear.cpp │ ├── QLinearUtils.hpp │ ├── QuantEmbedUtils.hpp │ ├── Rope.cpp │ ├── RopeUtils.hpp │ ├── Sdpa_ref.cpp │ ├── Singletons.cpp │ ├── Threading.cpp │ ├── Threading.hpp │ ├── Utils.hpp │ ├── WOQMatmul.cpp │ ├── WOQMatmulUtils.hpp │ ├── WeightReorder.cpp │ └── kernels │ │ ├── vec │ │ ├── add_softmax.h │ │ ├── utils.h │ │ └── vec512_bfloat16.h │ │ ├── zen_MaskedMultiHeadAttention_512.cpp │ │ ├── zen_Sdpa.cpp │ │ └── zen_cpukernels.hpp │ └── python │ └── zentorch │ ├── _C │ └── __init__.py │ ├── _StaticQuantizedLinear.py │ ├── _WOQLinear.py │ ├── _WOQ_embedding_bag.py │ ├── __init__.py │ ├── _compile_backend.py │ ├── _custom_op_replacement.py │ ├── _eltwise_fusions.py │ ├── _freeze_utils.py │ ├── _freezing.py │ ├── _fusion_matcher.py │ ├── _fusion_patterns.py │ ├── _graph_cleanup.py │ ├── _graph_preprocess_matcher.py │ ├── _graph_preprocess_patterns.py │ ├── _info.py │ ├── _logging.py │ ├── _meta_registrations.py │ ├── _mkldnn.py │ ├── _op_replacement.py │ ├── _optimize.py │ ├── _quant_model_reload.py │ ├── _quantization_utils.py │ ├── _utils.py │ ├── llm │ ├── __init__.py │ ├── _checks.py │ ├── _custom_model_forward.py │ ├── _custom_models_reference_linear_fusion.py │ ├── _model_conversion_functions.py │ └── _optimize.py │ └── utils.py └── test ├── README.md ├── __init__.py ├── install_requirements.py ├── llm_tests ├── __init__.py ├── llm_utils.py ├── test_fused_rope_op.py └── test_masked_mha.py ├── pre_trained_model_tests ├── __init__.py ├── pre_trained_model_utils.py ├── test_bert.py └── test_cnn.py ├── unittests ├── __init__.py ├── miscellaneous_tests │ ├── __init__.py │ ├── test_avx512_device.py │ ├── test_bf16_device.py │ ├── test_model_reload.py │ └── test_zentorch_version.py ├── model_tests │ ├── __init__.py │ ├── test_addmm.py │ ├── test_addmm_1dbias.py │ ├── test_addmm_1dbias_add.py │ ├── test_addmm_1dbias_gelu.py │ ├── test_addmm_1dbias_mul_add.py │ ├── test_addmm_1dbias_relu.py │ ├── test_addmm_1dbias_silu_mul.py │ ├── test_addmm_gelu.py │ ├── test_addmm_relu.py │ ├── test_attn_qkv_fusion.py │ ├── test_baddbmm.py │ ├── test_convolution.py │ ├── test_embedding.py │ ├── test_embedding_bag.py │ ├── test_group_embeded_ops_with_sum_ops.py │ ├── test_horizontal_embedding_bag_group.py │ ├── test_horizontal_embedding_bag_group_addmm_1dbias.py │ ├── test_horizontal_embedding_bag_group_addmm_1dbias_relu.py │ ├── test_horizontal_embedding_group.py │ ├── test_mini_mha.py │ ├── test_mm_relu.py │ ├── test_mm_silu.py │ ├── test_mm_silu_mul.py │ ├── test_pattern_matcher.py │ ├── test_qlinear.py │ ├── test_qlinear_eltwise.py │ ├── test_qlinear_mul_add.py │ ├── test_qlinear_reorder.py │ ├── test_quant_embedding_bag.py │ ├── test_quant_embedding_bag_with_cat_fusion.py │ ├── test_sdpa.py │ └── test_woq_linear.py ├── op_tests │ ├── __init__.py │ ├── _pack.py │ ├── test_addmm.py │ ├── test_addmm_1dbias.py │ ├── test_addmm_1dbias_add.py │ ├── test_addmm_1dbias_mul_add.py │ ├── test_addmm_silu.py │ ├── test_addmm_silu_mul.py │ ├── test_attn_qkv_fusion.py │ ├── test_baddbmm.py │ ├── test_bmm.py │ ├── test_convolution.py │ ├── test_embedding.py │ ├── test_embedding_bag.py │ ├── test_embeg_pack_weight.py │ ├── test_fuse_index_mul_index_add_wrapper.py │ ├── test_horizontal_embedding_bag_group.py │ ├── test_horizontal_embedding_group.py │ ├── test_matmul_impl.py │ ├── test_mm.py │ ├── test_mm_silu.py │ ├── test_mm_silu_mul.py │ ├── test_prepare_4d_causal_attention_mask.py │ ├── test_qlinear.py │ ├── test_qlinear_eltwise.py │ ├── test_qlinear_mul_add.py │ ├── test_quant_embedding_bag.py │ ├── test_rope.py │ ├── test_weight_reorder_for_matmul_with_qlinear.py │ └── test_woq_linear.py ├── quant_utils.py └── unittest_utils.py └── utils.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = .git,build,dist,third_party, benchmark/recsys/dlrmv2-mlperf/inference 3 | max-line-length = 80 4 | select = C,E,F,W,B,B950 5 | extend-ignore = E203,E501,W503 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | dist/ 3 | third_party/ 4 | *.egg-info/ 5 | src/cpu/cpp/Config.hpp 6 | src/cpu/python/zentorch/_build_info.py 7 | __pycache__ 8 | benchmark/recsys/dlrmv2-mlperf/mlperf.conf 9 | benchmark/recsys/dlrmv2-mlperf/output 10 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | #****************************************************************************** 2 | # Copyright (c) 2023-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | #****************************************************************************** 5 | 6 | cmake_minimum_required(VERSION 3.1 FATAL_ERROR) 7 | project(zentorch) 8 | 9 | # set cmake folder as a place to search for .cmake files 10 | set(CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/modules) 11 | 12 | # build and add ZenDNN and BLIS libraries 13 | find_package(ZENDNN REQUIRED) 14 | 15 | find_package(Torch REQUIRED) 16 | 17 | # build mha kernel 18 | find_package(CPUkernels REQUIRED) 19 | 20 | configure_file ("${CMAKE_CURRENT_SOURCE_DIR}/src/cpu/cpp/Config.hpp.in" 21 | "${CMAKE_CURRENT_SOURCE_DIR}/src/cpu/cpp/Config.hpp") 22 | 23 | #TODO: Restructure this block to remove warnings 24 | file(GLOB ZENTORCH_CPP_SOURCES "${CMAKE_SOURCE_DIR}/src/cpu/cpp/*.cpp") 25 | set(ZENTORCH_INCLUDE_DIR "${CMAKE_SOURCE_DIR}/src/cpu/cpp/") 26 | list(REMOVE_ITEM ZENTORCH_CPP_SOURCES "${CMAKE_SOURCE_DIR}/src/cpu/cpp/Bindings.cpp") 27 | 28 | 29 | add_library(zentorch SHARED ${ZENTORCH_CPP_SOURCES}) 30 | 31 | add_dependencies(CPUkernels libamdZenDNN libamdblis) 32 | add_dependencies(zentorch libamdZenDNN CPUkernels) 33 | 34 | # Enable C++17 35 | target_compile_features(zentorch PUBLIC cxx_std_17) 36 | 37 | set_target_properties(zentorch PROPERTIES 38 | LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib/) 39 | 40 | target_include_directories(zentorch PUBLIC 41 | ${ZENTORCH_INCLUDE_DIR} 42 | ${MHA_INCLUDE_DIR} 43 | ${ZENDNN_INCLUDE_DIR} 44 | ${FBGEMM_INCLUDE_DIR} 45 | ${TORCH_INCLUDE_DIRS} 46 | ${BLIS_INCLUDE_DIR} 47 | ) 48 | 49 | target_link_libraries(zentorch PUBLIC 50 | ${MHA_LIBRARIES} 51 | ${ZENDNN_LIBRARIES} 52 | ${BLIS_LIBRARIES} 53 | ${FBGEMM_LIBRARIES} 54 | ${LIBXSMM_LIBRARIES} 55 | ${TORCH_LIBRARIES} 56 | ${CMAKE_CURRENT_BINARY_DIR}/lib/libasmjit.a) 57 | 58 | add_custom_command( 59 | TARGET zentorch POST_BUILD 60 | COMMAND ${CMAKE_COMMAND} -E copy 61 | ${CMAKE_CURRENT_BINARY_DIR}/lib/libzentorch.so 62 | ${CMAKE_SOURCE_DIR}/${INSTALL_LIB_DIR}/${PROJECT_NAME}/) 63 | 64 | # Set default build type 65 | if(NOT CMAKE_BUILD_TYPE) 66 | message(STATUS "Build type not set - defaulting to Release") 67 | set(CMAKE_BUILD_TYPE "Release") 68 | endif() 69 | -------------------------------------------------------------------------------- /DESCRIPTION.md: -------------------------------------------------------------------------------- 1 | __The latest ZenDNN Plugin for PyTorch* (zentorch) 5.0.2 is here!__ 2 | 3 | ZenDNN 5.0.2 is a minor release building upon the major ZenDNN 5.0 release. This upgrade continues the focus on optimizing inference with Recommender Systems and Large Language Models on AMD EPYC™ CPUs. 4 | ZenDNN includes AMD EPYC™ enhancements for bfloat16 performance, expanded support for cutting-edge models like Llama 3.1 and 3.2, Microsoft Phi, and more as well as support for INT4 quantized datatype. 5 | This includes the advanced Activation-Aware Weight Quantization (AWQ) algorithm for LLMs and quantized support for the DLRM-v2 model with int8 weights. 6 | 7 | Under the hood, ZenDNN’s enhanced AMD-specific optimizations operate at every level. In addition to highly optimized operator microkernels, these include comprehensive graph optimizations including pattern identification, graph reordering, and fusions. 8 | They also incorporate optimized embedding bag kernels and enhanced zenMatMul matrix splitting strategies which leverage the AMD EPYC™ microarchitecture to deliver enhanced throughput and latency. 9 | 10 | The ZenDNN PyTorch plugin is called zentorch. Combined with PyTorch's torch.compile, zentorch transforms deep learning pipelines into finely-tuned, AMD-specific engines, delivering unparalleled efficiency and speed for large-scale inference workloads. 11 | 12 | The zentorch 5.0.2 release plugs seamlessly with PyTorch versions from 2.6 to 2.2, offering a high-performance experience for deep learning on AMD EPYC™ platforms. 13 | 14 | ## Support 15 | 16 | We welcome feedback, suggestions, and bug reports. Should you have any of the these, please kindly file an issue on the ZenDNN Plugin for PyTorch Github page [here](https://github.com/amd/ZenDNN-pytorch-plugin/issues) 17 | 18 | ## License 19 | 20 | AMD copyrighted code in ZenDNN is subject to the [Apache-2.0, MIT, or BSD-3-Clause](https://github.com/amd/ZenDNN-pytorch-plugin/blob/main/LICENSE) licenses; consult the source code file headers for the applicable license. Third party copyrighted code in ZenDNN is subject to the licenses set forth in the source code file headers of such code. -------------------------------------------------------------------------------- /benchmark/recsys/dlrmv2-mlperf/README.md: -------------------------------------------------------------------------------- 1 | # Running the Quantized DLRMv2 Model with Zentorch 2 | 3 | > **_NOTE:_** The following paths are relative to the directory this file is located in. 4 | 5 | ## 1. Environment Setup 6 | 7 | ### 1.1. Create a New Conda Environment 8 | 9 | ```bash 10 | conda create -n zentorch-env-py3.10 python=3.10 -y 11 | conda activate zentorch-env-py3.10 12 | ``` 13 | 14 | ### 1.2. Install Zentorch 15 | 16 | Ensure GCC version is 12.2 or higher. 17 | 18 | Follow the zentorch installation steps in the [README](https://github.com/amd/ZenDNN-pytorch-plugin?tab=readme-ov-file#2-installation) file. 19 | 20 | ## 2. Data Preparation 21 | 22 | ### 2.1 To prepare the data, refer to [MLPerf DLRMv2](https://github.com/mlcommons/training/tree/master/recommendation_v2/torchrec_dlrm#create-the-synthetic-multi-hot-dataset). The data structure should be as follows 23 | 24 | ```shell 25 | . 26 | ├── terabyte_input 27 | │   ├── day_23_dense.npy 28 | │   ├── day_23_labels.npy 29 | │   └── day_23_sparse_multi_hot.npz 30 | ``` 31 | 32 | Set the directory path to `$DATA_DIR` 33 | 34 | ```bash 35 | export DATA_DIR=/path/to/terabyte_input/ 36 | ``` 37 | 38 | ## 3. Model Preparation 39 | 40 | Download and install Quark v0.8. Installation instructions can be found [here](https://quark.docs.amd.com/release-0.8/install.html). 41 | We suggest downloading the "zip release". 42 | 43 | > zentorch v5.0.2 is compatible with Quark v0.8. Please make sure you download the right version. 44 | 45 | Follow the steps in the README file at "examples/torch/rm" directory to download, prepare and quantize the model. 46 | 47 | Set the path for the quantized DLRM model directory. 48 | 49 | ```bash 50 | export MODEL_DIR=/path/to/dlrm_quark 51 | ``` 52 | 53 | ## 4. Execute DLRMv2 54 | 55 | ### 4.1 Dependency Installation 56 | 57 | ```shell 58 | bash prepare_env.sh 59 | ``` 60 | 61 | ### 4.2. Setup 62 | 63 | Ensure `$DATA_DIR` and `$MODEL_DIR` are set. Use the most optimal setup for performance using 64 | 65 | ```shell 66 | source ../../../scripts/dlrm_optimal_env_setup.sh 67 | ``` 68 | 69 | And modify the configuration files, `setup_env_offline.sh`, to match your machine's specifications. 70 | 71 | ```shell 72 | export NUM_SOCKETS=2 # e.g., 2 73 | export CPUS_PER_SOCKET=128 # e.g., 128 74 | export CPUS_PER_PROCESS=128 # determines the number of processes used 75 | # process-per-socket = CPUS_PER_SOCKET/CPUS_PER_PROCESS 76 | export CPUS_PER_INSTANCE=2 # instance-per-process number=CPUS_PER_PROCESS/CPUS_PER_INSTANCE 77 | # total-instance = instance-per-process * process-per-socket 78 | export CPUS_FOR_LOADGEN=1 # number of CPUs for loadgen 79 | # finally used in our code is max(CPUS_FOR_LOADGEN, remaining cores for instances) 80 | export BATCH_SIZE=100 81 | ``` 82 | 83 | ### 4.3. Offline Performance 84 | 85 | To generate the performance numbers in offline mode, please execute the following command. 86 | 87 | ```shell 88 | source setup_env_offline.sh && ./run_main.sh offline int8-bf16 89 | ``` 90 | 91 | ### 4.4. Offline Accuracy 92 | 93 | To generate the accuracy numbers in offline mode, please execute the following command. 94 | 95 | ```shell 96 | source setup_env_offline.sh && ./run_main.sh offline accuracy int8-bf16 97 | ``` 98 | -------------------------------------------------------------------------------- /benchmark/recsys/dlrmv2-mlperf/prepare_env.sh: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # * Modifications Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # * All rights reserved. 4 | # * 5 | # * Was sourced from 6 | # * https://github.com/mlcommons/inference_results_v3.1/blob/main/closed/Intel/code/dlrm-v2-99/pytorch-cpu-int8/prepare_env.sh 7 | # * commit ID: eaf622a 8 | # ****************************************************************************** 9 | 10 | # Install required libraries 11 | pip install scikit-learn pybind11 iopath==0.1.10 pyre_extensions==0.0.30 12 | pip install "git+https://github.com/mlperf/logging.git@3.0.0-rc2" 13 | conda install -c conda-forge gperftools llvm-openmp -y 14 | 15 | # Install torch and required libraries 16 | pip3 install torch==2.6.0 --index-url https://download.pytorch.org/whl/cpu 17 | pip install fbgemm-gpu==1.0.0 --index-url https://download.pytorch.org/whl/cpu 18 | pip install torchrec==0.7.0 19 | pip install torchsnapshot==0.1.0 20 | 21 | # Install mlperf loadgen 22 | git clone https://github.com/mlcommons/inference.git 23 | pushd inference 24 | git checkout v4.1 25 | git submodule update --init --recursive 26 | pushd loadgen 27 | CFLAGS="-std=c++14" python setup.py install 28 | popd 29 | cp -r mlperf.conf .. 30 | popd 31 | 32 | 33 | -------------------------------------------------------------------------------- /benchmark/recsys/dlrmv2-mlperf/python/backend.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # * Modifications Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # * All rights reserved. 4 | # * 5 | # * Was sourced from 6 | # * https://github.com/mlcommons/inference_results_v3.1/blob/main/closed/Intel/code/dlrm-v2-99/pytorch-cpu-int8/python/backend.py # noqa: B950 7 | # * commit ID: eaf622a 8 | # ****************************************************************************** 9 | """ 10 | abstract backend class 11 | """ 12 | 13 | 14 | # TODO: Base class is currenlty not used, can be removed in future. 15 | class Backend: 16 | def __init__(self): 17 | self.inputs = [] 18 | self.outputs = [] 19 | 20 | def version(self): 21 | raise NotImplementedError("Backend:version") 22 | 23 | def name(self): 24 | raise NotImplementedError("Backend:name") 25 | 26 | def load(self, model_path, inputs=None, outputs=None): 27 | raise NotImplementedError("Backend:load") 28 | -------------------------------------------------------------------------------- /benchmark/recsys/dlrmv2-mlperf/python/items.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # * Modifications Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # * All rights reserved. 4 | # * 5 | # * Was sourced from 6 | # * https://github.com/mlcommons/inference_results_v3.1/blob/main/closed/Intel/code/dlrm-v2-99/pytorch-cpu-int8/python/items.py # noqa: B950 7 | # * commit ID: eaf622a 8 | # ****************************************************************************** 9 | import time 10 | 11 | 12 | class Item: 13 | """An item that we queue for processing by the thread pool.""" 14 | 15 | def __init__(self, query_id, content_id, idx_offsets): 16 | self.query_id = query_id 17 | self.content_id = content_id 18 | self.idx_offsets = idx_offsets 19 | self.start = time.time() 20 | 21 | 22 | class OItem: 23 | def __init__( 24 | self, 25 | presults, 26 | query_ids=None, 27 | array_ref=None, 28 | good=0, 29 | total=0, 30 | timing=0, 31 | ): 32 | self.good = good 33 | self.total = total 34 | self.timing = timing 35 | self.presults = presults 36 | self.query_ids = query_ids 37 | self.array_ref = array_ref 38 | -------------------------------------------------------------------------------- /benchmark/recsys/dlrmv2-mlperf/run_common.sh: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # * Modifications Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # * All rights reserved. 4 | # * 5 | # * Was sourced from 6 | # * https://github.com/mlcommons/inference_results_v3.1/blob/main/closed/Intel/code/dlrm-v2-99/pytorch-cpu-int8/run_common.sh 7 | # * commit ID: eaf622a 8 | # ****************************************************************************** 9 | #!/bin/bash 10 | 11 | if [ "x$DATA_DIR" == "x" ]; then 12 | echo "DATA_DIR not set" && exit 1 13 | fi 14 | if [ "x$MODEL_DIR" == "x" ]; then 15 | echo "MODEL_DIR not set" && exit 1 16 | fi 17 | 18 | # defaults 19 | backend=pytorch 20 | model=dlrm 21 | dataset=dlrm-multihot-pytorch 22 | device="cpu" 23 | mode="Offline" 24 | dtype="fp32" 25 | test_type="performance" 26 | 27 | for i in $* ; do 28 | case $i in 29 | pytorch) backend=$i; shift;; 30 | dlrm) model=$i; shift;; 31 | multihot-criteo) dataset=$i; shift;; 32 | cpu) device=$i; shift;; 33 | fp32|int8-fp32|int8-bf16) dtype=$i; shift;; 34 | performance|accuracy) test_type=$i; shift;; 35 | Server|Offline) mode=$i; 36 | esac 37 | done 38 | # debuging 39 | # echo $backend 40 | # echo $model 41 | # echo $dataset 42 | # echo $device 43 | # echo $MODEL_DIR 44 | # echo $DATA_DIR 45 | # echo $DLRM_DIR 46 | # echo $EXTRA_OPS 47 | 48 | if [[ $dtype == "int8-fp32" ]] ; then 49 | extra_args="$extra_args --use-int8-fp32" 50 | elif [[ $dtype == "int8-bf16" ]] ; then 51 | extra_args="$extra_args --use-int8-bf16" 52 | fi 53 | 54 | if [[ $test_type == "accuracy" ]] ; then 55 | extra_args="$extra_args --accuracy" 56 | fi 57 | 58 | name="$model-$dataset-$backend" 59 | 60 | echo $name 61 | # 62 | # pytorch 63 | # 64 | if [ $name == "dlrm-multihot-criteo-pytorch" ] ; then 65 | model_path="$MODEL_DIR/dlrm-multihot-pytorch.pt" 66 | profile=dlrm-multihot-pytorch 67 | fi 68 | # debuging 69 | # echo $model_path 70 | # echo $profile 71 | # echo $extra_args 72 | 73 | name="$backend-$device/$model" 74 | EXTRA_OPS="$extra_args $EXTRA_OPS" 75 | -------------------------------------------------------------------------------- /benchmark/recsys/dlrmv2-mlperf/run_local.sh: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # * Modifications Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # * All rights reserved. 4 | # * 5 | # * Was sourced from 6 | # * https://github.com/mlcommons/inference_results_v3.1/blob/main/closed/Intel/code/dlrm-v2-99/pytorch-cpu-int8/run_local.sh 7 | # * commit ID: eaf622a 8 | # ****************************************************************************** 9 | #!/bin/bash 10 | 11 | source ./run_common.sh 12 | 13 | common_opt="--config ./mlperf.conf" 14 | OUTPUT_DIR=$PWD/output/$name/$mode/$test_type 15 | if [[ $test_type == "performance" ]]; then 16 | OUTPUT_DIR=$OUTPUT_DIR/run_1 17 | # OUTPUT_DIR="$OUTPUT_DIR/$(date +%Y-%m-%d-%H:%M:%S)" 18 | fi 19 | if [ ! -d $OUTPUT_DIR ]; then 20 | mkdir -p $OUTPUT_DIR 21 | fi 22 | 23 | set -x # echo the next command 24 | 25 | profiling=0 26 | if [ $profiling == 1 ]; then 27 | EXTRA_OPS="$EXTRA_OPS --enable-profiling=True" 28 | fi 29 | 30 | ## multi-instance 31 | python -u python/runner.py --profile $profile $common_opt --model $model --int8-model-path $MODEL_DIR \ 32 | --dataset $dataset --dataset-path $DATA_DIR --output $OUTPUT_DIR $EXTRA_OPS $@ 33 | -------------------------------------------------------------------------------- /benchmark/recsys/dlrmv2-mlperf/run_main.sh: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # * Modifications Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # * All rights reserved. 4 | # * 5 | # * Was sourced from 6 | # * https://github.com/mlcommons/inference_results_v3.1/blob/main/closed/Intel/code/dlrm-v2-99/pytorch-cpu-int8/run_main.sh 7 | # * commit ID: eaf622a 8 | # ****************************************************************************** 9 | 10 | #!/bin/bash 11 | 12 | dtype="fp32" 13 | batch_size=$(($BATCH_SIZE + 0)) 14 | if [ $# -ge 2 ]; then 15 | if [[ $2 == "accuracy" ]]; then 16 | test_type="accuracy" 17 | fi 18 | if [[ $2 == "int8-fp32" ]] || [[ $3 == "int8-fp32" ]]; then 19 | dtype="int8-fp32" 20 | elif [[ $2 == "int8-bf16" ]] || [[ $3 == "int8-bf16" ]]; then 21 | dtype="int8-bf16" 22 | fi 23 | else 24 | test_type="performance" 25 | fi 26 | 27 | export TORCHINDUCTOR_SIZE_ASSERTS=0 28 | export TORCHINDUCTOR_NAN_ASSERTS=0 29 | export TORCHINDUCTOR_SCALAR_ASSERTS=0 30 | export ZENDNN_EB_THREAD_TYPE=2 31 | export USE_ZENDNN_EB=0 32 | export OMP_NUM_THREADS=$CPUS_PER_INSTANCE 33 | export ZENDNN_PRIMITIVE_CACHE_CAPACITY=20971520 # https://oneapi-src.github.io/oneDNN/dev_guide_primitive_cache.html. Kindly refer this link for more details 34 | export LD_PRELOAD="${CONDA_PREFIX}/lib/libtcmalloc.so:${CONDA_PREFIX}/lib/libomp.so:$LD_PRELOAD" 35 | export TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=30469645312 # https://github.com/gperftools/gperftools/issues/360. Kindly refer this link for more details 36 | 37 | # echo $LD_PRELOAD 38 | # export LD_PRELOAD="/usr/local/lib/libjemalloc.so:/home/amd/anaconda3/envs/zentorch_dinesh/lib/libomp.so:$LD_PRELOAD" 39 | # echo $LD_PRELOAD 40 | 41 | mode="Offline" 42 | extra_option="--samples-per-query-offline=204800" 43 | if [ $1 == "server" ]; then 44 | mode="Server" 45 | extra_option="" 46 | fi 47 | 48 | # sudo ./run_clean.sh 49 | echo "Running $mode bs=$batch_size int8 $test_type" 50 | ./run_local.sh pytorch dlrm multihot-criteo cpu $dtype $test_type --scenario $mode --max-ind-range=40000000 --samples-to-aggregate-quantile-file=${PWD}/tools/dist_quantile.txt --max-batchsize=$batch_size $extra_option 51 | -------------------------------------------------------------------------------- /benchmark/recsys/dlrmv2-mlperf/setup_env_offline.sh: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # * Modifications Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # * All rights reserved. 4 | # * 5 | # * Was sourced from 6 | # * https://github.com/mlcommons/inference_results_v3.1/blob/main/closed/Intel/code/dlrm-v2-99/pytorch-cpu-int8/setup_env_offline.sh 7 | # * commit ID: eaf622a 8 | # ****************************************************************************** 9 | 10 | set -x 11 | export NUM_SOCKETS=2 # i.e. 2 12 | export CPUS_PER_SOCKET=128 # i.e. 128 13 | export CPUS_PER_CONSUMER=128 # which determine how much processes will be used 14 | # consumer-per-socket = CPUS_PER_SOCKET/CPUS_PER_CONSUMER 15 | export CPUS_PER_INSTANCE=2 # instance-per-consumer number=CPUS_PER_CONSUMER/CPUS_PER_INSTANCE 16 | # total-instance = instance-per-consumer * consumer-per-socket 17 | export CPUS_FOR_LOADGEN=1 # number of cpus for loadgen 18 | # finally used in our code is max(CPUS_FOR_LOADGEN, left cores for instances) 19 | export BATCH_SIZE=100 20 | set +x 21 | 22 | -------------------------------------------------------------------------------- /benchmark/recsys/dlrmv2-mlperf/setup_env_server.sh: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # * Modifications Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # * All rights reserved. 4 | # * 5 | # * Was sourced from 6 | # * https://github.com/mlcommons/inference_results_v3.1/blob/main/closed/Intel/code/dlrm-v2-99/pytorch-cpu-int8/setup_env_server.sh 7 | # * commit ID: eaf622a 8 | # ****************************************************************************** 9 | set -x 10 | export NUM_SOCKETS=2 # i.e. 2 11 | export CPUS_PER_SOCKET=128 # i.e. 128 12 | export CPUS_PER_CONSUMER=128 # which determine how much processes will be used 13 | # consumer-per-socket = CPUS_PER_SOCKET/CPUS_PER_CONSUMER 14 | export CPUS_PER_INSTANCE=2 # instance-per-consumer number=CPUS_PER_CONSUMER/CPUS_PER_INSTANCE 15 | # total-instance = instance-per-consumer * consumer-per-socket 16 | export CPUS_FOR_LOADGEN=1 # number of cpus for loadgen 17 | # finally used in our code is max(CPUS_FOR_LOADGEN, left cores for instances) 18 | export BATCH_SIZE=200 19 | set +x 20 | -------------------------------------------------------------------------------- /benchmark/recsys/dlrmv2-mlperf/tools/dist_quantile.txt: -------------------------------------------------------------------------------- 1 | 100, 100, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 300, 300, 400, 500, 600, 700 2 | -------------------------------------------------------------------------------- /benchmark/recsys/dlrmv2-mlperf/user.conf: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # * Modifications Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # * All rights reserved. 4 | # * 5 | # * Was sourced from 6 | # * https://github.com/mlcommons/inference_results_v3.1/blob/main/closed/Intel/code/dlrm-v2-99/pytorch-cpu-int8/user.conf 7 | # * commit ID: eaf622a 8 | # ****************************************************************************** 9 | dlrm.Server.target_qps = 14500.0 10 | dlrm.Offline.target_qps = 10000.0 11 | -------------------------------------------------------------------------------- /cmake/modules/FindCPUkernels.cmake: -------------------------------------------------------------------------------- 1 | #****************************************************************************** 2 | # Copyright (c) 2023-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | #****************************************************************************** 5 | 6 | IF (NOT MHA_FOUND) 7 | 8 | find_package(Torch REQUIRED) 9 | 10 | # Collect all .cpp files in the src directory 11 | file(GLOB cpu_kernels "${CMAKE_CURRENT_SOURCE_DIR}/src/cpu/cpp/kernels/*.cpp") 12 | 13 | # setting necessary flags for .cpp files 14 | set(FLAGS "-Wall -Werror -Wno-unknown-pragmas -Wno-error=uninitialized \ 15 | -Wno-error=maybe-uninitialized -fPIC -fopenmp -fno-math-errno \ 16 | -fno-trapping-math -O2 -std=c++17 -D_GLIBCXX_USE_CXX11_ABI=0 \ 17 | -mavx512f -mavx512bf16 -mavx512vl -mavx512dq -DCPU_CAPABILITY_AVX512") 18 | 19 | set_source_files_properties(${cpu_kernels} PROPERTIES COMPILE_FLAGS "${FLAGS}") 20 | 21 | # creating library for mha and sdpa 22 | add_library(CPUkernels STATIC ${cpu_kernels}) 23 | 24 | set_target_properties(CPUkernels PROPERTIES 25 | ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib/) 26 | 27 | target_include_directories(CPUkernels PUBLIC 28 | ${TORCH_INCLUDE_DIRS} 29 | ${ZENDNN_INCLUDE_DIR} 30 | ${BLIS_INCLUDE_DIR}) 31 | 32 | LIST(APPEND MHA_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/lib/libCPUkernels.a) 33 | set(MHA_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/src/cpu/cpp/kernels/") 34 | 35 | SET(MHA_FOUND ON) 36 | 37 | ENDIF (NOT MHA_FOUND) -------------------------------------------------------------------------------- /cmake/modules/FindZentorch.cmake: -------------------------------------------------------------------------------- 1 | #****************************************************************************** 2 | # Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | #****************************************************************************** 5 | 6 | cmake_minimum_required(VERSION 3.1) 7 | 8 | execute_process( 9 | COMMAND python -c "import zentorch; print(zentorch.__path__[0])" 10 | OUTPUT_VARIABLE ZENTORCH_PACKAGE_DIR 11 | OUTPUT_STRIP_TRAILING_WHITESPACE 12 | ) 13 | 14 | # Set ZENTORCH_LIBRARY to the path of the libzentorch.so 15 | set(ZENTORCH_LIBRARY "${ZENTORCH_PACKAGE_DIR}/libzentorch.so") 16 | 17 | # Check if the library file exists 18 | if(EXISTS ${ZENTORCH_LIBRARY}) 19 | set(ZENTORCH_FOUND TRUE) 20 | message(STATUS "FOUND libzentorch: ${ZENTORCH_LIBRARY}") 21 | endif() 22 | 23 | if(ZENTORCH_FOUND AND NOT TARGET Zentorch::Zentorch) 24 | add_library(Zentorch::Zentorch SHARED IMPORTED) 25 | set_target_properties(Zentorch::Zentorch PROPERTIES 26 | IMPORTED_LOCATION "${ZENTORCH_LIBRARY}" 27 | ) 28 | endif() 29 | 30 | set(CMAKE_SKIP_BUILD_RPATH FALSE) 31 | set(CMAKE_BUILD_WITH_INSTALL_RPATH FALSE) 32 | set(CMAKE_INSTALL_RPATH "${ZENTORCH_PACKAGE_DIR}") 33 | set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) 34 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023-2025 Advanced Micro Devices, Inc. 3 | * All rights reserved. 4 | ******************************************************************************/ 5 | 6 | # Examples 7 | Given below are some examples for running inference for various models with zentorch. Note that you may need to install additional packages in your environment if not already present. Assuming you have installed zentorch plugin in your environment, you can install the rest of the packages by running: 8 | ```bash 9 | pip install -r requirements.txt 10 | ``` 11 | 12 | ## BERT 13 | ### Execute the following command to run inference for bert model: 14 | ```bash 15 | python bert_example.py 16 | ``` 17 | 18 | ### Output 19 | Last hidden states shape: torch.Size([3, 339, 1024]) 20 | 21 | ## DLRM 22 | ### Execute the following command to run inference for dlrm model: 23 | ```bash 24 | python dlrm_example.py 25 | ``` 26 | ### Output 27 | ```plain 28 | AUC Score: 0.5 29 | ``` 30 | 31 | ## LLAMA: Bfloat16 32 | ### Execute the following command to run inference for llama bf16 model: 33 | ```bash 34 | python llama_bf16_example.py 35 | ``` 36 | ### Output 37 | 38 | ```plain 39 | 'Hi, How are you today? I hope you are having a great day. I' 40 | ``` 41 | 42 | ## LLAMA: Weight Only Quantization 43 | Please update the following line with the correct path to your quantized model in llama_woq_example.py: 44 | ```python 45 | safetensor_path = "" 46 | ``` 47 | ### Execute the following command to run inference for llama woq model: 48 | ```bash 49 | python llama_woq_example.py 50 | ``` 51 | ### Output 52 | ```plain 53 | 'Hi, How are you today? I hope you are having a great day. I' 54 | ``` 55 | 56 | ## Resnet 57 | ### Execute the following command to run inference for resnet model: 58 | ```bash 59 | python resnet_example.py 60 | ``` 61 | ### Output 62 | ```plain 63 | plane, carpenter's plane, woodworking plane 64 | ``` 65 | -------------------------------------------------------------------------------- /examples/bert_example.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # * Copyright (c) 2023-2025 Advanced Micro Devices, Inc. 3 | # * All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import torch 7 | import zentorch 8 | from transformers import BertTokenizer, BertModel 9 | from datasets import load_dataset 10 | 11 | print("\n" + "=" * 10 + " BERT Example Execution Started " + "=" * 10 + "\n") 12 | 13 | # Load the IMDB dataset 14 | print("Loading IMDB dataset") 15 | dataset = load_dataset("imdb", split="test") 16 | print("Sample text from dataset:", dataset[0]['text']) 17 | 18 | # Load the tokenizer 19 | tokenizer = BertTokenizer.from_pretrained( 20 | "bert-large-uncased", trust_remote_code=True 21 | ) 22 | 23 | # Load the BERT model 24 | model_id = "google-bert/bert-large-uncased" 25 | print(f"Loading BERT model: {model_id}") 26 | model = BertModel.from_pretrained( 27 | model_id, torch_dtype=torch.bfloat16, trust_remote_code=True 28 | ) 29 | model = model.eval() 30 | 31 | # Optimize model with ZenTorch 32 | print("Optimizing model with ZenTorch") 33 | model.forward = torch.compile(model.forward, backend="zentorch") 34 | 35 | # Inference 36 | print("Running inference") 37 | with torch.inference_mode(), torch.no_grad(), \ 38 | torch.amp.autocast("cpu", enabled=True), \ 39 | zentorch.freezing_enabled(): 40 | # Prepare inputs 41 | inputs = tokenizer( 42 | dataset["text"][:3], return_tensors="pt", padding=True, truncation=True 43 | ) 44 | 45 | # Generate outputs 46 | outputs = model(**inputs) 47 | 48 | # Get last hidden states 49 | last_hidden_states = outputs.last_hidden_state 50 | print("Last hidden states shape:", last_hidden_states.shape) 51 | print("\n" + "=" * 10 + " Script Executed Successfully " + "=" * 10 + "\n") 52 | -------------------------------------------------------------------------------- /examples/dlrm_example.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # * Copyright (c) 2023-2025 Advanced Micro Devices, Inc. 3 | # * All rights reserved. 4 | # ****************************************************************************** 5 | 6 | # Import dependencies 7 | from dlrm_model import DLRMMLPerf 8 | import torch 9 | import numpy as np 10 | import zentorch 11 | import random 12 | from sklearn.metrics import roc_auc_score 13 | 14 | print("\n" + "=" * 10 + " DLRM Example Execution Started " + "=" * 10 + "\n") 15 | 16 | # Basic setup for reproducibility 17 | np.random.seed(123) 18 | random.seed(123) 19 | torch.manual_seed(123) 20 | 21 | # Initialize the model 22 | print("Initializing DLRM model") 23 | DEFAULT_INT_NAMES = [f'int_{i}' for i in range(13)] 24 | model = DLRMMLPerf( 25 | embedding_dim=128, 26 | num_embeddings_pool=[ 27 | 40000000, 39060, 17295, 7424, 20265, 3, 7122, 1543, 63, 40000000, 28 | 3067956, 405282, 10, 2209, 11938, 155, 4, 976, 14, 40000000, 29 | 40000000, 40000000, 590152, 12973, 108, 36 30 | ], 31 | dense_in_features=len(DEFAULT_INT_NAMES), 32 | dense_arch_layer_sizes=[512, 256, 128], 33 | over_arch_layer_sizes=[1024, 1024, 512, 256, 1], 34 | dcn_num_layers=3, 35 | dcn_low_rank_dim=512, 36 | use_int8=False, 37 | use_bf16=True 38 | ).bfloat16() 39 | 40 | # Prepare Inputs 41 | print("Preparing input tensors") 42 | multi_hot = [3, 2, 1, 2, 6, 1, 1, 1, 1, 7, 3, 8, 1, 6, 9, 5, 1, 1, 43 | 1, 12, 100, 27, 10, 3, 1, 1] 44 | batchsize = 32768 45 | densex = torch.randn((batchsize, 13), dtype=torch.float).to(torch.bfloat16) 46 | index = [torch.ones((batchsize * h), dtype=torch.long) for h in multi_hot] 47 | offset = [torch.arange(0, (batchsize + 1) * h, h, dtype=torch.long) for h in multi_hot] 48 | 49 | # Optimize Model with ZenTorch 50 | print("Optimizing model with ZenTorch") 51 | model = torch.compile(model, backend="zentorch") 52 | 53 | # Run Inference 54 | print("Running inference") 55 | with torch.inference_mode(), torch.no_grad(), \ 56 | torch.amp.autocast("cpu", enabled=True), \ 57 | zentorch.freezing_enabled(): 58 | out = model(densex, index, offset) 59 | 60 | # Simulating labels 61 | true_labels = torch.randint(0, 2, (32768,)) 62 | predicted_probabilities = out.to(torch.float32).cpu().detach().numpy().reshape(-1) 63 | true_labels = true_labels.cpu().detach().numpy() 64 | 65 | # Calculate AUC 66 | auc_score = roc_auc_score(true_labels, predicted_probabilities) 67 | print(f"AUC Score: {auc_score}") 68 | 69 | print("\n" + "=" * 10 + " Script Executed Successfully " + "=" * 10 + "\n") 70 | -------------------------------------------------------------------------------- /examples/llama_bf16_example.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # * Copyright (c) 2023-2025 Advanced Micro Devices, Inc. 3 | # * All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import torch 7 | import zentorch 8 | from transformers import AutoModelForCausalLM, AutoTokenizer 9 | 10 | print("\n" + "=" * 10 + " Llama BF16 Example Execution Started " + "=" * 10 + "\n") 11 | 12 | # Load Tokenizer and Model 13 | model_id = "meta-llama/Llama-3.1-8B" 14 | print(f"Loading model: {model_id}") 15 | model = AutoModelForCausalLM.from_pretrained( 16 | model_id, 17 | torchscript=True, 18 | return_dict=False, 19 | torch_dtype=torch.bfloat16, 20 | ) 21 | tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) 22 | model = model.eval() 23 | 24 | # Prepare Inputs 25 | print("Preparing inputs") 26 | generate_kwargs = { 27 | "do_sample": False, 28 | "num_beams": 4, 29 | "max_new_tokens": 10, 30 | "min_new_tokens": 2, 31 | } 32 | prompt = "Hi, How are you today?" 33 | print(f"Input prompt: {prompt}") 34 | 35 | # Optimize Model with ZenTorch 36 | print("Optimizing model with ZenTorch") 37 | model = zentorch.llm.optimize(model, dtype=torch.bfloat16) 38 | 39 | # Run Inference 40 | print("Running inference") 41 | with torch.inference_mode(), torch.no_grad(), \ 42 | torch.amp.autocast("cpu", enabled=True): 43 | model.forward = torch.compile(model.forward, backend="zentorch") 44 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids 45 | output = model.generate(input_ids, **generate_kwargs) 46 | 47 | # Decode Output 48 | gen_text = tokenizer.batch_decode(output, skip_special_tokens=True) 49 | print(f"Generated text: {gen_text}") 50 | 51 | print("\n" + "=" * 10 + " Script Executed Successfully " + "=" * 10 + "\n") 52 | -------------------------------------------------------------------------------- /examples/llama_woq_example.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # * Copyright (c) 2023-2025 Advanced Micro Devices, Inc. 3 | # * All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import torch 7 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig 8 | import zentorch 9 | 10 | print("\n" + "=" * 10 + " Llama WOQ Example Execution Started " + "=" * 10 + "\n") 11 | 12 | # Load Tokenizer and WOQ Model 13 | model_id = "meta-llama/Llama-3.1-8B" 14 | safetensor_path = "" 15 | config = AutoConfig.from_pretrained( 16 | model_id, 17 | torchscript=True, 18 | return_dict=False, 19 | torch_dtype=torch.bfloat16, 20 | ) 21 | 22 | model = AutoModelForCausalLM.from_config( 23 | config, trust_remote_code=True, torch_dtype=torch.bfloat16 24 | ) 25 | 26 | # Load WOQ Model 27 | print(f"Loading quantized model: {model_id}") 28 | model = zentorch.load_quantized_model(model, safetensor_path) 29 | model = model.eval() 30 | 31 | # Load Tokenizer 32 | tokenizer = AutoTokenizer.from_pretrained( 33 | model_id, trust_remote_code=True, padding_side="left", use_fast=False 34 | ) 35 | 36 | # Prepare Inputs 37 | print("Preparing inputs") 38 | generate_kwargs = { 39 | "do_sample": False, 40 | "num_beams": 4, 41 | "max_new_tokens": 10, 42 | "min_new_tokens": 2, 43 | } 44 | prompt = "Hi, How are you today?" 45 | print(f"Input prompt: {prompt}") 46 | 47 | # Optimize Model with ZenTorch 48 | print("Optimizing model with ZenTorch") 49 | model = zentorch.llm.optimize(model, dtype=torch.bfloat16) 50 | 51 | # Run Inference 52 | print("Running inference") 53 | with torch.inference_mode(), torch.no_grad(), \ 54 | torch.amp.autocast("cpu", enabled=True): 55 | model.forward = torch.compile(model.forward, backend="zentorch") 56 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids 57 | output = model.generate(input_ids, **generate_kwargs) 58 | 59 | # Decode Output 60 | gen_text = tokenizer.batch_decode(output, skip_special_tokens=True) 61 | print(f"Generated text: \"{gen_text[0]}\"") 62 | 63 | print("\n" + "=" * 10 + " Script Executed Successfully " + "=" * 10 + "\n") 64 | -------------------------------------------------------------------------------- /examples/requirements.txt: -------------------------------------------------------------------------------- 1 | datasets 2 | transformers==4.46.2 3 | scikit-learn 4 | Pillow 5 | urllib3 6 | -------------------------------------------------------------------------------- /examples/resnet_example.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # * Copyright (c) 2023-2025 Advanced Micro Devices, Inc. 3 | # * All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import torch 7 | import zentorch 8 | from transformers import AutoImageProcessor, ResNetForImageClassification 9 | from PIL import Image 10 | import urllib.request 11 | 12 | print("\n" + "=" * 10 + " ResNet Example Execution Started " + "=" * 10 + "\n") 13 | 14 | # Load Processor and Model 15 | model_id = "microsoft/resnet-50" 16 | print(f"Loading RESNET model: {model_id}") 17 | processor = AutoImageProcessor.from_pretrained(model_id, use_fast=False) 18 | model = ResNetForImageClassification.from_pretrained( 19 | model_id, torch_dtype=torch.bfloat16 20 | ) 21 | 22 | # Prepare Inputs 23 | print("Downloading and loading image") 24 | urllib.request.urlretrieve( 25 | "https://raw.githubusercontent.com/EliSchwartz/" 26 | "imagenet-sample-images/master/n03954731_plane.JPEG", 27 | "airplane.jpeg" 28 | ) 29 | image = Image.open("airplane.jpeg") 30 | 31 | inputs = processor(image, return_tensors="pt") 32 | 33 | # Convert inputs to BF16 34 | inputs = {k: v.to(torch.bfloat16) for k, v in inputs.items()} 35 | 36 | # Optimize Model with ZenTorch 37 | print("Optimizing model with ZenTorch") 38 | model.forward = torch.compile(model.forward, backend="zentorch") 39 | 40 | # Run Inference 41 | print("Running inference") 42 | with torch.inference_mode(), torch.no_grad(), \ 43 | torch.amp.autocast("cpu", enabled=True), \ 44 | zentorch.freezing_enabled(): 45 | logits = model(**inputs).logits 46 | 47 | # Get prediction 48 | predicted_label = logits.argmax(-1).item() 49 | print(f"Predicted label: {model.config.id2label[predicted_label]}") 50 | 51 | print("\n" + "=" * 10 + " Script Executed Successfully " + "=" * 10 + "\n") 52 | -------------------------------------------------------------------------------- /linter/py_cpp_linter.sh: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2023 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | #!/bin/bash 7 | 8 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 9 | cd $SCRIPT_DIR 10 | pip install --upgrade -r requirements.txt 11 | cd ../ 12 | 13 | # Flake8 14 | echo "********************************************************************************" 15 | echo " * Starting flake8 linting for python scripts... *" 16 | echo "********************************************************************************" 17 | flake8 18 | echo "To re-format the above files (if any) with black, run the following command (files will not always be changed): " 19 | tput bold 20 | echo "flake8 --quiet | xargs black --verbose" 21 | tput sgr0 22 | echo "Completed py-linting!" 23 | echo -e "********************************************************************************\n\n" 24 | 25 | # clang-format 26 | echo "********************************************************************************" 27 | echo " * Now executing clang-format checks for C++ files... *" 28 | echo "********************************************************************************" 29 | git clang-format --commit `git rev-list HEAD | tail -n 1` --diff 30 | echo "If clang-format suggested some modifications, then you can use the following command to re-format: " 31 | tput bold 32 | echo "git clang-format -f" 33 | tput sgr0 34 | echo "CPP linting completed!" 35 | echo "********************************************************************************" 36 | -------------------------------------------------------------------------------- /linter/requirements.txt: -------------------------------------------------------------------------------- 1 | black>=23.7.0 2 | clang-format>=16.0.6 3 | flake8>=6.1.0 4 | flake8-bugbear>=23.3.23 5 | flake8-comprehensions>=3.12.0 6 | flake8-executable>=2.1.3 7 | flake8-logging-format>=0.9.0 8 | flake8-noqa>=1.3.2 9 | flake8-pyi>=23.3.1 10 | flake8-simplify>=0.19.3 11 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ['torch', 'numpy'] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cmake==3.31.6 2 | expecttest==0.1.6 3 | hypothesis 4 | ninja 5 | numpy 6 | parameterized 7 | setuptools>=50.0 8 | setupext_janitor 9 | deprecated -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | ### For CNN/RECSYS/LLM/NLP models: 2 | This scripts setups the optimal env settings for zentorch/ipex llm/recsys/cnn/nlp runs. 3 | 4 | #### Usage: 5 | Create a conda environment where you run the benchmarks.(Don't use any conda base environment) 6 | 7 | Before you run the benchmarks, activate the conda environment and run the zentorch_env_setup.sh file 8 | 9 | source zentorch_env_setup.sh --help 10 | 11 | source zentorch_env_setup.sh --framework zentorch/ipex --model llm/recsys/cnn/nlp --threads 96/128/192 --precision bf16_amp/bf16/fp32/woq/int8 12 | 13 | It sets all the necessary variables for respective runs based on the options provided. 14 | 15 | ### For DLRM model: 16 | This scripts setups the optimal env settings for zentorch dlrm runs. 17 | 18 | #### Usage: 19 | Create a conda environment where you run the benchmarks.(Don't use any conda base environment) 20 | 21 | Before you run the benchmarks, activate the conda environment and run the dlrm_optimal_env_setup.sh file 22 | 23 | source dlrm_optimal_env_setup.sh --help 24 | 25 | source dlrm_optimal_env_setup.sh --threads 96/128/192 --precision bf16/fp32/int8 26 | 27 | It sets all the necessary variables for respective runs based on the options provided. 28 | 29 | -------------------------------------------------------------------------------- /src/cpu/cpp/Bindings.cpp: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023-2025 Advanced Micro Devices, Inc. 3 | * All rights reserved. 4 | ******************************************************************************/ 5 | 6 | /* 7 | TORCH_LIBRARY is used for all ops which will replace ATen/prims ops in 8 | fx based graph optimizations. Few guidelines for prototypes. 9 | - If there is simlar op in ATen of PyTorch, please check 10 | "aten/src/ATen/native/native_functions.yaml" in PyTorch repo. 11 | - Our op arguments should be superset of the corresponding 12 | arguments in ATen op. 13 | - Our op arguments should match the arguments of corresponding 14 | op in both order of the arguments and type. 15 | - Our op specific arguments should be at the end of the list. 16 | - All ops should have prefix "zentorch_", for example 17 | zentorch_. 18 | */ 19 | 20 | // needs to be included only once in library. 21 | #include "Ops.hpp" 22 | #include "Threading.hpp" 23 | #include "Utils.hpp" 24 | 25 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 26 | m.def("show_config", &zentorch::show_config, 27 | "Show the current configuration of ZenTorch."); 28 | 29 | m.def("is_avx512_supported", zentorch::is_avx512_supported, 30 | "Check if AVX512 instructions are supported on the current hardware." 31 | "\n\nReturns:\n" 32 | "\tBool: True if AVX512 instructions are supported, False otherwise."); 33 | 34 | m.def("is_bf16_supported", zendnn::utils::zendnn_bf16_device_check, 35 | "Check if BF16 is supported on the current device.\n\n" 36 | "Returns:\n" 37 | " Bool: True if BF16 is supported, False otherwise."); 38 | 39 | m.def("zentorch_matmul_impl", &zentorch::zentorch_matmul_impl, 40 | py::arg("input"), py::arg("weight"), py::arg("bias"), 41 | py::arg("self_or_result"), py::arg("post_op_ids"), 42 | py::arg("post_op_buffers"), py::arg("beta") = 0.0f, 43 | py::arg("alpha") = 1.0f, 44 | py::arg("zentorch_op_name") = "zentorch::zendnn_matmul_impl", 45 | "Perform matrix multiplication with ZenTorch optimizations.\n\n" 46 | "Args:\n" 47 | " input (torch.Tensor): The input tensor.\n" 48 | " weight (torch.Tensor): The weight tensor.\n" 49 | " bias (torch.Tensor): The bias tensor.\n" 50 | " self_or_result (torch.Tensor): The result tensor.\n" 51 | " post_op_ids (List[int]): Post Op IDs.\n" 52 | " post_op_buffers (List[torch.Tensor]): Post Op buffers.\n" 53 | " beta (float, optional): The beta value. Default is 0.0.\n" 54 | " alpha (float, optional): The alpha value. Default is 1.0.\n" 55 | " zentorch_op_name (str, optional): The operator name. Default is " 56 | "'zentorch::zendnn_matmul_impl'." 57 | "Returns:\n" 58 | " Tensor: Result of the maxtrix multiplication."); 59 | 60 | m.def("zentorch_get_packed_embedding_weight", 61 | &zentorch::zentorch_get_packed_embedding_weight, py::arg("weight"), 62 | py::arg("weight_scales"), py::arg("weight_zero_points"), 63 | "Get packed embedding weights for ZenTorch.\n\n" 64 | "Args:\n" 65 | " weight (torch.Tensor): The weight tensor.\n" 66 | " weight_scales (List[float]): The weight scales.\n" 67 | " weight_zero_points (List[int]): The weight zero points." 68 | "Returns:\n" 69 | " Tensor: Packed embedding weights."); 70 | 71 | m.def("thread_bind", &zentorch::thread_bind, py::arg("core_ids"), 72 | "Bind threads to specified CPU cores.\n\n" 73 | "Args:\n" 74 | " core_ids (List[int]): A list of core IDs to bind threads to."); 75 | 76 | m.def("zentorch_weight_reorder_for_matmul", 77 | &zentorch::zentorch_weight_reorder_for_matmul, py::arg("weight"), 78 | py::arg("is_weight_oc_x_ic") = true, 79 | "Reorder the weight tensor to desired format.\n\n" 80 | "Args:\n" 81 | " weight (torch.Tensor): The weight tensor.\n" 82 | " is_weight_oc_x_ic (bool, optional): True if weight is stored as " 83 | "OCxIC.\n" 84 | "Returns:\n" 85 | " Tensor: Reordered weight tensor."); 86 | } 87 | -------------------------------------------------------------------------------- /src/cpu/cpp/CausalAttentionMask.cpp: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************************************** 2 | Modifications Copyright(c) 2025 Advanced Micro Devices, Inc. 3 | All rights reserved. 4 | 5 | Was sourced from 6 | https: // 7 | github.com/intel/intel-extension-for-pytorch/blob/main/csrc/cpu/aten/kernels/MaskedMultiHeadAttentionKrnl.cpp 8 | ***************************************************************************************************************************/ 9 | 10 | #include "Utils.hpp" 11 | #include 12 | 13 | namespace zentorch { 14 | 15 | template 16 | inline void 17 | attention_mask_2d_to_4d(const T *attention_mask_ptr, T *causal_4d_mask_ptr, 18 | const at::Tensor &finfo_min, int64_t batch_size, 19 | int64_t seq_length, int64_t src_length, 20 | int64_t past_key_value_length, int64_t length, 21 | int64_t diagonal) { 22 | T finfo_min_val = finfo_min.item(); 23 | 24 | #pragma omp parallel for collapse(3) 25 | for (int64_t b = 0; b < batch_size; ++b) { 26 | for (int64_t l = 0; l < seq_length; ++l) { 27 | for (int64_t c = 0; c < length; ++c) { 28 | int64_t idx = b * seq_length * length + l * length + c; 29 | T value = finfo_min_val; 30 | if (l + diagonal <= c && l + past_key_value_length >= c) { 31 | value = 0; 32 | } 33 | if (c < src_length) { 34 | T inverted_mask_value = 1.0 - attention_mask_ptr[b * src_length + c]; 35 | if (inverted_mask_value != 0) { 36 | value = finfo_min_val; 37 | } 38 | } 39 | causal_4d_mask_ptr[idx] = value; 40 | } 41 | } 42 | } 43 | } 44 | 45 | at::Tensor prepare_4d_causal_attention_mask_kernel_impl( 46 | const at::Tensor &attention_mask, const at::Tensor &inputs_embeds, 47 | int64_t past_key_value_length, const at::Tensor &finfo_min, 48 | int64_t sliding_window) { 49 | 50 | auto dtype = inputs_embeds.scalar_type(); 51 | int64_t batch_size = inputs_embeds.size(0); 52 | int64_t seq_length = inputs_embeds.size(1); 53 | int64_t src_length = attention_mask.size(-1); 54 | int64_t length = seq_length + past_key_value_length; 55 | int64_t diagonal = past_key_value_length - sliding_window; 56 | 57 | at::Tensor causal_4d_mask = torch::empty({batch_size, 1, seq_length, length}, 58 | inputs_embeds.options()); 59 | if (dtype == at::kFloat) { 60 | float *attention_mask_ptr = attention_mask.data_ptr(); 61 | float *causal_4d_mask_ptr = causal_4d_mask.data_ptr(); 62 | attention_mask_2d_to_4d( 63 | attention_mask_ptr, causal_4d_mask_ptr, finfo_min, batch_size, 64 | seq_length, src_length, past_key_value_length, length, diagonal); 65 | } else if (dtype == at::kBFloat16) { 66 | at::BFloat16 *attention_mask_ptr = attention_mask.data_ptr(); 67 | at::BFloat16 *causal_4d_mask_ptr = causal_4d_mask.data_ptr(); 68 | attention_mask_2d_to_4d( 69 | attention_mask_ptr, causal_4d_mask_ptr, finfo_min, batch_size, 70 | seq_length, src_length, past_key_value_length, length, diagonal); 71 | } else { 72 | ZENTORCH_CHECK(false, "zentorch::prepare_4d_causal_attention_mask_" 73 | "kernel_impl supports only float and bfloat16 " 74 | "datatypes"); 75 | } 76 | 77 | return causal_4d_mask; 78 | } 79 | 80 | TORCH_LIBRARY_FRAGMENT(zentorch, m) { 81 | m.def("prepare_4d_causal_attention_mask(Tensor attention_mask, " 82 | "Tensor inputs_embeds, int past_key_value_length, Tensor " 83 | "finfo_min, int " 84 | "sliding_window)-> (Tensor)"); 85 | } 86 | TORCH_LIBRARY_IMPL(zentorch, CPU, m) { 87 | m.impl("prepare_4d_causal_attention_mask", 88 | prepare_4d_causal_attention_mask_kernel_impl); 89 | } 90 | 91 | } // namespace zentorch 92 | -------------------------------------------------------------------------------- /src/cpu/cpp/Config.cpp: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023-2024 Advanced Micro Devices, Inc. 3 | * All rights reserved. 4 | ******************************************************************************/ 5 | 6 | #include "Config.hpp" 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | namespace zentorch { 14 | 15 | #define TO_STRING2(x) #x 16 | #define TO_STRING(x) TO_STRING2(x) 17 | 18 | std::string get_zendnn_version() { 19 | std::ostringstream ss; 20 | ss << ZENDNN_VERSION_MAJOR << "." << ZENDNN_VERSION_MINOR << "." 21 | << ZENDNN_VERSION_PATCH; 22 | return ss.str(); 23 | } 24 | 25 | std::string show_config() { 26 | std::ostringstream ss; 27 | ss << "zentorch Version: " << TO_STRING(ZENTORCH_VERSION) << "\n"; 28 | ss << "zentorch built with:\n"; 29 | ss << " - Commit-id: " << TO_STRING(ZENTORCH_VERSION_HASH) << "\n"; 30 | ss << " - PyTorch: " << TO_STRING(PT_VERSION) << "\n"; 31 | #if defined(__GNUC__) 32 | ss << " - GCC Version: " << __GNUC__ << "." << __GNUC_MINOR__ << "\n"; 33 | #endif 34 | #if defined(__cplusplus) 35 | ss << " - C++ Version: " << __cplusplus << "\n"; 36 | #endif 37 | ss << "Third_party libraries:\n"; 38 | ss << " - " 39 | << "AMD " << bli_info_get_version_str() << " ( Git Hash " 40 | << BLIS_VERSION_HASH << " )" 41 | << "\n"; 42 | ss << " - " 43 | << "AMD ZENDNN v" << get_zendnn_version() << " ( Git Hash " 44 | << ZENDNN_LIB_VERSION_HASH << " )" 45 | << "\n"; 46 | ss << " - " 47 | << "FBGEMM " << FBGEMM_VERSION_TAG << " ( Git Hash " << FBGEMM_VERSION_HASH 48 | << " )" 49 | << "\n"; 50 | ss << " - " 51 | << "LIBXSMM " << LIBXSMM_VERSION_TAG << " ( Git Hash " 52 | << LIBXSMM_VERSION_HASH << " )" 53 | << "\n"; 54 | 55 | return ss.str(); 56 | } 57 | 58 | } // namespace zentorch 59 | -------------------------------------------------------------------------------- /src/cpu/cpp/Config.hpp.in: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023 Advanced Micro Devices, Inc. 3 | * All rights reserved. 4 | ******************************************************************************/ 5 | 6 | #pragma once 7 | 8 | #define ZENDNN_LIB_VERSION_HASH "@ZENDNN_LIB_VERSION_HASH@" 9 | #define BLIS_VERSION_HASH "@BLIS_VERSION_HASH@" 10 | #define FBGEMM_VERSION_HASH "@FBGEMM_VERSION_HASH@" 11 | #define FBGEMM_VERSION_TAG "@FBGEMM_VERSION_TAG@" 12 | #define LIBXSMM_VERSION_HASH "@LIBXSMM_VERSION_HASH@" 13 | #define LIBXSMM_VERSION_TAG "@LIBXSMM_VERSION_TAG@" 14 | -------------------------------------------------------------------------------- /src/cpu/cpp/ConvUtils.hpp: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2024 Advanced Micro Devices, Inc. 3 | * All rights reserved. 4 | ******************************************************************************/ 5 | 6 | #pragma once 7 | 8 | #include "Memory.hpp" 9 | namespace zentorch { 10 | using namespace zendnn; 11 | 12 | inline std::vector 13 | get_conv_output_sizes(const at::Tensor &input, const at::Tensor &weight, 14 | const at::IntArrayRef &stride, 15 | const at::IntArrayRef &padding, 16 | const at::IntArrayRef &dilation) { 17 | 18 | // Convert the tensor to a list of integers 19 | const at::IntArrayRef &input_size = input.sizes(); 20 | const at::IntArrayRef &weight_size = weight.sizes(); 21 | 22 | bool has_dilation = !dilation.empty(); 23 | auto dim = input_size.size(); 24 | std::vector output_size(dim); 25 | 26 | output_size[0] = input_size[0]; 27 | output_size[1] = weight_size[0]; 28 | for (const auto d : c10::irange(2, dim)) { 29 | auto dilation_ = has_dilation ? dilation[d - 2] : 1; 30 | auto kernel = dilation_ * (weight_size[d] - 1) + 1; 31 | output_size[d] = 32 | (input_size[d] + (2 * padding[d - 2]) - kernel) / stride[d - 2] + 1; 33 | } 34 | return output_size; 35 | } 36 | 37 | // TODO: Use zen_memory_desc function from Memory.hpp for this purpose 38 | // and modify accordingly 39 | inline std::tuple 41 | conv_tensors_to_memory_desc(const at::Tensor &input, const at::Tensor &weight, 42 | const at::Tensor &bias, const at::Tensor &output) { 43 | 44 | memory::data_type dtype = input.dtype() == at::ScalarType::BFloat16 45 | ? memory::data_type::bf16 46 | : memory::data_type::f32; 47 | 48 | zendnn::memory::format_tag format_tag = 49 | input.is_contiguous(at::MemoryFormat::ChannelsLast) 50 | ? zendnn::memory::format_tag::nhwc 51 | : zendnn::memory::format_tag::nchw; 52 | 53 | std::vector dst_dims(output.sizes().begin(), output.sizes().end()); 54 | memory::desc dst_desc(dst_dims, dtype, format_tag); 55 | 56 | // Get the tensor's shape as a vector of int64_t 57 | std::vector src_dims(input.sizes().begin(), input.sizes().end()); 58 | std::vector weights_dims(weight.sizes().begin(), 59 | weight.sizes().end()); 60 | std::vector bias_dims(bias.sizes().begin(), bias.sizes().end()); 61 | 62 | // Create a descriptor with a different format tag 63 | memory::desc src_desc(src_dims, dtype, format_tag); 64 | memory::desc weights_desc(weights_dims, dtype, memory::format_tag::any); 65 | memory::desc bias_desc(bias_dims, dtype, memory::format_tag::x); 66 | memory::desc t_weights_desc(weights_dims, dtype, format_tag); 67 | 68 | std::tuple 70 | out; 71 | out = std::make_tuple(src_desc, weights_desc, bias_desc, dst_desc, 72 | t_weights_desc); 73 | 74 | return out; 75 | } 76 | 77 | inline void check_conv_inputs(const at::Tensor &input, const at::Tensor &weight, 78 | const at::IntArrayRef &dilation) { 79 | 80 | ZENTORCH_CHECK((input.dim() == 4 && weight.dim() == 4), 81 | "unsupported dims for conv input and weight"); 82 | 83 | ZENTORCH_CHECK((dilation[0] == 1 && dilation[1] == 1), 84 | "unsupported value of dilation, only [1,1] supported for now"); 85 | 86 | ZENTORCH_CHECK((input.dtype() == at::ScalarType::BFloat16 || 87 | input.dtype() == at::ScalarType::Float), 88 | "unsupported data type, only bf16 and fp32 supported for now"); 89 | } 90 | 91 | } // namespace zentorch 92 | -------------------------------------------------------------------------------- /src/cpu/cpp/Embed.cpp: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023-2025 Advanced Micro Devices, Inc. 3 | * All rights reserved. 4 | ******************************************************************************/ 5 | 6 | #include "EmbedUtils.hpp" 7 | #include "Memory.hpp" 8 | #define ZENDNN_EMBED_THRDS 16 9 | 10 | using namespace zendnn; 11 | 12 | namespace zentorch { 13 | at::Tensor zentorch_embedding_impl(const at::Tensor &weight, 14 | const at::Tensor &indices, 15 | int64_t padding_idx, bool scale_grad_by_freq, 16 | bool sparse, std::string zentorch_op_name) { 17 | LOG(INFO) << "[" << __FILE__ << ": " << __LINE__ << "] " 18 | << "Executing function: " << __FUNCTION__; 19 | 20 | at::Tensor cindices, output; 21 | memory z_weight, z_indices, z_dst; 22 | std::tie(cindices, output) = 23 | embed_tensors_to_memory(weight, indices, z_weight, z_indices, z_dst); 24 | 25 | // Currently there is no primitive for embedding as an op. 26 | // So, the manipulations on the embeddingbag op are taken care by the 27 | // ZenDNN library and the ZenDNN library call is made from the plugin side. 28 | LOG(INFO) << "Embedding compute in progress..."; 29 | zendnn_custom_op::zendnn_embedding( 30 | z_weight, z_indices, static_cast(padding_idx), 31 | scale_grad_by_freq, sparse, z_dst, zentorch_op_name.c_str()); 32 | 33 | LOG(INFO) << "Finished executing: " << __FUNCTION__ << "!\n"; 34 | 35 | return output; 36 | } 37 | 38 | std::vector zentorch_horizontal_embedding_group( 39 | at::TensorList weight, at::TensorList indices, at::IntArrayRef padding_idx, 40 | at::IntArrayRef scale_grad_by_freq, at::IntArrayRef sparse, 41 | std::string zentorch_op_name) { 42 | 43 | LOG(INFO) << "[" << __FILE__ << ": " << __LINE__ << "] " 44 | << "Executing function: " << __FUNCTION__; 45 | int num_eb_ops = weight.size(); 46 | 47 | std::vector z_weights(num_eb_ops); 48 | std::vector z_indices(num_eb_ops); 49 | std::vector z_padding_idx(num_eb_ops); 50 | std::vector z_scale_grad_by_freq(num_eb_ops); 51 | std::vector z_sparse(num_eb_ops); 52 | 53 | std::vector temp_indices(num_eb_ops); 54 | std::vector output(num_eb_ops); 55 | std::vector z_destination(num_eb_ops); 56 | 57 | at::parallel_for(0, num_eb_ops, 0, [&](int64_t start, int64_t end) { 58 | for (auto i = start; i < end; i++) { 59 | 60 | std::tie(temp_indices[i], output[i]) = embed_tensors_to_memory( 61 | weight[i], indices[i], z_weights[i], z_indices[i], z_destination[i]); 62 | 63 | z_padding_idx[i] = padding_idx[i]; 64 | z_scale_grad_by_freq[i] = scale_grad_by_freq[i]; 65 | z_sparse[i] = sparse[i]; 66 | } 67 | }); 68 | 69 | LOG(INFO) << "GroupEmbedding compute in progress..."; 70 | zendnn_custom_op::zendnn_grp_embedding( 71 | z_weights, z_indices, z_padding_idx, z_scale_grad_by_freq, z_sparse, 72 | z_destination, zentorch_op_name.c_str()); 73 | LOG(INFO) << "Finished executing: " << __FUNCTION__ << "!\n"; 74 | 75 | return output; 76 | } 77 | 78 | TORCH_LIBRARY_FRAGMENT(zentorch, m) { 79 | m.def("zentorch_embedding(Tensor weight, Tensor indices, " 80 | "int padding_idx=-1, bool scale_grad_by_freq=False, " 81 | "bool sparse=False, str " 82 | "zentorch_op_name='zentorch::zentorch_embedding') -> " 83 | "Tensor"); 84 | m.def( 85 | "zentorch_horizontal_embedding_group(Tensor[] weight, Tensor[] indices, " 86 | "int[] padding_idx, int[] scale_grad_by_freq, " 87 | "int[] sparse, str zentorch_op_name = " 88 | "'zentorch::zentorch_horizontal_embedding_group') -> Tensor[]"); 89 | } 90 | 91 | TORCH_LIBRARY_IMPL(zentorch, CPU, m) { 92 | m.impl("zentorch_embedding", zentorch_embedding_impl); 93 | m.impl("zentorch_horizontal_embedding_group", 94 | zentorch_horizontal_embedding_group); 95 | } 96 | } // namespace zentorch 97 | -------------------------------------------------------------------------------- /src/cpu/cpp/MoE.cpp: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Modifications Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | * All rights reserved. 4 | * 5 | * Was sourced from 6 | * https://github.com/intel/intel-extension-for-pytorch/blob/v2.4.0%2Bcpu/csrc/cpu/aten/kernels/MoEKrnl.cpp 7 | * IPEX commit ID: 070f1d7 8 | ******************************************************************************/ 9 | #include 10 | #if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR > 3 11 | #include "Utils.hpp" 12 | #include 13 | 14 | namespace zentorch { 15 | at::Tensor fuse_index_mul_index_add(const at::Tensor &curr_state, 16 | const at::Tensor &top_x, 17 | const at::Tensor &idx, 18 | const at::Tensor &routing_weights, 19 | const at::Tensor &output, 20 | std::string zentorch_op_name) { 21 | if (curr_state.scalar_type() != at::ScalarType::BFloat16) { 22 | ZENTORCH_CHECK(false, "zentorch::fuse_index_mul_index_add supports" 23 | " only bfloat16 datatype"); 24 | } 25 | using lpVec = at::vec::Vectorized; 26 | using fVec = at::vec::Vectorized; 27 | auto vec_size = lpVec::size(); 28 | auto topx_s0 = top_x.size(0); 29 | auto *output_ptr = output.data_ptr(); 30 | auto *curr_state_ptr = curr_state.data_ptr(); 31 | auto *routing_weights_ptr = routing_weights.data_ptr(); 32 | auto *top_x_ptr = top_x.data_ptr(); 33 | auto *idx_ptr = idx.data_ptr(); 34 | 35 | int64_t output_stride0 = output.stride(0); 36 | int64_t output_stride1 = output.stride(1); 37 | int64_t curr_state_size2 = curr_state.size(2); 38 | int64_t curr_state_stride1 = curr_state.stride(1); 39 | int64_t curr_state_stride2 = curr_state.stride(2); 40 | int64_t routing_weights_stride0 = routing_weights.stride(0); 41 | int64_t routing_weights_stride1 = routing_weights.stride(1); 42 | #pragma omp parallel for 43 | for (int i = 0; i < topx_s0; ++i) { 44 | int64_t rw_index = top_x_ptr[i] * routing_weights_stride0 + 45 | idx_ptr[i] * routing_weights_stride1; 46 | auto rw_v = lpVec(static_cast(routing_weights_ptr[rw_index])); 47 | for (int j = 0; j < curr_state_size2 - (curr_state_size2 % vec_size); 48 | j += vec_size) { 49 | int64_t cs_index = i * curr_state_stride1 + j * curr_state_stride2; 50 | int64_t output_index = top_x_ptr[i] * output_stride0 + j * output_stride1; 51 | auto cs_v = lpVec::loadu(curr_state_ptr + cs_index); 52 | auto out_v = lpVec::loadu(output_ptr + output_index); 53 | fVec rw_v1, rw_v2, cs_v1, cs_v2, out_v1, out_v2; 54 | std::tie(rw_v1, rw_v2) = at::vec::convert_to_float(rw_v); 55 | std::tie(cs_v1, cs_v2) = at::vec::convert_to_float(cs_v); 56 | std::tie(out_v1, out_v2) = at::vec::convert_to_float(out_v); 57 | auto output_v1 = out_v1 + cs_v1 * rw_v1; 58 | auto output_v2 = out_v2 + cs_v2 * rw_v2; 59 | at::vec::convert_from_float(output_v1, output_v2) 60 | .store(output_ptr + output_index); 61 | } 62 | for (int j = curr_state_size2 - (curr_state_size2 % vec_size); 63 | j < curr_state_size2; ++j) { 64 | int64_t cs_index = i * curr_state_stride1 + j * curr_state_stride2; 65 | int64_t output_index = top_x_ptr[i] * output_stride0 + j * output_stride1; 66 | output_ptr[output_index] += 67 | routing_weights_ptr[rw_index] * curr_state_ptr[cs_index]; 68 | } 69 | } 70 | return output; 71 | } 72 | 73 | TORCH_LIBRARY_FRAGMENT(zentorch, m) { 74 | m.def("fuse_index_mul_index_add(Tensor curr_state, Tensor top_x, " 75 | "Tensor idx, " 76 | "Tensor routing_weights, Tensor output, " 77 | "str zentorch_op_name='zentorch::fuse_index_mul_index_add') -> " 78 | "Tensor"); 79 | } 80 | TORCH_LIBRARY_IMPL(zentorch, CPU, m) { 81 | m.impl("fuse_index_mul_index_add", fuse_index_mul_index_add); 82 | } 83 | } // namespace zentorch 84 | #endif 85 | -------------------------------------------------------------------------------- /src/cpu/cpp/Ops.hpp: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023-2025 Advanced Micro Devices, Inc. 3 | * All rights reserved. 4 | ******************************************************************************/ 5 | 6 | // Declarations for ZenTorchOps (EmbedBag etc.) 7 | 8 | #pragma once 9 | 10 | #include 11 | 12 | namespace zentorch { 13 | 14 | at::Tensor zentorch_matmul_impl( 15 | const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, 16 | at::Tensor &self_or_result, const std::vector &post_op_ids, 17 | const std::vector &post_op_buffers, const float &beta, 18 | const float &alpha, std::string zentorch_op_name); 19 | 20 | at::Tensor zentorch_get_packed_embedding_weight(at::Tensor &weight, 21 | at::Tensor &weight_scales, 22 | at::Tensor &weight_zero_points); 23 | 24 | at::Tensor zentorch_weight_reorder_for_matmul(at::Tensor &weight, 25 | const bool &is_weight_oc_x_ic); 26 | 27 | std::string show_config(); 28 | } // namespace zentorch 29 | -------------------------------------------------------------------------------- /src/cpu/cpp/Rope.cpp: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Modifications Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | * All rights reserved. 4 | * 5 | * Was sourced from 6 | * https://github.com/intel/intel-extension-for-pytorch/blob/v2.6.0%2Bcpu/csrc/cpu/aten/kernels/RotaryPositionEmbeddingKnl.cpp 7 | * IPEX commit ID: 18eeefa 8 | ******************************************************************************/ 9 | 10 | #include "RopeUtils.hpp" 11 | #include 12 | #include 13 | #include 14 | 15 | namespace zentorch { 16 | 17 | std::tuple 18 | zentorch_rope_impl(at::Tensor &t_in, at::Tensor &t_emb_pos, at::Tensor &t_pos, 19 | int64_t N, // N: number of head, H: head size 20 | int64_t H, int64_t offset, int64_t rotary_dim, 21 | std::string zentorch_op_name) { 22 | t_in = t_in.contiguous(); 23 | t_emb_pos = t_emb_pos.contiguous(); 24 | t_pos = t_pos.contiguous(); 25 | // only supported types are fp32 and bf16 26 | if (t_in.scalar_type() == at::kFloat) { 27 | return zentorch::cpu::kernel::ApplyROPEKernel( 28 | t_in, t_emb_pos, t_pos, N, H, offset, rotary_dim); 29 | } else if (t_in.scalar_type() == at::kBFloat16) { 30 | return zentorch::cpu::kernel::ApplyROPEKernel( 31 | t_in, t_emb_pos, t_pos, N, H, offset, rotary_dim); 32 | } else { 33 | ZENTORCH_CHECK(false, "unsupported '", t_in.scalar_type(), "'"); 34 | return std::make_tuple(at::Tensor(), at::Tensor(), at::Tensor()); 35 | } 36 | } 37 | 38 | std::tuple 39 | zentorch_rope_deepseek_kernel_impl(at::Tensor &q, at::Tensor &kv, 40 | at::Tensor &k_pe, at::Tensor &t_emb_pos, 41 | at::Tensor &t_pos, 42 | int64_t N, // N: number of head, H: head size 43 | int64_t H, int64_t offset, 44 | int64_t rotary_dim) { 45 | q = q.contiguous(); 46 | kv = kv.contiguous(); 47 | k_pe = k_pe.contiguous(); 48 | t_emb_pos = t_emb_pos.contiguous(); 49 | t_pos = t_pos.contiguous(); 50 | if (q.scalar_type() == at::kFloat) { 51 | return zentorch::cpu::kernel::ApplyDeepseekROPEKernel( 52 | q, kv, k_pe, t_emb_pos, t_pos, N, H, offset, rotary_dim); 53 | } else if (q.scalar_type() == at::kBFloat16) { 54 | return zentorch::cpu::kernel::ApplyDeepseekROPEKernel( 55 | q, kv, k_pe, t_emb_pos, t_pos, N, H, offset, rotary_dim); 56 | } else { 57 | ZENTORCH_CHECK(false, "unsupported '", q.scalar_type(), "'"); 58 | return std::make_tuple(at::Tensor(), at::Tensor(), at::Tensor()); 59 | } 60 | } 61 | 62 | TORCH_LIBRARY_FRAGMENT(zentorch, m) { 63 | m.def("zentorch_rope(Tensor t_in, Tensor t_emb_pos, Tensor t_pos, int N, int " 64 | "H, int offset, int rotary_dim, str zentorch_op_name = " 65 | "'zentorch::zentorch_rope') -> (Tensor, Tensor, Tensor)"); 66 | m.def("zentorch_rope_deepseek(Tensor q, Tensor kv, Tensor k_pe, Tensor " 67 | "t_emb_pos, Tensor t_pos, int N, int H, int offset, int " 68 | "rotary_ndims)-> (Tensor, Tensor, Tensor)"); 69 | } 70 | 71 | TORCH_LIBRARY_IMPL(zentorch, CPU, m) { 72 | m.impl("zentorch_rope", zentorch_rope_impl); 73 | m.impl("zentorch_rope_deepseek", zentorch_rope_deepseek_kernel_impl); 74 | } 75 | 76 | } // namespace zentorch 77 | -------------------------------------------------------------------------------- /src/cpu/cpp/Singletons.cpp: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | * All rights reserved. 4 | ******************************************************************************/ 5 | 6 | #include "Utils.hpp" 7 | 8 | namespace zendnn { 9 | namespace utils { 10 | 11 | engine &engine::cpu_engine() { 12 | static engine cpu_engine(kind::cpu, 0); 13 | return cpu_engine; 14 | } 15 | 16 | } // namespace utils 17 | } // namespace zendnn 18 | -------------------------------------------------------------------------------- /src/cpu/cpp/Threading.cpp: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | * All rights reserved. 4 | ******************************************************************************/ 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "Threading.hpp" 13 | 14 | namespace zentorch { 15 | 16 | void thread_bind(const std::vector &cpu_core_list) { 17 | omp_set_num_threads(cpu_core_list.size()); 18 | 19 | #pragma omp parallel num_threads(cpu_core_list.size()) 20 | { 21 | int thread_index = omp_get_thread_num(); 22 | cpu_set_t cpuset; 23 | CPU_ZERO(&cpuset); 24 | CPU_SET(cpu_core_list[thread_index], &cpuset); 25 | if (pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset) != 26 | 0) { 27 | throw std::runtime_error("Fail to bind cores."); 28 | } 29 | } 30 | } 31 | 32 | } // namespace zentorch 33 | -------------------------------------------------------------------------------- /src/cpu/cpp/Threading.hpp: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | * All rights reserved. 4 | ******************************************************************************/ 5 | 6 | #include 7 | 8 | namespace zentorch { 9 | 10 | /** 11 | * @brief Accepts a list of cores and binds the current process 12 | * to the gives cores 13 | * 14 | * @param cpu_core_list Lists of cores for the process to be binded 15 | */ 16 | void thread_bind(const std::vector &cpu_core_list); 17 | 18 | } // namespace zentorch 19 | -------------------------------------------------------------------------------- /src/cpu/cpp/Utils.hpp: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | * All rights reserved. 4 | ******************************************************************************/ 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | 19 | // TODO: Make the __FILE__ give the name of the file relative to only 20 | // ZenDNN_PyTorch_Plugin 21 | #define ZENTORCH_CHECK(condition, ...) \ 22 | TORCH_CHECK(condition, __FILE__, ":", __LINE__, " ", __FUNCTION__, " : ", \ 23 | ##__VA_ARGS__) 24 | 25 | namespace zentorch { 26 | 27 | // zentorch:: Check if m/c supports AVX512/AVX256 28 | inline bool is_avx512_supported() { 29 | return cpuinfo_initialize() && cpuinfo_has_x86_avx512f() && 30 | cpuinfo_has_x86_avx512vl() && cpuinfo_has_x86_avx512dq() && 31 | cpuinfo_has_x86_avx512vnni() && cpuinfo_has_x86_avx512bf16() && 32 | cpuinfo_has_x86_avx512bw(); 33 | } 34 | 35 | enum UNARY_POST_OP { 36 | // Add unary post ops here 37 | POST_OP_NONE, 38 | RELU, 39 | GELU_TANH, 40 | GELU_ERF, 41 | SILU, 42 | SIGMOID, 43 | // Add unary post op before this, 44 | // if you add any post op 45 | // update UNARY_OP_COUNT by that post op. 46 | UNARY_OP_COUNT = SIGMOID 47 | }; 48 | // Initializing the first enum in BINARY_POST_OP so that all post ops will have 49 | // unique value. 50 | enum BINARY_POST_OP { MUL = UNARY_POST_OP::UNARY_OP_COUNT + 1, ADD }; 51 | 52 | // Each value of QUANT_GRANULARITY enum indicates the mappings for various 53 | // quantization granularity levels(PER_TENSOR/PER_CHANNEL/PER_GROUP) 54 | // with the zendnn library's tensor mask values. 55 | enum QUANT_GRANULARITY { PER_TENSOR = 0, PER_CHANNEL = 2, PER_GROUP = 3 }; 56 | } // namespace zentorch 57 | 58 | namespace zendnn { 59 | 60 | using kind = zendnn::primitive::kind; 61 | 62 | namespace utils { 63 | // CPU execution engine only. 64 | struct engine : public zendnn::engine { 65 | 66 | // Singleton CPU engine for all primitives 67 | static engine &cpu_engine(); 68 | 69 | engine(kind akind = kind::cpu, size_t index = 0) 70 | : zendnn::engine(akind, index) {} 71 | }; 72 | 73 | // A default stream 74 | struct stream : public zendnn::stream { 75 | static zendnn::stream &default_stream() { 76 | static zendnn::stream s(engine::cpu_engine()); 77 | return s; 78 | } 79 | }; 80 | 81 | // Check AVX512 bf16 support 82 | inline bool zendnn_bf16_device_check() { 83 | return cpuinfo_initialize() && cpuinfo_has_x86_avx512bf16(); 84 | } 85 | 86 | } // namespace utils 87 | } // namespace zendnn 88 | -------------------------------------------------------------------------------- /src/cpu/cpp/WeightReorder.cpp: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | * All rights reserved. 4 | ******************************************************************************/ 5 | 6 | #include "Memory.hpp" 7 | 8 | namespace zentorch { 9 | 10 | using namespace zendnn; 11 | 12 | // Reorder the weight tensor for quantized matmul performance 13 | // weight: 2D tensor to be reordered 14 | // is_weight_oc_x_ic: true if the weight tensor is in OCxIC format 15 | // Return: reordered weight tensor 16 | // 17 | // Normally Linear layers store the weight tensor in 18 | // output_channel(OC) x input_channel(IC) format, but matmul 19 | // operation works on input_channel(IC) x output_channel(OC) 20 | // format. So we need pass is_weight_oc_x_ic as true if weight is 21 | // in OCxIC format to apporpriately reorder the weight tensor. 22 | inline at::Tensor 23 | zentorch_weight_reorder_for_matmul(at::Tensor &weight, 24 | const bool &is_weight_oc_x_ic) { 25 | ZENTORCH_CHECK(weight.scalar_type() == at::kChar, 26 | "only int8 weight is supported"); 27 | ZENTORCH_CHECK(weight.dim() == 2, 28 | "only 2-dimensional weight tensor is supported"); 29 | ZENTORCH_CHECK(weight.is_contiguous(), 30 | "reorder of weight tensor which is stored as contiguous is " 31 | "only supported") 32 | 33 | // Execute the zendnn_custom_op::zendnn_reorder api to reorder the weight 34 | if (is_weight_oc_x_ic) { 35 | zendnn_custom_op::zendnn_reorder(/*src*/ weight.data_ptr(), 36 | /*dst*/ weight.data_ptr(), 37 | /*k*/ weight.size(1), /*n*/ weight.size(0), 38 | /*trans*/ true, /*dtype*/ zendnn_s8); 39 | } else { 40 | zendnn_custom_op::zendnn_reorder(/*src*/ weight.data_ptr(), 41 | /*dst*/ weight.data_ptr(), 42 | /*k*/ weight.size(0), /*n*/ weight.size(1), 43 | /*trans*/ false, /*dtype*/ zendnn_s8); 44 | } 45 | return weight; 46 | } 47 | 48 | TORCH_LIBRARY_FRAGMENT(zentorch, m) { 49 | m.def("zentorch_weight_reorder_for_matmul(Tensor weight, bool " 50 | "is_weight_oc_x_ic=True) -> Tensor"); 51 | } 52 | 53 | TORCH_LIBRARY_IMPL(zentorch, CPU, m) { 54 | m.impl("zentorch_weight_reorder_for_matmul", 55 | zentorch::zentorch_weight_reorder_for_matmul); 56 | } 57 | 58 | } // namespace zentorch 59 | -------------------------------------------------------------------------------- /src/cpu/cpp/kernels/vec/utils.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Modifications Copyright (c) 2024 Advanced Micro Devices, Inc. 3 | * All rights reserved. 4 | * 5 | * Was sourced from 6 | * https://github.com/intel/intel-extension-for-pytorch/blob/v2.3.0%2Bcpu/csrc/cpu/vec/ref 7 | * IPEX commit ID: d3c52443e 8 | ******************************************************************************/ 9 | 10 | // zentorch:: APIs that are suffix with 256 will be invoked on AVX256 m/c 11 | 12 | // TODO:: Observing seg-falut issues when enabling omp pragmas in AVX256 13 | // implementations will be fixed in later patches 14 | #pragma once 15 | #define ZENTORCH_FORCE_INLINE inline __attribute__((always_inline)) 16 | 17 | namespace zentorch { 18 | template 19 | ZENTORCH_FORCE_INLINE void move_ker_ref(dst_type *inout, const src_type *in, 20 | int64_t len) { 21 | #pragma omp simd 22 | for (int64_t i = 0; i < len; i++) { 23 | *(inout + i) = *(in + i); 24 | } 25 | } 26 | 27 | template 28 | ZENTORCH_FORCE_INLINE void move_ker(dst_type *inout, const src_type *in, 29 | int64_t len) { 30 | #pragma omp simd 31 | for (int64_t i = 0; i < len; i++) { 32 | *(inout + i) = *(in + i); 33 | } 34 | } 35 | 36 | template 37 | ZENTORCH_FORCE_INLINE void zero_ker_ref(T *out, int64_t len) { 38 | #pragma omp simd 39 | for (int64_t i = 0; i < len; i++) { 40 | *(out + i) = 0; 41 | } 42 | } 43 | 44 | template ZENTORCH_FORCE_INLINE void zero_ker(T *out, int64_t len) { 45 | #pragma omp simd 46 | for (int64_t i = 0; i < len; i++) { 47 | *(out + i) = 0; 48 | } 49 | } 50 | 51 | template 52 | ZENTORCH_FORCE_INLINE void add_ker(dst_type *inout, const src_type *in, 53 | int64_t len) { 54 | #pragma omp simd 55 | for (int64_t i = 0; i < len; i++) { 56 | *(inout + i) += *(in + i); 57 | } 58 | } 59 | 60 | template 61 | ZENTORCH_FORCE_INLINE void add_ker_ref(dst_type *inout, const src_type *in, 62 | int64_t len) { 63 | #pragma omp simd 64 | for (int64_t i = 0; i < len; i++) { 65 | *(inout + i) += *(in + i); 66 | } 67 | } 68 | 69 | } // namespace zentorch -------------------------------------------------------------------------------- /src/cpu/cpp/kernels/zen_cpukernels.hpp: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Modifications Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | * All rights reserved. 4 | ******************************************************************************/ 5 | 6 | namespace zentorch { 7 | 8 | std::tuple 9 | masked_multihead_self_attention_kernel_impl_512( 10 | at::Tensor &query, at::Tensor &key, at::Tensor &value, 11 | at::Tensor &key_cache, at::Tensor &value_cache, at::Tensor &beam_idx, 12 | at::Tensor seq_info, const double scale_attn, int64_t max_positions, 13 | const c10::optional &head_mask /* optional */, 14 | const c10::optional &attention_mask /* optional */, 15 | c10::optional add_causal_mask /* optional */); 16 | 17 | template 18 | void flash_attention_kernel_impl_512( 19 | const at::Tensor &output, const at::Tensor &logsumexp, 20 | const at::Tensor &query, const at::Tensor &key, const at::Tensor &value, 21 | double dropout_p, bool is_causal, std::optional attn_mask, 22 | std::optional scale); 23 | } // namespace zentorch 24 | -------------------------------------------------------------------------------- /src/cpu/python/zentorch/_C/__init__.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2023-2024 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | -------------------------------------------------------------------------------- /src/cpu/python/zentorch/__init__.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2023-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | from ._build_info import __torchversion__ as buildtime_torchversion 7 | from torch.torch_version import __version__ as runtime_torchversion 8 | 9 | # Pytorch lacks symbol-level compatibility, requiring extensions 10 | # to be pinned to the same minor version. To avoid issues, it is 11 | # necessary to error out if the runtime Pytorch version 12 | # differs from the build-time version. 13 | 14 | if runtime_torchversion[:3] != buildtime_torchversion[:3]: 15 | raise ImportError( 16 | f"Incompatible PyTorch version {runtime_torchversion} detected. " 17 | f"The installed zentorch binary is only compatible " 18 | f"with PyTorch versions {buildtime_torchversion[:3]}.x" 19 | ) 20 | 21 | from ._optimize import optimize # noqa 22 | from ._info import __config__, __version__ # noqa 23 | from ._compile_backend import * # noqa 24 | from ._meta_registrations import * # noqa 25 | from ._freeze_utils import freezing_enabled # noqa 26 | 27 | # llm optimizations 28 | from . import llm # noqa 29 | 30 | # model reload utility for quantized models 31 | from ._quant_model_reload import load_quantized_model, load_woq_model # noqa F401 32 | from . import utils # noqa F401 33 | -------------------------------------------------------------------------------- /src/cpu/python/zentorch/_freeze_utils.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import torch 7 | import contextlib 8 | from ._freezing import freeze 9 | 10 | 11 | @contextlib.contextmanager 12 | def freezing_enabled(): 13 | # read previous/default values of freeze, this works even if 14 | # the user has set the values manually as we restore to that value 15 | previous_freeze_config = torch._inductor.config.freezing 16 | previous_freeze_path = torch._inductor.freezing.freeze 17 | # monkey patch pytorch freeze 18 | torch._inductor.config.freezing = True 19 | torch._inductor.freezing.freeze = freeze 20 | yield 21 | # reset to the previous values 22 | torch._inductor.config.freezing = previous_freeze_config 23 | torch._inductor.freezing.freeze = previous_freeze_path 24 | -------------------------------------------------------------------------------- /src/cpu/python/zentorch/_freezing.py: -------------------------------------------------------------------------------- 1 | # *************************************************************************** 2 | # Modifications Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # 5 | # Was sourced from: 6 | # https://github.com/pytorch/pytorch/blob/v2.4.1/torch/_inductor/freezing.py 7 | # *************************************************************************** 8 | 9 | import torch 10 | from torch._inductor.freezing import ( 11 | replace_params_with_constants, 12 | invalidate_eager_modules, 13 | discard_traced_gm_params, 14 | ) 15 | from torch._inductor.compile_fx import fake_tensor_prop 16 | from torch._inductor.pattern_matcher import stable_topological_sort 17 | from torch._inductor.fx_passes.post_grad import view_to_reshape 18 | from torch._functorch.compile_utils import fx_graph_cse 19 | from typing import List, Tuple 20 | from ._logging import get_logger 21 | from ._optimize import optimize 22 | from ._utils import is_version_compatible_import 23 | 24 | if is_version_compatible_import(["_inductor", "constant_folding"], ["constant_fold"]): 25 | from torch._inductor.constant_folding import constant_fold 26 | else: 27 | from torch._inductor.freezing import constant_fold # for PT 2.1.x 28 | if is_version_compatible_import(["fx", "_utils"], ["lazy_format_graph_code"]): 29 | from torch.fx._utils import lazy_format_graph_code 30 | else: 31 | from torch._dynamo.utils import lazy_format_graph_code # for PT 2.1.x, 2.2.x, 2.3.x 32 | 33 | logger = get_logger(__name__) 34 | 35 | 36 | # freeze to monkey-patch in PyTorch 37 | # constant propogation is unsupported in forward_compiler_base 38 | # pass, we have to define a custom freeze function and use that 39 | # in forward_compiler_freezing to avoid unnecessary/inapplicable 40 | # optimizations in native (like mkldnn); selectively porting 41 | # constant_fold logic also results in multiple downstream errors. 42 | def freeze( 43 | dynamo_gm: torch.fx.GraphModule, 44 | aot_autograd_gm: torch.fx.GraphModule, 45 | example_inputs: List[torch._subclasses.FakeTensor], 46 | ) -> Tuple[torch.fx.GraphModule, List[int]]: 47 | logger.info("Optimizing the model with zentorch ops.") 48 | zen_gm = optimize(aot_autograd_gm) 49 | # we do view to reshape to avoid lowering exception 50 | view_to_reshape(zen_gm) 51 | 52 | if tracing_context := torch._guards.TracingContext.try_get(): 53 | fw_metadata = tracing_context.fw_metadata 54 | if hasattr(tracing_context, "params_flat_unwrap_subclasses"): 55 | # added in PT 2.6.0 56 | assert tracing_context.params_flat_unwrap_subclasses is not None 57 | params_flat = tracing_context.params_flat_unwrap_subclasses 58 | else: 59 | # for PT 2.5.x and below 60 | params_flat = tracing_context.params_flat 61 | assert fw_metadata is not None and params_flat is not None 62 | 63 | preserved_arg_indices = replace_params_with_constants( 64 | zen_gm, params_flat, fw_metadata 65 | ) 66 | else: 67 | inputs = zen_gm.graph.find_nodes(op="placeholder") 68 | preserved_arg_indices = list(range(len(inputs))) 69 | 70 | # we eliminate commom subexpressions from the graph (CSE) 71 | logger.info("Running common subexpression elimination on the fx-graph.") 72 | cse_graph = fx_graph_cse(zen_gm.graph) 73 | zen_gm.graph = cse_graph 74 | zen_gm.recompile() 75 | 76 | aot_example_inputs = [example_inputs[ind] for ind in preserved_arg_indices] 77 | fake_tensor_prop(zen_gm, aot_example_inputs, force_allow_non_fake_inputs=True) 78 | 79 | logger.info("Constant folding the fx-graph.") 80 | constant_fold(zen_gm) 81 | fake_tensor_prop(zen_gm, aot_example_inputs, force_allow_non_fake_inputs=True) 82 | stable_topological_sort(zen_gm.graph) 83 | zen_gm.recompile() 84 | zen_gm.graph.lint() 85 | 86 | # invalidate nn Modules 87 | if torch._inductor.config.freezing_discard_parameters: 88 | invalidate_eager_modules() 89 | discard_traced_gm_params(dynamo_gm) 90 | 91 | logger.debug("%s", lazy_format_graph_code("FROZEN GRAPH", zen_gm)) 92 | 93 | return zen_gm, preserved_arg_indices 94 | -------------------------------------------------------------------------------- /src/cpu/python/zentorch/_fusion_matcher.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import torch 7 | import inspect 8 | 9 | # pattern matcher relevant imports 10 | from torch._inductor.pattern_matcher import ( 11 | PatternMatcherPass, 12 | stable_topological_sort, 13 | init_once_fakemode, 14 | ) 15 | from torch._inductor import config 16 | from ._logging import get_logger 17 | 18 | logger = get_logger(__name__) 19 | 20 | matcher_pass = PatternMatcherPass() 21 | 22 | 23 | # for fake tensors 24 | @init_once_fakemode 25 | def lazy_init(): 26 | from ._fusion_patterns import _replace_init 27 | 28 | _replace_init() 29 | 30 | 31 | # applies the registered patterns to fx_graph 32 | def fusions_graph_pass(gm: torch.fx.GraphModule): 33 | 34 | # In PyTorch versions > 2.1, the arguments in `kwargs` are treated as optional. 35 | # This allows fusion to work whether `kwargs` are provided or not. 36 | # However, in PyTorch versions <= 2.1, all arguments are considered part of the 37 | # function signature. 38 | # As a result, if some arguments are not explicitly defined, it raises an error, 39 | # complaining that the signature expects x + 2 arguments, but only x are provided in 40 | # pattern matcher. 41 | # For this reason we are disabling pattern matcher for 2.1 42 | 43 | sig = inspect.signature(torch._inductor.pattern_matcher.register_replacement) 44 | 45 | if "search_fn_pattern" not in sig.parameters: 46 | return gm 47 | 48 | lazy_init() 49 | count = 0 50 | if config.pattern_matcher: 51 | count += matcher_pass.apply(gm.graph) 52 | else: 53 | logger.info( 54 | "Inductor config for pattern matching is set to False," 55 | + " no matcher passes will be run." 56 | ) 57 | if count: 58 | stable_topological_sort(gm.graph) 59 | gm.graph.lint() 60 | gm.recompile() 61 | return gm 62 | -------------------------------------------------------------------------------- /src/cpu/python/zentorch/_graph_cleanup.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import torch 7 | import operator 8 | 9 | # import the custom logging module 10 | from ._logging import get_logger 11 | 12 | # make a logger for this file 13 | logger = get_logger(__name__) 14 | 15 | at_ops = torch.ops.aten 16 | 17 | 18 | def unused_node_elimination(fx_graph: torch.fx.GraphModule): 19 | 20 | # TODO 21 | # This function will always be under progress as this will undergo 22 | # continuous enhancement and changes as generalization of removal of unused 23 | # nodes and specialization for removal of special unused nodes will keep 24 | # changing based on the scenarios and the models we encounter. 25 | 26 | """ 27 | unused_node_elimination: 28 | removes the nodes with no users from the fx_graph 29 | """ 30 | logger.info("Removing unused nodes from the fx_graph.") 31 | 32 | # Why the following nodes ? 33 | # operator.getitem 34 | # We want to remove the getitem nodes from the graph which are the outputs 35 | # of the aten embedding bag if they don't have any users. Based on the 36 | # historical evidence, for inference, only the first output of the aten 37 | # embedding bag is used and the other getitem nodes are not used. So, we 38 | # can safely remove them from the graph. 39 | # at_ops.clone.default 40 | # Clone nodes are used to duplicate a tensor with its properties and 41 | # contents to be used by some other node in the graph or model. If its 42 | # output is not used by any other node in the graph, then it is unused 43 | # and can be removed. 44 | # at_ops.view.default 45 | # View nodes are used to reshape the tensor. If its output is not used by 46 | # any other node in the graph, then it is an unused node and can be 47 | # removed. 48 | # at_ops.detach.default 49 | # Detach node can impact the number of users. If the output of the detach 50 | # node is not used by any other node in the graph, then it is an unused 51 | # node and can be removed. 52 | 53 | supported_nodes_for_removal = { 54 | operator.getitem, 55 | at_ops.clone.default, 56 | at_ops.view.default, 57 | at_ops.detach.default, 58 | } 59 | 60 | for node in fx_graph.graph.nodes: 61 | if (node.target in supported_nodes_for_removal) and (len(node.users) == 0): 62 | fx_graph.graph.erase_node(node) 63 | 64 | return fx_graph 65 | -------------------------------------------------------------------------------- /src/cpu/python/zentorch/_graph_preprocess_matcher.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import torch 7 | 8 | # pattern matcher relevant imports 9 | from torch._inductor.pattern_matcher import ( 10 | PatternMatcherPass, 11 | stable_topological_sort, 12 | init_once_fakemode, 13 | ) 14 | from torch._inductor import config 15 | from ._logging import get_logger 16 | 17 | logger = get_logger(__name__) 18 | 19 | matcher_pass = PatternMatcherPass() 20 | 21 | 22 | # for fake tensors 23 | @init_once_fakemode 24 | def lazy_init(): 25 | from ._graph_preprocess_patterns import _replace_init 26 | 27 | _replace_init() 28 | 29 | 30 | # applies the registered patterns to fx_graph 31 | def preprocess_graph_pass(gm: torch.fx.GraphModule): 32 | lazy_init() 33 | count = 0 34 | if config.pattern_matcher: 35 | count += matcher_pass.apply(gm.graph) 36 | else: 37 | logger.info( 38 | "Inductor config for pattern matching is set to False," 39 | + " no matcher passes will be run." 40 | ) 41 | if count: 42 | stable_topological_sort(gm.graph) 43 | gm.graph.lint() 44 | gm.recompile() 45 | return gm 46 | -------------------------------------------------------------------------------- /src/cpu/python/zentorch/_info.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2023-2024 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import zentorch._C 7 | import sys 8 | 9 | __config__ = zentorch._C.show_config() 10 | if sys.version_info >= (3, 8): 11 | from importlib import metadata 12 | __version__ = metadata.version('zentorch') 13 | else: 14 | __version__ = '' 15 | -------------------------------------------------------------------------------- /src/cpu/python/zentorch/_logging.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2023-2024 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import os 7 | import logging 8 | 9 | 10 | def get_logger(__name__): 11 | """ 12 | get_logger: 13 | takes in the filename and, 14 | returns a logger based on logging.conf file 15 | """ 16 | # define a message format 17 | FORMAT = "[%(levelname)s %(name)s - %(funcName)s:%(lineno)d] %(message)s" 18 | logging.basicConfig(format=FORMAT) 19 | # make a logger for this file 20 | logger = logging.getLogger(__name__) 21 | # check if user has set some logging level 22 | if os.environ.get("ZENTORCH_PY_LOG_LEVEL") is not None: 23 | logger.setLevel(os.environ.get("ZENTORCH_PY_LOG_LEVEL")) 24 | 25 | return logger 26 | -------------------------------------------------------------------------------- /src/cpu/python/zentorch/_quantization_utils.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | from typing import Union, Dict, Any, Tuple, Iterable 7 | import torch.nn as nn 8 | 9 | 10 | def set_op_by_name( 11 | layer: Union[nn.Module, nn.ModuleList], name: str, new_module: nn.Module 12 | ) -> None: 13 | """ 14 | Replaces a submodule in a given neural network layer with a new module 15 | (e.g. quantized module). The submodule to be replaced is identified by 16 | the 'name' parameter, which specifies the name of the submodule using 17 | dot notation. If the name includes dots, it navigates through nested 18 | submodules to find the specific layer to replace. Otherwise, it 19 | directly replaces the submodule in the provided layer. 20 | 21 | Parameters: 22 | - layer: The top-level module containing the submodule. 23 | - name: name of the submodule, split by dots. 24 | - new_module: The new module to replace the existing one, 25 | for example the quantized module. 26 | """ 27 | levels = name.split(".") 28 | if len(levels) > 1: 29 | mod_ = layer 30 | for l_idx in range(len(levels) - 1): 31 | if levels[l_idx].isdigit() and isinstance(mod_, nn.ModuleList): 32 | mod_ = mod_[int(levels[l_idx])] 33 | else: 34 | mod_ = getattr(mod_, levels[l_idx]) 35 | setattr(mod_, levels[-1], new_module) 36 | else: 37 | setattr(layer, name, new_module) 38 | 39 | 40 | def get_name_and_info( 41 | model_info: Dict[str, Any], parent_key: str = "" 42 | ) -> Iterable[Tuple[str, Dict[str, Any]]]: 43 | for key, value in model_info.items(): 44 | new_key = f"{parent_key}.{key}" if parent_key else key 45 | if isinstance(value, dict): 46 | if ( 47 | value.get("type", None) is not None 48 | and value.get("weight", None) is not None 49 | ): 50 | yield new_key, value 51 | else: 52 | yield from get_name_and_info(value, new_key) 53 | else: 54 | continue 55 | -------------------------------------------------------------------------------- /src/cpu/python/zentorch/_utils.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2023-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import torch 7 | from torch.fx import passes 8 | from os import environ 9 | import collections 10 | import importlib.util 11 | from typing import List 12 | 13 | counters = collections.defaultdict(collections.Counter) 14 | 15 | 16 | # getattr can result in false negatives if the submodule 17 | # isn't already imported in __init.py__ 18 | # To check if a submodule exists without importing it, 19 | # we use importlib.util.find_spec 20 | def is_version_compatible_import(modules: List[str], functions: List[str]) -> bool: 21 | """ 22 | Checks if the specified modules and functions exist in the current 23 | version of PyTorch. 24 | The check is done sequentially for each module and function. 25 | 26 | Args: 27 | modules (list): A list of module names to check sequentially 28 | in torch (e.g., [_x1, x2]). 29 | functions (list): A list of function names to check for within 30 | the final module (e.g., [a1, a2]). 31 | 32 | Returns: 33 | bool: True if all modules and functions are available in the current 34 | PyTorch version, False otherwise. 35 | """ 36 | current_module = torch # Start with the base 'torch' module 37 | full_name = "torch" 38 | # Sequentially check if each module exists in the hierarchy 39 | for module_name in modules: 40 | full_name = f"{full_name}.{module_name}" 41 | spec = importlib.util.find_spec(full_name) 42 | if spec is None: 43 | return False 44 | 45 | # Move to the next level of module 46 | current_module = importlib.import_module(f"{full_name}") 47 | 48 | # Check if the functions exist in the final module 49 | for func in functions: 50 | if not hasattr(current_module, func): 51 | return False 52 | 53 | # If all checks pass 54 | return True 55 | 56 | 57 | def save_graph(fx_graph, graph_name): 58 | env_var = "ZENTORCH_SAVE_GRAPH" 59 | if env_var in environ and environ[env_var] == "1": 60 | g = passes.graph_drawer.FxGraphDrawer(fx_graph, graph_name) 61 | with open(f"{graph_name}.svg", "wb") as f: 62 | f.write(g.get_dot_graph().create_svg()) 63 | 64 | 65 | def add_version_suffix(major: str, minor: str, patch: str = 0): 66 | # This function will add a ".dev" substring to the input arguments. 67 | # This will extend the pytorch version comparisions done using TorchVersion 68 | # class to include nightly and custom build versions as well. 69 | # The following tables shows the behaviour of TorchVersion comparisons 70 | # for release, nightlies and custom binaries, when the substring is used. 71 | # ".dev" is added to second column i.e A.B.C -> A.B.C.dev 72 | 73 | # This function is intended for only lesser than comparisons. 74 | 75 | # X.Y.Z < A.B.C 76 | # +---------------+----------------+-----------------+ 77 | # | Torch Version | Torch Version | Implementation | 78 | # | used by user | to be | Behaviour | 79 | # | (X.Y.Z) | compared with | | 80 | # | | (A.B.C) | | 81 | # +---------------+----------------+-----------------+ 82 | # | 2.3.1 | 2.4.0 | True | 83 | # +---------------+----------------+-----------------+ 84 | # | 2.4.0 | 2.4.0 | False | 85 | # +---------------+----------------+-----------------+ 86 | # | 2.4.0.dev | 2.4.0 | False | 87 | # | (Nightly | | | 88 | # | binaries) | | | 89 | # +---------------+----------------+-----------------+ 90 | # | 2.5.0.dev | 2.4.0 | False | 91 | # | 2.6.0.dev | | | 92 | # | (Nightly | | | 93 | # | binaries) | | | 94 | # +---------------+----------------+-----------------+ 95 | # | 2.4.0a0+git | 2.4.0 | False | 96 | # | d990dad | | | 97 | # +---------------+----------------+-----------------+ 98 | 99 | return f"{major}.{minor}.{patch}.dev" 100 | -------------------------------------------------------------------------------- /src/cpu/python/zentorch/llm/__init__.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | from ._optimize import optimize # noqa 7 | from ._checks import SUPPORTED_MODELS # noqa 8 | -------------------------------------------------------------------------------- /src/cpu/python/zentorch/llm/_checks.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import torch 7 | from torch.torch_version import TorchVersion 8 | from .._logging import get_logger 9 | 10 | # make a logger for this file 11 | logger = get_logger(__name__) 12 | 13 | # This set contains the strings found in the model.config.architectures[0], for 14 | # a valid huggingface transformer model 15 | SUPPORTED_MODELS = { 16 | "Qwen2ForCausalLM", 17 | "ChatGLMModel", 18 | "GPTJForCausalLM", 19 | "LlamaForCausalLM", 20 | "PhiForCausalLM", 21 | "Phi3ForCausalLM", 22 | "MistralForCausalLM", 23 | "GPTNeoXForCausalLM", 24 | "OPTForCausalLM", 25 | "BloomForCausalLM", 26 | "CodeGenForCausalLM", 27 | "GPTBigCodeForCausalLM", 28 | "StableLmForCausalLM", 29 | "GitForCausalLM", 30 | "MixtralForCausalLM", 31 | "QWenLMHeadModel", 32 | "YuanForCausalLM", 33 | } 34 | 35 | 36 | def get_installed_ipex_version(): 37 | # Previous approach made use of freeze API from pip._internal.operations 38 | # This caused the script to error out in certain cases. This was due to 39 | # conflicts in imports of distutils used in pip and setuptools. To avoid 40 | # the above situation, the usage of importlib.metadata.version is done. The 41 | # usage of importlib.metadata is the recommended way of achieving this. 42 | # This will not actually import module, but will find the version from 43 | # metadata stored in dist-info or egg-info. 44 | 45 | from importlib.metadata import version, PackageNotFoundError 46 | try: 47 | return version("intel_extension_for_pytorch") 48 | except PackageNotFoundError: 49 | return None 50 | 51 | 52 | def essential_checks(model, dtype): 53 | if hasattr(model, "config") and hasattr(model.config, "architectures"): 54 | is_well_supported_model = model.config.architectures[0] in SUPPORTED_MODELS 55 | 56 | if is_well_supported_model: 57 | installed_ipex_version = get_installed_ipex_version() 58 | if installed_ipex_version: 59 | # Zentorch will work with IPEX of atleast 2.3 60 | min_version = TorchVersion("2.3") 61 | installed_ipex_version = TorchVersion(installed_ipex_version) 62 | 63 | if installed_ipex_version >= min_version: 64 | # All checks good... 65 | if dtype != torch.bfloat16: 66 | logger.warning( 67 | "The supported datatype for the most optimal " 68 | "performance with zentorch is bfloat16." 69 | ) 70 | return False 71 | return True 72 | else: 73 | logger.warning( 74 | "zentorch.llm.optimize requires IPEX: at least " 75 | f"{min_version} but your IPEX is " 76 | f"{installed_ipex_version}. Some of the ZenTorch " 77 | "specific optimizations for LLMs might not be " 78 | "triggered." 79 | ) 80 | return False 81 | 82 | else: 83 | logger.warning( 84 | "Intel Extension for PyTorch not installed. So, the " 85 | "ZenTorch specific optimizations for LLMs might not " 86 | "be triggered." 87 | ) 88 | return False 89 | else: 90 | logger.warning( 91 | "Complete set of optimizations are currently unavailable" 92 | " for this model." 93 | ) 94 | return False 95 | else: 96 | logger.warning( 97 | "Cannot detect the model transformers family by " 98 | "model.config.architectures. Please pass a valid HuggingFace LLM " 99 | "model to the zentorch.llm.optimize API.", 100 | ) 101 | return False 102 | -------------------------------------------------------------------------------- /src/cpu/python/zentorch/llm/_optimize.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import torch 7 | from ._checks import essential_checks 8 | 9 | 10 | def optimize(model, dtype=torch.bfloat16): 11 | if essential_checks(model, dtype): 12 | import intel_extension_for_pytorch as ipex 13 | from ._model_conversion_functions import model_convert_lowering, customize_model 14 | 15 | ipex_t = ipex.transformers 16 | # Runtime over-riding of IPEX model_convert_lowering with ZenTorch 17 | # model_convert_lowering. So, after this line whenever IPEX 18 | # model_convert_lowering is called control would go to ZenTorch 19 | # model_convert_lowering. 20 | model = customize_model(model) 21 | ipex_t.optimize.model_convert_lowering = model_convert_lowering 22 | 23 | model = ipex.llm.optimize(model, optimizer=None, dtype=dtype) 24 | 25 | return model 26 | -------------------------------------------------------------------------------- /src/cpu/python/zentorch/utils.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2023-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | from zentorch._C import thread_bind # noqa F401 7 | -------------------------------------------------------------------------------- /test/README.md: -------------------------------------------------------------------------------- 1 | # About Tests 2 | 3 | ## Overview 4 | 5 | All the tests have been divided into the following groups 6 | 7 | - **unittests** :- This folder consists of all the unit tests for Zentorch. 8 | To run these tests use 9 | 10 | ```bash 11 | python -m unittest discover -s ./test/unittests 12 | ``` 13 | 14 | - **_model_tests_** :- These are the unit tests consisting of Custom Models. 15 | - **_op_tests_** :- These are the unit tests for individual ops or multiple ops not in form of Custom Models. 16 | - **_miscellaneous_tests_** :- These are the unit tests for miscellaneous(non-op) things. 17 | 18 | - **llm_tests** :- These are the tests for major operators used in LLMs 19 | To run these tests use 20 | 21 | ```bash 22 | python -m unittest discover -s ./test/llm_tests 23 | ``` 24 | 25 | - **pre_trained_models** :- These are tests for pre-trained models constisting of multiple operators. 26 | To run these tests use 27 | 28 | ```bash 29 | python -m unittest discover -s ./test/pre_trained_model_tests 30 | ``` 31 | 32 | ## Testing Guide 33 | 34 | Follow the below steps to run the tests. 35 | 36 | - To run all tests. 37 | 38 | ```bash 39 | python -m unittest discover -s ./test 40 | ``` 41 | 42 | - To run all tests in a particular folder use. 43 | 44 | ```bash 45 | python -m unittest discover -s ./test/unittests 46 | ``` 47 | 48 | - To run tests in a particular file use. 49 | 50 | ```bash 51 | python -m unittest test/unittests/op_tests/test_bmm.py 52 | ``` 53 | 54 | To filter out tests in a subset use. 55 | 56 | - `-k ""` to filter based on test names(file_name+class_name+function_name). 57 | 58 | Example: To run all test with "woq" in their name please use below command. 59 | 60 | ```bash 61 | python -m unittest discover -s ./test/unittests -k "woq" 62 | ``` 63 | 64 | - `-p ""` to filter based on file names. 65 | 66 | Example: To run tests in all the files with "test_mm" in their file name please use below command. 67 | 68 | ```bash 69 | python -m unittest discover -s ./test/unittests -p "test_mm*" 70 | ``` 71 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | -------------------------------------------------------------------------------- /test/install_requirements.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2023-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import importlib.util as importutil 7 | import torch 8 | import subprocess 9 | 10 | _all_ = ["transformers", "expecttest==0.1.6", "parameterized"] 11 | 12 | 13 | def install_package(cmd): 14 | """ 15 | Installs a package using the provided command. 16 | 17 | This function uses subprocess.Popen to run the given command in a new 18 | subprocess, which helps to avoid conflicts that might arise from running 19 | pip.main() within the current Python process. 20 | Args: 21 | cmd (str): The command to run for installing the package, typically a 22 | pip install command. 23 | Returns: 24 | tuple: A tuple containing the return code (int), standard output (str), 25 | and standard error (str). 26 | """ 27 | p1 = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) 28 | out, err = p1.communicate() 29 | rc = p1.returncode 30 | return rc, out.decode("ascii", "ignore"), err.decode("ascii", "ignore") 31 | 32 | 33 | def install(packages): 34 | for package in packages: 35 | rc, out, err = install_package("pip install %s" % package) 36 | if rc != 0: 37 | print("Issue while installing the package=%s" % package) 38 | print(err) 39 | exit(1) 40 | else: 41 | print(out) 42 | 43 | 44 | if __name__ == "__main__": 45 | 46 | install(_all_) 47 | extra_args = ( 48 | " --index-url https://download.pytorch.org/whl/cpu" 49 | if not torch.version.cuda 50 | else "" 51 | ) 52 | torch_version = torch.__version__ 53 | torch_version = torch_version.split("+")[0] 54 | torchvision_compatibilty = { 55 | # "2.0.0": "torchvision==0.15.0", 56 | # "2.0.1": "torchvision==0.15.2", 57 | # "2.1.0": "torchvision==0.16.0", 58 | # "2.1.1": "torchvision==0.16.1", 59 | # "2.1.2": "torchvision==0.16.2", 60 | "2.2.0": "torchvision==0.17.0", 61 | "2.2.1": "torchvision==0.17.1", 62 | "2.2.2": "torchvision==0.17.2", 63 | "2.3.0": "torchvision==0.18.0", 64 | "2.3.1": "torchvision==0.18.1", 65 | "2.4.0": "torchvision==0.19.0", 66 | "2.4.1": "torchvision==0.19.1", 67 | "2.5.0": "torchvision==0.20.0", 68 | "2.5.1": "torchvision==0.20.1", 69 | "2.6.0": "torchvision==0.21.0", 70 | } 71 | if importutil.find_spec("torchvision") is not None: 72 | print("Warning: Torchvision already installed, skipping installing it") 73 | exit(1) 74 | elif torch_version in torchvision_compatibilty.keys(): 75 | torchvision_cmd = [torchvision_compatibilty[torch_version] + extra_args] 76 | install(torchvision_cmd) 77 | else: 78 | print( 79 | "Couldnot find the valid torchvision version which is \ 80 | compatibility with installed torch version. Supported Torch versions \ 81 | are 2.2.*/2.3.*/2.4.*/2.5.*/2.6.*" 82 | ) 83 | exit(1) 84 | -------------------------------------------------------------------------------- /test/llm_tests/__init__.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | -------------------------------------------------------------------------------- /test/llm_tests/llm_utils.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import sys 7 | from pathlib import Path 8 | 9 | sys.path.append(str(Path(__file__).parent.parent)) 10 | from utils import ( # noqa: 402 # noqa: F401 11 | BaseZentorchTestCase, 12 | run_tests, 13 | zentorch, 14 | set_seed, 15 | skip_test_pt_2_3, 16 | freeze_opt, 17 | test_with_freeze_opt, 18 | Test_Data, 19 | ) 20 | 21 | 22 | class Zentorch_TestCase(BaseZentorchTestCase): 23 | def setUp(self): 24 | super().setUp() 25 | self.data = Test_Data() 26 | 27 | def tearDown(self): 28 | del self.data 29 | -------------------------------------------------------------------------------- /test/pre_trained_model_tests/__init__.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | -------------------------------------------------------------------------------- /test/pre_trained_model_tests/pre_trained_model_utils.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2023-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import sys 7 | from pathlib import Path 8 | 9 | sys.path.append(str(Path(__file__).parent.parent)) 10 | from utils import ( # noqa: 402 # noqa: F401 11 | BaseZentorchTestCase, 12 | run_tests, 13 | zentorch, 14 | has_zentorch, 15 | supported_dtypes, 16 | reset_dynamo, 17 | set_seed, 18 | freeze_opt, 19 | test_with_freeze_opt, 20 | Test_Data, 21 | ) 22 | 23 | 24 | class Zentorch_TestCase(BaseZentorchTestCase): 25 | def setUp(self): 26 | super().setUp() 27 | self.data = Test_Data() 28 | 29 | def tearDown(self): 30 | del self.data 31 | -------------------------------------------------------------------------------- /test/pre_trained_model_tests/test_bert.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import copy 7 | import unittest 8 | import torch 9 | from parameterized import parameterized 10 | from itertools import product 11 | from transformers import BertModel 12 | import sys 13 | from pathlib import Path 14 | 15 | sys.path.append(str(Path(__file__).parent)) 16 | from pre_trained_model_utils import ( # noqa: 402 17 | Zentorch_TestCase, 18 | has_zentorch, 19 | run_tests, 20 | supported_dtypes, 21 | reset_dynamo, 22 | set_seed, 23 | zentorch, 24 | freeze_opt, 25 | test_with_freeze_opt, 26 | ) 27 | 28 | 29 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 30 | class Test_Bert_Model(Zentorch_TestCase): 31 | @parameterized.expand(product(supported_dtypes, freeze_opt)) 32 | @torch.inference_mode() 33 | def test_bert_base_model(self, dtype, freeze_opt): 34 | self.skip_if_bfloat16_path_issue(dtype) 35 | self.data.create_pretrained_model_data(dtype) 36 | native_model = BertModel.from_pretrained("bert-large-uncased").eval() 37 | inductor_model = copy.deepcopy(native_model) 38 | zentorch_graph = torch.compile(native_model, backend="zentorch") 39 | zentorch_graph_output = test_with_freeze_opt( 40 | zentorch_graph, 41 | (self.data.input_tensor), 42 | freeze_opt 43 | ) 44 | reset_dynamo() 45 | inductor_graph = torch.compile(inductor_model, backend="inductor") 46 | inductor_graph_output = inductor_graph(self.data.input_tensor) 47 | 48 | self.assertEqual( 49 | zentorch_graph_output, inductor_graph_output, atol=1e-2, rtol=1e-5 50 | ) 51 | 52 | 53 | if __name__ == "__main__": 54 | run_tests() 55 | -------------------------------------------------------------------------------- /test/pre_trained_model_tests/test_cnn.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import copy 7 | import unittest 8 | import torch 9 | from parameterized import parameterized 10 | from itertools import product 11 | from torchvision import models 12 | import sys 13 | from pathlib import Path 14 | 15 | sys.path.append(str(Path(__file__).parent)) 16 | from pre_trained_model_utils import ( # noqa: 402 17 | Zentorch_TestCase, 18 | has_zentorch, 19 | run_tests, 20 | supported_dtypes, 21 | reset_dynamo, 22 | set_seed, 23 | zentorch, 24 | freeze_opt, 25 | test_with_freeze_opt, 26 | ) 27 | 28 | 29 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 30 | class Test_CNN_Model(Zentorch_TestCase): 31 | @parameterized.expand(product(supported_dtypes, freeze_opt)) 32 | @torch.inference_mode() 33 | def test_resnet18_model(self, dtype, freeze_opt): 34 | self.skip_if_bfloat16_path_issue(dtype) 35 | self.data.create_pretrained_model_data(dtype) 36 | model = models.__dict__["resnet18"](pretrained=True).eval() 37 | inductor_model = copy.deepcopy(model) 38 | zentorch_graph = torch.compile(model, backend="zentorch", dynamic=False) 39 | zentorch_graph_output = test_with_freeze_opt( 40 | zentorch_graph, 41 | (self.data.input3d), 42 | freeze_opt 43 | ) 44 | reset_dynamo() 45 | inductor_graph = torch.compile(inductor_model, backend="inductor") 46 | 47 | inductor_graph_output = inductor_graph(self.data.input3d) 48 | 49 | self.assertEqual(inductor_graph_output, zentorch_graph_output) 50 | 51 | 52 | if __name__ == "__main__": 53 | run_tests() 54 | -------------------------------------------------------------------------------- /test/unittests/__init__.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | -------------------------------------------------------------------------------- /test/unittests/miscellaneous_tests/__init__.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | -------------------------------------------------------------------------------- /test/unittests/miscellaneous_tests/test_avx512_device.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import sys 8 | from pathlib import Path 9 | 10 | sys.path.append(str(Path(__file__).parent.parent)) 11 | from unittest_utils import( # noqa: 402 12 | Zentorch_TestCase, 13 | has_zentorch, 14 | run_tests, 15 | zentorch, 16 | ) 17 | 18 | 19 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 20 | class Test_AVX512_Device(Zentorch_TestCase): 21 | @unittest.skipIf( 22 | not zentorch._C.is_avx512_supported(), 23 | "CPU does not support AVX512 instructions.", 24 | ) 25 | def test_avx512_device(self): 26 | self.assertTrue( 27 | zentorch._C.is_avx512_supported(), "CPU supports AVX512 instructions." 28 | ) 29 | 30 | 31 | if __name__ == "__main__": 32 | run_tests() 33 | -------------------------------------------------------------------------------- /test/unittests/miscellaneous_tests/test_bf16_device.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import sys 8 | from pathlib import Path 9 | 10 | sys.path.append(str(Path(__file__).parent.parent)) 11 | from unittest_utils import( # noqa: 402 12 | Zentorch_TestCase, 13 | has_zentorch, 14 | run_tests, 15 | zentorch, 16 | ) 17 | 18 | 19 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 20 | class Test_BF16_Device(Zentorch_TestCase): 21 | @unittest.skipIf( 22 | not zentorch._C.is_bf16_supported(), "CPU does not support AVX512 BF16." 23 | ) 24 | def test_bf16_device(self): 25 | self.assertTrue(zentorch._C.is_bf16_supported(), "CPU supports AVX512 BF16.") 26 | 27 | 28 | if __name__ == "__main__": 29 | run_tests() 30 | -------------------------------------------------------------------------------- /test/unittests/miscellaneous_tests/test_zentorch_version.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | from importlib import metadata 8 | import sys 9 | from pathlib import Path 10 | 11 | sys.path.append(str(Path(__file__).parent.parent)) 12 | from unittest_utils import( # noqa: 402 13 | Zentorch_TestCase, 14 | has_zentorch, 15 | run_tests, 16 | zentorch, 17 | ) 18 | 19 | 20 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 21 | class Test_ZenTorch_Version(Zentorch_TestCase): 22 | def test_zentorch_version(self): 23 | self.assertTrue(zentorch.__version__, metadata.version("zentorch")) 24 | 25 | 26 | if __name__ == "__main__": 27 | run_tests() 28 | -------------------------------------------------------------------------------- /test/unittests/model_tests/__init__.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | -------------------------------------------------------------------------------- /test/unittests/model_tests/test_addmm_1dbias_gelu.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | from itertools import product 10 | from torch import nn 11 | import sys 12 | from pathlib import Path 13 | 14 | sys.path.append(str(Path(__file__).parent.parent)) 15 | from unittest_utils import ( # noqa: 402 16 | Zentorch_TestCase, 17 | has_zentorch, 18 | reset_dynamo, 19 | run_tests, 20 | supported_dtypes, 21 | zentorch, 22 | freeze_opt, 23 | test_with_freeze_opt, 24 | ) 25 | 26 | 27 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 28 | class Test_Addmm_1dbias_Gelu_Model(Zentorch_TestCase): 29 | @parameterized.expand(product(supported_dtypes, freeze_opt)) 30 | @torch.inference_mode() 31 | def test_addmm_1dbias_gelu_tanh_model(self, dtype, freeze_opt): 32 | 33 | self.data.create_unittest_data(dtype) 34 | model = nn.Sequential( 35 | nn.Linear(self.data.n, self.data.m), nn.GELU(approximate="tanh") 36 | ) 37 | if dtype == "bfloat16": 38 | model = model.bfloat16() 39 | model_output = model(self.data.input) 40 | reset_dynamo() 41 | compiled_graph = torch.compile(model, backend="zentorch") 42 | compiled_graph_output = test_with_freeze_opt( 43 | compiled_graph, 44 | (self.data.input), 45 | freeze_opt 46 | ) 47 | self.assertEqual(model_output, compiled_graph_output) 48 | 49 | @parameterized.expand(product(supported_dtypes, freeze_opt)) 50 | @torch.inference_mode() 51 | def test_addmm_1dbias_gelu_none_model(self, dtype, freeze_opt): 52 | 53 | self.data.create_unittest_data(dtype) 54 | model = nn.Sequential( 55 | nn.Linear(self.data.n, self.data.m), nn.GELU(approximate="none") 56 | ) 57 | if dtype == "bfloat16": 58 | model = model.bfloat16() 59 | model_output = model(self.data.input) 60 | reset_dynamo() 61 | compiled_graph = torch.compile(model, backend="zentorch") 62 | compiled_graph_output = test_with_freeze_opt( 63 | compiled_graph, 64 | (self.data.input), 65 | freeze_opt 66 | ) 67 | self.assertEqual(model_output, compiled_graph_output) 68 | 69 | 70 | if __name__ == "__main__": 71 | run_tests() 72 | -------------------------------------------------------------------------------- /test/unittests/model_tests/test_addmm_1dbias_relu.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | from torch import nn 10 | from torch.fx.experimental.proxy_tensor import make_fx 11 | import sys 12 | from pathlib import Path 13 | 14 | sys.path.append(str(Path(__file__).parent.parent)) 15 | from unittest_utils import ( # noqa: 402 16 | Zentorch_TestCase, 17 | has_zentorch, 18 | run_tests, 19 | supported_dtypes, 20 | zentorch, 21 | ) 22 | 23 | 24 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 25 | class Test_Addmm_1dbias_Relu_Model(Zentorch_TestCase): 26 | @parameterized.expand(supported_dtypes) 27 | @torch.inference_mode() 28 | def test_addmm_1dbias_relu_model(self, dtype): 29 | self.data.create_unittest_data(dtype) 30 | model = nn.Sequential(nn.Linear(self.data.n, self.data.m), nn.ReLU()) 31 | if dtype == "bfloat16": 32 | model = model.bfloat16() 33 | fx_g = make_fx(model)(self.data.input) 34 | fx_g_modified = zentorch.optimize(fx_g) 35 | fx_g_output = fx_g(self.data.input) 36 | fx_g_modified_output = fx_g_modified(self.data.input) 37 | self.assertEqual(fx_g_output, fx_g_modified_output) 38 | for node in fx_g_modified.graph.nodes: 39 | if isinstance(node.target, torch._ops.OpOverload): 40 | if node.target.name() in ["aten::addmm"]: 41 | self.assertEqual( 42 | node.target, torch.ops.zentorch.zentorch_addmm_1dbias 43 | ) 44 | 45 | 46 | if __name__ == "__main__": 47 | run_tests() 48 | -------------------------------------------------------------------------------- /test/unittests/model_tests/test_baddbmm.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | from itertools import product 10 | from torch import nn 11 | import sys 12 | from pathlib import Path 13 | 14 | sys.path.append(str(Path(__file__).parent.parent)) 15 | from unittest_utils import ( # noqa: 402 16 | Zentorch_TestCase, 17 | has_zentorch, 18 | reset_dynamo, 19 | run_tests, 20 | supported_dtypes, 21 | zentorch, 22 | freeze_opt, 23 | test_with_freeze_opt, 24 | ) 25 | 26 | 27 | class Custom_Model_Baddbmm(nn.Module): 28 | def __init__(self): 29 | super(Custom_Model_Baddbmm, self).__init__() 30 | 31 | def forward(self, input, batch1, batch2): 32 | bmm_res = torch.bmm(batch1, batch2) 33 | add_res = torch.add(bmm_res, input) 34 | baddbmm_res = torch.baddbmm(add_res, batch1, batch2, beta=1.5, alpha=1.4) 35 | return baddbmm_res 36 | 37 | 38 | class Custom_Model_Baddbmm_Unsupport(nn.Module): 39 | def __init__(self): 40 | super(Custom_Model_Baddbmm_Unsupport, self).__init__() 41 | 42 | def forward(self, input, batch1, batch2): 43 | bmm_res = torch.bmm(batch1, batch2) 44 | add_res = torch.add(bmm_res, input) 45 | return add_res 46 | 47 | 48 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 49 | class Test_Baddbmm_Model(Zentorch_TestCase): 50 | @parameterized.expand(product(supported_dtypes, freeze_opt)) 51 | @torch.inference_mode() 52 | def test_baddbmm_model(self, dtype, freeze_opt): 53 | self.skip_if_bfloat16_path_issue(dtype) 54 | self.data.create_unittest_data(dtype) 55 | model = Custom_Model_Baddbmm().eval() 56 | for i in range(len(self.data.x2)): 57 | for j in range(len(self.data.y2)): 58 | model_output = model(self.data.M2, self.data.x2[i], self.data.y2[j]) 59 | reset_dynamo() 60 | compiled_graph = torch.compile(model, backend="zentorch") 61 | compiled_graph_output = test_with_freeze_opt( 62 | compiled_graph, 63 | (self.data.M2, self.data.x2[i], self.data.y2[j]), 64 | freeze_opt 65 | ) 66 | self.assertEqual( 67 | model_output, compiled_graph_output, atol=1e-5, rtol=1e-3 68 | ) 69 | 70 | @parameterized.expand(product(supported_dtypes, freeze_opt)) 71 | @torch.inference_mode() 72 | def test_baddbmm_unsupport_model(self, dtype, freeze_opt): 73 | self.skip_if_bfloat16_path_issue(dtype) 74 | self.data.create_unittest_data(dtype) 75 | model = Custom_Model_Baddbmm_Unsupport().eval() 76 | model_output = model(self.data.M3, self.data.x2[0], self.data.y2[0]) 77 | reset_dynamo() 78 | compiled_graph = torch.compile(model, backend="zentorch") 79 | compiled_graph_output = test_with_freeze_opt( 80 | compiled_graph, 81 | (self.data.M3, self.data.x2[0], self.data.y2[0]), 82 | freeze_opt 83 | ) 84 | self.assertEqual(model_output, compiled_graph_output, atol=1e-5, rtol=1e-3) 85 | 86 | 87 | if __name__ == "__main__": 88 | run_tests() 89 | -------------------------------------------------------------------------------- /test/unittests/model_tests/test_convolution.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | from torch import nn 10 | import sys 11 | from pathlib import Path 12 | from itertools import product 13 | 14 | sys.path.append(str(Path(__file__).parent.parent)) 15 | from unittest_utils import ( # noqa: 402 16 | Zentorch_TestCase, 17 | has_zentorch, 18 | reset_dynamo, 19 | run_tests, 20 | supported_dtypes, 21 | zentorch, 22 | freeze_opt, 23 | test_with_freeze_opt, 24 | ) 25 | 26 | 27 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 28 | class Custom_Model_Convolution(nn.Module): 29 | def __init__(self): 30 | super(Custom_Model_Convolution, self).__init__() 31 | self.convolution = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1) 32 | 33 | def forward(self, input): 34 | output = self.convolution(input) 35 | return output 36 | 37 | 38 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 39 | class Test_Convolution_Model(Zentorch_TestCase): 40 | @parameterized.expand(product(supported_dtypes, freeze_opt)) 41 | @torch.inference_mode() 42 | def test_convolution_model(self, dtype, freeze_opt): 43 | self.data.create_unittest_data(dtype) 44 | model = Custom_Model_Convolution() 45 | if dtype == "bfloat16": 46 | model = model.to(torch.bfloat16) 47 | model_output = model(self.data.conv_input) 48 | reset_dynamo() 49 | zentorch_graph = torch.compile(model, backend="zentorch") 50 | zentorch_graph_output = test_with_freeze_opt( 51 | zentorch_graph, 52 | (self.data.conv_input), 53 | freeze_opt 54 | ) 55 | self.assertEqual( 56 | model_output, 57 | zentorch_graph_output, 58 | ) 59 | 60 | 61 | if __name__ == "__main__": 62 | run_tests() 63 | -------------------------------------------------------------------------------- /test/unittests/model_tests/test_embedding.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | from itertools import product 10 | from torch import nn 11 | import sys 12 | from pathlib import Path 13 | 14 | sys.path.append(str(Path(__file__).parent.parent)) 15 | from unittest_utils import ( # noqa: 402 16 | Zentorch_TestCase, 17 | has_zentorch, 18 | reset_dynamo, 19 | run_tests, 20 | supported_dtypes, 21 | zentorch, 22 | freeze_opt, 23 | test_with_freeze_opt, 24 | ) 25 | 26 | 27 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 28 | class Custom_Model_Embedding(nn.Module): 29 | def __init__(self, embedding_dim, dtype=torch.float): 30 | super(Custom_Model_Embedding, self).__init__() 31 | self.embedding = nn.Embedding(10000, embedding_dim, dtype=dtype) 32 | 33 | def forward(self, input): 34 | embed = self.embedding(input) 35 | return embed 36 | 37 | 38 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 39 | class Test_Embedding_Model(Zentorch_TestCase): 40 | @parameterized.expand(product(supported_dtypes, freeze_opt)) 41 | @torch.inference_mode() 42 | def test_embedding_compile_model(self, dtype, freeze_opt): 43 | new_dtype = self.data.get_torch_type(dtype) 44 | model = Custom_Model_Embedding(100, dtype=new_dtype) 45 | input = torch.randint(0, 10000, (10,)) 46 | model_output = model(input) 47 | reset_dynamo() 48 | compiled_graph = torch.compile(model, backend="zentorch") 49 | compiled_graph_output = test_with_freeze_opt( 50 | compiled_graph, 51 | (input), 52 | freeze_opt 53 | ) 54 | self.assertEqual(model_output, compiled_graph_output) 55 | 56 | 57 | if __name__ == "__main__": 58 | run_tests() 59 | -------------------------------------------------------------------------------- /test/unittests/model_tests/test_embedding_bag.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | from itertools import product 10 | from torch import nn 11 | import sys 12 | from pathlib import Path 13 | 14 | sys.path.append(str(Path(__file__).parent.parent)) 15 | from unittest_utils import ( # noqa: 402 16 | Zentorch_TestCase, 17 | has_zentorch, 18 | reset_dynamo, 19 | run_tests, 20 | supported_dtypes, 21 | zentorch, 22 | freeze_opt, 23 | test_with_freeze_opt, 24 | ) 25 | 26 | 27 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 28 | class Custom_Model_Embedding_Bag(nn.Module): 29 | def __init__(self, embedding_dim, output_dim, dtype=torch.float): 30 | super(Custom_Model_Embedding_Bag, self).__init__() 31 | self.embedding = nn.EmbeddingBag(10000, embedding_dim, dtype=dtype) 32 | self.intermediate = nn.Linear(embedding_dim, output_dim, dtype=dtype) 33 | self.output = nn.Linear(output_dim, 1, dtype=dtype) 34 | 35 | def forward(self, input): 36 | embed = self.embedding(input) 37 | intermediate = self.intermediate(embed) 38 | output = self.output(intermediate) 39 | return output 40 | 41 | 42 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 43 | class Test_Embedding_Bag_Model(Zentorch_TestCase): 44 | @parameterized.expand(product(supported_dtypes, freeze_opt)) 45 | @torch.inference_mode() 46 | def test_embedding_bag_compile_model(self, dtype, freeze_opt): 47 | new_dtype = self.data.get_torch_type(dtype) 48 | model = Custom_Model_Embedding_Bag(100, 10, dtype=new_dtype) 49 | input = torch.randint(0, 10000, (1, 10)) 50 | model_output = model(input) 51 | reset_dynamo() 52 | compiled_graph = torch.compile(model, backend="zentorch") 53 | compiled_graph_output = test_with_freeze_opt( 54 | compiled_graph, 55 | (input), 56 | freeze_opt 57 | ) 58 | # TODO 59 | # Increased tolerent for bfloat16 dtype by atol=1e-03, rtol=0.01 60 | # Getting failure due to higer diff than allowed 61 | # Change will restore after fix 62 | # ZENAI-858 63 | self.assertEqual(model_output, compiled_graph_output, atol=1e-03, rtol=0.01) 64 | 65 | 66 | if __name__ == "__main__": 67 | run_tests() 68 | -------------------------------------------------------------------------------- /test/unittests/model_tests/test_group_embeded_ops_with_sum_ops.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | from itertools import product 10 | from torch import nn 11 | import sys 12 | from pathlib import Path 13 | 14 | sys.path.append(str(Path(__file__).parent.parent)) 15 | from unittest_utils import ( # noqa: 402 16 | Zentorch_TestCase, 17 | has_zentorch, 18 | reset_dynamo, 19 | run_tests, 20 | supported_dtypes, 21 | zentorch, 22 | freeze_opt, 23 | test_with_freeze_opt, 24 | ) 25 | 26 | 27 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 28 | class Custom_Model_Embedding_Bag_Sum_nodes(nn.Module): 29 | def __init__(self, num_embeddings): 30 | super(Custom_Model_Embedding_Bag_Sum_nodes, self).__init__() 31 | self.eb_bags_grp = [ 32 | torch.nn.EmbeddingBag(num_embeddings, 3, mode="sum") for _ in range(10) 33 | ] 34 | 35 | def forward(self, eb_input, eb_offset): 36 | outputs_grp = [op(eb_input, eb_offset) for op in self.eb_bags_grp] 37 | 38 | outputs_grp[5] = torch.sum(outputs_grp[5], dim=1, keepdim=True) 39 | outputs_grp[6] = torch.sum(outputs_grp[6], dim=1, keepdim=True) 40 | 41 | output = torch.sum(torch.cat(outputs_grp, dim=1), dim=0) 42 | 43 | return output 44 | 45 | 46 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 47 | class Custom_Model_Embedding_Sum_nodes(nn.Module): 48 | def __init__(self, num_embeddings): 49 | super(Custom_Model_Embedding_Sum_nodes, self).__init__() 50 | self.emebdding_grp = [torch.nn.Embedding(num_embeddings, 3) for _ in range(10)] 51 | 52 | def forward(self, inputs): 53 | outputs_grp = [op(inputs) for op in self.emebdding_grp] 54 | 55 | outputs_grp[3] = torch.sum(outputs_grp[3], dim=1, keepdim=True) 56 | outputs_grp[5] = torch.sum(outputs_grp[3], dim=1, keepdim=True) 57 | 58 | output = torch.sum(torch.cat(outputs_grp, dim=1), dim=0) 59 | 60 | return output 61 | 62 | 63 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 64 | # Testing revealed one of the corner cases where the common output node can 65 | # have heterogeneous nodes like embedding1, embedding2, sum1, sum2, embedding3. 66 | # To test the above scenario, the following testcases are added. 67 | # Both the group ops are being tested here, with the heterogeneous op being sum 68 | class Test_Group_Embeded_Ops_With_Sum_Ops_Model(Zentorch_TestCase): 69 | @parameterized.expand(product(supported_dtypes, freeze_opt)) 70 | @torch.inference_mode() 71 | def test_group_eb_with_sum_model(self, dtype, freeze_opt): 72 | self.data.create_unittest_data(dtype) 73 | 74 | indices = self.data.emb_input 75 | offsets = self.data.offsets 76 | 77 | model = Custom_Model_Embedding_Bag_Sum_nodes(self.data.R) 78 | 79 | native_output = model(indices, offsets) 80 | reset_dynamo() 81 | compiled_graph = torch.compile(model, backend="zentorch") 82 | compiled_output = test_with_freeze_opt( 83 | compiled_graph, 84 | (indices, offsets), 85 | freeze_opt 86 | ) 87 | self.assertEqual(native_output, compiled_output) 88 | 89 | @parameterized.expand(product(supported_dtypes, freeze_opt)) 90 | @torch.inference_mode() 91 | def test_group_embedding_with_sum_model(self, dtype, freeze_opt): 92 | self.data.create_unittest_data(dtype) 93 | indices = self.data.emb_input 94 | model = Custom_Model_Embedding_Sum_nodes(self.data.R) 95 | native_output = model(indices) 96 | reset_dynamo() 97 | compiled_graph = torch.compile(model, backend="zentorch") 98 | compiled_output = test_with_freeze_opt( 99 | compiled_graph, 100 | (indices), 101 | freeze_opt 102 | ) 103 | self.assertEqual(native_output, compiled_output) 104 | 105 | 106 | if __name__ == "__main__": 107 | run_tests() 108 | -------------------------------------------------------------------------------- /test/unittests/model_tests/test_horizontal_embedding_bag_group.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | from itertools import product 10 | from torch import nn 11 | from torch.fx.experimental.proxy_tensor import make_fx 12 | import sys 13 | from pathlib import Path 14 | 15 | sys.path.append(str(Path(__file__).parent.parent)) 16 | from unittest_utils import ( # noqa: 402 17 | Zentorch_TestCase, 18 | has_zentorch, 19 | reset_dynamo, 20 | run_tests, 21 | supported_dtypes, 22 | zentorch, 23 | freeze_opt, 24 | test_with_freeze_opt, 25 | ) 26 | 27 | 28 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 29 | class Custom_Model_Embedding_Bag_Group(nn.Module): 30 | def __init__(self, num_embeddings): 31 | super(Custom_Model_Embedding_Bag_Group, self).__init__() 32 | self.eb_bags_grp_0 = [torch.nn.EmbeddingBag(num_embeddings, 3, mode="sum")] * 5 33 | self.eb_bags_grp_1 = [torch.nn.EmbeddingBag(num_embeddings, 3, mode="sum")] * 10 34 | self.eb_bags_grp_2 = [torch.nn.EmbeddingBag(num_embeddings, 3, mode="sum")] * 6 35 | 36 | def forward(self, eb_input, eb_offset): 37 | eb_outputs_grp_0 = [ 38 | self.eb_bags_grp_0[i](eb_input, eb_offset) for i in range(5) 39 | ] 40 | concat_eb_tensors_0 = torch.cat(eb_outputs_grp_0) 41 | 42 | eb_outputs_grp_1 = [ 43 | self.eb_bags_grp_1[i](eb_input, eb_offset) for i in range(10) 44 | ] 45 | concat_eb_tensors_1 = torch.cat(eb_outputs_grp_1) 46 | 47 | eb_outputs_grp_2 = [ 48 | self.eb_bags_grp_2[i](eb_input, eb_offset) for i in range(6) 49 | ] 50 | concat_eb_tensors_2 = torch.cat(eb_outputs_grp_2) 51 | 52 | output = torch.cat( 53 | [concat_eb_tensors_0, concat_eb_tensors_1, concat_eb_tensors_2] 54 | ) 55 | 56 | return output 57 | 58 | 59 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 60 | class Test_Embedding_Bag_Group_Model(Zentorch_TestCase): 61 | @parameterized.expand(supported_dtypes) 62 | @torch.inference_mode() 63 | def test_embedding_bag_group_model(self, dtype): 64 | self.data.create_unittest_data(dtype) 65 | model = Custom_Model_Embedding_Bag_Group(self.data.R) 66 | indices = self.data.emb_input 67 | offsets = self.data.offsets 68 | fx_g = make_fx(model)(indices, offsets) 69 | fx_g_output = fx_g(indices, offsets) 70 | fx_g_optimized = zentorch.optimize(fx_g) 71 | fx_g_optimized_output = fx_g_optimized(indices, offsets) 72 | self.assertEqual(fx_g_output, fx_g_optimized_output) 73 | target = torch.ops.zentorch.zentorch_horizontal_embedding_bag_group.default 74 | group_eb_count = 0 75 | for node in fx_g_optimized.graph.nodes: 76 | if isinstance(node.target, torch._ops.OpOverload) and node.target == target: 77 | group_eb_count += 1 78 | 79 | self.assertEqual(group_eb_count, 3) 80 | 81 | @parameterized.expand(product(supported_dtypes, freeze_opt)) 82 | @torch.inference_mode() 83 | def test_embedding_bag_group_compile_model(self, dtype, freeze_opt): 84 | self.data.create_unittest_data(dtype) 85 | model = Custom_Model_Embedding_Bag_Group(self.data.R) 86 | indices = self.data.emb_input 87 | offset = self.data.offsets 88 | native_output = model(indices, offset) 89 | reset_dynamo() 90 | compiled_graph = torch.compile(model, backend="zentorch") 91 | compiled_output = test_with_freeze_opt( 92 | compiled_graph, 93 | (indices, offset), 94 | freeze_opt 95 | ) 96 | self.assertEqual(native_output, compiled_output) 97 | 98 | 99 | if __name__ == "__main__": 100 | run_tests() 101 | -------------------------------------------------------------------------------- /test/unittests/model_tests/test_horizontal_embedding_bag_group_addmm_1dbias.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | from itertools import product 10 | from torch import nn 11 | import sys 12 | from pathlib import Path 13 | 14 | sys.path.append(str(Path(__file__).parent.parent)) 15 | from unittest_utils import ( # noqa: 402 16 | Zentorch_TestCase, 17 | has_zentorch, 18 | reset_dynamo, 19 | run_tests, 20 | supported_dtypes, 21 | zentorch, 22 | freeze_opt, 23 | test_with_freeze_opt, 24 | ) 25 | 26 | 27 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 28 | class Custom_Model_Group_Embedding_Bag_Addmm_1dbias(nn.Module): 29 | def __init__(self, num_embeddings, k): 30 | super(Custom_Model_Group_Embedding_Bag_Addmm_1dbias, self).__init__() 31 | self.eb_bags_grp = [torch.nn.EmbeddingBag(num_embeddings, 3)] * 3 32 | self.mlp_0 = torch.nn.Linear(k, 12) 33 | self.mlp_1 = torch.nn.Linear(12, 6) 34 | self.mlp_2 = torch.nn.Linear(6, 3) 35 | 36 | def forward(self, eb_input, eb_offset, mlp_input): 37 | eb_grp_outputs = [self.eb_bags_grp[i](eb_input, eb_offset) for i in range(3)] 38 | mlp_output = self.mlp_0(mlp_input) 39 | mlp_output = self.mlp_1(mlp_output) 40 | mlp_output = self.mlp_2(mlp_output) 41 | 42 | outputs = eb_grp_outputs + [mlp_output] 43 | outputs = torch.cat(outputs, dim=1) 44 | 45 | return outputs 46 | 47 | 48 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 49 | class Custom_Model_Group_Addmm_1dbias_Embedding_Bag(nn.Module): 50 | def __init__(self, num_embeddings, k): 51 | super(Custom_Model_Group_Addmm_1dbias_Embedding_Bag, self).__init__() 52 | self.eb_bags_grp = [torch.nn.EmbeddingBag(num_embeddings, 3)] * 3 53 | self.mlp_0 = torch.nn.Linear(k, 12) 54 | self.mlp_1 = torch.nn.Linear(12, 6) 55 | self.mlp_2 = torch.nn.Linear(6, 3) 56 | 57 | def forward(self, eb_input, eb_offset, mlp_input): 58 | mlp_output = self.mlp_0(mlp_input) 59 | mlp_output = self.mlp_1(mlp_output) 60 | mlp_output = self.mlp_2(mlp_output) 61 | 62 | eb_grp_outputs = [self.eb_bags_grp[i](eb_input, eb_offset) for i in range(3)] 63 | 64 | outputs = eb_grp_outputs + [mlp_output] 65 | outputs = torch.cat(outputs, dim=1) 66 | 67 | return outputs 68 | 69 | 70 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 71 | class Test_Group_Embedding_Bad_Addmm_1dbias_Model(Zentorch_TestCase): 72 | @parameterized.expand(product(supported_dtypes, freeze_opt)) 73 | @torch.inference_mode() 74 | def test_group_embedding_bag_addmm_1dbias_model(self, dtype, freeze_opt): 75 | self.data.create_unittest_data(dtype) 76 | indices = self.data.emb_input 77 | offsets = self.data.offsets 78 | mlp_inputs = self.data.mlp_inputs 79 | model = Custom_Model_Group_Embedding_Bag_Addmm_1dbias(self.data.R, self.data.k) 80 | native_output = model(indices, offsets, mlp_inputs) 81 | reset_dynamo() 82 | compiled_graph = torch.compile(model, backend="zentorch") 83 | compiled_output = test_with_freeze_opt( 84 | compiled_graph, (indices, offsets, mlp_inputs), freeze_opt 85 | ) 86 | self.assertEqual(native_output, compiled_output) 87 | 88 | @parameterized.expand(product(supported_dtypes, freeze_opt)) 89 | @torch.inference_mode() 90 | def test_group_addmm_1dbias_embedding_bag_model(self, dtype, freeze_opt): 91 | self.data.create_unittest_data(dtype) 92 | indices = self.data.emb_input 93 | offsets = self.data.offsets 94 | mlp_inputs = self.data.mlp_inputs 95 | model = Custom_Model_Group_Addmm_1dbias_Embedding_Bag(self.data.R, self.data.k) 96 | native_output = model(indices, offsets, mlp_inputs) 97 | reset_dynamo() 98 | compiled_graph = torch.compile(model, backend="zentorch") 99 | compiled_output = test_with_freeze_opt( 100 | compiled_graph, (indices, offsets, mlp_inputs), freeze_opt 101 | ) 102 | self.assertEqual(native_output, compiled_output) 103 | 104 | 105 | if __name__ == "__main__": 106 | run_tests() 107 | -------------------------------------------------------------------------------- /test/unittests/model_tests/test_horizontal_embedding_bag_group_addmm_1dbias_relu.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | from itertools import product 10 | import sys 11 | from pathlib import Path 12 | 13 | sys.path.append(str(Path(__file__).parent.parent)) 14 | from unittest_utils import ( # noqa: 402 15 | Zentorch_TestCase, 16 | has_zentorch, 17 | reset_dynamo, 18 | run_tests, 19 | supported_dtypes, 20 | zentorch, 21 | freeze_opt, 22 | test_with_freeze_opt, 23 | ) 24 | 25 | 26 | class Custom_Model_Group_Embedding_Bag_Addmm_1dbias_Relu(torch.nn.Module): 27 | def __init__(self, num_embeddings, k): 28 | super(Custom_Model_Group_Embedding_Bag_Addmm_1dbias_Relu, self).__init__() 29 | # Common Nodes 30 | self.relu = torch.nn.ReLU() 31 | self.sigmoid = torch.nn.Sigmoid() 32 | 33 | self.eb_bags = [torch.nn.EmbeddingBag(num_embeddings, 3)] * 2 34 | 35 | self.bmlp_0 = torch.nn.Linear(k, 4) 36 | self.bmlp_1 = torch.nn.Linear(4, 4) 37 | self.bmlp_2 = torch.nn.Linear(4, 3) 38 | 39 | self.tmlp_0 = torch.nn.Linear(12, 4) 40 | self.tmlp_1 = torch.nn.Linear(4, 2) 41 | self.tmlp_2 = torch.nn.Linear(2, 2) 42 | self.tmlp_3 = torch.nn.Linear(2, 1) 43 | 44 | def forward(self, eb_inputs, eb_offsets, mlp_inputs): 45 | 46 | outputs = [] 47 | 48 | for _ in range(3): 49 | eb_outputs = [eb_op(eb_inputs, eb_offsets) for eb_op in self.eb_bags] 50 | 51 | mlp_outputs = self.bmlp_0(mlp_inputs) 52 | mlp_outputs = self.relu(mlp_outputs) 53 | mlp_outputs = self.bmlp_1(mlp_outputs) 54 | mlp_outputs = self.relu(mlp_outputs) 55 | mlp_outputs = self.bmlp_2(mlp_outputs) 56 | mlp_outputs = self.relu(mlp_outputs) 57 | 58 | interaction_input = eb_outputs + [mlp_outputs] 59 | interaction_output = torch.concat(interaction_input, dim=1) 60 | 61 | tmlp_input = torch.concat([mlp_outputs, interaction_output], dim=1) 62 | 63 | tmlp_outputs = self.tmlp_0(tmlp_input) 64 | tmlp_outputs = self.relu(tmlp_outputs) 65 | tmlp_outputs = self.tmlp_1(tmlp_outputs) 66 | tmlp_outputs = self.relu(tmlp_outputs) 67 | tmlp_outputs = self.tmlp_2(tmlp_outputs) 68 | tmlp_outputs = self.relu(tmlp_outputs) 69 | tmlp_outputs = self.tmlp_3(tmlp_outputs) 70 | tmlp_outputs = self.sigmoid(tmlp_outputs) 71 | 72 | outputs.append(tmlp_outputs) 73 | 74 | return outputs 75 | 76 | 77 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 78 | class Test_Group_Embedding_Bag_Addmm_1dbias_Relu_Model(Zentorch_TestCase): 79 | @parameterized.expand(product(supported_dtypes, freeze_opt)) 80 | @torch.inference_mode() 81 | def test_group_embedding_bag_addmm_1dbias_relu_model(self, dtype, freeze_opt): 82 | self.data.create_unittest_data(dtype) 83 | indices = self.data.emb_input 84 | offsets = self.data.offsets 85 | mlp_inputs = self.data.mlp_inputs 86 | model = Custom_Model_Group_Embedding_Bag_Addmm_1dbias_Relu( 87 | self.data.R, self.data.k 88 | ) 89 | native_output = model(indices, offsets, mlp_inputs) 90 | reset_dynamo() 91 | compiled_model = torch.compile(model, backend="zentorch") 92 | compiled_output = test_with_freeze_opt( 93 | compiled_model, 94 | (indices, offsets, mlp_inputs), 95 | freeze_opt 96 | ) 97 | self.assertEqual(native_output, compiled_output) 98 | 99 | 100 | if __name__ == "__main__": 101 | run_tests() 102 | -------------------------------------------------------------------------------- /test/unittests/model_tests/test_mini_mha.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import sys 8 | from pathlib import Path 9 | 10 | sys.path.append(str(Path(__file__).parent.parent.parent)) 11 | from unittests.unittest_utils import ( # noqa: 402 12 | Zentorch_TestCase, 13 | run_tests, 14 | skip_test_pt_2_3, 15 | ) 16 | from llm_tests.test_masked_mha import Test_Masked_MHA # noqa: 402 17 | 18 | 19 | @unittest.skipIf( 20 | skip_test_pt_2_3, "Skipping test as OP support available from PyTorch 2.3" 21 | ) 22 | class Test_MHA_Model(Zentorch_TestCase): 23 | def setUp(self): 24 | super().setUp() 25 | self.mha = Test_Masked_MHA() 26 | self.beam_size = 1 27 | self.batch_size = 1 28 | self.head_size = 256 29 | self.head_num = 16 30 | self.head_num_kv = 1 31 | self.max_seq_len = 64 32 | self.first_seq_len = 32 33 | 34 | def tearDown(self): 35 | del self.mha 36 | super().tearDown() 37 | 38 | def test_mha_model(self): 39 | self.mha._test_mha( 40 | self.beam_size, 41 | self.batch_size, 42 | self.head_size, 43 | self.head_num, 44 | self.head_num_kv, 45 | self.max_seq_len, 46 | self.first_seq_len, 47 | ) 48 | 49 | 50 | if __name__ == "__main__": 51 | run_tests() 52 | -------------------------------------------------------------------------------- /test/unittests/model_tests/test_mm_silu.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | from itertools import product 10 | from torch import nn 11 | import sys 12 | from pathlib import Path 13 | 14 | sys.path.append(str(Path(__file__).parent.parent)) 15 | from unittest_utils import ( # noqa: 402 16 | Zentorch_TestCase, 17 | has_zentorch, 18 | reset_dynamo, 19 | run_tests, 20 | supported_dtypes, 21 | zentorch, 22 | freeze_opt, 23 | test_with_freeze_opt, 24 | ) 25 | 26 | 27 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 28 | class Test_MM_Silu_Model(Zentorch_TestCase): 29 | @parameterized.expand(product(supported_dtypes, freeze_opt)) 30 | @torch.inference_mode() 31 | def test_mm_with_bias_silu_model(self, dtype, freeze_opt): 32 | self.data.create_unittest_data(dtype) 33 | model = nn.Sequential(nn.Linear(self.data.n, self.data.m, bias=True), nn.SiLU()) 34 | if dtype == "bfloat16": 35 | model = model.bfloat16() 36 | model_output = model(self.data.input) 37 | reset_dynamo() 38 | compiled_graph = torch.compile(model, backend="zentorch") 39 | compiled_graph_output = test_with_freeze_opt( 40 | compiled_graph, 41 | (self.data.input), 42 | freeze_opt 43 | ) 44 | self.assertEqual(model_output, compiled_graph_output) 45 | 46 | @parameterized.expand(product(supported_dtypes, freeze_opt)) 47 | @torch.inference_mode() 48 | def test_mm_without_bias_silu_model(self, dtype, freeze_opt): 49 | self.data.create_unittest_data(dtype) 50 | model = nn.Sequential( 51 | nn.Linear(self.data.n, self.data.m, bias=False), nn.SiLU() 52 | ) 53 | if dtype == "bfloat16": 54 | model = model.bfloat16() 55 | model_output = model(self.data.input) 56 | reset_dynamo() 57 | compiled_graph = torch.compile(model, backend="zentorch") 58 | compiled_graph_output = test_with_freeze_opt( 59 | compiled_graph, 60 | (self.data.input), 61 | freeze_opt 62 | ) 63 | self.assertEqual(model_output, compiled_graph_output) 64 | 65 | 66 | if __name__ == "__main__": 67 | run_tests() 68 | -------------------------------------------------------------------------------- /test/unittests/model_tests/test_quant_embedding_bag.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from torch import nn 9 | import sys 10 | from pathlib import Path 11 | from parameterized import parameterized 12 | 13 | sys.path.append(str(Path(__file__).parent.parent)) 14 | from unittest_utils import ( # noqa: 402 15 | Zentorch_TestCase, 16 | has_zentorch, 17 | zentorch, 18 | run_tests, 19 | supported_dtypes, 20 | ) 21 | 22 | 23 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 24 | class Custom_Model_Quant_Embedding_Group(nn.Module): 25 | def __init__(self): 26 | super(Custom_Model_Quant_Embedding_Group, self).__init__() 27 | 28 | def forward(self, weights, indices, offsets, cat_input, output_dtype): 29 | eb1 = torch.ops.zentorch.zentorch_quant_embedding_bag( 30 | weights, 31 | indices, 32 | offsets, 33 | 4, # assumes that weights has been quantized to uint4 hence 4 bits 34 | output_dtype, 35 | scale_grad_by_freq=False, 36 | mode=0, 37 | sparse=False, 38 | per_sample_weights=None, 39 | include_last_offset=False, 40 | padding_idx=-1, 41 | ) 42 | 43 | eb2 = torch.ops.zentorch.zentorch_quant_embedding_bag( 44 | weights, 45 | indices, 46 | offsets, 47 | 4, # assumes that weights has been quantized to uint4 hence 4 bits 48 | output_dtype, 49 | scale_grad_by_freq=False, 50 | mode=0, 51 | sparse=False, 52 | per_sample_weights=None, 53 | include_last_offset=False, 54 | padding_idx=-1, 55 | ) 56 | res = torch.cat([eb1, eb2, cat_input], dim=1) 57 | return res 58 | 59 | 60 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 61 | class Test_WOQ_Embedding_Bag_Group(Zentorch_TestCase): 62 | 63 | @parameterized.expand(supported_dtypes) 64 | @torch.inference_mode() 65 | def test_quant_embedding_bag_group(self, dtype): 66 | torch_type = self.data.get_torch_type(dtype) 67 | weight = torch.randint(low=0, high=15, size=(4, 16), dtype=torch_type) 68 | indices = torch.tensor([1, 2, 3], dtype=torch.long) 69 | offsets = torch.tensor([0, 2], dtype=torch.long) 70 | # TODO Check with zendnn for decimal place rounding 71 | scales = torch.rand(weight.size(0), 1).round(decimals=2) 72 | zero_points = torch.tensor([0, 0, 0, 0], dtype=torch.int32) 73 | dequant_weight = weight * scales 74 | dequant_weight = dequant_weight.to(torch_type) 75 | 76 | from op_tests._pack import create_pack_method 77 | 78 | packmethod = create_pack_method("awq", "int4") 79 | packed_weight = packmethod.pack( 80 | (weight.to(torch.int32)), False, transpose=False 81 | ) 82 | 83 | zentorch_packed_weights = zentorch._C.zentorch_get_packed_embedding_weight( 84 | packed_weight, scales, zero_points 85 | ) 86 | 87 | cat_input = torch.randn(2, 8).type(torch_type) 88 | 89 | eb1 = torch.nn.functional.embedding_bag( 90 | indices, 91 | dequant_weight, 92 | offsets, 93 | mode="sum", 94 | ) 95 | eb2 = torch.nn.functional.embedding_bag( 96 | indices, 97 | dequant_weight, 98 | offsets, 99 | mode="sum", 100 | ) 101 | 102 | ref_result = torch.cat([eb1, eb2, cat_input], dim=1) 103 | model = Custom_Model_Quant_Embedding_Group() 104 | model = torch.compile(model, backend="zentorch") 105 | model_result = model(zentorch_packed_weights, indices, offsets, 106 | cat_input, torch_type) 107 | self.assertEqual(ref_result, model_result, atol=1e-3, rtol=1e-3) 108 | 109 | 110 | if __name__ == "__main__": 111 | run_tests() 112 | -------------------------------------------------------------------------------- /test/unittests/model_tests/test_sdpa.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | import sys 10 | from pathlib import Path 11 | from torch.nn.functional import scaled_dot_product_attention 12 | from itertools import product 13 | 14 | sys.path.append(str(Path(__file__).parent.parent)) 15 | from unittest_utils import ( # noqa: 402 16 | Zentorch_TestCase, 17 | has_zentorch, 18 | reset_dynamo, 19 | run_tests, 20 | supported_dtypes, 21 | seq_length_opt, 22 | batch_size_opt, 23 | ) 24 | 25 | skip_test_pt_2_4 = False 26 | 27 | if torch.__version__[:3] < "2.4": 28 | skip_test_pt_2_4 = True 29 | 30 | 31 | class Custom_Model_Sdpa(torch.nn.Module): 32 | def __init__(self, *args, **kwargs) -> None: 33 | super().__init__(*args, **kwargs) 34 | self.sdpa = scaled_dot_product_attention 35 | 36 | def forward(self, query, key, value, attention_mask): 37 | # Scale= 1/sqrt(hidden_size_per_head) and here we considered head size as 64, 38 | # Hence scale 0.125 is used 39 | # TODO Update test case with 40 | # parameterizing the num_heads and hidden_size_per_head 41 | return self.sdpa(query, key, value, attn_mask=attention_mask, scale=0.125) 42 | 43 | 44 | @unittest.skipIf( 45 | skip_test_pt_2_4, "Skipping test as OP support available from PyTorch 2.4" 46 | ) 47 | class Test_Sdpa_Model(Zentorch_TestCase): 48 | @parameterized.expand(product(supported_dtypes, seq_length_opt, batch_size_opt)) 49 | @torch.inference_mode() 50 | def test_sdpa_model(self, dtype, seq_length, batch_size): 51 | self.data.create_unittest_data(dtype) 52 | torch_type = self.data.get_torch_type(dtype) 53 | amp_enabled = True if dtype == "bfloat16" else False 54 | native_model = Custom_Model_Sdpa().eval() 55 | zentorch_model = Custom_Model_Sdpa().eval() 56 | zentorch_model = torch.compile(zentorch_model, backend="zentorch") 57 | with torch.inference_mode(), torch.autocast( 58 | device_type="cpu", enabled=amp_enabled 59 | ): 60 | sdpa_query = torch.randn( 61 | batch_size, 16, seq_length, 64, device="cpu", requires_grad=False 62 | ).type(torch_type) 63 | sdpa_key = torch.randn( 64 | batch_size, 16, seq_length, 64, device="cpu", requires_grad=False 65 | ).type(torch_type) 66 | sdpa_value = torch.randn( 67 | batch_size, 16, seq_length, 64, device="cpu", requires_grad=False 68 | ).type(torch_type) 69 | sdpa_attention_mask = None 70 | native_output = native_model( 71 | sdpa_query, 72 | sdpa_key, 73 | sdpa_value, 74 | sdpa_attention_mask, 75 | ) 76 | zentorch_output = zentorch_model( 77 | sdpa_query, 78 | sdpa_key, 79 | sdpa_value, 80 | sdpa_attention_mask, 81 | ) 82 | self.assertEqual(native_output, zentorch_output, atol=1e-3, rtol=1e-2) 83 | 84 | 85 | if __name__ == "__main__": 86 | run_tests() 87 | -------------------------------------------------------------------------------- /test/unittests/op_tests/__init__.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | -------------------------------------------------------------------------------- /test/unittests/op_tests/test_addmm_1dbias.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | import sys 10 | from pathlib import Path 11 | 12 | sys.path.append(str(Path(__file__).parent.parent)) 13 | from unittest_utils import ( # noqa: 402 14 | Zentorch_TestCase, 15 | has_zentorch, 16 | run_tests, 17 | supported_dtypes, 18 | ) 19 | 20 | 21 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 22 | class Test_Addmm_1dbias(Zentorch_TestCase): 23 | 24 | @parameterized.expand(supported_dtypes) 25 | @torch.inference_mode() 26 | def test_addmm_1dbias_incorrect_dims(self, dtype): 27 | 28 | self.data.create_unittest_data(dtype) 29 | 30 | with self.assertRaises(RuntimeError) as context: 31 | torch.ops.zentorch.zentorch_addmm_1dbias( 32 | self.data.x, self.data.x, self.data.x 33 | ) 34 | self.assertTrue( 35 | "unsupported dims for self, mat1 and mat2" in str(context.exception) 36 | ) 37 | 38 | 39 | if __name__ == "__main__": 40 | run_tests() 41 | -------------------------------------------------------------------------------- /test/unittests/op_tests/test_addmm_1dbias_mul_add.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | import sys 10 | from pathlib import Path 11 | 12 | sys.path.append(str(Path(__file__).parent.parent)) 13 | from unittest_utils import ( # noqa: 402 14 | Zentorch_TestCase, 15 | has_zentorch, 16 | reset_dynamo, 17 | run_tests, 18 | supported_dtypes, 19 | ) 20 | 21 | 22 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 23 | class Test_Addmm_1dbias_Mul_Add(Zentorch_TestCase): 24 | @parameterized.expand(supported_dtypes) 25 | @torch.inference_mode() 26 | def test_addmm_1dbias_mul_add_mismatched_dimensions(self, dtype): 27 | self.data.create_unittest_data(dtype) 28 | with self.assertRaises(RuntimeError) as context: 29 | torch.ops.zentorch.zentorch_addmm_1dbias_mul_add( 30 | self.data.input1d, 31 | self.data.x, 32 | self.data.y, 33 | torch.reshape( 34 | self.data.input, 35 | (1, list(self.data.input.shape)[0], list(self.data.input.shape)[1]), 36 | ), 37 | torch.reshape( 38 | self.data.input, 39 | (1, list(self.data.input.shape)[0], list(self.data.input.shape)[1]), 40 | ), 41 | ) 42 | self.assertTrue( 43 | "unsupported dims for mat1, mat2, " 44 | "binary1_input and binary2_input" in str(context.exception) 45 | ) 46 | 47 | @parameterized.expand(supported_dtypes) 48 | @torch.inference_mode() 49 | def test_addmm_1dbias_mul_add_mismatched_sizes(self, dtype): 50 | self.data.create_unittest_data(dtype) 51 | with self.assertRaises(RuntimeError) as context: 52 | torch.ops.zentorch.zentorch_addmm_1dbias_mul_add( 53 | self.data.input1d, self.data.x, self.data.y, self.data.x, self.data.x 54 | ) 55 | self.assertTrue( 56 | "unsupported sizes for mat1, mat2, " 57 | "binary1_input and binary2_input" in str(context.exception) 58 | ) 59 | 60 | @parameterized.expand(supported_dtypes) 61 | @torch.inference_mode() 62 | def test_addmm_1dbias_mul_add(self, dtype): 63 | self.skip_if_bfloat16_path_issue(dtype) 64 | new_dtype = self.data.get_torch_type(dtype) 65 | arg_0 = torch.rand((30), dtype=new_dtype) 66 | arg_1 = torch.rand((20, 40), dtype=new_dtype) 67 | arg_2 = torch.rand((30, 40), dtype=new_dtype) 68 | arg_3 = torch.rand((20, 30), dtype=new_dtype) 69 | reset_dynamo() 70 | output_1 = torch.add( 71 | torch.mul(torch.nn.functional.linear(arg_1, arg_2, arg_0), arg_3), arg_3 72 | ) 73 | output_2 = torch.ops.zentorch.zentorch_addmm_1dbias_mul_add( 74 | arg_0, arg_1, arg_2.t(), arg_3, arg_3 75 | ) 76 | self.assertEqual(output_1, output_2, atol=1e-9, rtol=1e-2) 77 | 78 | 79 | if __name__ == "__main__": 80 | run_tests() 81 | -------------------------------------------------------------------------------- /test/unittests/op_tests/test_addmm_silu.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | import sys 10 | from pathlib import Path 11 | 12 | sys.path.append(str(Path(__file__).parent.parent)) 13 | from unittest_utils import ( # noqa: 402 14 | Zentorch_TestCase, 15 | has_zentorch, 16 | run_tests, 17 | supported_dtypes, 18 | ) 19 | 20 | 21 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 22 | class Test_Addmm_Silu(Zentorch_TestCase): 23 | @parameterized.expand(supported_dtypes) 24 | @torch.inference_mode() 25 | def test_addmm_silu(self, dtype): 26 | self.data.create_unittest_data(dtype) 27 | native_output = torch.nn.functional.silu( 28 | torch.addmm(self.data.input, self.data.x, self.data.y) 29 | ) 30 | zentorch_output = torch.ops.zentorch.zentorch_addmm_silu( 31 | self.data.input, self.data.x, self.data.y 32 | ) 33 | 34 | self.assertEqual(native_output, zentorch_output) 35 | 36 | 37 | if __name__ == "__main__": 38 | run_tests() 39 | -------------------------------------------------------------------------------- /test/unittests/op_tests/test_addmm_silu_mul.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | import sys 10 | from pathlib import Path 11 | 12 | sys.path.append(str(Path(__file__).parent.parent)) 13 | from unittest_utils import ( # noqa: 402 14 | Zentorch_TestCase, 15 | has_zentorch, 16 | run_tests, 17 | supported_dtypes, 18 | ) 19 | 20 | 21 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 22 | class Test_Addmm_SiLU_Mul(Zentorch_TestCase): 23 | 24 | @parameterized.expand(supported_dtypes) 25 | @torch.inference_mode() 26 | def test_addmm_silu_mul(self, dtype): 27 | self.data.create_unittest_data(dtype) 28 | bias = self.data.input.clone() 29 | native_output = ( 30 | torch.nn.functional.silu(torch.addmm(bias, self.data.x, self.data.y)) 31 | * self.data.input 32 | ) 33 | zentorch_output = torch.ops.zentorch.zentorch_addmm_silu_mul( 34 | bias, self.data.x, self.data.y, self.data.input 35 | ) 36 | self.assertEqual(native_output, zentorch_output) 37 | 38 | @parameterized.expand(supported_dtypes) 39 | @torch.inference_mode() 40 | def test_addmm_silu_mul_mismatched_dimensions(self, dtype): 41 | self.data.create_unittest_data(dtype) 42 | with self.assertRaises(RuntimeError) as context: 43 | torch.ops.zentorch.zentorch_addmm_silu_mul( 44 | self.data.input, 45 | self.data.x, 46 | self.data.y, 47 | torch.reshape( 48 | self.data.input, 49 | (1, list(self.data.input.shape)[0], list(self.data.input.shape)[1]), 50 | ), 51 | ) 52 | self.assertTrue( 53 | "unsupported dims for mat1, mat2 and post op buffer" 54 | in str(context.exception) 55 | ) 56 | 57 | @parameterized.expand(supported_dtypes) 58 | @torch.inference_mode() 59 | def test_addmm_silu_mul_mismatched_sizes(self, dtype): 60 | self.data.create_unittest_data(dtype) 61 | with self.assertRaises(RuntimeError) as context: 62 | torch.ops.zentorch.zentorch_addmm_silu_mul( 63 | self.data.input, self.data.x, self.data.y, self.data.x 64 | ) 65 | self.assertTrue( 66 | "unsupported shapes for mat1, mat2 and post op buffer" 67 | in str(context.exception) 68 | ) 69 | 70 | 71 | if __name__ == "__main__": 72 | run_tests() 73 | -------------------------------------------------------------------------------- /test/unittests/op_tests/test_bmm.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | import sys 10 | from pathlib import Path 11 | 12 | sys.path.append(str(Path(__file__).parent.parent)) 13 | from unittest_utils import ( # noqa: E402 14 | Zentorch_TestCase, 15 | has_zentorch, 16 | run_tests, 17 | skip_test_pt_2_0, 18 | supported_dtypes, 19 | ) 20 | 21 | 22 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 23 | class Test_BMM_Op(Zentorch_TestCase): 24 | @parameterized.expand(supported_dtypes) 25 | @unittest.skipIf(skip_test_pt_2_0, "Skipping test due to PT2.0 instability") 26 | def test_bmm_variants(self, dtype): 27 | 28 | self.data.create_unittest_data(dtype) 29 | self.assertEqual( 30 | torch._C._VariableFunctions.bmm(self.data.x3d, self.data.y3d), 31 | torch.ops.zentorch.zentorch_bmm(self.data.x3d, self.data.y3d), 32 | ) 33 | 34 | @parameterized.expand(supported_dtypes) 35 | def test_bmm_unsupported_dims(self, dtype): 36 | 37 | self.data.create_unittest_data(dtype) 38 | with self.assertRaises(RuntimeError) as context: 39 | torch.ops.zentorch.zentorch_bmm(self.data.x, self.data.y) 40 | 41 | self.assertTrue("unsupported dims for self and mat2" in str(context.exception)) 42 | with self.assertRaises(RuntimeError) as context: 43 | torch.ops.zentorch.zentorch_bmm(self.data.x, self.data.x) 44 | self.assertTrue("unsupported dims for self and mat2" in str(context.exception)) 45 | 46 | @parameterized.expand([("int",)]) 47 | def test_bmm_unsupported_dtype(self, dtype): 48 | 49 | self.data.create_unittest_data(dtype) 50 | with self.assertRaises(RuntimeError) as context: 51 | torch.ops.zentorch.zentorch_bmm(self.data.x3d, self.data.y3d) 52 | 53 | self.assertTrue( 54 | "zentorch_matmul only supports Float and BFloat16" in str(context.exception) 55 | ) 56 | 57 | 58 | if __name__ == "__main__": 59 | run_tests() 60 | -------------------------------------------------------------------------------- /test/unittests/op_tests/test_convolution.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from itertools import product 9 | from parameterized import parameterized 10 | import sys 11 | from pathlib import Path 12 | 13 | sys.path.append(str(Path(__file__).parent.parent)) 14 | from unittest_utils import ( # noqa: 402 15 | Zentorch_TestCase, 16 | has_zentorch, 17 | run_tests, 18 | supported_dtypes, 19 | conv_stride, 20 | conv_padding, 21 | ) 22 | 23 | 24 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 25 | class Test_Convolution(Zentorch_TestCase): 26 | @parameterized.expand([("int",)]) 27 | def test_convolution_unsupported_dtype(self, dtype): 28 | self.data.create_unittest_data(dtype) 29 | with self.assertRaises(RuntimeError) as context: 30 | torch.ops.zentorch.zentorch_convolution( 31 | self.data.conv_input, 32 | self.data.conv_weight, 33 | self.data.conv_bias, 34 | self.data.stride, 35 | self.data.padding, 36 | self.data.dilation, 37 | False, 38 | self.data.output_padding, 39 | 1, 40 | ) 41 | self.assertTrue( 42 | "unsupported data type, only bf16 and fp32 supported for now" 43 | in str(context.exception) 44 | ) 45 | 46 | @parameterized.expand(supported_dtypes) 47 | def test_convolution_invalid_dims(self, dtype): 48 | self.data.create_unittest_data(dtype) 49 | with self.assertRaises(RuntimeError) as context: 50 | torch.ops.zentorch.zentorch_convolution( 51 | self.data.conv_input3d, 52 | self.data.conv_weight3d, 53 | self.data.conv_bias, 54 | self.data.stride, 55 | self.data.padding, 56 | self.data.dilation, 57 | False, 58 | self.data.output_padding, 59 | 1, 60 | ) 61 | self.assertTrue( 62 | "unsupported dims for conv input and weight" in str(context.exception) 63 | ) 64 | 65 | @parameterized.expand(supported_dtypes) 66 | def test_convolution_unsupported_dilation(self, dtype): 67 | self.data.create_unittest_data(dtype) 68 | with self.assertRaises(RuntimeError) as context: 69 | torch.ops.zentorch.zentorch_convolution( 70 | self.data.conv_input, 71 | self.data.conv_weight, 72 | self.data.conv_bias, 73 | self.data.stride, 74 | self.data.padding, 75 | self.data.dilation2, 76 | False, 77 | self.data.output_padding, 78 | 1, 79 | ) 80 | self.assertTrue( 81 | "unsupported value of dilation, only [1,1] supported for now" 82 | in str(context.exception) 83 | ) 84 | 85 | @parameterized.expand( 86 | product( 87 | supported_dtypes, 88 | conv_stride, 89 | conv_padding, 90 | ) 91 | ) 92 | def test_convolution(self, dtype, stride, padding): 93 | self.data.create_unittest_data(dtype) 94 | conv_output = torch._C._VariableFunctions.convolution( 95 | self.data.conv_input, 96 | self.data.conv_weight, 97 | self.data.conv_bias, 98 | stride, 99 | padding, 100 | self.data.dilation, 101 | False, 102 | self.data.output_padding, 103 | 1, 104 | ) 105 | 106 | conv_output_z = torch.ops.zentorch.zentorch_convolution( 107 | self.data.conv_input, 108 | self.data.conv_weight, 109 | self.data.conv_bias, 110 | stride, 111 | padding, 112 | self.data.dilation, 113 | False, 114 | self.data.output_padding, 115 | 1, 116 | ) 117 | self.assertEqual(conv_output, conv_output_z) 118 | 119 | 120 | if __name__ == "__main__": 121 | run_tests() 122 | -------------------------------------------------------------------------------- /test/unittests/op_tests/test_embedding.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | import sys 10 | from pathlib import Path 11 | 12 | sys.path.append(str(Path(__file__).parent.parent)) 13 | from unittest_utils import ( # noqa: 402 14 | Zentorch_TestCase, 15 | has_zentorch, 16 | run_tests, 17 | supported_dtypes, 18 | ) 19 | 20 | 21 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 22 | class Test_Embedding(Zentorch_TestCase): 23 | @parameterized.expand(supported_dtypes) 24 | def test_embedding(self, dtype): 25 | self.data.create_unittest_data(dtype) 26 | y_eb = torch._C._VariableFunctions.embedding( 27 | self.data.embedding_matrix, self.data.emb_input 28 | ) 29 | y_ebz = torch.ops.zentorch.zentorch_embedding( 30 | self.data.embedding_matrix, self.data.emb_input 31 | ) 32 | self.assertEqual(y_eb, y_ebz) 33 | 34 | @parameterized.expand(supported_dtypes) 35 | def test_embedding_sparse_scale(self, dtype): 36 | self.data.create_unittest_data(dtype) 37 | sparse_opt = [True, False] 38 | scale_grad_opt = [True, False] 39 | 40 | for sprs_opt in sparse_opt: 41 | for scale_opt in scale_grad_opt: 42 | y_eb = torch._C._VariableFunctions.embedding( 43 | self.data.embedding_matrix, 44 | self.data.emb_input, 45 | -1, 46 | scale_opt, 47 | sprs_opt, 48 | ) 49 | y_ebz = torch.ops.zentorch.zentorch_embedding( 50 | self.data.embedding_matrix, 51 | self.data.emb_input, 52 | -1, 53 | scale_opt, 54 | sprs_opt, 55 | ) 56 | self.assertEqual(y_eb, y_ebz) 57 | 58 | 59 | if __name__ == "__main__": 60 | run_tests() 61 | -------------------------------------------------------------------------------- /test/unittests/op_tests/test_embedding_bag.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | from itertools import product 8 | import torch 9 | from parameterized import parameterized 10 | import sys 11 | from pathlib import Path 12 | 13 | sys.path.append(str(Path(__file__).parent.parent)) 14 | from unittest_utils import ( # noqa: 402 15 | Zentorch_TestCase, 16 | has_zentorch, 17 | include_last_offset_opt, 18 | mode_opt, 19 | run_tests, 20 | scale_grad_opt, 21 | sparse_opt, 22 | supported_dtypes, 23 | ) 24 | 25 | 26 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 27 | class Test_Embedding_Bag(Zentorch_TestCase): 28 | @parameterized.expand(supported_dtypes) 29 | def test_embedding_bag(self, dtype): 30 | 31 | self.data.create_unittest_data(dtype) 32 | y_eb, _, _, _ = torch._C._VariableFunctions._embedding_bag( 33 | self.data.embedding_matrix, 34 | self.data.emb_input, 35 | self.data.offsets, 36 | False, 37 | 0, 38 | False, 39 | None, 40 | False, 41 | ) 42 | y_ebz = torch.ops.zentorch.zentorch_embedding_bag( 43 | self.data.embedding_matrix, 44 | self.data.emb_input, 45 | self.data.offsets, 46 | False, 47 | 0, 48 | False, 49 | None, 50 | False, 51 | -1, 52 | ) 53 | self.assertEqual(y_eb, y_ebz) 54 | 55 | @parameterized.expand( 56 | product( 57 | supported_dtypes, 58 | mode_opt, 59 | include_last_offset_opt, 60 | sparse_opt, 61 | scale_grad_opt, 62 | ) 63 | ) 64 | def test_embedding_bag_sparse_scale_mode( 65 | self, dtype, mode, include_last_offset, sprs_opt, scale_opt 66 | ): 67 | 68 | self.data.create_unittest_data(dtype) 69 | 70 | # max mode is not supported whenever any of the sparse_opt 71 | # or scale_grad_opt is True 72 | y_eb, _, _, _ = torch._C._VariableFunctions._embedding_bag( 73 | self.data.embedding_matrix, 74 | self.data.emb_input, 75 | self.data.offsets, 76 | scale_opt, 77 | mode, 78 | sprs_opt, 79 | None, 80 | include_last_offset, 81 | ) 82 | 83 | y_ebz = torch.ops.zentorch.zentorch_embedding_bag( 84 | self.data.embedding_matrix, 85 | self.data.emb_input, 86 | self.data.offsets, 87 | scale_opt, 88 | mode, 89 | sprs_opt, 90 | None, 91 | include_last_offset, 92 | -1, 93 | ) 94 | self.assertEqual(y_eb, y_ebz) 95 | 96 | 97 | if __name__ == "__main__": 98 | run_tests() 99 | -------------------------------------------------------------------------------- /test/unittests/op_tests/test_embeg_pack_weight.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | import zentorch 9 | 10 | from unittest_utils import ( # noqa: 402 11 | Zentorch_TestCase, 12 | has_zentorch, 13 | ) 14 | 15 | 16 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 17 | class Test_Embedding_Packed_Weight(Zentorch_TestCase): 18 | def test_embedding_packed_weight(self): 19 | weight = torch.randint(low=0, high=255, size=(20, 40), dtype=torch.int32) 20 | zero_points = torch.zeros(20).type(torch.int32) 21 | weight_scales = torch.randn(20) 22 | packed_weight = zentorch._C.zentorch_get_packed_embedding_weight( 23 | weight, weight_scales, zero_points 24 | ) 25 | self.assertEqual(weight, packed_weight[:, :-1]) 26 | -------------------------------------------------------------------------------- /test/unittests/op_tests/test_horizontal_embedding_bag_group.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | from itertools import product 8 | import torch 9 | from parameterized import parameterized 10 | import sys 11 | from pathlib import Path 12 | 13 | sys.path.append(str(Path(__file__).parent.parent)) 14 | from unittest_utils import ( # noqa: 402 15 | Zentorch_TestCase, 16 | has_zentorch, 17 | include_last_offset_opt, 18 | mode_opt, 19 | run_tests, 20 | scale_grad_opt, 21 | sparse_opt, 22 | supported_dtypes, 23 | ) 24 | 25 | 26 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 27 | class Test_Horizontal_Embedding_Bag_Group(Zentorch_TestCase): 28 | @parameterized.expand( 29 | product( 30 | supported_dtypes, 31 | mode_opt, 32 | include_last_offset_opt, 33 | sparse_opt, 34 | scale_grad_opt, 35 | ) 36 | ) 37 | def test_horizontal_embedding_bag_group( 38 | self, dtype, mode, include_last_offset, sprs_opt, scale_opt 39 | ): 40 | self.data.create_unittest_data(dtype) 41 | y_eb, _, _, _ = torch._C._VariableFunctions._embedding_bag( 42 | self.data.embedding_matrix, 43 | self.data.emb_input, 44 | self.data.offsets, 45 | scale_opt, 46 | mode, 47 | sprs_opt, 48 | None, 49 | include_last_offset, 50 | ) 51 | y_ebz_list = torch.ops.zentorch.zentorch_horizontal_embedding_bag_group( 52 | [self.data.embedding_matrix] * 3, 53 | [self.data.emb_input] * 3, 54 | [self.data.offsets] * 3, 55 | [scale_opt] * 3, 56 | [mode] * 3, 57 | [sprs_opt] * 3, 58 | [None] * 3, 59 | [include_last_offset] * 3, 60 | [-1] * 3, 61 | ) 62 | for i in range(0, int(len(y_ebz_list) / 4)): 63 | self.assertEqual(y_eb, y_ebz_list[i * 4]) 64 | 65 | 66 | if __name__ == "__main__": 67 | run_tests() 68 | -------------------------------------------------------------------------------- /test/unittests/op_tests/test_horizontal_embedding_group.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | import sys 10 | from pathlib import Path 11 | 12 | sys.path.append(str(Path(__file__).parent.parent)) 13 | from unittest_utils import ( # noqa: 402 14 | Zentorch_TestCase, 15 | has_zentorch, 16 | run_tests, 17 | supported_dtypes, 18 | ) 19 | 20 | 21 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 22 | class Test_Horizontal_Embedding_Group(Zentorch_TestCase): 23 | @parameterized.expand(supported_dtypes) 24 | def test_horizontal_embedding_group(self, dtype): 25 | self.data.create_unittest_data(dtype) 26 | y_eb = torch._C._VariableFunctions.embedding( 27 | self.data.embedding_matrix, self.data.emb_input 28 | ) 29 | y_ebz_list = torch.ops.zentorch.zentorch_horizontal_embedding_group( 30 | [self.data.embedding_matrix] * 3, 31 | [self.data.emb_input] * 3, 32 | [-1] * 3, 33 | [False] * 3, 34 | [False] * 3, 35 | ) 36 | for i in range(0, int(len(y_ebz_list))): 37 | self.assertEqual(y_eb, y_ebz_list[i]) 38 | 39 | 40 | if __name__ == "__main__": 41 | run_tests() 42 | -------------------------------------------------------------------------------- /test/unittests/op_tests/test_matmul_impl.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | import sys 10 | from pathlib import Path 11 | 12 | sys.path.append(str(Path(__file__).parent.parent)) 13 | from unittest_utils import ( # noqa: 402 14 | Zentorch_TestCase, 15 | has_zentorch, 16 | run_tests, 17 | supported_dtypes, 18 | zentorch, 19 | ) 20 | 21 | 22 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 23 | class Test_Matmul_Impl_Op(Zentorch_TestCase): 24 | @parameterized.expand(supported_dtypes) 25 | def test_matmul_impl_for_mv_and_dot(self, dtype): 26 | 27 | self.data.create_unittest_data(dtype) 28 | # mv 29 | self.assertEqual( 30 | torch.mv(self.data.input, self.data.input1d), 31 | zentorch._C.zentorch_matmul_impl( 32 | self.data.input, 33 | self.data.input1d, 34 | self.data.empty_bias, 35 | self.data.result_m, 36 | [], 37 | [], 38 | ), 39 | atol=1e-3, 40 | rtol=1e-2, 41 | ) 42 | # dot 43 | self.assertEqual( 44 | torch.dot(self.data.input1d, self.data.input1d), 45 | zentorch._C.zentorch_matmul_impl( 46 | self.data.input1d, 47 | self.data.input1d, 48 | self.data.empty_bias, 49 | self.data.result_1, 50 | [], 51 | [], 52 | ), 53 | ) 54 | 55 | def test_matmul_impl_bfloat16_postop(self): 56 | self.data.create_unittest_data("float32") 57 | with self.assertRaises(RuntimeError) as context: 58 | bias_as_postop = self.data.x.clone().to(torch.bfloat16) 59 | post_op_add = 6 60 | zentorch._C.zentorch_matmul_impl( 61 | self.data.x, 62 | self.data.x, 63 | self.data.empty_bias, 64 | self.data.result.to(self.data.x.dtype), 65 | [post_op_add], 66 | [bias_as_postop], 67 | ) 68 | 69 | self.assertTrue( 70 | "zentorch_matmul only supports Float post ops " 71 | "when input matrix is Float" in str(context.exception) 72 | ) 73 | 74 | def test_matmul_impl_int_postop(self): 75 | self.skip_if_bfloat16_unsupported_hardware() 76 | self.data.create_unittest_data("bfloat16") 77 | with self.assertRaises(RuntimeError) as context_int: 78 | bias_as_postop = self.data.x.clone().to(torch.int) 79 | post_op_add = 6 80 | zentorch._C.zentorch_matmul_impl( 81 | self.data.x, 82 | self.data.x, 83 | self.data.empty_bias, 84 | self.data.result.to(self.data.x.dtype), 85 | [post_op_add], 86 | [bias_as_postop], 87 | ) 88 | 89 | self.assertTrue( 90 | "zentorch_matmul only supports BFloat16 post ops " 91 | "when input matrix is BFloat16" in str(context_int.exception) 92 | ) 93 | 94 | def test_int_matmul_impl_postop(self): 95 | self.data.create_unittest_data("int") 96 | with self.assertRaises(RuntimeError) as context_int: 97 | bias_as_postop = self.data.x3d.clone().to(torch.int) 98 | post_op_add = 6 99 | zentorch._C.zentorch_matmul_impl( 100 | self.data.x, 101 | self.data.x, 102 | self.data.empty_bias, 103 | self.data.result.to(self.data.x.dtype), 104 | [post_op_add], 105 | [bias_as_postop], 106 | ) 107 | 108 | self.assertTrue( 109 | "zentorch_matmul only supports Float and BFloat16" 110 | in str(context_int.exception) 111 | ) 112 | 113 | 114 | if __name__ == "__main__": 115 | run_tests() 116 | -------------------------------------------------------------------------------- /test/unittests/op_tests/test_mm.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | import sys 10 | from pathlib import Path 11 | 12 | sys.path.append(str(Path(__file__).parent.parent)) 13 | from unittest_utils import ( # noqa: 402 14 | Zentorch_TestCase, 15 | has_zentorch, 16 | run_tests, 17 | skip_test_pt_2_0, 18 | supported_dtypes, 19 | ) 20 | 21 | 22 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 23 | class Test_MM_Op(Zentorch_TestCase): 24 | @parameterized.expand(supported_dtypes) 25 | @unittest.skipIf(skip_test_pt_2_0, "Skipping test due to PT2.0 instability") 26 | def test_mm_variants(self, dtype): 27 | self.data.create_unittest_data(dtype) 28 | # mm 29 | self.assertEqual( 30 | torch._C._VariableFunctions.mm(self.data.x, self.data.y), 31 | torch.ops.zentorch.zentorch_mm(self.data.x, self.data.y), 32 | ) 33 | self.assertEqual( 34 | torch.matmul(self.data.x, self.data.y), 35 | torch.ops.zentorch.zentorch_mm(self.data.x, self.data.y), 36 | ) 37 | self.assertEqual( 38 | torch.mm(self.data.x, self.data.y), 39 | torch.ops.zentorch.zentorch_mm(self.data.x, self.data.y), 40 | ) 41 | 42 | self.assertEqual( 43 | self.data.x @ self.data.y, 44 | torch.ops.zentorch.zentorch_mm(self.data.x, self.data.y), 45 | ) 46 | 47 | self.assertEqual( 48 | torch.mul(self.data.A, self.data.B), 49 | torch.ops.zentorch.zentorch_mm(self.data.A, self.data.B), 50 | ) 51 | 52 | @parameterized.expand(supported_dtypes) 53 | def test_mm_mismatched_dimensions(self, dtype): 54 | self.data.create_unittest_data(dtype) 55 | with self.assertRaises(RuntimeError) as context: 56 | torch.ops.zentorch.zentorch_mm( 57 | self.data.x, 58 | torch.reshape( 59 | self.data.x, 60 | (1, list(self.data.x.shape)[0], list(self.data.x.shape)[1]), 61 | ), 62 | ) 63 | self.assertTrue("unsupported dims for self and mat2" in str(context.exception)) 64 | with self.assertRaises(RuntimeError) as context: 65 | torch.ops.zentorch.zentorch_mm(self.data.x3d, self.data.x3d) 66 | self.assertTrue("unsupported dims for self and mat2" in str(context.exception)) 67 | 68 | @parameterized.expand([("int",)]) 69 | def test_mm_unsupported_dtype(self, dtype): 70 | 71 | self.data.create_unittest_data(dtype) 72 | with self.assertRaises(RuntimeError) as context: 73 | torch.ops.zentorch.zentorch_mm(self.data.x, self.data.y) 74 | self.assertTrue( 75 | "zentorch_matmul only supports Float and BFloat16" in str(context.exception) 76 | ) 77 | 78 | @parameterized.expand(supported_dtypes) 79 | def test_mm_relu(self, dtype): 80 | 81 | self.data.create_unittest_data(dtype) 82 | # mm->relu 83 | self.assertEqual( 84 | torch._C._VariableFunctions.relu( 85 | torch._C._VariableFunctions.mm(self.data.x, self.data.y) 86 | ), 87 | torch.ops.zentorch.zentorch_mm_relu(self.data.x, self.data.y), 88 | ) 89 | 90 | 91 | if __name__ == "__main__": 92 | run_tests() 93 | -------------------------------------------------------------------------------- /test/unittests/op_tests/test_mm_silu.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | import sys 10 | from pathlib import Path 11 | 12 | sys.path.append(str(Path(__file__).parent.parent)) 13 | from unittest_utils import ( # noqa: 402 14 | Zentorch_TestCase, 15 | has_zentorch, 16 | run_tests, 17 | supported_dtypes, 18 | ) 19 | 20 | 21 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 22 | class Test_MM_Silu(Zentorch_TestCase): 23 | @parameterized.expand(supported_dtypes) 24 | @torch.inference_mode() 25 | def test_mm_silu(self, dtype): 26 | self.data.create_unittest_data(dtype) 27 | native_output = torch.nn.functional.silu(torch.matmul(self.data.x, self.data.y)) 28 | zentorch_output = torch.ops.zentorch.zentorch_mm_silu(self.data.x, self.data.y) 29 | 30 | self.assertEqual(native_output, zentorch_output) 31 | 32 | 33 | if __name__ == "__main__": 34 | run_tests() 35 | -------------------------------------------------------------------------------- /test/unittests/op_tests/test_mm_silu_mul.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | import sys 10 | from pathlib import Path 11 | 12 | sys.path.append(str(Path(__file__).parent.parent)) 13 | from unittest_utils import ( # noqa: 402 14 | Zentorch_TestCase, 15 | has_zentorch, 16 | run_tests, 17 | supported_dtypes, 18 | ) 19 | 20 | 21 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 22 | class Test_MM_SiLU_Mul(Zentorch_TestCase): 23 | @parameterized.expand(supported_dtypes) 24 | @torch.inference_mode() 25 | def test_mm_silu_mul(self, dtype): 26 | self.data.create_unittest_data(dtype) 27 | native_output = ( 28 | torch.nn.functional.silu(torch.matmul(self.data.x, self.data.y)) 29 | * self.data.input 30 | ) 31 | zentorch_output = torch.ops.zentorch.zentorch_mm_silu_mul( 32 | self.data.x, self.data.y, self.data.input 33 | ) 34 | self.assertEqual(native_output, zentorch_output) 35 | 36 | @parameterized.expand(supported_dtypes) 37 | @torch.inference_mode() 38 | def test_mm_silu_mul_mismatched_dimensions(self, dtype): 39 | self.data.create_unittest_data(dtype) 40 | with self.assertRaises(RuntimeError) as context: 41 | torch.ops.zentorch.zentorch_mm_silu_mul( 42 | self.data.x, 43 | self.data.y, 44 | torch.reshape( 45 | self.data.input, 46 | (1, list(self.data.input.shape)[0], list(self.data.input.shape)[1]), 47 | ), 48 | ) 49 | self.assertTrue( 50 | "unsupported dims for mat1, mat2 and post op buffer" 51 | in str(context.exception) 52 | ) 53 | 54 | @parameterized.expand(supported_dtypes) 55 | @torch.inference_mode() 56 | def test_mm_silu_mul_mismatched_sizes(self, dtype): 57 | self.data.create_unittest_data(dtype) 58 | with self.assertRaises(RuntimeError) as context: 59 | torch.ops.zentorch.zentorch_mm_silu_mul( 60 | self.data.x, self.data.y, self.data.x 61 | ) 62 | self.assertTrue( 63 | "unsupported shapes for mat1, mat2 and post op buffers" 64 | in str(context.exception) 65 | ) 66 | 67 | 68 | if __name__ == "__main__": 69 | run_tests() 70 | -------------------------------------------------------------------------------- /test/unittests/op_tests/test_prepare_4d_causal_attention_mask.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | from parameterized import parameterized 9 | from itertools import product 10 | import sys 11 | from pathlib import Path 12 | 13 | sys.path.append(str(Path(__file__).parent.parent)) 14 | from unittest_utils import ( # noqa: 402 15 | Zentorch_TestCase, 16 | has_zentorch, 17 | run_tests, 18 | supported_dtypes, 19 | Test_Data, 20 | ) 21 | 22 | 23 | sliding_windows = [10, 40] 24 | seq_lens = [1, 32] 25 | 26 | 27 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 28 | class Test_Prepare_4d_causal_Attention_Mask(Zentorch_TestCase): 29 | @parameterized.expand(product(supported_dtypes, sliding_windows, seq_lens)) 30 | @torch.inference_mode() 31 | def test_prepare_4d_causal_attention_mask(self, dtype, sliding_window, seq_len): 32 | torch_dtype = torch.float32 if dtype == "float32" else torch.bfloat16 33 | inputs_embeds = torch.rand((1, seq_len, 768), dtype=torch_dtype) 34 | finfo_min = torch.finfo(torch_dtype).min 35 | past_key_values_length = 0 36 | if seq_len == 1: 37 | past_key_values_length = 32 38 | attention_mask = torch.ones( 39 | (1, past_key_values_length + seq_len), dtype=torch_dtype 40 | ) 41 | output = torch.ops.zentorch.prepare_4d_causal_attention_mask( 42 | attention_mask, 43 | inputs_embeds, 44 | past_key_values_length, 45 | torch.tensor(finfo_min).contiguous(), 46 | sliding_window, 47 | ) 48 | 49 | from transformers.modeling_attn_mask_utils import ( 50 | _prepare_4d_causal_attention_mask, 51 | ) 52 | 53 | output_ref = _prepare_4d_causal_attention_mask( 54 | attention_mask, 55 | (inputs_embeds.shape[0], inputs_embeds.shape[1]), 56 | inputs_embeds, 57 | past_key_values_length, 58 | sliding_window, 59 | ) 60 | self.assertEqual(output, output_ref) 61 | 62 | @parameterized.expand(product(sliding_windows, seq_lens)) 63 | def test_prepare_4d_causal_attention_mask_incorrect_dtype( 64 | self, sliding_window, seq_len 65 | ): 66 | inputs_embeds = torch.randint(low=0, high=100, size=(1, seq_len, 768)) 67 | finfo_min = torch.iinfo(torch.int).min 68 | past_key_values_length = 0 69 | if seq_len == 1: 70 | past_key_values_length = 32 71 | attention_mask = torch.ones( 72 | (1, past_key_values_length + seq_len), dtype=torch.long 73 | ) 74 | 75 | with self.assertRaises(RuntimeError) as context: 76 | _ = torch.ops.zentorch.prepare_4d_causal_attention_mask( 77 | attention_mask, 78 | inputs_embeds, 79 | past_key_values_length, 80 | torch.tensor(finfo_min).contiguous(), 81 | sliding_window, 82 | ) 83 | self.assertTrue( 84 | "zentorch::prepare_4d_causal_attention_mask_kernel_impl supports " 85 | "only float and bfloat16 datatypes" in str(context.exception) 86 | ) 87 | 88 | 89 | if __name__ == "__main__": 90 | run_tests() 91 | -------------------------------------------------------------------------------- /test/unittests/op_tests/test_qlinear_eltwise.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | from itertools import product 8 | import torch 9 | from parameterized import parameterized 10 | import sys 11 | from pathlib import Path 12 | 13 | sys.path.append(str(Path(__file__).parent.parent)) 14 | from unittest_utils import ( # noqa: 402 15 | Zentorch_TestCase, 16 | has_zentorch, 17 | run_tests, 18 | qlinear_dtypes, 19 | input_dim_opt, 20 | q_weight_list_opt, 21 | bias_opt, 22 | q_granularity_opt, 23 | q_zero_points_dtype_opt, 24 | q_linear_dtype_opt, 25 | qlinear_eltwise_map, 26 | ) 27 | from quant_utils import qdq_linear # noqa: 402 28 | 29 | 30 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 31 | class Test_Qlinear_Eltwise(Zentorch_TestCase): 32 | @parameterized.expand( 33 | product( 34 | qlinear_dtypes, 35 | input_dim_opt, 36 | q_weight_list_opt, 37 | bias_opt, 38 | q_granularity_opt, 39 | q_zero_points_dtype_opt, 40 | q_linear_dtype_opt, 41 | q_linear_dtype_opt, 42 | qlinear_eltwise_map.keys(), 43 | ), 44 | skip_on_empty=True, 45 | ) 46 | @torch.inference_mode() 47 | def test_qlinear_eltwise_fused_op_accuracy( 48 | self, 49 | dtype, 50 | input_dim, 51 | q_weight_idx, 52 | bias_opt_idx, 53 | q_granularity_val, 54 | q_zero_points_dtype, 55 | input_dtype, 56 | output_dtype, 57 | eltwise_op, 58 | ): 59 | self.data.create_unittest_data(dtype) 60 | self.skip_if_does_not_support_arg_combination_for_qlinear( 61 | bias_opt_idx, input_dtype, output_dtype 62 | ) 63 | 64 | # simulated qlinear + eltwise op 65 | qdq_linear_eltwise_output = qdq_linear( 66 | self.data.x_for_qlinear[input_dtype][input_dim], 67 | self.data.y_int8[q_weight_idx], 68 | self.data.bias_for_qlinear[bias_opt_idx], 69 | self.data.x_scales["per_tensor"], 70 | self.data.x_zero_points["per_tensor"][input_dtype][q_zero_points_dtype], 71 | self.data.y_scales[q_granularity_val], 72 | self.data.y_zero_points[q_granularity_val], 73 | qlinear_eltwise_map[eltwise_op][0].eval(), 74 | self.data.get_torch_type(output_dtype), 75 | self.data.output_scales["per_tensor"][output_dtype]["positive_scales"], 76 | self.data.output_zero_points["per_tensor"][output_dtype], 77 | ) 78 | 79 | # zentorch qlinear + eltwise fused op 80 | zentorch_qlinear_eltwise_output = qlinear_eltwise_map[eltwise_op][1]( 81 | self.data.x_for_qlinear[input_dtype][input_dim], 82 | self.data.y_int8[q_weight_idx], 83 | self.data.bias_for_qlinear[bias_opt_idx], 84 | self.data.x_scales["per_tensor"], 85 | self.data.x_zero_points["per_tensor"][input_dtype][q_zero_points_dtype], 86 | self.data.y_scales[q_granularity_val], 87 | self.data.y_zero_points[q_granularity_val], 88 | output_dtype=self.data.get_torch_type(output_dtype), 89 | output_scales=self.data.output_scales["per_tensor"][output_dtype][ 90 | "positive_scales" 91 | ], 92 | output_zero_points=self.data.output_zero_points["per_tensor"][output_dtype], 93 | ) 94 | 95 | self.assertEqual( 96 | qdq_linear_eltwise_output, 97 | zentorch_qlinear_eltwise_output, 98 | atol=1e-2, 99 | rtol=1e-2, 100 | ) 101 | 102 | 103 | if __name__ == "__main__": 104 | run_tests() 105 | -------------------------------------------------------------------------------- /test/unittests/op_tests/test_quant_embedding_bag.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import unittest 7 | import torch 8 | import sys 9 | from pathlib import Path 10 | from parameterized import parameterized 11 | 12 | sys.path.append(str(Path(__file__).parent.parent)) 13 | from unittest_utils import ( # noqa: 402 14 | Zentorch_TestCase, 15 | has_zentorch, 16 | reset_dynamo, 17 | run_tests, 18 | supported_dtypes, 19 | zentorch, 20 | ) 21 | 22 | 23 | @unittest.skipIf(not has_zentorch, "ZENTORCH is not installed") 24 | class Test_WOQ_Embedding_Bag(Zentorch_TestCase): 25 | 26 | @parameterized.expand(supported_dtypes) 27 | @torch.inference_mode() 28 | def test_quant_embedding_bag(self, dtype): 29 | torch_type = self.data.get_torch_type(dtype) 30 | weight = torch.randint(low=0, high=15, size=(4, 16), dtype=torch_type) 31 | indices = torch.tensor([1, 2, 3], dtype=torch.long) 32 | offsets = torch.tensor([0, 2], dtype=torch.long) 33 | scales = torch.tensor([[1.0], [4.0], [5.0], [7.0]]) 34 | zero_points = torch.tensor([0, 0, 0, 0], dtype=torch.int32) 35 | # 1 2 3 4 5 6 7 8 36 | # will be packed as 37 | # 8 7 6 5 4 3 2 1 38 | # 1000 0111 0110 0101 0100 0011 0010 0001 39 | 40 | from op_tests._pack import create_pack_method 41 | 42 | packmethod = create_pack_method("awq", "int4") 43 | packed_weight = packmethod.pack( 44 | (weight.to(torch.int32)), False, transpose=False 45 | ) 46 | dequant_weight = weight * scales 47 | 48 | ref_result = torch.nn.functional.embedding_bag( 49 | indices, dequant_weight, offsets, mode="sum" 50 | ).to(torch_type) 51 | 52 | zentorch_packed_weights = zentorch._C.zentorch_get_packed_embedding_weight( 53 | packed_weight, scales, zero_points 54 | ) 55 | op_result = torch.ops.zentorch.zentorch_quant_embedding_bag( 56 | zentorch_packed_weights, 57 | indices, 58 | offsets, 59 | 4, # assumes that weights has been quantized to uint4 hence 4 bits 60 | torch_type, 61 | False, 62 | 0, 63 | False, 64 | None, 65 | 0, 66 | -1, 67 | ) 68 | self.assertEqual(ref_result, op_result) 69 | 70 | 71 | if __name__ == "__main__": 72 | run_tests() 73 | -------------------------------------------------------------------------------- /test/unittests/quant_utils.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import sys 7 | from pathlib import Path 8 | import torch 9 | 10 | sys.path.append(str(Path(__file__).parent.parent)) 11 | 12 | 13 | def qdq_linear( 14 | inp, 15 | weight, 16 | bias, 17 | inp_scales, 18 | inp_zero_points, 19 | weight_scales, 20 | weight_zero_points, 21 | eltwise_op, 22 | output_dtype, 23 | output_scales=None, 24 | output_zero_points=None, 25 | ): 26 | inp_min_val = -128 if inp_zero_points.dtype == torch.int8 else 0 27 | inp_max_val = 127 if inp_zero_points.dtype == torch.int8 else 255 28 | weight_min_val = -128 if weight_zero_points.dtype == torch.int8 else 0 29 | weight_max_val = 127 if weight_zero_points.dtype == torch.int8 else 255 30 | out_features_axis = 0 31 | 32 | if inp.dtype == torch.float32 or inp.dtype == torch.bfloat16: 33 | # fake_quantize_per_tensor_affine only supports fp32 inputs 34 | qdq_inp = torch.fake_quantize_per_tensor_affine( 35 | inp.to(torch.float32), inp_scales, inp_zero_points, inp_min_val, inp_max_val 36 | ) 37 | else: 38 | qdq_inp = torch.ops.quantized_decomposed.dequantize_per_tensor.default( 39 | inp, 40 | inp_scales, 41 | inp_zero_points, 42 | inp_min_val, 43 | inp_max_val, 44 | inp.dtype, 45 | ) 46 | if weight_scales.numel() == 1: 47 | dq_weight = torch.ops.quantized_decomposed.dequantize_per_tensor.default( 48 | weight, 49 | weight_scales, 50 | weight_zero_points, 51 | weight_min_val, 52 | weight_max_val, 53 | weight.dtype, 54 | ) 55 | else: 56 | dq_weight = torch.ops.quantized_decomposed.dequantize_per_channel.default( 57 | weight, 58 | weight_scales, 59 | weight_zero_points, 60 | out_features_axis, 61 | weight_min_val, 62 | weight_max_val, 63 | weight.dtype, 64 | ) 65 | 66 | qdq_linear_output = torch.nn.functional.linear(qdq_inp, dq_weight, bias) 67 | 68 | if eltwise_op is not None: 69 | qdq_linear_output = eltwise_op(qdq_linear_output) 70 | 71 | if output_scales is not None and output_zero_points is not None: 72 | output_min_val = -128 if output_zero_points.dtype == torch.int8 else 0 73 | output_max_val = 127 if output_zero_points.dtype == torch.int8 else 255 74 | return torch.ops.quantized_decomposed.quantize_per_tensor.default( 75 | qdq_linear_output, 76 | output_scales, 77 | output_zero_points, 78 | output_min_val, 79 | output_max_val, 80 | output_zero_points.dtype, 81 | ).to(output_dtype) 82 | else: 83 | return qdq_linear_output.to(output_dtype) 84 | -------------------------------------------------------------------------------- /test/unittests/unittest_utils.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright (c) 2024-2025 Advanced Micro Devices, Inc. 3 | # All rights reserved. 4 | # ****************************************************************************** 5 | 6 | import sys 7 | from pathlib import Path 8 | import os 9 | import shutil 10 | 11 | sys.path.append(str(Path(__file__).parent.parent)) 12 | 13 | from utils import ( # noqa: 402 # noqa: F401 14 | BaseZentorchTestCase, 15 | run_tests, 16 | zentorch, 17 | has_zentorch, 18 | counters, 19 | supported_dtypes, 20 | qlinear_dtypes, 21 | skip_test_pt_2_0, 22 | skip_test_pt_2_1, 23 | skip_test_pt_2_3, 24 | skip_test_pt_2_4, 25 | reset_dynamo, 26 | set_seed, 27 | freeze_opt, 28 | test_with_freeze_opt, 29 | Test_Data, 30 | woq_dtypes, 31 | include_last_offset_opt, 32 | scale_grad_opt, 33 | mode_opt, 34 | sparse_opt, 35 | input_dim_opt, 36 | q_weight_list_opt, 37 | bias_opt, 38 | woq_qzeros_opt, 39 | group_size_opt, 40 | q_granularity_opt, 41 | q_zero_points_dtype_opt, 42 | q_linear_dtype_opt, 43 | conv_stride, 44 | conv_padding, 45 | at_ops, 46 | zt_ops, 47 | qlinear_eltwise_map, 48 | seq_length_opt, 49 | batch_size_opt, 50 | torch, 51 | ) 52 | 53 | 54 | path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) 55 | 56 | 57 | class Zentorch_TestCase(BaseZentorchTestCase): 58 | def setUp(self): 59 | super().setUp() 60 | os.makedirs(os.path.join(path, "data"), exist_ok=True) 61 | self.data = Test_Data() 62 | 63 | def tearDown(self): 64 | del self.data 65 | shutil.rmtree(os.path.join(path, "data")) 66 | 67 | def skip_if_does_not_support_arg_combination_for_qlinear( 68 | self, bias_opt_idx, input_dtype, output_dtype 69 | ): 70 | if ( 71 | self.data.bias_for_qlinear[bias_opt_idx] is None 72 | and input_dtype in ("float32", "bfloat16") 73 | and output_dtype not in (input_dtype, "int8", "uint8") 74 | ): 75 | self.skipTest( 76 | "Skipping test, if bias is None and input is floating-point, then " 77 | "output dtype has to match either input dtype or be any of int8 " 78 | "or uint8" 79 | ) 80 | 81 | if ( 82 | self.data.bias_for_qlinear[bias_opt_idx] is not None 83 | and self.data.bias_for_qlinear[bias_opt_idx].dtype == torch.float32 84 | and output_dtype == "bfloat16" 85 | ): 86 | self.skipTest( 87 | "Skipping test, if bias is fp32, then output dtype cannot be bf16." 88 | ) 89 | 90 | if ( 91 | self.data.bias_for_qlinear[bias_opt_idx] is not None 92 | and self.data.bias_for_qlinear[bias_opt_idx].dtype == torch.bfloat16 93 | and output_dtype == "float32" 94 | ): 95 | self.skipTest( 96 | "Skipping test, if bias is bf16, then output dtype cannot be fp32." 97 | ) 98 | 99 | if ( 100 | self.data.bias_for_qlinear[bias_opt_idx] is not None 101 | and input_dtype in ("float32", "bfloat16") 102 | and self.data.bias_for_qlinear[bias_opt_idx].dtype 103 | != self.data.get_torch_type(input_dtype) 104 | ): 105 | self.skipTest( 106 | "Skipping test, if bias is not None and input is floating-point, then " 107 | "bias dtype has to match input dtype" 108 | ) 109 | --------------------------------------------------------------------------------