├── .bazelrc ├── .clang-format ├── .github ├── CODEOWNERS └── workflows │ └── IDC_1100_Public_CI.yml ├── .gitignore ├── .pylintrc ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.txt ├── README.md ├── WORKSPACE ├── configure ├── configure.py ├── docs ├── acc_jax.md ├── images │ ├── perfetto_profile.png │ ├── tensorboard_profile.png │ ├── tensorboard_profile_kernelstats.png │ ├── tensorboard_profile_opstats.png │ └── tensorboard_profile_traceview.png └── profiler.md ├── example ├── bert │ ├── README.md │ ├── requirements.txt │ ├── run_qa.py │ └── utils_qa.py ├── fp8 │ ├── README.md │ └── run.py ├── gemma │ ├── README.md │ ├── finetune.py │ ├── gemma.patch │ ├── inference.py │ ├── keras_nlp.patch │ └── prompt.json ├── gptj │ ├── README.md │ ├── jax_gptj.py │ └── prompt.json ├── grok │ ├── README.md │ ├── inference.py │ └── prompt.json ├── qkv_fusion │ └── test_qkv_fusion.py ├── resnet50 │ └── README.md ├── sdxl │ ├── README.md │ ├── inference.py │ └── target.png ├── stable_diffusion │ ├── README.md │ ├── jax_stable.py │ └── target.png └── t5 │ ├── README.md │ ├── install_xpu.sh │ ├── patch │ └── t5.patch │ ├── quick_start.sh │ ├── reference.json │ ├── xl_infer.gin │ └── xxl_infer.gin ├── openxla_for_intel_gpu.png ├── security.md ├── test ├── BRANCH_NAME ├── jax.patch ├── jax_ut.patch └── requirements.txt ├── third-party-programs └── third-party-programs-of-onednn.txt ├── third_party ├── BUILD ├── common.bzl ├── gpus │ ├── BUILD │ ├── crosstool │ │ ├── BUILD │ │ ├── BUILD.sycl.tpl │ │ ├── clang │ │ │ └── bin │ │ │ │ ├── BUILD │ │ │ │ ├── ar_driver.tpl │ │ │ │ └── crosstool_wrapper_driver.tpl │ │ └── sycl_cc_toolchain_config.bzl.tpl │ ├── find_sycl_config.py │ ├── sycl │ │ ├── BUILD │ │ ├── BUILD.tpl │ │ └── build_defs.bzl.tpl │ └── sycl_configure.bzl ├── llvm_spir │ ├── BUILD │ ├── llvm_spir.BUILD │ └── llvm_spir.patch ├── onednn │ ├── BUILD │ ├── LICENSE │ ├── MKL_LICENSE │ ├── gen_gpu_kernel_list.py │ ├── gen_onednn_version.py │ ├── onednn.bzl │ └── onednn_gpu.BUILD ├── openxla.patch ├── version_check.bzl └── xetla │ ├── BUILD │ └── xetla.patch └── xla ├── BUILD ├── profiler ├── BUILD ├── correlator.cc ├── correlator.h ├── device_tracer_sycl.cc ├── trace_options.h ├── tracing.h ├── utils.h ├── ze_api_collector.h ├── ze_kernel_collector.h ├── ze_tracer.h └── ze_utils.h ├── python ├── BUILD ├── __init__.py ├── build_defs.bzl ├── gen_xla_version.py ├── version.py.in └── xpu_plugin_extension.cc ├── service ├── BUILD ├── gpu │ ├── BUILD │ ├── ccl_all_gather_thunk.cc │ ├── ccl_all_gather_thunk.h │ ├── ccl_all_reduce_thunk.cc │ ├── ccl_all_reduce_thunk.h │ ├── ccl_all_to_all_thunk.cc │ ├── ccl_all_to_all_thunk.h │ ├── ccl_collective_broadcast_thunk.cc │ ├── ccl_collective_broadcast_thunk.h │ ├── ccl_collective_permute_thunk.cc │ ├── ccl_collective_permute_thunk.h │ ├── ccl_collective_thunk.cc │ ├── ccl_collective_thunk.h │ ├── ccl_ops.cc │ ├── ccl_ops.h │ ├── ccl_p2p_thunk_common.cc │ ├── ccl_p2p_thunk_common.h │ ├── dot_expand_dims.cc │ ├── dot_expand_dims.h │ ├── gemm_impl_picker.cc │ ├── gemm_impl_picker.h │ ├── matrix_descriptor.h │ ├── onednn_gpu_conv_runner.cc │ ├── onednn_gpu_conv_runner.h │ ├── onednn_matmul_utils.cc │ ├── onednn_matmul_utils.h │ ├── redundant_convert_mover.cc │ ├── redundant_convert_mover.h │ ├── scratch_allocator.cc │ ├── scratch_allocator.h │ ├── sycl_custom_call.cc │ ├── sycl_onednn.cc │ ├── sycl_onednn.h │ ├── utils.h │ ├── xetla │ │ ├── gemm │ │ │ ├── BUILD │ │ │ ├── dispatch_col_major.cc │ │ │ ├── dispatch_col_major.h │ │ │ ├── dispatch_row_major.cc │ │ │ ├── dispatch_row_major.h │ │ │ ├── epilogue_impl.h │ │ │ ├── gemm.cc │ │ │ ├── gemm.h │ │ │ ├── gemm_common.h │ │ │ ├── gemm_dispatch.h │ │ │ └── hgemm_impl.h │ │ └── sdp │ │ │ ├── BUILD │ │ │ ├── fmha_backward.h │ │ │ ├── fmha_forward.h │ │ │ ├── fmha_policy.h │ │ │ ├── fmha_utils.h │ │ │ ├── sdp_backward.cc │ │ │ ├── sdp_backward.h │ │ │ ├── sdp_forward.cc │ │ │ └── sdp_forward.h │ ├── xetla_gpu_fused_mha_runner.cc │ └── xetla_gpu_fused_mha_runner.h └── onednn_util.h ├── stream_executor ├── BUILD └── sycl │ ├── BUILD │ ├── hw_info.cc │ ├── hw_info.h │ ├── sycl_blas.cc │ ├── sycl_blas.h │ ├── sycl_collectives.cc │ ├── sycl_conditional_kernels.cc │ ├── sycl_dnn.cc │ ├── sycl_dnn.h │ ├── sycl_driver.cc │ ├── sycl_event.cc │ ├── sycl_event.h │ ├── sycl_executor.cc │ ├── sycl_executor.h │ ├── sycl_fft.cc │ ├── sycl_fft.h │ ├── sycl_gpu_runtime.cc │ ├── sycl_gpu_runtime.h │ ├── sycl_kernel.cc │ ├── sycl_kernel.h │ ├── sycl_platform.cc │ ├── sycl_platform.h │ ├── sycl_platform_id.cc │ ├── sycl_platform_id.h │ └── sycl_stream.h ├── tools └── pip_package │ ├── BUILD │ ├── MANIFEST.in │ ├── README.md │ ├── build_pip_package.sh │ ├── simple_console.py │ └── xla_setup.py ├── workspace.bzl └── xla.bzl /.bazelrc: -------------------------------------------------------------------------------- 1 | # Required by OpenXLA 2 | # https://github.com/openxla/xla/issues/1323 3 | build --nocheck_visibility 4 | 5 | # Make Bazel print out all options from rc files. 6 | build --announce_rc 7 | 8 | build --config=gpu 9 | 10 | # This config option is used for GPU backend. 11 | build:gpu --crosstool_top=@local_config_sycl//crosstool:toolchain 12 | build:gpu --define=using_sycl=true 13 | build:gpu --repo_env TF_NEED_SYCL=1 14 | build:gpu --define=tensorflow_mkldnn_contraction_kernel=0 15 | build:gpu --cxxopt=-std=c++17 16 | build:gpu --host_cxxopt=-std=c++17 17 | 18 | build --define=use_fast_cpp_protos=true 19 | build --define=allow_oversize_protos=true 20 | 21 | build --spawn_strategy=standalone 22 | build --strategy=Genrule=standalone 23 | build -c opt 24 | 25 | # Default paths for TF_SYSTEM_LIBS 26 | build --define=PREFIX=/usr 27 | build --define=LIBDIR=$(PREFIX)/lib 28 | build --define=INCLUDEDIR=$(PREFIX)/include 29 | 30 | # host build is useless 31 | build --distinct_host_configuration=false 32 | 33 | # Flag to enable remote config 34 | common --experimental_repo_remote_exec 35 | 36 | # Default options should come above this line 37 | 38 | # Options from ./configure 39 | try-import %workspace%/.xla_extension_configure.bazelrc 40 | 41 | # Put user-specific options in .bazelrc.user 42 | try-import %workspace%/.bazelrc.user 43 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | # Run manually to reformat a file: 2 | # clang-format -i --style=file 3 | BasedOnStyle: Google 4 | DerivePointerAlignment: false 5 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | 2 | * @agramesh1 @vsanghavi @mdfaijul @akhilgoe @Solaryee @bhavani-subramanian @ashraf-bhuiyan 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .ipynb_checkpoints 3 | node_modules 4 | /.bazelrc.user 5 | /.xla_extension_configure.bazelrc 6 | /bazel-* 7 | /bazel_pip 8 | /itex/tools/python_bin_path.sh 9 | /_python_build 10 | *.pyc 11 | __pycache__ 12 | *.swp 13 | .vscode/ 14 | .idea/** 15 | .clwb/ 16 | /build/ 17 | [Bb]uild/ 18 | *.whl 19 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | ### License 4 | 5 | Intel® Extension for OpenXLA is licensed under the terms in [LICENSE](LICENSE.txt). By contributing to the project, you agree to the license and copyright terms therein and release your contribution under these terms. 6 | 7 | ### Sign your work 8 | 9 | Please use the sign-off line at the end of the patch. Your signature certifies that you wrote the patch or otherwise have the right to pass it on as an open-source patch. The rules are pretty simple: if you can certify 10 | the below (from [developercertificate.org](http://developercertificate.org/)): 11 | 12 | ``` 13 | Developer Certificate of Origin 14 | Version 1.1 15 | 16 | Copyright (C) 2004, 2006 The Linux Foundation and its contributors. 17 | 660 York Street, Suite 102, 18 | San Francisco, CA 94110 USA 19 | 20 | Everyone is permitted to copy and distribute verbatim copies of this 21 | license document, but changing it is not allowed. 22 | 23 | Developer's Certificate of Origin 1.1 24 | 25 | By making a contribution to this project, I certify that: 26 | 27 | (a) The contribution was created in whole or in part by me and I 28 | have the right to submit it under the open source license 29 | indicated in the file; or 30 | 31 | (b) The contribution is based upon previous work that, to the best 32 | of my knowledge, is covered under an appropriate open source 33 | license and I have the right under that license to submit that 34 | work with modifications, whether created in whole or in part 35 | by me, under the same open source license (unless I am 36 | permitted to submit under a different license), as indicated 37 | in the file; or 38 | 39 | (c) The contribution was provided directly to me by some other 40 | person who certified (a), (b) or (c) and I have not modified 41 | it. 42 | 43 | (d) I understand and agree that this project and the contribution 44 | are public and that a record of the contribution (including all 45 | personal information I submit with it, including my sign-off) is 46 | maintained indefinitely and may be redistributed consistent with 47 | this project or the open source license(s) involved. 48 | ``` 49 | 50 | Then you just add a line to every git commit message: 51 | 52 | Signed-off-by: Joe Smith 53 | 54 | Use your real name (sorry, no pseudonyms or anonymous contributions.) 55 | 56 | If you set your `user.name` and `user.email` git configs, you can sign your 57 | commit automatically with `git commit -s`. 58 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | workspace(name = "intel_extension_for_openxla") 2 | 3 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 4 | load("//third_party:version_check.bzl", "check_bazel_version_at_least") 5 | 6 | check_bazel_version_at_least("6.5.0") 7 | 8 | # To update XLA to a new revision, 9 | # a) update URL and strip_prefix to the new git commit hash 10 | # b) get the sha256 hash of the commit by running: 11 | # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum 12 | # and update the sha256 with the result. 13 | http_archive( 14 | name = "xla", 15 | patch_args = ["-p1"], 16 | patches = ["//third_party:openxla.patch"], 17 | sha256 = "0870fcd86678cae31c56cfc57018f52ceec8e4691472af62c847ade746a0eb13", 18 | strip_prefix = "xla-20a482597b7dd3067b26ca382b88084ee5a21cf7", 19 | urls = [ 20 | "https://github.com/openxla/xla/archive/20a482597b7dd3067b26ca382b88084ee5a21cf7.tar.gz", 21 | ], 22 | ) 23 | 24 | # For development, one often wants to make changes to the XLA repository as well 25 | # as the JAX repository. You can override the pinned repository above with a 26 | # local checkout by either: 27 | # a) overriding the XLA repository by passing a flag like: 28 | # bazel build --bazel_options=--override_repository=xla=/path/to/xla 29 | # or 30 | # b) by commenting out the http_archive above and uncommenting the following: 31 | # local_repository( 32 | # name = "xla", 33 | # path = "/path/to/xla", 34 | # ) 35 | 36 | # Initialize hermetic Python 37 | load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules") 38 | 39 | python_init_rules() 40 | 41 | load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories") 42 | 43 | python_init_repositories( 44 | default_python_version = "system", 45 | requirements = { 46 | "3.9": "@xla//:requirements_lock_3_9.txt", 47 | "3.10": "@xla//:requirements_lock_3_10.txt", 48 | "3.11": "@xla//:requirements_lock_3_11.txt", 49 | "3.12": "@xla//:requirements_lock_3_12.txt", 50 | "3.13": "@xla//:requirements_lock_3_13.txt", 51 | }, 52 | ) 53 | 54 | load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains") 55 | 56 | python_init_toolchains() 57 | 58 | load("@xla//third_party/py:python_init_pip.bzl", "python_init_pip") 59 | 60 | python_init_pip() 61 | 62 | load("@pypi//:requirements.bzl", "install_deps") 63 | 64 | install_deps() 65 | 66 | load("@xla//:workspace4.bzl", "xla_workspace4") 67 | 68 | xla_workspace4() 69 | 70 | load("@xla//:workspace3.bzl", "xla_workspace3") 71 | 72 | xla_workspace3() 73 | 74 | load("@xla//:workspace2.bzl", "xla_workspace2") 75 | 76 | xla_workspace2() 77 | 78 | load("@xla//:workspace1.bzl", "xla_workspace1") 79 | 80 | xla_workspace1() 81 | 82 | load("@xla//:workspace0.bzl", "xla_workspace0") 83 | 84 | xla_workspace0() 85 | 86 | load( 87 | "@xla//third_party/tsl/third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", 88 | "cuda_json_init_repository", 89 | ) 90 | 91 | cuda_json_init_repository() 92 | 93 | load( 94 | "@cuda_redist_json//:distributions.bzl", 95 | "CUDA_REDISTRIBUTIONS", 96 | "CUDNN_REDISTRIBUTIONS", 97 | ) 98 | load( 99 | "@xla//third_party/tsl/third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", 100 | "cuda_redist_init_repositories", 101 | "cudnn_redist_init_repository", 102 | ) 103 | 104 | cuda_redist_init_repositories( 105 | cuda_redistributions = CUDA_REDISTRIBUTIONS, 106 | ) 107 | 108 | cudnn_redist_init_repository( 109 | cudnn_redistributions = CUDNN_REDISTRIBUTIONS, 110 | ) 111 | 112 | load( 113 | "@xla//third_party/tsl/third_party/gpus/cuda/hermetic:cuda_configure.bzl", 114 | "cuda_configure", 115 | ) 116 | 117 | cuda_configure(name = "local_config_cuda") 118 | 119 | load( 120 | "@xla//third_party/tsl/third_party/nccl/hermetic:nccl_redist_init_repository.bzl", 121 | "nccl_redist_init_repository", 122 | ) 123 | 124 | nccl_redist_init_repository() 125 | 126 | load( 127 | "@xla//third_party/tsl/third_party/nccl/hermetic:nccl_configure.bzl", 128 | "nccl_configure", 129 | ) 130 | 131 | nccl_configure(name = "local_config_nccl") 132 | 133 | load( 134 | "@bazel_toolchains//repositories:repositories.bzl", 135 | bazel_toolchains_repositories = "repositories", 136 | ) 137 | 138 | bazel_toolchains_repositories() 139 | 140 | load("//xla:workspace.bzl", "workspace") 141 | 142 | workspace() 143 | -------------------------------------------------------------------------------- /configure: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | set -o pipefail 5 | 6 | if [ -z "$PYTHON_BIN_PATH" ]; then 7 | PYTHON_BIN_PATH=$(which python || which python3 || true) 8 | fi 9 | 10 | # Set all env variables 11 | CONFIGURE_DIR=$(dirname "$0") 12 | "$PYTHON_BIN_PATH" "${CONFIGURE_DIR}/configure.py" "$@" 13 | 14 | echo "Configuration finished" 15 | -------------------------------------------------------------------------------- /docs/images/perfetto_profile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intel/intel-extension-for-openxla/68b2112b3466d0fcb111369799246d39281c7452/docs/images/perfetto_profile.png -------------------------------------------------------------------------------- /docs/images/tensorboard_profile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intel/intel-extension-for-openxla/68b2112b3466d0fcb111369799246d39281c7452/docs/images/tensorboard_profile.png -------------------------------------------------------------------------------- /docs/images/tensorboard_profile_kernelstats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intel/intel-extension-for-openxla/68b2112b3466d0fcb111369799246d39281c7452/docs/images/tensorboard_profile_kernelstats.png -------------------------------------------------------------------------------- /docs/images/tensorboard_profile_opstats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intel/intel-extension-for-openxla/68b2112b3466d0fcb111369799246d39281c7452/docs/images/tensorboard_profile_opstats.png -------------------------------------------------------------------------------- /docs/images/tensorboard_profile_traceview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intel/intel-extension-for-openxla/68b2112b3466d0fcb111369799246d39281c7452/docs/images/tensorboard_profile_traceview.png -------------------------------------------------------------------------------- /docs/profiler.md: -------------------------------------------------------------------------------- 1 | # GPU Profiler 2 | 3 | Intel® Extension for OpenXLA* provides Profiler to track the performance of workloads running on the Intel GPU. [PJRT_C_API](https://github.com/openxla/xla/blob/main/xla/backends/profiler/plugin/profiler_c_api.h) lets third-party plugin communicate profiling data in XLA's native profiling format which takes a serialized XSpace object and fills the object with runtime data obtained through the oneAPI [Level Zero](https://www.intel.com/content/www/us/en/developer/articles/technical/using-oneapi-level-zero-interface.html) low-level device interface. 4 | 5 | Users can display the execution profile of specific XLA modules, XLA ops, GPU kernels and so on with profiling visualizer like TensorBoard or Perfetto. 6 | 7 | ## How to use 8 | * Neccessary environment variables: 9 | 10 | | env | functionality | 11 | | --- | --- | 12 | | ZE_ENABLE_TRACING_LAYER | Set to 1 to enable the Tracing Layer for Level-Zero API Tracing, see [L0 loader APIs](https://github.com/oneapi-src/level-zero/blob/77d092e314365cc54b9b873a47210a799ed5a77c/doc/loader_api.md?plain=1#L40) for more details. | 13 | | UseCyclesPerSecondTimer | Set to 1 to help libraries with transition to the new resolution since time resolution returned by device properties has been changed to cycles/second in Level-Zero. | 14 | | NEOReadDebugKeys | Set to 1 to read debug environment variables on Linux release builds, see [NEOReadDebugKeys](https://github.com/intel/compute-runtime/blob/master/FAQ.md#how-can-i-enable-reading-debug-environment-variables-on-linux-release-builds) for more details. | 15 | 16 | * Script: 17 | ``` 18 | import jax 19 | import jax.numpy as jnp 20 | 21 | print("jax.local_devices(): ", jax.local_devices()) 22 | 23 | @jax.jit 24 | def lax_conv(): 25 | key = jax.random.PRNGKey(0) 26 | lhs = jax.random.uniform(key, (2,1,9,9), jnp.float32) 27 | rhs = jax.random.uniform(key, (1,1,4,4), jnp.float32) 28 | side = jax.random.uniform(key, (1,1,1,1), jnp.float32) 29 | out = jax.lax.conv_with_general_padding(lhs, rhs, (1,1), ((0,0),(0,0)), (1,1), (1,1)) 30 | out = jax.nn.relu(out) 31 | out = jnp.multiply(out, side) 32 | return out 33 | 34 | jax.profiler.start_trace("./profile_tmp") 35 | print(lax_conv()) 36 | jax.profiler.stop_trace() 37 | ``` 38 | 39 | * Run: 40 | ``` 41 | $ export ZE_ENABLE_TRACING_LAYER=1 42 | $ export UseCyclesPerSecondTimer=1 43 | $ python jax_conv.py 44 | ``` 45 | When this computation is done, the program will generate a directory "profile_tmp", choose one of following tools to visualize profiling data collected in this directory. 46 | 47 | ### TensorBoard profiling 48 | TensorBoard's profiler can be used to profiler JAX or TensorFlow programs. Tensorboard is a great way to acquire and visualize performance traces and profiles of your program. 49 | 50 | The end result looks something like this: 51 |

52 | tensorboard_profile.png 53 |

54 | 55 | 56 | * Requirement: 57 | ``` 58 | pip install -U tensorboard-plugin-profile 59 | ``` 60 | 61 | * Run TensorBoard: 62 | After executing above python script code, you will find the log files in ./profile_tmp. Then, run TensorBoard with following command: 63 | ``` 64 | tensorboard --logdir=./profile_tmp --bind_all 65 | ``` 66 | 67 | * Analyze the result from the Profile tab: 68 | The GPU profiler supports the following profiling items: 69 | * kernel_stats: 70 | ![image](images/tensorboard_profile_kernelstats.png) 71 | * framework_op_stats: 72 | ![image](images/tensorboard_profile_opstats.png) 73 | 74 | * trace_viewer: 75 | ![image](images/tensorboard_profile_traceview.png) 76 | 77 | 78 | ### Perfetto profiling 79 | [Perfetto](https://ui.perfetto.dev/) is a high-performance system tracing and analysis tool primarily used for capturing and analyzing various performance events in Linux or Android systems. We can use Perfetto to visualize profiling data generated by JAX profiler. 80 | After executing above python script code, you will find the log files in ./profile_tmp. Then follow below steps: 81 | * Preparation 82 | Unzip the .gz file in ./profile_tmp: 83 | ``` 84 | unzip xxx.trace.json.gz 85 | ``` 86 | Then we can get a xxx.trace.json. 87 | 88 | * Open the trace file within Perfetto: 89 | ![image](images/perfetto_profile.png) 90 | 91 | ## FAQ 92 | 1.If you see "No dashboards are activated for the current data set." the first time you enter the TensorBoard in the browser: 93 | 94 | Refresh the page, and the profile should be shown. 95 | -------------------------------------------------------------------------------- /example/bert/requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==2.20.0 2 | optax>=0.0.8 3 | transformers>=4.48.0 4 | evaluate>=0.4.1 5 | -------------------------------------------------------------------------------- /example/fp8/README.md: -------------------------------------------------------------------------------- 1 | # FP8 training and inference 2 | FP8 becomes increasingly important as models become larger. Here we introduce how to enable fp8 in openxla via keras quantization api. 3 | 4 | ## Setup Kaggle key to access model 5 | Get your Kaggle key follow [Configure your API key](https://ai.google.dev/gemma/docs/setup#:~:text=T4%20GPU.-,Configure%20your%20API%20key,-To%20use%20Gemma) and then 6 | 7 | ``` 8 | export KAGGLE_USERNAME=xxxxxxxxx 9 | export KAGGLE_KEY=xxxxxxxx 10 | ``` 11 | 12 | ## Package dependency 13 | 14 | Mark `intel-extension-for-openxla` folder as \, then 15 | ```bash 16 | cd /example/fp8/ 17 | pip install keras-nlp==0.10.0 keras==3.3.2 kagglehub==0.2.5 18 | pip install -r ../../test/requirements.txt 19 | ``` 20 | 21 | ## Dataset 22 | 23 | Download [databricks-dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k/blob/main/databricks-dolly-15k.jsonl) dataset. 24 | 25 | ## Options 26 | ``` 27 | --model: The model name. Choices are "gpt2_base_en", "gpt2_medium_en", "gemma_2b_en". Default is "gpt2_base_en". 28 | --fp8: Store true. Whether to use float8 technique. 29 | --batch-size: The batch size. Default is 32. 30 | ``` 31 | 32 | ## Environment Variables 33 | | **ENV** | **Description** | **PVC Platform** | **ATSM/DG2 Platform** | 34 | | :---: | :---: | :---: |:---: | 35 | | ZE_AFFINITY_MASK | Run this model on single GPU tile |export ZE_AFFINITY_MASK=0 | export ZE_AFFINITY_MASK=0 | 36 | | XLA_FLAGS | Disable SimplifyFPConversions pass | export XLA_FLAGS="--xla_allow_excess_precision=0" | export XLA_FLAGS="--xla_allow_excess_precision=0" | 37 | | KERAS_BACKEND | Set keras backend | export KERAS_BACKEND=jax | export KERAS_BACKEND=jax | 38 | 39 | 40 | ### Example 41 | 42 | ```bash 43 | python run.py --model=gpt2_base_en --batch-size=32 --fp8 44 | ``` 45 | 46 | ### Expected Output 47 | 48 | ``` 49 | transformer_layer_10/feedforward_output_dense/outputs_grad_scale 2.494614e-09 50 | transformer_layer_10/feedforward_intermediate_dense/outputs_grad_scale 1.4901163e-08 51 | transformer_layer_10/self_attention/attention_output/outputs_grad_scale 1.9374835e-09 52 | transformer_layer_10/self_attention/value/outputs_grad_scale 9.313226e-09 53 | transformer_layer_10/self_attention/key/outputs_grad_scale 4.6898747e-09 54 | ``` 55 | -------------------------------------------------------------------------------- /example/gemma/gemma.patch: -------------------------------------------------------------------------------- 1 | diff --git a/lm_eval/models/__init__.py b/lm_eval/models/__init__.py 2 | index 8ca27fac..6127ce6c 100644 3 | --- a/lm_eval/models/__init__.py 4 | +++ b/lm_eval/models/__init__.py 5 | @@ -4,6 +4,7 @@ from . import anthropic_llms 6 | from . import huggingface 7 | from . import textsynth 8 | from . import dummy 9 | +from . import gemma 10 | 11 | MODEL_REGISTRY = { 12 | "hf": gpt2.HFLM, 13 | @@ -15,6 +16,7 @@ MODEL_REGISTRY = { 14 | "anthropic": anthropic_llms.AnthropicLM, 15 | "textsynth": textsynth.TextSynthLM, 16 | "dummy": dummy.DummyLM, 17 | + "gemma": gemma.Gemma, 18 | } 19 | 20 | 21 | diff --git a/lm_eval/models/gemma.py b/lm_eval/models/gemma.py 22 | new file mode 100644 23 | index 00000000..732185c4 24 | --- /dev/null 25 | +++ b/lm_eval/models/gemma.py 26 | @@ -0,0 +1,79 @@ 27 | +import os 28 | +import numpy as np 29 | +import keras 30 | +import keras_nlp 31 | +from keras import ops 32 | +from lm_eval.base import BaseLM 33 | +from tqdm import tqdm 34 | +import time 35 | +import math 36 | + 37 | +class Gemma(BaseLM): 38 | + def __init__(self, model_name="gemma_2b_en", dtype="bfloat16", num_beams=1, **kwargs): 39 | + super().__init__() 40 | + keras.config.set_floatx(dtype) 41 | + self.model = keras_nlp.models.GemmaCausalLM.from_preset(model_name) 42 | + if num_beams > 1: 43 | + from keras_nlp.samplers import BeamSampler 44 | + self.model.compile(sampler=BeamSampler(num_beams=num_beams)) 45 | + 46 | + @property 47 | + def eot_token_id(self): 48 | + raise NotImplementedError() 49 | + 50 | + @property 51 | + def max_length(self): 52 | + raise NotImplementedError() 53 | + 54 | + @property 55 | + def max_gen_toks(self): 56 | + raise NotImplementedError() 57 | + 58 | + @property 59 | + def batch_size(self): 60 | + raise NotImplementedError() 61 | + 62 | + @property 63 | + def device(self): 64 | + raise NotImplementedError() 65 | + 66 | + def tok_encode(self, string: str): 67 | + raise NotImplementedError() 68 | + 69 | + def tok_decode(self, tokens): 70 | + raise NotImplementedError() 71 | + 72 | + def loglikelihood(self, requests, disable_tqdm=False): 73 | + results = [] 74 | + for chunk in tqdm( 75 | + requests, total=math.ceil(len(requests)), disable=disable_tqdm 76 | + ): 77 | + context, continuation = chunk 78 | + ctx_encode = self.model.preprocessor.generate_preprocess(context) 79 | + cont_encode = self.model.preprocessor.generate_preprocess(continuation) 80 | + pred_encode = self.model.preprocessor.generate_preprocess(context + continuation) 81 | + ctx_len = ops.sum(ctx_encode["padding_mask"]) 82 | + cont_len = ops.sum(cont_encode["padding_mask"]) 83 | + pred_len = ops.sum(pred_encode["padding_mask"]) 84 | + logits = self.model.score(ops.expand_dims(pred_encode["token_ids"][:pred_len], 0), ops.expand_dims(pred_encode["padding_mask"][:pred_len], 0)) 85 | + cont_token = cont_encode["token_ids"][1:cont_len] 86 | + logits = logits[:, ctx_len-1:pred_len-1, :] 87 | + log_softmax = ops.log_softmax(logits, axis=-1) 88 | + greedy_tokens = ops.squeeze(ops.argmax(log_softmax, axis=-1), 0) 89 | + max_equal = ops.all((greedy_tokens == cont_token)) 90 | + cont_logits = ops.squeeze(ops.take_along_axis(ops.squeeze(log_softmax, 0), ops.expand_dims(cont_token, -1), axis=1), -1) 91 | + answer = (float(ops.convert_to_numpy(ops.sum(cont_logits))), bool(ops.convert_to_numpy(max_equal))) 92 | + results.append(answer) 93 | + 94 | + return results 95 | + 96 | + def greedy_until(self, requests): 97 | + raise NotImplementedError() 98 | + 99 | + def _model_call(self, inps): 100 | + # Isn't used because we override _loglikelihood_tokens 101 | + raise NotImplementedError() 102 | + 103 | + def _model_generate(self, context, max_length, eos_token_id): 104 | + # Isn't used because we override greedy_until 105 | + raise NotImplementedError() 106 | diff --git a/setup.py b/setup.py 107 | index 5db43c17..ea8b627a 100644 108 | --- a/setup.py 109 | +++ b/setup.py 110 | @@ -21,7 +21,7 @@ setuptools.setup( 111 | ], 112 | python_requires=">=3.9", 113 | install_requires=[ 114 | - "datasets>=2.0.0", 115 | + "datasets>=2.20.0", 116 | "jsonlines", 117 | "numexpr", 118 | "openai>=0.6.4", 119 | -------------------------------------------------------------------------------- /example/gemma/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["KERAS_BACKEND"] = "jax" 3 | import time 4 | import json 5 | import pathlib 6 | import argparse 7 | import sys 8 | import keras 9 | import keras_nlp 10 | import jax 11 | 12 | MODEL_CLASSES = { 13 | "gemma_2b": "gemma_2b_en", 14 | "gemma_7b": "gemma_7b_en", 15 | "gemma_2b_it": "gemma_instruct_2b_en", 16 | "gemma_7b_it": "gemma_instruct_7b_en", 17 | } 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | "--model", 22 | type=str, 23 | choices=["gemma_2b", "gemma_7b", "gemma_2b_it", "gemma_7b_it"], 24 | default="gemma_2b", 25 | help="the mdoel name", 26 | ) 27 | parser.add_argument( 28 | "--dtype", 29 | type=str, 30 | choices=["float32", "bfloat16"], 31 | default="float32", 32 | help="bfloat16, float32", 33 | ) 34 | parser.add_argument( 35 | "--input-tokens", 36 | default="32", 37 | choices=["32", "64", "128", "256", "512", "1024", "2016", "2017", "2048", "4096", "8192"], 38 | type=str, 39 | help="input tokens length if needed from prompt.json", 40 | ) 41 | parser.add_argument( 42 | "--max-new-tokens", default=32, type=int, help="output max new tokens" 43 | ) 44 | parser.add_argument( 45 | "--prompt", default=None, type=str, help="input prompt for self-defined if needed" 46 | ) 47 | parser.add_argument("--num-beams", default=1, type=int, help="beam width") 48 | parser.add_argument("--num-iter", default=10, type=int, help="num iter") 49 | parser.add_argument("--num-warmup", default=3, type=int, help="num warmup") 50 | parser.add_argument("--batch-size", default=1, type=int, help="batch size") 51 | args = parser.parse_args() 52 | 53 | if args.dtype == "bfloat16": 54 | keras.config.set_floatx("bfloat16") 55 | model = keras_nlp.models.GemmaCausalLM.from_preset(MODEL_CLASSES[args.model]) 56 | if args.num_beams > 1: 57 | from keras_nlp.samplers import BeamSampler 58 | print("beam") 59 | model.compile(sampler=BeamSampler(num_beams=args.num_beams)) 60 | current_path = os.path.dirname(__file__) 61 | with open(str(current_path) + "/prompt.json") as f: 62 | prompt_pool = json.load(f) 63 | prompt = prompt_pool[args.input_tokens] 64 | 65 | total_time = 0.0 66 | num_iter = args.num_iter 67 | num_warmup = args.num_warmup 68 | prompt = [prompt] * args.batch_size 69 | total_list = [] 70 | output = model.generate(prompt, max_length=int(args.max_new_tokens)+int(args.input_tokens)) 71 | for i in range(num_iter): 72 | tic = time.time() 73 | if i == 5 and False: 74 | jax.profiler.start_trace("./profile_data_bs4") 75 | output = model.generate( 76 | prompt, max_length=int(args.max_new_tokens)+int(args.input_tokens) 77 | ) 78 | print(output) 79 | if i == 5 and False: 80 | jax.profiler.stop_trace() 81 | toc = time.time() 82 | print("Iteration: %d, Time: %.6f sec" % (i, toc - tic), flush = True) 83 | if i >= num_warmup: 84 | total_time += toc - tic 85 | 86 | print("\n", "-" * 10, "Summary:", "-" * 10) 87 | latency = total_time / (num_iter - num_warmup) 88 | print("Inference latency: %.3f sec." % latency) 89 | -------------------------------------------------------------------------------- /example/gemma/keras_nlp.patch: -------------------------------------------------------------------------------- 1 | diff --git a/keras_nlp/models/gemma/gemma_attention.py b/keras_nlp/models/gemma/gemma_attention.py 2 | index 4b39126..c180752 100644 3 | --- a/keras_nlp/models/gemma/gemma_attention.py 4 | +++ b/keras_nlp/models/gemma/gemma_attention.py 5 | @@ -155,15 +155,15 @@ class CachedGemmaAttention(keras.layers.Layer): 6 | query = self._apply_rope(query, cache_update_index) 7 | 8 | if cache is not None: 9 | - key_cache = cache[:, 0, ...] 10 | - value_cache = cache[:, 1, ...] 11 | + key_cache = cache[0] 12 | + value_cache = cache[1] 13 | key_update = self.key_dense(x) 14 | key_update = self._apply_rope(key_update, cache_update_index) 15 | value_update = self.value_dense(x) 16 | start = [0, cache_update_index, 0, 0] 17 | key = ops.slice_update(key_cache, start, key_update) 18 | value = ops.slice_update(value_cache, start, value_update) 19 | - cache = ops.stack((key, value), axis=1) 20 | + cache = [key, value] 21 | else: 22 | key = self.key_dense(x) 23 | key = self._apply_rope(key, cache_update_index) 24 | diff --git a/keras_nlp/models/gemma/gemma_causal_lm.py b/keras_nlp/models/gemma/gemma_causal_lm.py 25 | index 26e9aad..d29238c 100644 26 | --- a/keras_nlp/models/gemma/gemma_causal_lm.py 27 | +++ b/keras_nlp/models/gemma/gemma_causal_lm.py 28 | @@ -215,17 +215,17 @@ class GemmaCausalLM(CausalLM): 29 | # Each decoder layer has a cache; we update them separately. 30 | caches = [] 31 | for i, transformer_layer in enumerate(self.backbone.transformer_layers): 32 | - current_cache = cache[:, i, ...] 33 | + current_cache = cache[i] 34 | x, next_cache = transformer_layer( 35 | x, 36 | cache=current_cache, 37 | cache_update_index=cache_update_index, 38 | ) 39 | caches.append(next_cache) 40 | - cache = ops.stack(caches, axis=1) 41 | + 42 | hidden_states = x = self.backbone.layer_norm(x) 43 | logits = self.backbone.token_embedding(x, reverse=True) 44 | - return logits, hidden_states, cache 45 | + return logits, hidden_states, caches 46 | 47 | def _build_cache(self, token_ids): 48 | """Build an empty cache for use with `call_with_cache()`.""" 49 | @@ -234,11 +234,13 @@ class GemmaCausalLM(CausalLM): 50 | num_layers = self.backbone.num_layers 51 | num_heads = self.backbone.num_key_value_heads 52 | head_dim = self.backbone.head_dim 53 | - shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim] 54 | - cache = ops.zeros(shape, dtype=self.compute_dtype) 55 | + shape = [batch_size, max_length, num_heads, head_dim] 56 | + cache_list = [] 57 | + for _ in range(0, num_layers): 58 | + cache_list.append([ops.zeros(shape, dtype=self.compute_dtype), ops.zeros(shape, dtype=self.compute_dtype)]) 59 | # Seed the cache. 60 | - _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) 61 | - return hidden_states, cache 62 | + _, hidden_states, cache_list = self.call_with_cache(token_ids, cache_list, 0) 63 | + return hidden_states, cache_list 64 | 65 | def generate_step( 66 | self, 67 | diff --git a/keras_nlp/models/gemma/gemma_decoder_block.py b/keras_nlp/models/gemma/gemma_decoder_block.py 68 | index 0a91655..3ae7f8a 100644 69 | --- a/keras_nlp/models/gemma/gemma_decoder_block.py 70 | +++ b/keras_nlp/models/gemma/gemma_decoder_block.py 71 | @@ -117,7 +117,7 @@ class GemmaDecoderBlock(keras.layers.Layer): 72 | batch_size = ops.shape(x)[0] 73 | input_length = output_length = ops.shape(x)[1] 74 | if cache is not None: 75 | - input_length = ops.shape(cache)[2] 76 | + input_length = ops.shape(cache[0])[1] 77 | 78 | causal_mask = compute_causal_mask( 79 | batch_size=batch_size, 80 | -------------------------------------------------------------------------------- /example/gptj/README.md: -------------------------------------------------------------------------------- 1 | # GPT-J-6B Jax Example 2 | 3 | Script jax_gptj.py for [EleutherAI/gpt-j-6B](https://huggingface.co/EleutherAI/gpt-j-6b). 4 | 5 | ## Prerequisites 6 | 7 | Mark `intel-extension-for-openxla` folder as \, then 8 | ```bash 9 | cd /example/gptj/ 10 | pip install transformers==4.49 datasets==2.20.0 11 | pip install -r ../../test/requirements.txt 12 | ``` 13 | 14 | ## Options 15 | 16 | | Option | Default Value | Description| 17 | | :-- | :--: | :--: | 18 | | *```--dtype```*| *```float16```*| Data type, support *```float16```*, *```bfloat16```*, and *```float32```*. | 19 | | *```--batch-size```*| *```1```*| Batch size | 20 | | *```--prompt```*| *```None```*| Customized prompt, not supported when *```--accuracy-only```* is on. | 21 | | *```--input-tokens```*| *```32```*| Input tokens. | 22 | | *```--max-new-tokens```*| *```32```*| Output max new tokens. | 23 | | *```--greedy```*| *```False```*| Enable greedy search or beam search. | 24 | | *```--num-iter```*| *```10```*| Number of iterations. | 25 | | *```--num-layer```*| *```28```*| Number of hidden layers. | 26 | | *```--num-warmup```*| *```3```*| Number of warmup iterations. | 27 | | *```--accuracy```*| *```False```*| Run accuracy check. | 28 | 29 | ## Example 30 | 31 | To fully utilize the hardware capabilities and achieve the best performance, you may consider setting the below ENV variables to enable our customized optimization strategies. 32 | 33 | | **ENV** | **Description** | **PVC Platform** | **ATSM/DG2 Platform** | 34 | | :---: | :---: | :---: |:---: | 35 | | ZE_AFFINITY_MASK | Run this model on single GPU tile |export ZE_AFFINITY_MASK=0 | export ZE_AFFINITY_MASK=0 | 36 | | XETLA_GEMM | Call the [XETLA](https://github.com/intel/xetla) library to run GEMMs, instead of using oneDNN.|export XETLA_GEMM=1 | NA | 37 | 38 | ### Greedy Search 39 | 40 | ```bash 41 | export ZE_AFFINITY_MASK=0 42 | python jax_gptj.py --greedy 43 | ``` 44 | 45 | ### Beam Search = 4 46 | 47 | ```bash 48 | export ZE_AFFINITY_MASK=0 49 | python jax_gptj.py --input-tokens 1024 --max-new-tokens 128 50 | ``` 51 | 52 | ### Performance Output 53 | 54 | ``` 55 | Inference latency: x.xxx sec. 56 | Inference throughput: x.xxx samples/sec. 57 | ``` 58 | 59 | ### Accuracy Output 60 | 61 | ```bash 62 | export ZE_AFFINITY_MASK=0 63 | python jax_gptj.py --input-tokens 1024 --max-new-tokens 128 --accuracy 64 | ``` 65 | 66 | ``` 67 | Inference latency: x.xxx sec. 68 | Inference throughput: x.xxx samples/sec. 69 | 70 | Accuracy = 1.00 71 | ``` 72 | 73 | ### Test with less memory 74 | 75 | Set option `--num-layer` (default value: `28`) to a small number, to reduce the memory footprint for test. 76 | ```bash 77 | export ZE_AFFINITY_MASK=0 78 | python jax_gptj.py --input-tokens 1024 --max-new-tokens 128 --accuracy --num-layer 14 79 | ``` 80 | -------------------------------------------------------------------------------- /example/grok/README.md: -------------------------------------------------------------------------------- 1 | # Quick Start for Grok Inference 2 | 3 | Loading and running the Grok-1 open-weights model by [Grok-1](https://github.com/xai-org/grok-1) 4 | 5 | The Grok-1 model running needs at least 8-tile GPU device. 6 | 7 | ## 1. Install intel-extension-for-openxla 8 | 9 | Please got the [main page](https://github.com/intel/intel-extension-for-openxla/blob/main/README.md#build-and-install), and follow the instructions to build and install intel-extension-for-openxla. 10 | 11 | ## 2. Install dependency 12 | 13 | Mark `intel-extension-for-openxla` folder as \, then 14 | ```bash 15 | cd /example/grok/ 16 | git clone https://github.com/xai-org/grok-1.git 17 | pip install -r grok-1/requirements.txt 18 | pip install -r ../../test/requirements.txt 19 | ``` 20 | 21 | ## 3. Download the weights 22 | 23 | Please follow [Downloading the weights](https://github.com/xai-org/grok-1#downloading-the-weights) to get the weights. 24 | 25 | Make sure to download the checkpoint and place the ckpt-0 directory in grok-1/checkpoints. 26 | 27 | Or you can create a soft link to it. 28 | 29 | ``` 30 | grok-1/ 31 | ├── checkpoints 32 | │   └── ckpt-0 -> /your_data_path/ckpt-0 33 | ├── ... 34 | ``` 35 | 36 | ## 4. Copy files & Run 37 | ```bash 38 | cp prompt.json inference.py grok-1 39 | python grok-1/inference.py 40 | ``` 41 | | **Parameter** | **Default Value** | 42 | | :---: | :--- | 43 | | **input-tokens** | 32 | 44 | | **max-new-tokens** | 32 | 45 | | **num-iter** | 4 | 46 | | **num-warmup** | 1 | 47 | -------------------------------------------------------------------------------- /example/grok/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import time 4 | import json 5 | import os 6 | import jax 7 | 8 | from model import LanguageModelConfig, TransformerConfig 9 | from runners import InferenceRunner, ModelRunner, sample_from_model 10 | 11 | def main(args): 12 | num_iter = args.num_iter 13 | num_warmup = args.num_warmup 14 | input_tokens = args.input_tokens 15 | max_new_tokens = args.max_new_tokens 16 | compilcation_cache = args.compilcation_cache 17 | input_len = int(input_tokens) 18 | 19 | current_path = str(os.path.dirname(__file__)) 20 | 21 | if compilcation_cache: 22 | COMPILATION_CACHE_PATH = current_path +"/compilcation_cache/" 23 | jax.config.update("jax_compilation_cache_dir", COMPILATION_CACHE_PATH) 24 | 25 | CKPT_PATH = current_path +"/checkpoints/" 26 | with open(current_path + "/prompt.json") as f: 27 | content = f.read() 28 | content_dict = json.loads(content) 29 | prompt = content_dict[input_tokens] 30 | 31 | print("initialize start", flush=True) 32 | start_time = time.time() 33 | grok_1_model = LanguageModelConfig( 34 | vocab_size=128 * 1024, 35 | pad_token=0, 36 | eos_token=2, 37 | sequence_len=8192, 38 | embedding_init_scale=1.0, 39 | output_multiplier_scale=0.5773502691896257, 40 | embedding_multiplier_scale=78.38367176906169, 41 | model=TransformerConfig( 42 | emb_size=48 * 128, 43 | widening_factor=8, 44 | key_size=128, 45 | num_q_heads=48, 46 | num_kv_heads=8, 47 | num_layers=64, 48 | attn_output_multiplier=0.08838834764831845, 49 | shard_activations=True, 50 | # MoE. 51 | num_experts=8, 52 | num_selected_experts=2, 53 | # Activation sharding. 54 | data_axis="data", 55 | model_axis="model", 56 | ), 57 | ) 58 | inference_runner = InferenceRunner( 59 | pad_sizes=(input_len,), 60 | runner=ModelRunner( 61 | model=grok_1_model, 62 | bs_per_device=0.125, 63 | checkpoint_path=CKPT_PATH, 64 | ), 65 | name="local", 66 | load=CKPT_PATH, 67 | tokenizer_path=current_path+"/tokenizer.model", 68 | local_mesh_config=(1, 8), 69 | between_hosts_config=(1, 1), 70 | ) 71 | inference_runner.initialize() 72 | gen = inference_runner.run() 73 | end_time = time.time() 74 | print("initialize time: {:.2f} s".format(end_time - start_time), flush=True) 75 | 76 | step = num_iter + num_warmup 77 | all_time = 0.0 78 | for i in range(step): 79 | print(f"inference start: {i}", flush=True) 80 | s_time = time.time() 81 | print(f"Output_{max_new_tokens} for prompt_{input_tokens}:", sample_from_model(gen, prompt, max_len=max_new_tokens, temperature=0.01), flush=True) 82 | e_time = time.time() 83 | print("inference time: {:.2f} s".format(e_time - s_time), flush=True) 84 | if(i >= num_warmup): 85 | all_time += (e_time - s_time) 86 | print("averange inference time: {:.2f} s except warmup steps.".format(all_time / float(num_iter)), flush=True) 87 | 88 | 89 | if __name__ == "__main__": 90 | logging.basicConfig(level=logging.INFO) 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument("--num-iter", default=4, type=int, help="num iter") 93 | parser.add_argument("--num-warmup", default=1, type=int, help="num warmup") 94 | parser.add_argument("--input-tokens",default="32",choices=["32", "64", "128", "256", "512", "1024", "2016", "2017", "2048", "4096", "8192"],type=str,help="input tokens length if needed from prompt.json") 95 | parser.add_argument("--max-new-tokens", default=32, type=int, help="output max new tokens") 96 | parser.add_argument("--compilcation-cache", default=False, type=bool, help="compilcation cache") 97 | args = parser.parse_args() 98 | main(args) 99 | -------------------------------------------------------------------------------- /example/qkv_fusion/test_qkv_fusion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Intel Corporation 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import jax 17 | from jax import lax 18 | from jax import random 19 | import jax.numpy as jnp 20 | import jax._src.test_util as jtu 21 | 22 | import numpy as np 23 | 24 | def seperateQKVGEMM(input, weight_q, weight_k, weight_v): 25 | out_q = jax.numpy.matmul(input, weight_q) 26 | out_k = jax.numpy.matmul(input, weight_k) 27 | out_v = jax.numpy.matmul(input, weight_v) 28 | return out_q, out_k, out_v 29 | 30 | @jax.jit 31 | def fusedQKVGEMM(input, weight_q, weight_k, weight_v): 32 | out_q = jax.numpy.matmul(input, weight_q) 33 | out_k = jax.numpy.matmul(input, weight_k) 34 | out_v = jax.numpy.matmul(input, weight_v) 35 | return out_q, out_k, out_v 36 | 37 | def testQKVFusion(): 38 | # Inputs 39 | m = 4 40 | k = 4096 41 | n = 4096 42 | key = jax.random.PRNGKey(1701) 43 | input = jax.random.uniform(key, (4, 4096)).astype(jnp.float16) 44 | weight_q = jax.random.uniform(key, (k, k)).astype(jnp.float16) # weights for Q 45 | weight_k = jax.random.uniform(key, (k, k)).astype(jnp.float16) # weights for K 46 | weight_v = jax.random.uniform(key, (k, k)).astype(jnp.float16) # weights for V 47 | cpu_q, cpu_k, cpu_v = seperateQKVGEMM(input, weight_q, weight_k, weight_v) 48 | xpu_q, xpu_k, xpu_v = fusedQKVGEMM(input, weight_q, weight_k, weight_v) 49 | print(np.allclose(xpu_q, cpu_q, atol=1e-3, rtol=1e-3)) 50 | print(np.allclose(xpu_k, cpu_k, atol=1e-3, rtol=1e-3)) 51 | print(np.allclose(xpu_v, cpu_v, atol=1e-3, rtol=1e-3)) 52 | 53 | if __name__ == '__main__': 54 | testQKVFusion() -------------------------------------------------------------------------------- /example/resnet50/README.md: -------------------------------------------------------------------------------- 1 | # Quick Start for ResNet50 Training 2 | 3 | Trains ResNet50 model (He et al., 2016) for the ImageNet classification task (Russakovsky et al., 2015) by [FLAX RN50 example](https://github.com/google/flax/tree/main/examples/imagenet) 4 | 5 | ## Requirements 6 | 7 | ### 1. Install intel-extension-for-openxla 8 | 9 | Please got the [main page](https://github.com/intel/intel-extension-for-openxla/blob/main/README.md#build-and-install), and follow the instructions to build and install intel-extension-for-openxla. 10 | 11 | ### 2. Install dependency 12 | 13 | Mark `intel-extension-for-openxla` folder as \, then 14 | ```bash 15 | cd /example/resnet50/ 16 | git clone --branch=main https://github.com/google/flax 17 | cd flax 18 | git checkout ba9e24a7b697e6407465cb4b05a24a4cea152248 19 | pip install -r examples/imagenet/requirements.txt 20 | cd .. 21 | pip install -r ../../test/requirements.txt 22 | ``` 23 | 24 | ### 3. Download dataset 25 | 26 | Please follow [Preparing the dataset](https://github.com/google/flax/tree/main/examples/imagenet#preparing-the-dataset) to get imagenet dataset. 27 | 28 | ## Run 29 | 30 | ### Set environment 31 | ```bash 32 | export PYTHONPATH=${path_to_flax} 33 | ``` 34 | 35 | ### Running command 36 | ```bash 37 | python main.py --workdir=./imagenet --config=configs/default.py 38 | ``` 39 | `config.batch_size` is global batchsize for all devices you selected. 40 | 41 | ### Select devices 42 | All GPU devices in same node will be used by default. If you only want some of devices, please use environmental variable `ZE_AFFINITY_MASK` to select. 43 | 44 | | **ENV** | **Description** | **PVC Platform** | 45 | | :---: | :---: | :---: | 46 | | ZE_AFFINITY_MASK | Run this model on single GPU device |export ZE_AFFINITY_MASK as your selected device list, like 0,1,2,3| 47 | -------------------------------------------------------------------------------- /example/sdxl/README.md: -------------------------------------------------------------------------------- 1 | # Quick Start for Stable Diffusion XL Inference 2 | 3 | [Stable Diffusion XL](https://arxiv.org/abs/2307.01952)(SDXL) is a powerful text-to-image generation model that iterates on the previous Stable Diffusion models in three key ways: 4 | 5 | 1. the UNet is 3x larger and SDXL combines a second text encoder (OpenCLIP ViT-bigG/14) with the original text encoder to significantly increase the number of parameters 6 | 2. introduces size and crop-conditioning to preserve training data from being discarded and gain more control over how a generated image should be cropped 7 | 3. introduces a two-stage model process; the base model (can also be run as a standalone model) generates an image as an input to the refiner model which adds additional high-quality details 8 | 9 | More details could be found in [Stability-AI/generative-models](https://github.com/Stability-AI/generative-models) 10 | 11 | ## Requirements 12 | 13 | ### 1. Install intel-extension-for-openxla 14 | 15 | please got the [main page](https://github.com/intel/intel-extension-for-openxla/blob/main/README.md#build-and-install), and follow the instructions to build and install intel-extension-for-openxla. 16 | 17 | ### 2. Install packages 18 | 19 | Mark `intel-extension-for-openxla` folder as \, then 20 | ```bash 21 | cd /example/sdxl/ 22 | pip install transformers==4.49 diffusers==0.31.0 datasets==2.20.0 msgpack==1.1.0 23 | pip install -r ../../test/requirements.txt 24 | ``` 25 | 26 | ## Run 27 | 28 | ### 1. Environmental Variables 29 | 30 | | **ENV** | **Description** | **PVC Platform** | **ATSM/DG2 Platform** | 31 | | :---: | :---: | :---: |:---: | 32 | | ZE_AFFINITY_MASK | Run this model on single GPU tile |export ZE_AFFINITY_MASK=0 | export ZE_AFFINITY_MASK=0 | 33 | | XLA_FLAGS | Customize xla debug options | export XLA_FLAGS="--xla_gpu_force_conv_nhwc --xla_disable_hlo_passes=dot-merger" | export XLA_FLAGS="--xla_gpu_force_conv_nhwc --xla_disable_hlo_passes=dot-merger" | 34 | 35 | ### 2. Options 36 | 37 | ``` 38 | --dtype: Support bfloat16 and float16, default is bfloat16. 39 | --num-iter: The number of times to run generation, default is 1. 40 | --num-inference-steps: The inference steps for each generated image, default is 25. 41 | --accuracy: Check whether the demo result is expected. Output range is `0`~`1`, higher is better. 42 | ``` 43 | 44 | ### 3. Inference Command Example 45 | 46 | ```shell 47 | python inference.py --dtype=bfloat16 --accuracy 48 | ``` 49 | 50 | ## Expected Output 51 | 52 | ### Performance 53 | ``` 54 | Average Latency per image is: x.xxx s 55 | Average Throughput per second is: x.xxx steps 56 | ``` 57 | 58 | ### Accuracy 59 | ``` 60 | RMSE accuracy is: 0.979 61 | ``` 62 | -------------------------------------------------------------------------------- /example/sdxl/target.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intel/intel-extension-for-openxla/68b2112b3466d0fcb111369799246d39281c7452/example/sdxl/target.png -------------------------------------------------------------------------------- /example/stable_diffusion/README.md: -------------------------------------------------------------------------------- 1 | # Quick Start for Stable Diffusion Inference 2 | 3 | [Stable Diffusion](https://arxiv.org/abs/2112.10752) is a latent text-to-image diffusion model. More details could be found in [Stable Diffusion v1](https://github.com/CompVis/stable-diffusion) and [Stable Diffusion v2](https://github.com/Stability-AI/stablediffusion). 4 | 5 | ## Requirements 6 | 7 | ### 1. Install intel-extension-for-openxla 8 | 9 | please got the [main page](https://github.com/intel/intel-extension-for-openxla/blob/main/README.md#build-and-install), and follow the instructions to build and install intel-extension-for-openxla. 10 | 11 | ### 2. Install packages 12 | 13 | Mark `intel-extension-for-openxla` folder as \, then 14 | ```bash 15 | cd /example/stable_diffusion/ 16 | pip install transformers==4.49 diffusers==0.31.0 datasets==2.20.0 msgpack==1.1.0 17 | pip install -r ../../test/requirements.txt 18 | ``` 19 | 20 | ## Run 21 | 22 | ### 1. Environmental Variables 23 | 24 | | **ENV** | **Description** | **PVC Platform** | **ATSM/DG2 Platform** | 25 | | :---: | :---: | :---: |:---: | 26 | | ZE_AFFINITY_MASK | Run this model on single GPU tile |export ZE_AFFINITY_MASK=0 | export ZE_AFFINITY_MASK=0 | 27 | | XLA_FLAGS | Customize xla debug options | export XLA_FLAGS="--xla_gpu_force_conv_nhwc" | export XLA_FLAGS="--xla_gpu_force_conv_nhwc" | 28 | 29 | ### 2. Inference Command 30 | 31 | | **Model** | **Output Image Resolution** | **Command** | 32 | | :---: | :---: | :---: | 33 | | [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) | 512x512 | ```python jax_stable.py``` | 34 | | [stabilityai/stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) | 768x768 | ```python jax_stable.py -m stabilityai/stable-diffusion-2``` | 35 | | [stabilityai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) | 768x768 | ```python jax_stable.py -m stabilityai/stable-diffusion-2-1``` | 36 | 37 | Add option `--accuracy` to check whether the demo result is expected. Output range is `0`~`1`, higher is better: 38 | ```shell 39 | python jax_stable.py -m stabilityai/stable-diffusion-2-1 --accuracy 40 | ``` 41 | 42 | ## Expected Output 43 | 44 | ### Performance 45 | ``` 46 | Average Latency per image is: x.xxx s 47 | Average Throughput per second is: x.xxx steps 48 | ``` 49 | 50 | ### Accuracy 51 | ``` 52 | RMSE accuracy is: 0.976 53 | ``` 54 | -------------------------------------------------------------------------------- /example/stable_diffusion/target.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intel/intel-extension-for-openxla/68b2112b3466d0fcb111369799246d39281c7452/example/stable_diffusion/target.png -------------------------------------------------------------------------------- /example/t5/README.md: -------------------------------------------------------------------------------- 1 | # FLAN-T5 Quick Start 2 | 3 | ## Checkpoint 4 | ### Download Checkpoint from URL 5 | #### 1. Chose a pre-trained model 6 | Click the URL attached on checkpoint localtion, and download your picked model manaully. For more T5-like model checkpoints, please got to the official [T5 website](https://t5x.readthedocs.io/en/latest/models.html#public-research-models). 7 | Model | Gin File Location | Checkpoint Location 8 | -------------------- | ------------------------------------------------------------------------------------------------------------------- | ------------------- 9 | Flan-T5 XL | [t5_1_1_xl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/xl.gin) | [gs://t5-data/pretrained_models/t5x/flan_t5_xl/checkpoint_1138000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/flan_t5_xl/checkpoint_1138000) 10 | Flan-T5 XXL | [t5_1_1_xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/xxl.gin) | [gs://t5-data/pretrained_models/t5x/flan_t5_xxl/checkpoint_1114000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/flan_t5_xxl/checkpoint_1114000) 11 | 12 | ##### How to download the above checkpoint 13 | please follow the officical [guidelines](https://cloud.google.com/storage/docs/gsutil_install#deb) to install gsutil firsty. 14 | 15 | ``` 16 | gsutil -m cp -r "gs://t5-data/pretrained_models/t5x/flan_t5_xl/checkpoint_1138000" . 17 | ``` 18 | 19 | #### 2. Download Vocabulary 20 | T5 Vocabulary: [cc_all.32000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/cc_all.32000.100extra) 21 | 22 | For more T5-like models, please go to the official [T5 website](https://t5x.readthedocs.io/en/latest/models.html#t5-1-1-checkpoints), and download its corresponding vocabulary manually. 23 | 24 | ##### How to download the above vocalbulary 25 | ``` 26 | gsutil -m cp -r "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" . 27 | ``` 28 | #### 3.Download Dataset(optional) 29 | 30 | We use The Pile for our pretraining experiments. If you would like to as well, run `download_the_pile.py` scipt in `t5x` folder to download it. The download is approximately 1TB. 31 | 32 | For benchmarking, you could skip this step because our model script will download a part of dataset automatically. 33 | 34 | ## Installation 35 | 36 | Mark `intel-extension-for-openxla` folder as \, then 37 | ``` 38 | cd /example/t5/ 39 | git clone https://github.com/google-research/t5x.git 40 | bash install_xpu.sh 41 | pip install --upgrade intel-extension-for-openxla 42 | pip install -r ../../test/requirements.txt 43 | ``` 44 | 45 | ## Inference 46 | 47 | To fully utilize the hardware capabilities and achieve the best performance, you may consider setting the below ENV variables to enable our customized optimization strategies. 48 | 49 | | **ENV** | **Description** | **PVC Platform** | **ATSM/DG2 Platform** | 50 | | :---: | :---: | :---: |:---: | 51 | | ZE_AFFINITY_MASK | Run this model on single GPU tile |export ZE_AFFINITY_MASK=0 | export ZE_AFFINITY_MASK=0 | 52 | | XETLA_GEMM | Call the [XETLA](https://github.com/intel/xetla) library to run GEMMs, instead of using oneDNN.|export XETLA_GEMM=1 | NA | 53 | | LLM | Enable our customized optimization strategies for large language models (LLM) |export LLM=1 | export LLM=1 | 54 | | XLA_FLAGS | Customize xla debug options | export XLA_FLAGS="--xla_disable_hlo_passes=dot-merger" | export XLA_FLAGS="--xla_disable_hlo_passes=dot-merger" | 55 | 56 | ### Command Description 57 | ``` 58 | bash quick_start.sh [model size] [dataset dir] [model dir] [input_length] [output_length] [device type] 59 | ``` 60 | 61 | #### FLAN-T5-XL Example 62 | 63 | #### 32/32 64 | ``` 65 | bash quick_start.sh xl /username/datasets/ThePile /username/t5x_models 32 32 XPU 66 | ``` 67 | #### 1024/128 68 | ``` 69 | bash quick_start.sh xl /username/datasets/ThePile /username/t5x_models 1024 128 XPU 70 | ``` 71 | #### Performance Output 72 | ``` 73 | avg time:xxx s,avg throughput:xxx sentences/sencond 74 | ``` 75 | -------------------------------------------------------------------------------- /example/t5/install_xpu.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | pushd ./t5x 4 | 5 | git checkout 6699ad54480a0691c491fa2aa28d8b46daf1a377 6 | git apply ../patch/t5.patch 7 | 8 | ln -s /usr/local/bin/pip /usr/bin/pip 9 | pip uninstall tensorflow-metadata numba cudf -y 10 | 11 | conda install libstdcxx-ng -c conda-forge -y 12 | 13 | pip uninstall mdit-py-plugins jupytext -y 14 | pip install t5 15 | pip install -e . 16 | 17 | pip install orbax-checkpoint==0.3.2 18 | pip install zstandard==0.21.0 19 | pip install jsonlines==3.1.0 20 | 21 | pip uninstall tensorflow tensorflow-cpu tensorflow-text -y 22 | #TensorFlow Text will auto-install the right TensorFlow version as its dependency 23 | pip install tensorflow-text==2.18.1 24 | 25 | popd 26 | -------------------------------------------------------------------------------- /example/t5/quick_start.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # A script for single-node pile pretraining 3 | set -x 4 | 5 | if [ ! -n "$1" ];then 6 | echo "Error: the format is ./quick_start.sh [model size] [dataset dir] [model dir] [input_length] [output_length] [device type] [is_profile] [profile out dir]!" 7 | exit 1 8 | fi 9 | 10 | pushd ./t5x 11 | 12 | mkdir logs 13 | mkdir output 14 | 15 | export T5X_DIR=`pwd` 16 | export T5X_WORKSPACE_DIR=${T5X_DIR}/workspace 17 | LOG_DIR=${T5X_DIR}/logs 18 | export PYTHONPATH=${T5X_DIR} 19 | export TFDS_DATA_DIR=$2 20 | MODEL_SIZE=$1 21 | MODEL_DIR=$3 22 | INPUT_LENGTH=$4 23 | OUTPUT_LENGTH=$5 24 | export DEVICE_TYPE="XPU" # XPU or CUDA 25 | if [ -n "$6" ];then 26 | DEVICE_TYPE=$6 27 | fi 28 | export IS_PROFILE=$7 29 | export PROFILE_DIR=$8 30 | if [ ! -n "$4" ];then 31 | INPUT_LENGTH=32 32 | fi 33 | if [ ! -n "$5" ];then 34 | OUTPUT_LENGTH=32 35 | fi 36 | if [ ! -n "$8" ];then 37 | mkdir -p ${PROFILE_DIR} 38 | fi 39 | 40 | export PTI_ENABLE_COLLECTION=0 41 | 42 | 43 | sed -i 's/"inputs": .*, "targets": .*}/"inputs": '${INPUT_LENGTH}', "targets": '${OUTPUT_LENGTH}'}/g' ${T5X_DIR}/../xl_infer.gin 44 | sed -i 's/"inputs": .*, "targets": .*}/"inputs": '${INPUT_LENGTH}', "targets": '${OUTPUT_LENGTH}'}/g' ${T5X_DIR}/../xxl_infer.gin 45 | 46 | 47 | # Arguments 48 | PREC="bfloat16" # Precision (float32, float16, bfloat16) 49 | 50 | 51 | NUM_GPUS=1 # Number of GPUs (1, 2, 4, 8) 52 | BSIZE_PER_GPU=1 # Batch size per GPU (varies with model size) 53 | T5_NAME=flan-t5-$MODEL_SIZE 54 | GIN_FILE="${T5X_DIR}/../xl_infer.gin" 55 | REF_FILE="${T5X_DIR}/../reference.json" 56 | MODEL_PATH=${MODEL_DIR}/checkpoint_1138000 57 | 58 | if [ ${MODEL_SIZE} == "xxl" ];then 59 | GIN_FILE="${T5X_DIR}/../xxl_infer.gin" 60 | MODEL_PATH=${MODEL_DIR}/checkpoint_1114000 61 | fi 62 | 63 | echo $MODEL_PATH 64 | 65 | echo "Please make sure ${NUM_GPUS} is the number of visible CUDA devices you have" 66 | 67 | # Setting XLA flags 68 | export XLA_FLAGS="--xla_allow_excess_precision --xla_gpu_all_reduce_combine_threshold_bytes=136314880 ${XLA_FLAGS}" 69 | 70 | 71 | PREFIX="" 72 | if [ -n "${IS_PROFILE}" ];then 73 | echo " MODE: Profile" 74 | export XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_hlo_as_html ${XLA_FLAGS}" 75 | export XLA_FLAGS="--xla_dump_to=${PROFILE_DIR} ${XLA_FLAGS}" 76 | #export XLA_FLAGS="--xla_dump_hlo_pass_re=attention ${XLA_FLAGS}" 77 | if [ ${DEVICE_TYPE} == "XPU" ];then 78 | PREFIX='unitrace --conditional-collection --chrome-device-logging --demangle --output-dir-path '${PROFILE_DIR} 79 | echo ${PREFIX} 80 | #exit 81 | fi 82 | fi 83 | 84 | # Global batch size 85 | BSIZE=$(( NUM_GPUS * BSIZE_PER_GPU )) 86 | 87 | ${PREFIX} \ 88 | python3 -u ${T5X_DIR}/t5x/infer.py \ 89 | --gin_file=${GIN_FILE} \ 90 | --gin.CHECKPOINT_PATH=\"${MODEL_PATH}\" \ 91 | --gin.network.T5Config.dtype=\"${PREC}\" \ 92 | --gin.utils.DatasetConfig.batch_size=${BSIZE} \ 93 | --gin.INFER_OUTPUT_DIR=\"${T5X_DIR}/output\" \ 94 | --gin.seqio.SentencePieceVocabulary.sentencepiece_model_file=\"${MODEL_DIR}/sentencepiece.model\" \ 95 | --gin.REFERENCE_FILE=\"${REF_FILE}\" \ 96 | |& tee ${LOG_DIR}/${T5_NAME}_gpu_${NUM_GPUS}_${PREC}_gbs_${BSIZE}.log 97 | 98 | popd 99 | -------------------------------------------------------------------------------- /example/t5/reference.json: -------------------------------------------------------------------------------- 1 | { 2 | "xl": { 3 | "32" : { 4 | "1" : [24348], 5 | "32" : [24348, 9713, 49, 6, 8127, 12704, 9713, 49, 30045, 6, 14546, 10, 6 | 96, 8639, 391, 1427, 15, 3207, 6, 67, 9713, 49, 46, 19800, 7 | 6, 745, 289, 25931, 5, 1, 0, 0] 8 | }, 9 | "1024" : { 10 | "1": [24348], 11 | "128" : [24348, 9713, 49, 6, 8127, 12704, 9713, 49, 30045, 6, 14546, 10, 12 | 96, 8639, 391, 1427, 15, 3207, 6, 67, 9713, 49, 46, 19800, 13 | 6, 745, 289, 25931, 5, 1, 0, 0, 0, 0, 0, 0, 14 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 19 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 20 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21 | 0, 0, 0, 0, 0, 0, 0, 0] 22 | } 23 | }, 24 | "xxl": { 25 | "1024" : { 26 | "1": [24348], 27 | "128" : [24348, 9713, 49, 6, 8127, 12704, 193, 9713, 49, 30045, 6, 3, 28 | 17473, 10, 96, 8639, 3207, 6, 67, 9713, 49, 30045, 46, 16004, 29 | 6, 745, 289, 22072, 5, 1, 0, 0, 0, 0, 0, 0, 30 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 31 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 33 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 34 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 36 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 37 | 0, 0, 0, 0, 0, 0, 0, 0] 38 | } 39 | } 40 | } -------------------------------------------------------------------------------- /example/t5/xl_infer.gin: -------------------------------------------------------------------------------- 1 | include 't5x/contrib/gpu/t5/t5_1_1/xl.gin' 2 | include 't5x/contrib/gpu/t5/configs/runs/infer.gin' 3 | 4 | 5 | # Register necessary SeqIO Tasks/Mixtures 6 | 7 | DROPOUT_RATE = 0.0 8 | #BATCH_SIZE = 8 9 | 10 | import t5.data.mixtures 11 | import t5x.contrib.gpu.scripts_gpu.seqio_tasks 12 | MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003" 13 | TASK_FEATURE_LENGTHS = {"inputs": 1024, "targets": 128} 14 | 15 | 16 | partitioning.PjitPartitioner: 17 | model_parallel_submesh = (1, 1, 1, 1) 18 | 19 | #network.T5Config: 20 | # num_encoder_layers = 2 21 | # num_decoder_layers = 2 22 | -------------------------------------------------------------------------------- /example/t5/xxl_infer.gin: -------------------------------------------------------------------------------- 1 | include 't5x/contrib/gpu/t5/t5_1_1/xxl.gin' 2 | #include 't5x/contrib/gpu/t5/configs/runs/infer_from_tfexample_file.gin' 3 | include 't5x/contrib/gpu/t5/configs/runs/infer.gin' 4 | 5 | 6 | # Register necessary SeqIO Tasks/Mixtures 7 | 8 | DROPOUT_RATE = 0.0 9 | #BATCH_SIZE = 8 10 | 11 | import t5.data.mixtures 12 | import t5x.contrib.gpu.scripts_gpu.seqio_tasks 13 | #MIXTURE_OR_TASK_NAME = "the_pile_span_corruption" 14 | MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003" 15 | #TASK_FEATURE_LENGTHS = {"inputs": 1024, "targets": 1} 16 | TASK_FEATURE_LENGTHS = {"inputs": 1024, "targets": 1} 17 | seqio.SentencePieceVocabulary.sentencepiece_model_file = "sentencepiece.model" 18 | 19 | #FEATURE_LENGTHS = {"inputs": 1024, "targets": 1} 20 | #TF_EXAMPLE_FILE_TYPE='tfrecord' 21 | #TF_EXAMPLE_FILE_PATHS=['/dataset/*train.tfrecord*'] 22 | #create_task_from_tfexample_file.inputs_key='question' 23 | 24 | 25 | partitioning.PjitPartitioner: 26 | model_parallel_submesh = (1, 1, 1, 1) 27 | 28 | #network.T5Config: 29 | # num_encoder_layers = 1 30 | # num_decoder_layers = 1 31 | -------------------------------------------------------------------------------- /openxla_for_intel_gpu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intel/intel-extension-for-openxla/68b2112b3466d0fcb111369799246d39281c7452/openxla_for_intel_gpu.png -------------------------------------------------------------------------------- /security.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | Intel is committed to rapidly addressing security vulnerabilities affecting our customers and providing clear guidance on the solution, impact, severity and mitigation. 3 | 4 | ## Reporting a Vulnerability 5 | Please report any security vulnerabilities in this project utilizing the guidelines [here](https://www.intel.com/content/www/us/en/security-center/vulnerability-handling-guidelines.html). 6 | -------------------------------------------------------------------------------- /test/BRANCH_NAME: -------------------------------------------------------------------------------- 1 | jax-v0.4.38 2 | -------------------------------------------------------------------------------- /test/requirements.txt: -------------------------------------------------------------------------------- 1 | jax==0.4.38 2 | jaxlib==0.4.38 3 | flax==0.10.0 4 | optax==0.2.4 5 | -------------------------------------------------------------------------------- /third_party/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | -------------------------------------------------------------------------------- /third_party/common.bzl: -------------------------------------------------------------------------------- 1 | # Rule for simple expansion of template files. This performs a simple 2 | # search over the template file for the keys in substitutions, 3 | # and replaces them with the corresponding values. 4 | # 5 | # Typical usage: 6 | # load("/tools/build_rules/template_rule", "expand_header_template") 7 | # template_rule( 8 | # name = "ExpandMyTemplate", 9 | # src = "my.template", 10 | # out = "my.txt", 11 | # substitutions = { 12 | # "$VAR1": "foo", 13 | # "$VAR2": "bar", 14 | # } 15 | # ) 16 | # 17 | # Args: 18 | # name: The name of the rule. 19 | # template: The template file to expand 20 | # out: The destination of the expanded file 21 | # substitutions: A dictionary mapping strings to their substitutions 22 | 23 | def template_rule_impl(ctx): 24 | ctx.actions.expand_template( 25 | template = ctx.file.src, 26 | output = ctx.outputs.out, 27 | substitutions = ctx.attr.substitutions, 28 | ) 29 | 30 | template_rule = rule( 31 | attrs = { 32 | "src": attr.label( 33 | mandatory = True, 34 | allow_single_file = True, 35 | ), 36 | "substitutions": attr.string_dict(mandatory = True), 37 | "out": attr.output(mandatory = True), 38 | }, 39 | # output_to_genfiles is required for header files. 40 | #output_to_genfiles = True, 41 | implementation = template_rule_impl, 42 | ) 43 | -------------------------------------------------------------------------------- /third_party/gpus/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | -------------------------------------------------------------------------------- /third_party/gpus/crosstool/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | -------------------------------------------------------------------------------- /third_party/gpus/crosstool/BUILD.sycl.tpl: -------------------------------------------------------------------------------- 1 | # This file is expanded from a template sycl_configure.bzl 2 | # Update sycl_configure.bzl#verify_build_defines when adding new variables. 3 | 4 | load(":cc_toolchain_config.bzl", "cc_toolchain_config") 5 | 6 | licenses(["restricted"]) 7 | 8 | package(default_visibility = ["//visibility:public"]) 9 | 10 | toolchain( 11 | name = "toolchain-linux-x86_64", 12 | exec_compatible_with = [ 13 | "@bazel_tools//platforms:linux", 14 | "@bazel_tools//platforms:x86_64", 15 | ], 16 | target_compatible_with = [ 17 | "@bazel_tools//platforms:linux", 18 | "@bazel_tools//platforms:x86_64", 19 | ], 20 | toolchain = ":cc-compiler-local", 21 | toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", 22 | ) 23 | 24 | cc_toolchain_suite( 25 | name = "toolchain", 26 | toolchains = { 27 | "local|compiler": ":cc-compiler-local", 28 | "k8": ":cc-compiler-local", 29 | }, 30 | ) 31 | 32 | cc_toolchain( 33 | name = "cc-compiler-local", 34 | all_files = ":crosstool_wrapper_driver", 35 | compiler_files = ":crosstool_wrapper_driver", 36 | ar_files = ":crosstool_wrapper_driver", 37 | as_files = ":crosstool_wrapper_driver", 38 | dwp_files = ":empty", 39 | linker_files = ":crosstool_wrapper_driver", 40 | objcopy_files = ":empty", 41 | strip_files = ":empty", 42 | # To support linker flags that need to go to the start of command line 43 | # we need the toolchain to support parameter files. Parameter files are 44 | # last on the command line and contain all shared libraries to link, so all 45 | # regular options will be left of them. 46 | supports_param_files = 1, 47 | toolchain_identifier = "local_linux", 48 | toolchain_config = ":cc-compiler-local-config", 49 | ) 50 | 51 | cc_toolchain_config( 52 | name = "cc-compiler-local-config", 53 | cpu = "local", 54 | builtin_include_directories = [%{cxx_builtin_include_directories}], 55 | # extra_no_canonical_prefixes_flags = [%{extra_no_canonical_prefixes_flags}], 56 | host_compiler_path = "%{host_compiler_path}", 57 | host_compiler_prefix = "%{host_compiler_prefix}", 58 | host_unfiltered_compile_flags = [%{unfiltered_compile_flags}], 59 | linker_bin_path = "%{linker_bin_path}", 60 | compiler = "unknown", 61 | ar_path = "%{ar_path}", 62 | ) 63 | 64 | filegroup( 65 | name = "empty", 66 | srcs = [], 67 | ) 68 | 69 | filegroup( 70 | name = "crosstool_wrapper_driver", 71 | srcs = ["clang/bin/crosstool_wrapper_driver"] 72 | ) 73 | -------------------------------------------------------------------------------- /third_party/gpus/crosstool/clang/bin/BUILD: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intel/intel-extension-for-openxla/68b2112b3466d0fcb111369799246d39281c7452/third_party/gpus/crosstool/clang/bin/BUILD -------------------------------------------------------------------------------- /third_party/gpus/crosstool/clang/bin/ar_driver.tpl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Create a temporary directory 4 | tmp_dir="$(mktemp -d -t tmp.XXXXXXXXXX)" 5 | input_object_file="" 6 | ar_flag="" 7 | output_file="" 8 | 9 | if [[ $# -eq 1 ]]; then 10 | arg="$1" 11 | shift 12 | if [[ $arg == "@"* ]]; then 13 | file_name=${arg#*@} 14 | 15 | { 16 | read -r ar_flag 17 | read -r output_file 18 | while IFS= read -r input_file; do 19 | if file "$input_file" | grep -q "current ar archive"; then 20 | ar x "$input_file" --output="$tmp_dir" 21 | else 22 | input_object_file="$input_object_file $input_file" 23 | fi 24 | done 25 | } < "$file_name" 26 | else 27 | echo "invalid argument" 28 | exit 1 29 | fi 30 | else 31 | ar_flag="$1" 32 | shift 33 | output_file="$1" 34 | shift 35 | 36 | for input_file in "$@"; do 37 | if file "$input_file" | grep -q "current ar archive"; then 38 | ar x "$input_file" --output="$tmp_dir" 39 | else 40 | input_object_file="$input_object_file $input_file" 41 | fi 42 | done 43 | fi 44 | 45 | if [[ $input_object_file != "" ]]; then 46 | ar "$ar_flag" "$output_file" $input_object_file 47 | else 48 | ar "$ar_flag" "$output_file" "$tmp_dir"/* 49 | fi 50 | 51 | # Remove the temporary directory 52 | rm -rf "$tmp_dir" 53 | -------------------------------------------------------------------------------- /third_party/gpus/sycl/BUILD: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intel/intel-extension-for-openxla/68b2112b3466d0fcb111369799246d39281c7452/third_party/gpus/sycl/BUILD -------------------------------------------------------------------------------- /third_party/gpus/sycl/BUILD.tpl: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | 3 | config_setting( 4 | name = "using_sycl", 5 | values = { 6 | "define": "using_sycl=true", 7 | }, 8 | ) 9 | 10 | cc_library( 11 | name = "sycl_headers", 12 | hdrs = [ 13 | %{sycl_headers} 14 | ], 15 | includes = [ 16 | ".", 17 | "sycl/include", 18 | "sycl/include/sycl", 19 | ], 20 | visibility = ["//visibility:public"], 21 | ) 22 | 23 | cc_library( 24 | name = "mkl", 25 | srcs = [ 26 | "sycl/lib/%{mkl_intel_ilp64_lib}", 27 | "sycl/lib/%{mkl_sequential_lib}", 28 | "sycl/lib/%{mkl_core_lib}", 29 | %{mkl_sycl_libs} 30 | ], 31 | data = [ 32 | "sycl/lib/%{mkl_intel_ilp64_lib}", 33 | "sycl/lib/%{mkl_sequential_lib}", 34 | "sycl/lib/%{mkl_core_lib}", 35 | %{mkl_sycl_libs} 36 | ], 37 | includes = [ 38 | ".", 39 | "sycl/include", 40 | ], 41 | linkopts = ["-Wl,-Bstatic,-lsvml,-lirng,-limf,-lirc,-lirc_s,-Bdynamic"], 42 | linkstatic = 1, 43 | visibility = ["//visibility:public"], 44 | ) 45 | 46 | cc_library( 47 | name = "level_zero", 48 | srcs = [ 49 | %{level_zero_libs} 50 | ], 51 | data = [ 52 | %{level_zero_libs} 53 | ], 54 | hdrs = [ 55 | %{level_zero_headers} 56 | ], 57 | includes = [ 58 | ".", 59 | "level_zero/include", 60 | ], 61 | linkstatic = 1, 62 | visibility = ["//visibility:public"], 63 | ) 64 | 65 | %{copy_rules} -------------------------------------------------------------------------------- /third_party/gpus/sycl/build_defs.bzl.tpl: -------------------------------------------------------------------------------- 1 | # Macros for building SYCL code. 2 | def if_sycl(if_true, if_false = []): 3 | """Shorthand for select()'ing on whether we're building with SYCL. 4 | 5 | Returns a select statement which evaluates to if_true if we're building 6 | with SYCL enabled. Otherwise, the select statement evaluates to if_false. 7 | 8 | """ 9 | return select({ 10 | "@local_config_sycl//sycl:using_sycl": if_true, 11 | "//conditions:default": if_false, 12 | }) 13 | 14 | def sycl_default_copts(): 15 | """Default options for all SYCL compilations.""" 16 | return if_sycl(["-x", "sycl"]) 17 | 18 | def sycl_build_is_configured(): 19 | """Returns true if SYCL compiler was enabled during the configure process.""" 20 | return %{sycl_build_is_configured} 21 | 22 | def if_sycl_is_configured(x, no_sycl = []): 23 | """Tests if the SYCL was enabled during the configure process. 24 | 25 | Unlike if_sycl(), this does not require that we are building with 26 | --config=sycl. Used to allow non-SYCL code to depend on SYCL libraries. 27 | """ 28 | if %{sycl_is_configured}: 29 | return select({"//conditions:default": x}) 30 | return select({"//conditions:default": []}) 31 | 32 | def if_sycl_build_is_configured(x, y): 33 | if sycl_build_is_configured(): 34 | return x 35 | return y 36 | 37 | def sycl_library(copts = [], **kwargs): 38 | """Wrapper over cc_library which adds default SYCL options.""" 39 | native.cc_library(copts = sycl_default_copts() + copts, **kwargs) -------------------------------------------------------------------------------- /third_party/llvm_spir/BUILD: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intel/intel-extension-for-openxla/68b2112b3466d0fcb111369799246d39281c7452/third_party/llvm_spir/BUILD -------------------------------------------------------------------------------- /third_party/llvm_spir/llvm_spir.BUILD: -------------------------------------------------------------------------------- 1 | exports_files(["LICENSE"]) 2 | 3 | cc_library( 4 | name = "llvm_spir_translator", 5 | srcs = glob([ 6 | "lib/SPIRV/libSPIRV/*.cpp", 7 | "lib/SPIRV/libSPIRV/*.hpp", 8 | "lib/SPIRV/libSPIRV/*.h", 9 | "lib/SPIRV/Mangler/*.cpp", 10 | "lib/SPIRV/Mangler/*.h", 11 | "lib/SPIRV/*.cpp", 12 | "lib/SPIRV/*.hpp", 13 | "lib/SPIRV/*.h", 14 | ]), 15 | hdrs = glob(["include/*"]), 16 | includes = [ 17 | "include/", 18 | "lib/SPIRV/", 19 | "lib/SPIRV/Mangler/", 20 | "lib/SPIRV/libSPIRV/", 21 | ], 22 | visibility = ["//visibility:public"], 23 | deps = [ 24 | "@llvm-project//llvm:Analysis", 25 | "@llvm-project//llvm:BitWriter", 26 | "@llvm-project//llvm:CodeGen", 27 | "@llvm-project//llvm:Core", 28 | "@llvm-project//llvm:Demangle", 29 | "@llvm-project//llvm:IRReader", 30 | "@llvm-project//llvm:Linker", 31 | "@llvm-project//llvm:Passes", 32 | "@llvm-project//llvm:Support", 33 | "@llvm-project//llvm:TransformUtils", 34 | "@spir_headers//:spirv_cpp_headers", 35 | ], 36 | ) 37 | -------------------------------------------------------------------------------- /third_party/onednn/BUILD: -------------------------------------------------------------------------------- 1 | load("@bazel_skylib//:bzl_library.bzl", "bzl_library") 2 | load("@bazel_skylib//lib:selects.bzl", "selects") 3 | 4 | exports_files(["LICENSE"]) 5 | 6 | package( 7 | default_visibility = [ 8 | "//tensorflow:__subpackages__", 9 | ], 10 | licenses = ["notice"], 11 | ) 12 | 13 | py_binary( 14 | name = "gen_gpu_kernel_list", 15 | srcs = ["gen_gpu_kernel_list.py"], 16 | visibility = [ 17 | "@onednn_gpu//:__subpackages__", 18 | ], 19 | ) 20 | 21 | py_binary( 22 | name = "gen_onednn_version", 23 | srcs = ["gen_onednn_version.py"], 24 | visibility = [ 25 | "@onednn_gpu//:__subpackages__", 26 | ], 27 | ) 28 | -------------------------------------------------------------------------------- /third_party/onednn/gen_onednn_version.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import sys 4 | import subprocess 5 | 6 | 7 | def parse_args(argv): 8 | result = {} 9 | for arg in argv: 10 | k, v = arg.split("=") 11 | result[k] = v 12 | 13 | return result 14 | 15 | 16 | def parse_version(cmake): 17 | pattern = re.compile('set\\(PROJECT_VERSION "([0-9]+\\.[0-9]+\\.[0-9]+)"\\)') 18 | with open(os.path.expanduser(cmake)) as f: 19 | for line in f.readlines(): 20 | result = pattern.match(line) 21 | if result is not None: 22 | return result.group(1) 23 | 24 | sys.exit("Can't get the right version from ", cmake) 25 | 26 | 27 | def get_root(header_in): 28 | """ 29 | This is an assumption that the root workspace should be the same depth 30 | with "include" folder. It will find start from right position, so sho- 31 | uld handle the include/**/itex/***/onednn/include/**. 32 | """ 33 | pos = header_in.rindex("include") 34 | root = header_in[:pos] 35 | return root 36 | 37 | 38 | def git_hash(header_in): 39 | root = get_root(header_in) 40 | commit_file = os.path.join(root, "COMMIT") 41 | with open(commit_file, 'r') as f: 42 | commit = f.readline().strip() 43 | 44 | return commit 45 | 46 | 47 | def get_cmake(header_in): 48 | root = get_root(header_in) 49 | cmake = os.path.join(root, "CMakeLists.txt") 50 | return cmake 51 | 52 | 53 | def generate_version(version, header_in, header_out): 54 | hash_value = git_hash(header_in) 55 | 56 | [major, minor, patch] = version.split(".") 57 | 58 | with open(os.path.expanduser(header_in)) as inf: 59 | content = inf.read() 60 | content = content.replace("@DNNL_VERSION_MAJOR@", major) 61 | content = content.replace("@DNNL_VERSION_MINOR@", minor) 62 | content = content.replace("@DNNL_VERSION_PATCH@", patch) 63 | content = content.replace("@DNNL_VERSION_HASH@", hash_value) 64 | 65 | header_out = os.path.expanduser(header_out) 66 | header_out_dir = os.path.dirname(header_out) 67 | if not os.path.exists(header_out_dir): 68 | os.makedirs(header_out_dir, exist_ok=True) 69 | 70 | with open(header_out, "w") as outf: 71 | outf.write(content) 72 | 73 | 74 | def main(): 75 | args = parse_args(sys.argv[1:]) 76 | cmake = get_cmake(args["--in"]) 77 | version = parse_version(cmake) 78 | generate_version(version, args["--in"], args["--out"]) 79 | 80 | 81 | if __name__ == "__main__": 82 | main() 83 | -------------------------------------------------------------------------------- /third_party/onednn/onednn.bzl: -------------------------------------------------------------------------------- 1 | def convert_cl_to_cpp(name, src, cl_list, **kwargs): 2 | """Create a miniature of the src image. 3 | The generated file is prefixed with 'small_'. 4 | """ 5 | cpp_list = [cl.replace(".cl", "_kernel.cpp") for cl in cl_list] 6 | kernel_list = src.replace(".in", "") 7 | cpp_list.append(kernel_list) 8 | 9 | tool = "@intel_extension_for_openxla//third_party/onednn:gen_gpu_kernel_list" 10 | 11 | native.genrule( 12 | name = name, 13 | srcs = [src], 14 | outs = cpp_list, 15 | tools = [tool], 16 | cmd = "$(location {}) ".format(tool) + "--in=$< --out=$(@D) --header=False", 17 | **kwargs 18 | ) 19 | 20 | def convert_header_to_cpp(name, src, header_list, **kwargs): 21 | """Create a miniature of the src image. 22 | The generated file is prefixed with 'small_'. 23 | """ 24 | cpp_list = [] 25 | h_list = [] 26 | for h in header_list: 27 | if h.endswith(".h"): 28 | h_list.append(h.replace(".h", "_header.cpp")) 29 | cpp_list.extend(h_list) 30 | 31 | tool = "@intel_extension_for_openxla//third_party/onednn:gen_gpu_kernel_list" 32 | 33 | native.genrule( 34 | name = name, 35 | srcs = [src], 36 | outs = cpp_list, 37 | tools = [tool], 38 | cmd = "$(location {}) ".format(tool) + "--in=$< --out=$(@D) --header=True", 39 | **kwargs 40 | ) 41 | 42 | def gen_onednn_version(name, header_in, header_out, **kwargs): 43 | tool = "@intel_extension_for_openxla//third_party/onednn:gen_onednn_version" 44 | 45 | native.genrule( 46 | name = name, 47 | srcs = [header_in], 48 | outs = [header_out], 49 | tools = [tool], 50 | cmd = "$(location {}) ".format(tool) + "--in=$< " + "--out=$@", 51 | **kwargs 52 | ) 53 | -------------------------------------------------------------------------------- /third_party/version_check.bzl: -------------------------------------------------------------------------------- 1 | """ Helpers to check minimum version of bazel.""" 2 | 3 | def _extract_version_number(bazel_version): 4 | """Extracts the semantic version number from a version string 5 | 6 | Args: 7 | bazel_version: the version string that begins with the semantic version 8 | e.g. "1.2.3rc1 abc1234" where "abc1234" is a commit hash. 9 | 10 | Returns: 11 | The semantic version string, like "1.2.3". 12 | """ 13 | for i in range(len(bazel_version)): 14 | c = bazel_version[i] 15 | if not (c.isdigit() or c == "."): 16 | return bazel_version[:i] 17 | return bazel_version 18 | 19 | # Parse the bazel version string from `native.bazel_version`. 20 | # e.g. 21 | # "0.10.0rc1 abc123d" => (0, 10, 0) 22 | # "0.3.0" => (0, 3, 0) 23 | def _parse_bazel_version(bazel_version): 24 | """Parses a version string into a 3-tuple of ints 25 | 26 | int tuples can be compared directly using binary operators (<, >). 27 | 28 | Args: 29 | bazel_version: the Bazel version string 30 | 31 | Returns: 32 | An int 3-tuple of a (major, minor, patch) version. 33 | """ 34 | 35 | version = _extract_version_number(bazel_version) 36 | return tuple([int(n) for n in version.split(".")]) 37 | 38 | def check_bazel_version_at_least(minimum_bazel_version): 39 | if "bazel_version" not in dir(native): 40 | fail("\nCurrent Bazel version is lower than the minimum supported version: %s\n" % minimum_bazel_version) 41 | elif not native.bazel_version: 42 | print("\nCurrent Bazel is not a release version, cannot check for compatibility.") 43 | print("Make sure that you are running at least Bazel %s.\n" % minimum_bazel_version) 44 | return 45 | 46 | if _parse_bazel_version(native.bazel_version) < _parse_bazel_version(minimum_bazel_version): 47 | fail("\nCurrent Bazel version is {}, expected at least {}\n".format( 48 | native.bazel_version, 49 | minimum_bazel_version, 50 | )) 51 | 52 | parse_bazel_version = _parse_bazel_version 53 | -------------------------------------------------------------------------------- /third_party/xetla/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | 3 | cc_library( 4 | name = "xetla_header", 5 | hdrs = glob( 6 | include = [ 7 | "include/util/*.h", 8 | "include/tile/*.h", 9 | "include/reduction/*.h", 10 | "include/mha_core_attention/*.h", 11 | "include/layer_norm/*.h", 12 | "include/gemm/*.h", 13 | "include/data_transformer/*.h", 14 | "include/core/*.h", 15 | "include/brgemm/*.h", 16 | ], 17 | ), 18 | includes = ["include"], 19 | strip_include_prefix = "include", 20 | ) 21 | -------------------------------------------------------------------------------- /xla/BUILD: -------------------------------------------------------------------------------- 1 | cc_binary( 2 | name = "pjrt_plugin_xpu.so", 3 | linkopts = ["-Wl,-rpath,$$ORIGIN/../intel_extension_for_openxla/service/gpu"], 4 | linkshared = True, 5 | visibility = ["//visibility:public"], 6 | deps = [ 7 | "@xla//xla/pjrt/c:pjrt_c_api_gpu", 8 | "@xla//xla/service:gpu_plugin", 9 | "//xla/stream_executor:sycl_platform", 10 | ], 11 | ) 12 | -------------------------------------------------------------------------------- /xla/profiler/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | 3 | cc_library( 4 | name = "sycl_device_tracer", 5 | srcs = ["device_tracer_sycl.cc"], 6 | linkstatic = 1, 7 | visibility = ["//visibility:public"], 8 | deps = [ 9 | ":ze_tracer", 10 | "//xla/stream_executor/sycl:hw_info", 11 | "//xla/stream_executor/sycl:sycl_gpu_runtime", 12 | "@tsl//tsl/profiler/backends/cpu:annotation_stack", 13 | "@tsl//tsl/profiler/lib:profiler_factory", 14 | "@tsl//tsl/profiler/lib:profiler_interface", 15 | "@tsl//tsl/profiler/protobuf:trace_events_proto_cc", 16 | "@tsl//tsl/profiler/protobuf:xplane_proto_cc", 17 | "@tsl//tsl/profiler/protobuf:profiler_options_proto_cc", 18 | "@tsl//tsl/profiler/utils:parse_annotation", 19 | "@tsl//tsl/profiler/utils:trace_utils", 20 | "@tsl//tsl/profiler/utils:tf_xplane_visitor", 21 | "@tsl//tsl/profiler/utils:xplane_builder", 22 | "@tsl//tsl/profiler/utils:xplane_utils", 23 | "@tsl//tsl/profiler/utils:xplane_schema", 24 | ], 25 | alwayslink = True, 26 | ) 27 | 28 | cc_library( 29 | name = "ze_tracer", 30 | hdrs = [ 31 | "trace_options.h", 32 | "tracing.h", 33 | "ze_api_collector.h", 34 | "ze_kernel_collector.h", 35 | "ze_tracer.h", 36 | "ze_utils.h", 37 | ], 38 | srcs = [ 39 | ":profiler_utils", 40 | ], 41 | visibility = ["//visibility:public"], 42 | deps = [ 43 | ":ze_correlator", 44 | "@tsl//tsl/platform:abi", 45 | ], 46 | ) 47 | 48 | cc_library( 49 | name = "ze_correlator", 50 | hdrs = [ 51 | "correlator.h", 52 | ], 53 | srcs = [ 54 | "correlator.cc", 55 | ":profiler_utils", 56 | ], 57 | visibility = ["//visibility:public"], 58 | deps = [ 59 | "@com_google_absl//absl/time", 60 | "@tsl//tsl/profiler/backends/cpu:annotation_stack", 61 | ], 62 | ) 63 | 64 | filegroup( 65 | name = "profiler_utils", 66 | srcs = [ 67 | "utils.h", 68 | ], 69 | visibility = ["//visibility:public"], 70 | ) 71 | -------------------------------------------------------------------------------- /xla/profiler/correlator.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | //============================================================== 17 | // Copyright (C) Intel Corporation 18 | // 19 | // SPDX-License-Identifier: MIT 20 | // ============================================================= 21 | 22 | #include "xla/profiler/correlator.h" 23 | 24 | thread_local uint64_t Correlator::kernel_id_ = 0; 25 | namespace xla{ 26 | namespace profiler { 27 | int64_t GetCurrentTimeNanos() { 28 | // absl::GetCurrentTimeNanos() is much faster than EnvTime::NowNanos(). 29 | // It is wrapped under xla::profiler::GetCurrentTimeNanos to avoid ODR 30 | // violation and to allow switching to yet another implementation if required. 31 | return absl::GetCurrentTimeNanos(); 32 | }; 33 | } // namespace profiler 34 | } // namespace xla 35 | // Returns the current CPU wallclock time in nanoseconds. 36 | -------------------------------------------------------------------------------- /xla/profiler/correlator.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021-2022 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | //============================================================== 17 | // Copyright (C) Intel Corporation 18 | // 19 | // SPDX-License-Identifier: MIT 20 | //============================================================== 21 | #ifndef XLA_PROFILER_CORRELATOR_H_ 22 | #define XLA_PROFILER_CORRELATOR_H_ 23 | 24 | #include 25 | 26 | #include 27 | #include 28 | #include 29 | 30 | #include "absl/time/clock.h" 31 | #include "xla/profiler/utils.h" 32 | 33 | namespace xla{ 34 | namespace profiler { 35 | // Returns the current CPU wallclock time in nanoseconds. 36 | int64_t GetCurrentTimeNanos(); 37 | 38 | } // namespace profiler 39 | } // namespace xla 40 | 41 | struct ApiCollectorOptions { 42 | bool call_tracing; 43 | bool need_tid; 44 | bool need_pid; 45 | }; 46 | 47 | class Correlator { 48 | public: 49 | Correlator() : base_time_(xla::profiler::GetCurrentTimeNanos()) { 50 | } 51 | 52 | uint64_t GetTimestamp() const { 53 | return xla::profiler::GetCurrentTimeNanos() - base_time_; 54 | } 55 | 56 | uint64_t GetStartPoint() const { return base_time_; } 57 | 58 | uint64_t GetKernelIdVector() const { return kernel_id_; } 59 | 60 | void SetKernelId(uint64_t kernel_id) { kernel_id_ = kernel_id; } 61 | 62 | std::vector GetKernelIdVector (ze_command_list_handle_t command_list) { 63 | if (kernel_id_map_.count(command_list) > 0) { 64 | return kernel_id_map_[command_list]; 65 | } else { 66 | return std::vector(); 67 | } 68 | } 69 | 70 | void CreateKernelIdList(ze_command_list_handle_t command_list) { 71 | kernel_id_map_[command_list] = std::vector(); 72 | } 73 | 74 | void RemoveKernelIdList(ze_command_list_handle_t command_list) { 75 | kernel_id_map_.erase(command_list); 76 | } 77 | 78 | void ResetKernelIdList(ze_command_list_handle_t command_list) { 79 | kernel_id_map_[command_list].clear(); 80 | } 81 | 82 | void AddKernelId(ze_command_list_handle_t command_list, uint64_t kernel_id) { 83 | kernel_id_map_[command_list].push_back(kernel_id); 84 | } 85 | 86 | std::vector GetCallIdVector(ze_command_list_handle_t command_list) { 87 | if (call_id_map_.count(command_list) > 0) { 88 | return call_id_map_[command_list]; 89 | } else { 90 | return std::vector(); 91 | } 92 | } 93 | 94 | void CreateCallIdList(ze_command_list_handle_t command_list) { 95 | call_id_map_[command_list] = std::vector(); 96 | } 97 | 98 | void RemoveCallIdList(ze_command_list_handle_t command_list) { 99 | call_id_map_.erase(command_list); 100 | } 101 | 102 | void ResetCallIdList(ze_command_list_handle_t command_list) { 103 | call_id_map_[command_list].clear(); 104 | } 105 | 106 | void AddCallId(ze_command_list_handle_t command_list, uint64_t call_id) { 107 | call_id_map_[command_list].push_back(call_id); 108 | } 109 | 110 | private: 111 | uint64_t base_time_; 112 | std::map > kernel_id_map_; 113 | std::map > call_id_map_; 114 | 115 | static thread_local uint64_t kernel_id_; 116 | }; 117 | 118 | #endif // XLA_PROFILER_CORRELATOR_H_ 119 | -------------------------------------------------------------------------------- /xla/profiler/trace_options.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | //============================================================== 17 | // Copyright (C) Intel Corporation 18 | // 19 | // SPDX-License-Identifier: MIT 20 | // ============================================================= 21 | 22 | #ifndef XLA_PROFILER_TRACE_OPTIONS_H_ 23 | #define XLA_PROFILER_TRACE_OPTIONS_H_ 24 | 25 | #include 26 | #include 27 | 28 | #include "xla/profiler/utils.h" 29 | 30 | #define TRACE_CALL_LOGGING 0 31 | #define TRACE_HOST_TIMING 1 32 | #define TRACE_DEVICE_TIMING 2 33 | #define TRACE_DEVICE_TIMING_VERBOSE 3 34 | #define TRACE_DEVICE_TIMELINE 4 35 | #define TRACE_CHROME_CALL_LOGGING 5 36 | #define TRACE_CHROME_DEVICE_TIMELINE 6 37 | #define TRACE_CHROME_DEVICE_STAGES 7 38 | #define TRACE_TID 8 39 | #define TRACE_PID 9 40 | #define TRACE_LOG_TO_FILE 10 41 | #define TRACE_HOST_RUNTIME_TIMING 11 42 | 43 | class TraceOptions { 44 | public: 45 | TraceOptions(uint32_t flags) 46 | : flags_(flags) { 47 | if (flags_ == 0) { 48 | flags_ |= (1 << TRACE_HOST_TIMING); 49 | flags_ |= (1 << TRACE_DEVICE_TIMING); 50 | } 51 | } 52 | 53 | bool CheckFlag(uint32_t flag) const { return (flags_ & (1 << flag)); } 54 | 55 | private: 56 | uint32_t flags_; 57 | }; 58 | 59 | #endif // XLA_PROFILER_TRACE_OPTIONS_H_ 60 | -------------------------------------------------------------------------------- /xla/python/BUILD: -------------------------------------------------------------------------------- 1 | load("@xla//xla/tsl:tsl.default.bzl", "tsl_pybind_extension") 2 | 3 | tsl_pybind_extension( 4 | name = "xpu_plugin_extension", 5 | srcs = ["xpu_plugin_extension.cc"], 6 | deps = [ 7 | "@com_google_absl//absl/status", 8 | "@nanobind", 9 | "@xla//xla:util", 10 | "@xla//xla/ffi/api:c_api", 11 | "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", 12 | "@xla//xla/pjrt/c:pjrt_c_api_hdrs", 13 | "@xla//xla/pjrt/c:pjrt_c_api_helpers", 14 | "@xla//xla/pjrt:status_casters", 15 | "@xla//xla/python:py_client_gpu", 16 | "@xla//xla/tsl/python/lib/core:numpy", 17 | ], 18 | ) 19 | -------------------------------------------------------------------------------- /xla/python/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Intel Corporation 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | '''Init file for register XPU backend''' 16 | 17 | import logging 18 | from pathlib import Path 19 | import os 20 | import platform 21 | import sys 22 | 23 | import jax._src.xla_bridge as xb 24 | from jax._src.lib import xla_client 25 | from jax_plugins.intel_extension_for_openxla.version import VersionClass 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | def initialize(): 30 | path = Path(__file__).resolve().parent / "pjrt_plugin_xpu.so" 31 | xla_extension_version = VersionClass() 32 | logger.warning("INFO: Intel Extension for OpenXLA version: %s, commit: %s", 33 | xla_extension_version.get_version(), 34 | xla_extension_version.get_hash()) 35 | if not path.exists(): 36 | logger.warning( 37 | f"WARNING: Native library {path} does not exist. " 38 | f"This most likely indicates an issue with how {__package__} " 39 | f"was built or installed.") 40 | 41 | options = dict() 42 | options['platform_name'] = 'sycl' 43 | allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower() 44 | if allocator not in ('default', 'platform', 'bfc', 'cuda_async'): 45 | raise ValueError( 46 | 'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", ' 47 | '"bfc", or "cuda_async", got "%s"' % allocator 48 | ) 49 | options['allocator'] = allocator 50 | 51 | # xb.CUDA_VISIBLE_DEVICES is set by jax.distribute.initialize(local_device_ids). 52 | # xb.CUDA_VISIBLE_DEVICES would has default value 'all' if users not call 53 | # jax.distribute.initialize or call it without setting local_device_ids. 54 | visible_devices = xb.CUDA_VISIBLE_DEVICES.value 55 | if visible_devices != 'all': 56 | options['visible_devices'] = [int(x) for x in visible_devices.split(',')] 57 | 58 | c_api = xb.register_plugin("sycl", 59 | priority=500, 60 | library_path=str(path), 61 | options=options) 62 | 63 | try: 64 | import functools 65 | from .python import xpu_plugin_extension 66 | xla_client.register_custom_call_handler( 67 | "SYCL", 68 | functools.partial( 69 | xpu_plugin_extension.register_custom_call_target, c_api 70 | ), 71 | ) 72 | for _name, _value in xpu_plugin_extension.registrations().items(): 73 | xla_client.register_custom_call_target(_name, _value, platform="SYCL") 74 | except: 75 | raise RuntimeError("Fail to load xpu_plugin_extension.so.") 76 | -------------------------------------------------------------------------------- /xla/python/build_defs.bzl: -------------------------------------------------------------------------------- 1 | def gen_xla_version(name, header_in, header_out, **kwargs): 2 | tool = "//xla/python:gen_xla_version" 3 | 4 | native.genrule( 5 | name = name, 6 | srcs = [header_in], 7 | outs = [header_out], 8 | tools = [tool], 9 | cmd = "$(location {}) ".format(tool) + "--in=$< " + "--out=$@", 10 | stamp = True, 11 | **kwargs 12 | ) 13 | -------------------------------------------------------------------------------- /xla/python/gen_xla_version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2024 Intel Corporation 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import os 17 | import sys 18 | 19 | 20 | def parse_args(argv): 21 | result = {} 22 | for arg in argv: 23 | k, v = arg.split("=") 24 | result[k] = v 25 | 26 | return result 27 | 28 | 29 | def generate_version(header_in, hash_value, header_out): 30 | with open(os.path.expanduser(header_in), encoding='utf-8') as inf: 31 | content = inf.read() 32 | content = content.replace("@VERSION_HASH@", hash_value) 33 | 34 | header_out = os.path.expanduser(header_out) 35 | header_out_dir = os.path.dirname(header_out) 36 | if not os.path.exists(header_out_dir): 37 | os.makedirs(header_out_dir, exist_ok=True) 38 | 39 | with open(header_out, "w", encoding='utf-8') as outf: 40 | outf.write(content) 41 | 42 | 43 | def main(): 44 | args = parse_args(sys.argv[1:]) 45 | generate_version(args["--in"], args["--hash"], args["--out"]) 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /xla/python/version.py.in: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2024 Intel Corporation 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | class VersionClass(object): 17 | def __init__(self): 18 | self.version = '0.6.0' 19 | self.hash = "@VERSION_HASH@" 20 | 21 | def get_version(self): 22 | return self.version 23 | 24 | def get_hash(self): 25 | return self.hash 26 | -------------------------------------------------------------------------------- /xla/python/xpu_plugin_extension.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include 17 | #include 18 | #include 19 | 20 | #include "absl/status/status.h" 21 | #include "nanobind/nanobind.h" 22 | #include "xla/ffi/api/c_api.h" 23 | #include "xla/pjrt/c/pjrt_c_api.h" 24 | #include "xla/pjrt/c/pjrt_c_api_gpu_extension.h" 25 | #include "xla/pjrt/c/pjrt_c_api_helpers.h" 26 | #include "xla/python/py_client_gpu.h" 27 | #include "xla/pjrt/status_casters.h" 28 | #include "xla/tsl/python/lib/core/numpy.h" 29 | #include "xla/util.h" 30 | 31 | namespace nb = nanobind; 32 | 33 | namespace xla { 34 | namespace { 35 | absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, 36 | const char* fn_name_c_str, size_t fn_name_size, 37 | nb::capsule fn, int api_version, 38 | XLA_FFI_Handler_Traits traits) { 39 | if (c_api->extension_start == nullptr) { 40 | return Unimplemented("The plugin does not have extension."); 41 | } 42 | const PJRT_Extension_Base* next = 43 | reinterpret_cast(c_api->extension_start); 44 | while (next != nullptr && 45 | next->type != 46 | PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) { 47 | next = next->next; 48 | } 49 | if (next == nullptr) { 50 | return Unimplemented("The plugin does not have a custom call extension."); 51 | } 52 | 53 | if (traits != 0) { 54 | return Unimplemented("The plugin does not support custom call traits."); 55 | } 56 | 57 | PJRT_Gpu_Register_Custom_Call_Args args; 58 | args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; 59 | args.function_name = fn_name_c_str; 60 | args.function_name_size = fn_name_size; 61 | args.api_version = api_version; 62 | args.custom_call_function = static_cast(fn.data()); 63 | RETURN_STATUS_IF_PJRT_ERROR( 64 | reinterpret_cast(next)->custom_call(&args), 65 | c_api); 66 | return absl::OkStatus(); 67 | } 68 | 69 | template 70 | nb::capsule EncapsulateFunction(T* fn) { 71 | return nb::capsule(absl::bit_cast(fn), 72 | "xla._CUSTOM_CALL_TARGET"); 73 | } 74 | 75 | nb::dict Registrations() { 76 | nb::dict dict; 77 | dict["xla_python_gpu_callback"] = 78 | EncapsulateFunction(xla::XlaPythonGpuCallback); 79 | return dict; 80 | } 81 | 82 | } // namespace 83 | 84 | NB_MODULE(xpu_plugin_extension, m) { 85 | tsl::ImportNumpy(); 86 | m.def( 87 | "register_custom_call_target", 88 | [](nb::capsule c_api, nb::object fn_name_py, nb::capsule fn, 89 | nb::str xla_platform_name, int api_version, 90 | XLA_FFI_Handler_Traits traits) { 91 | const char* fn_name_c_str; 92 | size_t fn_name_size; 93 | nb::str fn_name_bn_str; 94 | if (nb::try_cast(fn_name_py, fn_name_bn_str)) { 95 | fn_name_c_str = fn_name_bn_str.c_str(); 96 | fn_name_size = nb::len(fn_name_bn_str); 97 | } else{ 98 | nb::bytes bytes = nb::cast(fn_name_py); 99 | fn_name_c_str = bytes.c_str(); 100 | fn_name_size = bytes.size(); 101 | } 102 | xla::ThrowIfError(RegisterCustomCallTarget( 103 | static_cast(c_api.data()), fn_name_c_str, 104 | fn_name_size, std::move(fn), api_version, traits)); 105 | }, 106 | nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"), 107 | nb::arg("xla_platform_name"), nb::arg("api_version") = 0, 108 | nb::arg("traits") = 0); 109 | m.def("registrations", &Registrations); 110 | } 111 | } // namespace xla 112 | -------------------------------------------------------------------------------- /xla/service/BUILD: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = [ 3 | "//visibility:public", 4 | ], 5 | licenses = ["notice"], 6 | ) 7 | 8 | cc_library( 9 | name = "onednn_util", 10 | hdrs = ["onednn_util.h"], 11 | deps = [ 12 | "@onednn_gpu//:onednn_gpu", 13 | "@tsl//tsl/platform:errors", 14 | "@tsl//tsl/platform:logging", 15 | ], 16 | ) 17 | -------------------------------------------------------------------------------- /xla/service/gpu/ccl_all_gather_thunk.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Copyright 2019 The OpenXLA Authors. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | #ifndef XLA_SERVICE_GPU_CCL_ALL_GATHER_THUNK_H_ 19 | #define XLA_SERVICE_GPU_CCL_ALL_GATHER_THUNK_H_ 20 | 21 | #include 22 | #include 23 | 24 | #include "absl/status/status.h" 25 | #include "absl/types/span.h" 26 | #include "xla/hlo/ir/hlo_instructions.h" 27 | #include "xla/service/collective_ops_utils.h" 28 | #include "xla/service/gpu/runtime/nccl_api.h" 29 | #include "xla/service/gpu/ccl_collective_thunk.h" 30 | #include "xla/stream_executor/stream.h" 31 | 32 | namespace xla { 33 | namespace gpu { 34 | 35 | struct NcclAllGatherConfig { 36 | NcclCollectiveConfig config; 37 | }; 38 | 39 | // Thunk that performs a NCCL-based All-Gather among CUDA GPU-based replicas. 40 | class NcclAllGatherStartThunk : public NcclCollectiveThunk { 41 | public: 42 | NcclAllGatherStartThunk(ThunkInfo thunk_info, NcclApi* nccl_api, 43 | const HloAllGatherInstruction* inst, 44 | std::vector buffers); 45 | 46 | static const char* GetHloOpName() { return "all-gather-start"; } 47 | 48 | static absl::Status CheckImplementable(const HloAllGatherInstruction* inst, 49 | int64_t replica_count, 50 | int64_t partition_count); 51 | 52 | static CollectiveOpGroupMode GetGroupMode( 53 | const HloAllGatherInstruction* inst); 54 | 55 | const NcclCollectiveConfig& config() const override { return config_.config; } 56 | absl::Span buffers() const { return buffers_; } 57 | 58 | protected: 59 | absl::Status RunNcclCollective(const ExecuteParams& params, 60 | se::Stream& stream, 61 | NcclApi::NcclCommHandle comm) override; 62 | 63 | private: 64 | const NcclAllGatherConfig config_; 65 | const std::vector buffers_; 66 | }; 67 | 68 | absl::Status RunAllGather(NcclApi* nccl_api, 69 | std::vector& buffers, 70 | se::Stream& stream, NcclApi::NcclCommHandle comm); 71 | 72 | } // namespace gpu 73 | } // namespace xla 74 | 75 | #endif // XLA_SERVICE_GPU_NCCL_ALL_GATHER_THUNK_H_ 76 | -------------------------------------------------------------------------------- /xla/service/gpu/ccl_all_to_all_thunk.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Copyright 2019 The OpenXLA Authors. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | #ifndef XLA_SERVICE_GPU_CCL_ALL_TO_ALL_THUNK_H_ 19 | #define XLA_SERVICE_GPU_CCL_ALL_TO_ALL_THUNK_H_ 20 | 21 | #include 22 | 23 | #include "xla/service/collective_ops_utils.h" 24 | #include "xla/service/gpu/ccl_collective_thunk.h" 25 | #include "xla/service/gpu/runtime/nccl_api.h" 26 | 27 | namespace xla { 28 | namespace gpu { 29 | 30 | struct NcclAllToAllConfig { 31 | NcclCollectiveConfig config; 32 | bool has_split_dimension; 33 | }; 34 | 35 | // Thunk that performs a NCCL-based All-to-All among CUDA GPU-based replicas. 36 | class NcclAllToAllStartThunk : public NcclCollectiveThunk { 37 | public: 38 | NcclAllToAllStartThunk(ThunkInfo thunk_info, NcclApi* nccl_api, 39 | const HloAllToAllInstruction* instr, 40 | std::vector buffers); 41 | 42 | // Returns whether the given instruction can be lowered to a nccl all-to-all 43 | // call. 44 | static absl::Status CheckImplementable(const HloAllToAllInstruction* instr, 45 | int64_t replica_count, 46 | int64_t partition_count); 47 | 48 | static const char* GetHloOpName() { return "all-to-all-start"; } 49 | 50 | static CollectiveOpGroupMode GetGroupMode( 51 | const HloAllToAllInstruction* instr); 52 | 53 | protected: 54 | const NcclCollectiveConfig& config() const override { return config_.config; } 55 | absl::Status RunNcclCollective(const ExecuteParams& params, 56 | se::Stream& stream, 57 | NcclApi::NcclCommHandle comm) override; 58 | 59 | private: 60 | const NcclAllToAllConfig config_; 61 | const std::vector buffers_; 62 | }; 63 | 64 | absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension, 65 | std::vector& buffers, 66 | se::Stream& stream, NcclApi::NcclCommHandle comm); 67 | 68 | } // namespace gpu 69 | } // namespace xla 70 | 71 | #endif // XLA_SERVICE_GPU_CCL_ALL_TO_ALL_THUNK_H_ 72 | -------------------------------------------------------------------------------- /xla/service/gpu/ccl_collective_broadcast_thunk.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Copyright 2024 The OpenXLA Authors. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | #include "xla/service/gpu/ccl_collective_broadcast_thunk.h" 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #include "tsl/platform/errors.h" 26 | #include "tsl/platform/statusor.h" 27 | #include "xla/hlo/ir/hlo_instruction.h" 28 | #include "xla/hlo/ir/hlo_instructions.h" 29 | #include "xla/service/collective_ops_utils.h" 30 | #include "xla/service/gpu/runtime/ccl_api.h" 31 | #include "xla/service/gpu/runtime/nccl_api.h" 32 | #include "xla/service/gpu/runtime/thunk.h" 33 | #include "xla/stream_executor/device_memory.h" 34 | #include "xla/stream_executor/stream.h" 35 | #include "xla/xla_data.pb.h" 36 | 37 | namespace xla::gpu { 38 | 39 | NcclCollectiveBroadcastStartThunk::NcclCollectiveBroadcastStartThunk( 40 | ThunkInfo thunk_info, NcclApi* nccl_api, 41 | const HloCollectiveBroadcastInstruction* instr, std::vector buffers) 42 | : NcclCollectiveThunk(Thunk::kNcclCollectiveBroadcastStart, thunk_info, 43 | nccl_api, IsSyncCollective(instr)), 44 | config_(GetNcclCollectiveConfig(instr, std::nullopt)), 45 | buffers_(std::move(buffers)) {} 46 | 47 | /*static*/ absl::Status NcclCollectiveBroadcastStartThunk::CheckImplementable( 48 | const HloInstruction* instr, int64_t replica_count, 49 | int64_t partition_count) { 50 | return absl::OkStatus(); 51 | } 52 | 53 | /*static*/ CollectiveOpGroupMode 54 | NcclCollectiveBroadcastStartThunk::GetGroupMode( 55 | const HloCollectiveBroadcastInstruction* inst) { 56 | return GetNcclCollectiveConfig(inst, std::nullopt).group_mode; 57 | } 58 | 59 | absl::Status NcclCollectiveBroadcastStartThunk::RunNcclCollective( 60 | const ExecuteParams& params, se::Stream& stream, 61 | NcclApi::NcclCommHandle comm) { 62 | TF_ASSIGN_OR_RETURN( 63 | std::vector device_buffers, 64 | ConvertToDeviceBuffers(params, buffers_, config_.operand_element_type)); 65 | return ::xla::gpu::RunCollectiveBroadcast(device_buffers, stream, comm, 66 | nccl_api()); 67 | } 68 | 69 | absl::Status RunCollectiveBroadcast(std::vector& buffers, 70 | se::Stream& stream, 71 | NcclApi::NcclCommHandle comm, 72 | NcclApi* nccl_api) { 73 | for (auto buffer : buffers) { 74 | se::DeviceMemoryBase src_addr = buffer.source_buffer; 75 | se::DeviceMemoryBase dest_addr = buffer.destination_buffer; 76 | TF_RETURN_IF_ERROR(nccl_api->Broadcast( 77 | // Always use rank 0 since we always broadcast from the first id in 78 | // replica_groups 79 | src_addr, dest_addr, buffer.element_type, buffer.element_count, 0, comm, 80 | &stream)); 81 | } 82 | return absl::OkStatus(); 83 | } 84 | 85 | } // namespace xla::gpu 86 | -------------------------------------------------------------------------------- /xla/service/gpu/ccl_collective_broadcast_thunk.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Copyright 2024 The OpenXLA Authors. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | #ifndef XLA_SERVICE_CCL_COLLECTIVE_BROADCAST_THUNK_H_ 19 | #define XLA_SERVICE_CCL_COLLECTIVE_BROADCAST_THUNK_H_ 20 | 21 | #include 22 | 23 | #include "xla/service/collective_ops_utils.h" 24 | #include "xla/service/gpu/ccl_collective_thunk.h" 25 | 26 | namespace xla::gpu { 27 | // Thunk that performs a NCCL-based collective broadcast. 28 | class NcclCollectiveBroadcastStartThunk : public NcclCollectiveThunk { 29 | public: 30 | static absl::Status CheckImplementable(const HloInstruction* instr, 31 | int64_t replica_count, 32 | int64_t partition_count); 33 | 34 | static CollectiveOpGroupMode GetGroupMode( 35 | const HloCollectiveBroadcastInstruction* inst); 36 | 37 | const NcclCollectiveConfig& config() const override { return config_; } 38 | absl::Span buffers() const { return buffers_; } 39 | 40 | static const char* GetHloOpName() { return "collective-broadcast-start"; } 41 | 42 | NcclCollectiveBroadcastStartThunk( 43 | ThunkInfo thunk_info, NcclApi* nccl_api, 44 | const HloCollectiveBroadcastInstruction* instr, 45 | std::vector buffers); 46 | 47 | protected: 48 | absl::Status RunNcclCollective(const ExecuteParams& params, 49 | se::Stream& stream, 50 | NcclApi::NcclCommHandle comm) override; 51 | 52 | private: 53 | const NcclCollectiveConfig config_; 54 | const std::vector buffers_; 55 | }; 56 | 57 | absl::Status RunCollectiveBroadcast(std::vector& buffers, 58 | se::Stream& stream, 59 | NcclApi::NcclCommHandle comm, 60 | NcclApi* nccl_api); 61 | 62 | } // namespace xla::gpu 63 | 64 | #endif // XLA_SERVICE_GPU_CCL_COLLECTIVE_BROADCAST_THUNK_H_ 65 | -------------------------------------------------------------------------------- /xla/service/gpu/ccl_collective_permute_thunk.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Copyright 2021 The OpenXLA Authors. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | #ifndef XLA_SERVICE_GPU_CCL_COLLECTIVE_PERMUTE_THUNK_H_ 19 | #define XLA_SERVICE_GPU_CCL_COLLECTIVE_PERMUTE_THUNK_H_ 20 | 21 | #include 22 | 23 | #include "xla/service/collective_ops_utils.h" 24 | #include "xla/service/gpu/ccl_collective_thunk.h" 25 | #include "xla/service/gpu/ccl_p2p_thunk_common.h" 26 | 27 | namespace xla { 28 | namespace gpu { 29 | 30 | // Thunk that performs a NCCL-based collective permute. 31 | class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk { 32 | public: 33 | static NcclP2PConfig GetNcclP2PConfig( 34 | const HloCollectivePermuteInstruction* instr, int64_t replica_count, 35 | int64_t partition_count); 36 | 37 | static bool IsDegenerate(const HloCollectivePermuteInstruction* instr, 38 | int64_t replica_count, int64_t partition_count); 39 | 40 | static CollectiveOpGroupMode GetGroupMode( 41 | const HloCollectivePermuteInstruction* instr); 42 | 43 | NcclCollectivePermuteStartThunk(ThunkInfo thunk_info, NcclApi* nccl_api, 44 | const HloCollectivePermuteInstruction* instr, 45 | int64_t replica_count, 46 | int64_t partition_count, const Buffer& buffer, 47 | bool p2p_memcpy_enabled); 48 | 49 | static const char* GetHloOpName() { return "collective-permute-start"; } 50 | 51 | protected: 52 | const NcclCollectiveConfig& config() const override { return config_.config; } 53 | absl::Status RunNcclCollective(const ExecuteParams& params, 54 | se::Stream& stream, 55 | NcclApi::NcclCommHandle comm) override; 56 | 57 | private: 58 | const NcclP2PConfig config_; 59 | const Buffer buffer_; 60 | bool p2p_memcpy_enabled_ = false; 61 | }; 62 | 63 | absl::Status RunCollectivePermute( 64 | NcclApi* nccl_api, NcclP2PConfig::SourceTargetMapEntry source_target, 65 | DeviceBufferPair& buffer, se::Stream& stream, NcclApi::NcclCommHandle comm, 66 | absl::string_view device_string, int64_t current_id); 67 | 68 | } // namespace gpu 69 | } // namespace xla 70 | 71 | #endif // XLA_SERVICE_GPU_CCL_COLLECTIVE_PERMUTE_THUNK_H_ 72 | -------------------------------------------------------------------------------- /xla/service/gpu/ccl_ops.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | #ifndef XLA_SERVICE_GPU_CCL_OPS_H_ 16 | #define XLA_SERVICE_GPU_CCL_OPS_H_ 17 | #include 18 | 19 | #include "xla/service/collective_ops_utils.h" 20 | #include "xla/stream_executor/gpu/gpu_types.h" 21 | 22 | namespace ccl { 23 | struct communicator { 24 | communicator(int nranks, int rank, const std::string id) 25 | : nranks(nranks), rank(rank), id(id) {} 26 | int nranks; 27 | int rank; 28 | const std::string id; 29 | }; 30 | } // namespace ccl 31 | 32 | using ncclComm_t = ccl::communicator*; 33 | #define MAX_RANK_SIZE 16 34 | 35 | #if !ITEX_USE_CCL 36 | 37 | namespace xla { 38 | namespace gpu { 39 | 40 | void sycl_allreduce(const void* send_buffer, void* recv_buffer, 41 | size_t element_count, PrimitiveType dtype, 42 | ReductionKind reduction_kind, 43 | se::gpu::GpuStreamHandle gpu_stream, ncclComm_t comm); 44 | 45 | void sycl_broadcast(const void* send_buffer, void* recv_buffer, 46 | size_t element_count, PrimitiveType dtype, size_t root, 47 | se::gpu::GpuStreamHandle gpu_stream, ncclComm_t comm); 48 | 49 | void sycl_allgather(const void* send_buffer, void* recv_buffer, 50 | size_t element_count, PrimitiveType dtype, 51 | se::gpu::GpuStreamHandle gpu_stream, ncclComm_t comm); 52 | 53 | void sycl_alltoall(std::vector send_buffer, 54 | std::vector recv_buffer, size_t element_count, 55 | PrimitiveType dtype, se::gpu::GpuStreamHandle gpu_stream, 56 | ncclComm_t comm); 57 | 58 | void sycl_alltoall_split(std::vector send_buffer, 59 | std::vector recv_buffer, size_t element_count, 60 | PrimitiveType dtype, 61 | se::gpu::GpuStreamHandle gpu_stream, ncclComm_t comm); 62 | 63 | void sycl_reduce_scatter(const void* send_buffer, void* recv_buffer, 64 | size_t element_count, PrimitiveType dtype, 65 | ReductionKind reduction_kind, 66 | se::gpu::GpuStreamHandle gpu_stream, ncclComm_t comm); 67 | 68 | void sycl_collective_permute(const void* send_buffer, void* recv_buffer, 69 | size_t element_count, PrimitiveType dtype, 70 | const std::optional& source_id, 71 | const std::optional& target_id, 72 | se::gpu::GpuStreamHandle gpu_stream, 73 | ncclComm_t comm); 74 | } // namespace gpu 75 | } // namespace xla 76 | 77 | #endif // ITEX_USE_CCL 78 | #endif // XLA_SERVICE_GPU_CCL_OPS_H_ 79 | -------------------------------------------------------------------------------- /xla/service/gpu/ccl_p2p_thunk_common.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Copyright 2023 The OpenXLA Authors. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | #ifndef XLA_SERVICE_GPU_CCL_P2P_THUNK_COMMON_H_ 19 | #define XLA_SERVICE_GPU_CCL_P2P_THUNK_COMMON_H_ 20 | 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | #include "absl/container/flat_hash_map.h" 27 | #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project 28 | #include "xla/service/collective_ops_utils.h" 29 | #include "xla/service/gpu/ccl_collective_thunk.h" 30 | 31 | namespace xla { 32 | namespace gpu { 33 | 34 | // Records the information for implementing CollectivePermute, Send and Recv. 35 | struct NcclP2PConfig { 36 | // Record the target ID for sending a data and the source ID from which to 37 | // receive a data. Either target or source can be optional. 38 | struct SourceTargetMapEntry { 39 | std::optional source; 40 | std::optional target; 41 | }; 42 | 43 | using IdToSourceTargetMap = 44 | absl::flat_hash_map; 45 | 46 | // Returns the source and target ID corresponding to the given ID (these IDs 47 | // are replica_ids for cross replica permute or partition_ids for cross 48 | // partition permute). The source ID is the id which will send data to this 49 | // ID and the target ID is the id to which this ID will send its data. Either 50 | // can be optional. 51 | static SourceTargetMapEntry GetSourceTarget( 52 | const IdToSourceTargetMap& id_to_source_target, int64_t id) { 53 | auto it = id_to_source_target.find(id); 54 | if (it != id_to_source_target.end()) return it->second; 55 | return SourceTargetMapEntry{}; 56 | } 57 | 58 | NcclCollectiveConfig config; 59 | IdToSourceTargetMap id_to_source_target; 60 | }; 61 | 62 | // Extracts source/target pairs for send/recv from frontend attributes. 63 | absl::StatusOr>> GetSourceTargetPairs( 64 | mlir::DictionaryAttr frontend_attributes); 65 | 66 | // Constructs the NcclP2PConfig for an HLO Send or Recv instruction. 67 | NcclP2PConfig GetNcclP2PConfigForSendRecv(const HloSendRecvInstruction* instr, 68 | const Shape& shape, 69 | int64_t replica_count, 70 | int64_t partition_count); 71 | 72 | } // namespace gpu 73 | } // namespace xla 74 | 75 | #endif // XLA_SERVICE_GPU_CCL_P2P_THUNK_COMMON_H_ 76 | -------------------------------------------------------------------------------- /xla/service/gpu/dot_expand_dims.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef XLA_SERVICE_GPU_DOT_EXPAND_DIMS_H_ 17 | #define XLA_SERVICE_GPU_DOT_EXPAND_DIMS_H_ 18 | 19 | #include "xla/hlo/ir/hlo_computation.h" 20 | #include "xla/hlo/ir/hlo_module.h" 21 | #include "xla/service/hlo_pass_interface.h" 22 | 23 | namespace xla { 24 | namespace gpu { 25 | 26 | // Expand dot dims for dimension 1 so that it can call onednn. 27 | class DotExpandDims : public HloModulePass { 28 | public: 29 | explicit DotExpandDims(); 30 | absl::string_view name() const override { return "dot-expand-dims"; } 31 | 32 | using HloPassInterface::Run; 33 | StatusOr Run( 34 | HloModule* module, 35 | const absl::flat_hash_set& execution_threads) override; 36 | }; 37 | 38 | } // namespace gpu 39 | } // namespace xla 40 | 41 | #endif // XLA_SERVICE_GPU_DOT_EXPAND_DIMS_H_ -------------------------------------------------------------------------------- /xla/service/gpu/gemm_impl_picker.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Copyright 2019 The TensorFlow Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | #ifndef XLA_SERVICE_GPU_GEMM_ALGORITHM_PICKER_H_ 19 | #define XLA_SERVICE_GPU_GEMM_ALGORITHM_PICKER_H_ 20 | 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | #include "absl/strings/string_view.h" 28 | #include "xla/hlo/ir/hlo_instructions.h" 29 | #include "xla/hlo/ir/hlo_module.h" 30 | #include "xla/service/gpu/autotuner_util.h" 31 | #include "xla/service/gpu/backend_configs.pb.h" 32 | #include "xla/service/gpu/matmul_utils.h" 33 | #include "xla/service/gpu/stream_executor_util.h" 34 | #include "xla/service/hlo_pass_interface.h" 35 | #include "xla/stream_executor/device_description.h" 36 | #include "xla/stream_executor/device_memory_allocator.h" 37 | #include "xla/stream_executor/stream_executor.h" 38 | 39 | namespace xla { 40 | namespace gpu { 41 | 42 | // GemmAlgorithmPicker supports two modes: device and deviceless. 43 | // In device mode, we run autotuning on the device and store autotune results. 44 | // In deviceless mode, we pass in some information related to the device and 45 | // use stored autotune results to rewrite Gemm instructions. If the required 46 | // autotune result is not stored, then algorithm is set to kRuntimeAutotuning. 47 | class GemmAlgorithmPicker : public HloModulePass { 48 | public: 49 | explicit GemmAlgorithmPicker(AutotuneConfig config) : config_(config) {} 50 | 51 | absl::string_view name() const override { return "gemm-impl-picker"; } 52 | 53 | using HloPassInterface::Run; 54 | absl::StatusOr Run( 55 | HloModule* module, 56 | const absl::flat_hash_set& execution_threads) override; 57 | 58 | private: 59 | AutotuneConfig config_; 60 | }; 61 | 62 | } // namespace gpu 63 | } // namespace xla 64 | 65 | #endif // XLA_SERVICE_GPU_GEMM_ALGORITHM_PICKER_H_ 66 | -------------------------------------------------------------------------------- /xla/service/gpu/matrix_descriptor.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Copyright 2022 The TensorFlow Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | #ifndef XLA_SERVICE_GPU_MATRIX_DESCRIPTOR_H_ 19 | #define XLA_SERVICE_GPU_MATRIX_DESCRIPTOR_H_ 20 | 21 | #include "xla/stream_executor/blas.h" 22 | #include "xla/stream_executor/device_memory.h" 23 | 24 | namespace xla { 25 | namespace gpu { 26 | 27 | namespace se = ::stream_executor; 28 | 29 | // This struct contains the metadata of a matrix, e.g., its base address and 30 | // dimensions. 31 | struct MatrixDescriptor { 32 | se::DeviceMemoryBase data; 33 | se::blas::Transpose transpose; 34 | int64_t num_rows; 35 | int64_t num_cols; 36 | int64_t batch_stride; 37 | int64_t leading_dim_stride; 38 | 39 | int64_t reduced_dim() const { 40 | return transpose == se::blas::Transpose::kTranspose ? num_rows : num_cols; 41 | } 42 | 43 | template 44 | se::DeviceMemory cast() const { 45 | return se::DeviceMemory(data); 46 | } 47 | }; 48 | 49 | } // namespace gpu 50 | } // namespace xla 51 | 52 | #endif // XLA_SERVICE_GPU_MATRIX_DESCRIPTOR_H_ 53 | -------------------------------------------------------------------------------- /xla/service/gpu/onednn_gpu_conv_runner.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef XLA_SERVICE_GPU_ONEDNN_GPU_CONV_RUNNER_H_ 17 | #define XLA_SERVICE_GPU_ONEDNN_GPU_CONV_RUNNER_H_ 18 | 19 | #include 20 | 21 | #include "xla/ffi/ffi.h" 22 | #include "xla/ffi/ffi_api.h" 23 | #include "xla/service/gpu/gpu_conv_runner.h" 24 | #include "xla/service/gpu/runtime/thunk.h" 25 | #include "xla/service/onednn_util.h" 26 | 27 | namespace xla { 28 | 29 | namespace gpu { 30 | 31 | typedef struct OneDnnConvPrimitive { 32 | dnnl::memory src_memory; 33 | dnnl::memory filter_memory; 34 | dnnl::memory dst_memory; 35 | dnnl::memory internal_filter_memory; 36 | dnnl::memory scratchpad_memory; 37 | dnnl::memory bias_memory; 38 | dnnl::convolution_forward fwd_primitive; 39 | dnnl::convolution_backward_data bwd_input_primitive; 40 | dnnl::convolution_backward_weights bwd_filter_primitive; 41 | dnnl::reorder filter_reorder_primitive; 42 | 43 | std::unordered_map fwd_primitives_args; 44 | std::unordered_map bwd_input_primitive_args; 45 | std::unordered_map bwd_filter_primitive_args; 46 | 47 | std::unordered_map reorder_args; 48 | 49 | dnnl::engine engine; 50 | dnnl::stream stream; 51 | bool has_reorder = false; 52 | } OneDnnConvPrimitive; 53 | 54 | absl::StatusOr GetOrCreateOneDnnConvPrimitive( 55 | se::Stream*, const ffi::Dictionary& dict, 56 | const std::vector& operand_se_buffers, 57 | const ffi::BufferBase& result_buffer, 58 | se::ScratchAllocator* scratch_allocator, CudnnConvKind conv_kind); 59 | 60 | absl::Status RunGpuConv(const OneDnnConvPrimitive& onednn_primitive, 61 | const ffi::Dictionary& dict, 62 | absl::Span operand_buffers, 63 | ffi::BufferBase result_buffer, CudnnConvKind conv_kind); 64 | 65 | } // namespace gpu 66 | } // namespace xla 67 | 68 | #endif // XLA_SERVICE_GPU_ONEDNN_GPU_CONV_RUNNER_H_ -------------------------------------------------------------------------------- /xla/service/gpu/onednn_matmul_utils.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Copyright 2022 The TensorFlow Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | #ifndef XLA_SERVICE_GPU_ONEDNN_MATMUL_UTILS_H_ 19 | #define XLA_SERVICE_GPU_ONEDNN_MATMUL_UTILS_H_ 20 | 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | #include "xla/ffi/ffi.h" 27 | #include "xla/ffi/ffi_api.h" 28 | #include "xla/service/gpu/matmul_utils.h" 29 | #include "xla/service/gpu/scratch_allocator.h" 30 | 31 | namespace xla { 32 | namespace gpu { 33 | 34 | namespace SYCLGemm { 35 | enum class GemmBackendEpilogue { 36 | DEFAULT, 37 | RELU, 38 | GELU, 39 | BIAS, 40 | BIAS_RELU, 41 | BIAS_GELU, 42 | GELU_AUX, 43 | BIAS_GELU_AUX, 44 | }; 45 | 46 | absl::StatusOr EpilogueCast(std::string& epilogue); 47 | 48 | absl::StatusOr EpilogueCast(GemmBackendEpilogue epilogue); 49 | 50 | absl::StatusOr EpilogueAddsVectorBias(GemmBackendEpilogue epilogue); 51 | 52 | absl::StatusOr EpilogueHasAuxiliaryOutput(GemmBackendEpilogue epilogue); 53 | 54 | absl::StatusOr AsSYCLEpilogue( 55 | GemmBackendConfig_Epilogue epilogue); 56 | } // namespace SYCLGemm 57 | 58 | absl::Status RunGemm(const ffi::Dictionary& dict, 59 | se::DeviceMemoryBase lhs_buffer, 60 | se::DeviceMemoryBase rhs_buffer, 61 | se::DeviceMemoryBase add_buffer, 62 | se::DeviceMemoryBase output_buffer, 63 | se::DeviceMemoryBase bias_buffer, se::Stream* stream, 64 | SYCLGemm::GemmBackendEpilogue epilogue, 65 | se::ScratchAllocator* scratch_allocator = nullptr); 66 | 67 | absl::Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, 68 | se::DeviceMemoryBase rhs_buffer, 69 | se::DeviceMemoryBase add_buffer, 70 | se::DeviceMemoryBase output_buffer, 71 | se::DeviceMemoryBase bias_buffer, se::Stream* stream, 72 | SYCLGemm::GemmBackendEpilogue epilogue, 73 | se::ScratchAllocator* scratch_allocator = nullptr); 74 | 75 | } // namespace gpu 76 | } // namespace xla 77 | 78 | #endif // XLA_SERVICE_GPU_ONEDNN_MATMUL_UTILS_H_ -------------------------------------------------------------------------------- /xla/service/gpu/redundant_convert_mover.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "xla/service/gpu/redundant_convert_mover.h" 17 | 18 | #include "xla/hlo/ir/hlo_instruction.h" 19 | #include "xla/primitive_util.h" 20 | #include "xla/service/hlo_creation_utils.h" 21 | #include "xla/service/pattern_matcher.h" 22 | 23 | namespace xla { 24 | namespace gpu { 25 | 26 | namespace { 27 | namespace m = match; 28 | 29 | bool IsBitcastAsReshape(const HloInstruction* instr) { 30 | DCHECK(instr->opcode() == HloOpcode::kBitcast); 31 | return instr->shape().element_type() == 32 | instr->operand(0)->shape().element_type(); 33 | } 34 | 35 | template 36 | auto OptionalBitcast(Pattern pattern) { 37 | return m::AnyOf( 38 | m::Bitcast(pattern).WithPredicate(IsBitcastAsReshape), 39 | std::move(pattern)); 40 | } 41 | 42 | bool IsConvertNoLoss(const HloInstruction* instr) { 43 | DCHECK(instr->opcode() == HloOpcode::kConvert); 44 | return primitive_util::CastPreservesValues( 45 | instr->operand(0)->shape().element_type(), instr->shape().element_type()); 46 | } 47 | 48 | bool IsConvert(const HloInstruction* instr) { 49 | return instr->opcode() == HloOpcode::kConvert; 50 | } 51 | 52 | bool MatchDuplicateConvertPatterns(HloInstruction* instr, 53 | HloInstruction** bitcast_input) { 54 | // try to match convert(optionalbitcast(convert(optionalbitcast(input)))) 55 | // where input's shape and element type is same as the final output 56 | auto default_duplicate_convert_pattern = 57 | m::Op().WithPredicate(IsConvert).WithOneUse().WithOperand( 58 | 0, OptionalBitcast( 59 | m::Op() 60 | .WithOperand(0, OptionalBitcast(m::Op(bitcast_input))) 61 | .WithPredicate(IsConvert) 62 | .WithPredicate(IsConvertNoLoss) 63 | .WithOneUse())); 64 | if (Match(instr, default_duplicate_convert_pattern) && 65 | instr->shape() == (*bitcast_input)->shape()) { 66 | return true; 67 | } 68 | return false; 69 | } 70 | 71 | StatusOr RemoveRedundantConversion(HloInstruction* instr) { 72 | HloInstruction* bitcast_input = nullptr; 73 | if (MatchDuplicateConvertPatterns(instr, &bitcast_input)) { 74 | TF_RETURN_IF_ERROR( 75 | instr->parent()->ReplaceInstruction(instr, bitcast_input)); 76 | return true; 77 | } 78 | return false; 79 | } 80 | 81 | } // namespace 82 | 83 | StatusOr RedundantConvertMover::Run( 84 | HloModule* module, 85 | const absl::flat_hash_set& execution_threads) { 86 | bool any_changed = false; 87 | for (HloComputation* computation : 88 | module->MakeNonfusionComputations(execution_threads)) { 89 | for (HloInstruction* instr : computation->MakeInstructionPostOrder()) { 90 | bool changed = false; 91 | TF_ASSIGN_OR_RETURN(changed, RemoveRedundantConversion(instr)); 92 | any_changed |= changed; 93 | } 94 | } 95 | return any_changed; 96 | } 97 | 98 | } // namespace gpu 99 | } // namespace xla -------------------------------------------------------------------------------- /xla/service/gpu/redundant_convert_mover.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef XLA_SERVICE_GPU_REDUNDANT_CONVERT_MOVER_H_ 17 | #define XLA_SERVICE_GPU_REDUNDANT_CONVERT_MOVER_H_ 18 | 19 | #include 20 | #include 21 | 22 | #include "xla/service/hlo_pass_interface.h" 23 | 24 | namespace xla { 25 | namespace gpu { 26 | 27 | class RedundantConvertMover : public HloModulePass { 28 | public: 29 | RedundantConvertMover() = default; 30 | 31 | absl::string_view name() const override { return "redundant-convert-mover"; } 32 | StatusOr Run( 33 | HloModule* module, 34 | const absl::flat_hash_set& execution_threads) override; 35 | }; 36 | 37 | } // namespace gpu 38 | } // namespace xla 39 | 40 | #endif // XLA_SERVICE_GPU_REDUNDANT_CONVERT_MOVER_H_ 41 | -------------------------------------------------------------------------------- /xla/service/gpu/scratch_allocator.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "xla/service/gpu/scratch_allocator.h" 17 | 18 | namespace xla { 19 | namespace gpu { 20 | tsl::Status AllocateWorkspace( 21 | void** workspace, stream_executor::ScratchAllocator* scratch_allocator, 22 | size_t num_bytes) { 23 | TF_ASSIGN_OR_RETURN(stream_executor::DeviceMemory workspace_bytes, 24 | scratch_allocator->AllocateBytes(num_bytes)); 25 | *workspace = static_cast(workspace_bytes.opaque()); 26 | return tsl::OkStatus(); 27 | } 28 | 29 | } // namespace gpu 30 | } // namespace xla -------------------------------------------------------------------------------- /xla/service/gpu/scratch_allocator.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef XLA_SERVICE_GPU_SCRATCH_ALLOCATOR_H_ 17 | #define XLA_SERVICE_GPU_SCRATCH_ALLOCATOR_H_ 18 | #include "xla/stream_executor/scratch_allocator.h" 19 | 20 | namespace xla { 21 | namespace gpu { 22 | tsl::Status AllocateWorkspace( 23 | void** workspace, stream_executor::ScratchAllocator* scratch_allocator, 24 | size_t num_bytes); 25 | 26 | } // namespace gpu 27 | } // namespace xla 28 | #endif // XLA_SERVICE_GPU_SCRATCH_ALLOCATOR_H_ -------------------------------------------------------------------------------- /xla/service/gpu/sycl_onednn.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "xla/service/gpu/sycl_onednn.h" 17 | 18 | #include 19 | 20 | namespace xla { 21 | namespace gpu { 22 | 23 | absl::Status RunGpuConvCustomCall( 24 | se::Stream* stream, se::ScratchAllocator* scratch_allocator, 25 | std::vector& operand_se_buffers, 26 | ffi::BufferBase& result_buffer, const ffi::Dictionary& dict, 27 | CudnnConvKind conv_kind) { 28 | TF_ASSIGN_OR_RETURN(auto conv_primitive, 29 | GetOrCreateOneDnnConvPrimitive( 30 | stream, dict, operand_se_buffers, result_buffer, 31 | scratch_allocator, conv_kind)); 32 | TF_RETURN_IF_ERROR(RunGpuConv(conv_primitive, dict, 33 | absl::MakeSpan(operand_se_buffers), 34 | result_buffer, conv_kind)); 35 | return absl::OkStatus(); 36 | } 37 | 38 | absl::Status RunGemmCustomCall(ffi::BufferBase* lhs, ffi::BufferBase* rhs, 39 | ffi::BufferBase* add, ffi::BufferBase* output, 40 | ffi::BufferBase* bias, se::Stream* stream, 41 | const ffi::Dictionary& dict, 42 | SYCLGemm::GemmBackendEpilogue epilogue, 43 | se::ScratchAllocator* scratch_allocator) { 44 | se::DeviceMemoryBase lhs_data = lhs->data; 45 | se::DeviceMemoryBase rhs_data = rhs->data; 46 | se::DeviceMemoryBase output_data = output->data; 47 | se::DeviceMemoryBase add_data; 48 | se::DeviceMemoryBase bias_data; 49 | if (add != nullptr) add_data = add->data; 50 | if (bias != nullptr) bias_data = bias->data; 51 | 52 | return RunGemm(dict, lhs_data, rhs_data, add_data, output_data, bias_data, 53 | stream, epilogue, scratch_allocator); 54 | } 55 | 56 | } // namespace gpu 57 | } // namespace xla -------------------------------------------------------------------------------- /xla/service/gpu/sycl_onednn.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef XLA_SERVICE_GPU_SYCL_ONEDNN_H_ 17 | #define XLA_SERVICE_GPU_SYCL_ONEDNN_H_ 18 | 19 | #include "xla/service/gpu/onednn_gpu_conv_runner.h" 20 | #include "xla/service/gpu/onednn_matmul_utils.h" 21 | 22 | namespace xla { 23 | namespace gpu { 24 | 25 | absl::Status RunGpuConvCustomCall( 26 | se::Stream* stream, se::ScratchAllocator* scratch_allocator, 27 | std::vector& operand_se_buffers, 28 | ffi::BufferBase& result_buffer, const ffi::Dictionary& dict, 29 | CudnnConvKind conv_kind); 30 | 31 | absl::Status RunGemmCustomCall( 32 | ffi::BufferBase* lhs, ffi::BufferBase* rhs, 33 | ffi::BufferBase* add, ffi::BufferBase* output, 34 | ffi::BufferBase* bias, se::Stream* stream, 35 | const ffi::Dictionary& dict, 36 | SYCLGemm::GemmBackendEpilogue epilogue, 37 | se::ScratchAllocator* scratch_allocator = nullptr); 38 | 39 | } // namespace gpu 40 | } // namespace xla 41 | 42 | #endif // XLA_SERVICE_GPU_SYCL_ONEDNN_H_ -------------------------------------------------------------------------------- /xla/service/gpu/utils.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef XLA_SERVICE_GPU_UTILS 17 | #define XLA_SERVICE_GPU_UTILS 18 | 19 | #include 20 | 21 | #define UNROLL_ON_DEVICE _Pragma("unroll") 22 | 23 | // Represents an aligned array of N elements of T. Data pointers can be 24 | // reinterpreted as this type to generate vectorized loads/stores in a kernel. 25 | template > 26 | class alignas(alignof(T) * N) AlignedVector { 27 | public: 28 | AlignedVector() = default; 29 | 30 | explicit AlignedVector(T uniform) { 31 | UNROLL_ON_DEVICE for (uint32_t i = 0; i < N; ++i) { values_[i] = uniform; } 32 | } 33 | 34 | template 35 | void Load(const AlignedVector& other) { 36 | UNROLL_ON_DEVICE for (uint32_t i = 0; i < N; ++i) { 37 | values_[i] = static_cast(other[i]); 38 | } 39 | } 40 | 41 | template 42 | void Accumulate(const AlignedVector& other) { 43 | UNROLL_ON_DEVICE for (uint32_t i = 0; i < N; ++i) { 44 | values_[i] = Func()(values_[i], static_cast(other[i])); 45 | } 46 | } 47 | 48 | template 49 | void Store(AlignedVector& other) { 50 | UNROLL_ON_DEVICE for (uint32_t i = 0; i < N; ++i) { 51 | other[i] = static_cast(values_[i]); 52 | } 53 | } 54 | 55 | template 56 | void PartialStore(AlignedVector& other, uint32_t num, 57 | uint32_t offset = 0) { 58 | UNROLL_ON_DEVICE for (uint32_t i = 0; i < N && i < num; ++i) { 59 | other[i] = static_cast(values_[i + offset]); 60 | } 61 | } 62 | 63 | T& operator[](uint32_t i) { return values_[i]; } 64 | const T& operator[](uint32_t i) const { return values_[i]; } 65 | 66 | private: 67 | T values_[N]; 68 | }; 69 | 70 | #undef UNROLL_ON_DEVICE 71 | 72 | #endif // XLA_SERVICE_GPU_UTILS 73 | -------------------------------------------------------------------------------- /xla/service/gpu/xetla/gemm/BUILD: -------------------------------------------------------------------------------- 1 | load("//xla:xla.bzl", "xetla_library") 2 | 3 | xetla_library( 4 | name = "gemm_common", 5 | hdrs = [ 6 | "gemm_common.h" 7 | ], 8 | copts = [ 9 | "-Wall", 10 | "-Wno-c++11-narrowing", 11 | ], 12 | visibility = ["//visibility:public"], 13 | deps = [ 14 | "//xla/service/gpu:matrix_descriptor", 15 | ], 16 | ) 17 | 18 | xetla_library( 19 | name = "gemm_dispatch", 20 | hdrs = [ 21 | "gemm_dispatch.h", 22 | "hgemm_impl.h", 23 | "epilogue_impl.h", 24 | ], 25 | copts = [ 26 | "-Wall", 27 | "-Wno-c++11-narrowing", 28 | ], 29 | visibility = ["//visibility:public"], 30 | deps = [ 31 | ":gemm_common", 32 | "//xla/service/gpu:matrix_descriptor", 33 | "//xla/stream_executor/sycl:sycl_executor", 34 | "@xetla//:xetla_header", 35 | "@com_google_absl//absl/strings", 36 | ], 37 | ) 38 | 39 | xetla_library( 40 | name = "dispatch_row_major", 41 | srcs = [ 42 | "dispatch_row_major.cc", 43 | ], 44 | hdrs = [ 45 | "dispatch_row_major.h", 46 | ], 47 | copts = [ 48 | "-Wall", 49 | "-Wno-c++11-narrowing", 50 | ], 51 | visibility = ["//visibility:public"], 52 | deps = [ 53 | ":gemm_dispatch", 54 | "//xla/stream_executor/sycl:sycl_executor", 55 | ], 56 | ) 57 | 58 | xetla_library( 59 | name = "dispatch_col_major", 60 | srcs = [ 61 | "dispatch_col_major.cc", 62 | ], 63 | hdrs = [ 64 | "dispatch_col_major.h", 65 | ], 66 | copts = [ 67 | "-Wall", 68 | "-Wno-c++11-narrowing", 69 | ], 70 | visibility = ["//visibility:public"], 71 | deps = [ 72 | ":gemm_dispatch", 73 | "//xla/stream_executor/sycl:sycl_executor", 74 | ], 75 | ) 76 | 77 | xetla_library( 78 | name = "gemm_kernel", 79 | srcs = [ 80 | "gemm.cc", 81 | ], 82 | hdrs = [ 83 | "gemm.h", 84 | ], 85 | copts = [ 86 | "-Wall", 87 | "-Wno-c++11-narrowing", 88 | ], 89 | visibility = ["//visibility:public"], 90 | deps = [ 91 | ":gemm_common", 92 | ":dispatch_row_major", 93 | ":dispatch_col_major", 94 | "//xla/service/gpu:matrix_descriptor", 95 | "//xla/stream_executor/sycl:sycl_executor", 96 | "@xetla//:xetla_header", 97 | "@com_google_absl//absl/strings", 98 | ], 99 | ) 100 | -------------------------------------------------------------------------------- /xla/service/gpu/xetla/gemm/dispatch_col_major.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "xla/service/gpu/xetla/gemm/dispatch_col_major.h" 17 | 18 | #include "xla/service/gpu/xetla/gemm/gemm_common.h" 19 | #include "xla/service/gpu/xetla/gemm/gemm_dispatch.h" 20 | #include "xla/stream_executor/gpu/gpu_types.h" 21 | 22 | namespace gpu { 23 | namespace xetla { 24 | 25 | template 26 | bool GemmColMajorDispatcher::run(se::gpu::GpuStreamHandle handle) { 27 | int WG_M = std::get<0>(selected_policy_id_); 28 | int WG_N = std::get<1>(selected_policy_id_); 29 | int SG_M = std::get<2>(selected_policy_id_); 30 | int SG_N = std::get<3>(selected_policy_id_); 31 | int SG_K = std::get<4>(selected_policy_id_); 32 | int SLM_KS = std::get<5>(selected_policy_id_); 33 | return gemm_policy::call(WG_M, WG_N, SG_M, SG_N, SG_K, SLM_KS, 34 | this, handle); 35 | } 36 | 37 | template 38 | template 39 | bool GemmColMajorDispatcher::dispatch( 40 | se::gpu::GpuStreamHandle handle) { 41 | return do_dispatch( 42 | handle, params_); 43 | } 44 | 45 | template class GemmColMajorDispatcher; 46 | template class GemmColMajorDispatcher; 47 | 48 | } // namespace xetla 49 | } // namespace gpu 50 | -------------------------------------------------------------------------------- /xla/service/gpu/xetla/gemm/dispatch_col_major.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef XLA_SERVICE_GPU_XETLA_GEMM_DISPATCH_COL_MAJOR_H_ 17 | #define XLA_SERVICE_GPU_XETLA_GEMM_DISPATCH_COL_MAJOR_H_ 18 | 19 | #include "xla/service/gpu/xetla/gemm/gemm_common.h" 20 | #include "xla/service/gpu/xetla/gemm/gemm_dispatch.h" 21 | #include "xla/stream_executor/gpu/gpu_types.h" 22 | 23 | namespace gpu { 24 | namespace xetla { 25 | 26 | template 27 | class GemmColMajorDispatcher { 28 | public: 29 | GemmColMajorDispatcher() = default; 30 | 31 | GemmColMajorDispatcher( 32 | DispatchParams* params, 33 | std::tuple selected_policy_id) 34 | : params_(params), selected_policy_id_(selected_policy_id) {} 35 | 36 | template 37 | bool dispatch(se::gpu::GpuStreamHandle handle); 38 | 39 | bool run(se::gpu::GpuStreamHandle handle); 40 | 41 | private: 42 | DispatchParams* params_; 43 | std::tuple selected_policy_id_; 44 | }; 45 | 46 | } // namespace xetla 47 | } // namespace gpu 48 | 49 | #endif // XLA_SERVICE_GPU_XETLA_GEMM_DISPATCH_COL_MAJOR_H_ -------------------------------------------------------------------------------- /xla/service/gpu/xetla/gemm/dispatch_row_major.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "xla/service/gpu/xetla/gemm/dispatch_row_major.h" 17 | 18 | #include "xla/service/gpu/xetla/gemm/gemm_common.h" 19 | #include "xla/service/gpu/xetla/gemm/gemm_dispatch.h" 20 | #include "xla/stream_executor/gpu/gpu_types.h" 21 | 22 | namespace gpu { 23 | namespace xetla { 24 | 25 | template 26 | bool GemmRowMajorDispatcher::run(se::gpu::GpuStreamHandle handle) { 27 | int WG_M = std::get<0>(selected_policy_id_); 28 | int WG_N = std::get<1>(selected_policy_id_); 29 | int SG_M = std::get<2>(selected_policy_id_); 30 | int SG_N = std::get<3>(selected_policy_id_); 31 | int SG_K = std::get<4>(selected_policy_id_); 32 | int SLM_KS = std::get<5>(selected_policy_id_); 33 | return gemm_policy::call(WG_M, WG_N, SG_M, SG_N, SG_K, SLM_KS, 34 | this, handle); 35 | } 36 | 37 | template 38 | template 39 | bool GemmRowMajorDispatcher::dispatch( 40 | se::gpu::GpuStreamHandle handle) { 41 | return do_dispatch( 42 | handle, params_); 43 | } 44 | 45 | template class GemmRowMajorDispatcher; 46 | template class GemmRowMajorDispatcher; 47 | 48 | } // namespace xetla 49 | } // namespace gpu 50 | -------------------------------------------------------------------------------- /xla/service/gpu/xetla/gemm/dispatch_row_major.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef XLA_SERVICE_GPU_XETLA_GEMM_DISPATCH_ROW_MAJOR_H_ 17 | #define XLA_SERVICE_GPU_XETLA_GEMM_DISPATCH_ROW_MAJOR_H_ 18 | 19 | #include "xla/service/gpu/xetla/gemm/gemm_common.h" 20 | #include "xla/service/gpu/xetla/gemm/gemm_dispatch.h" 21 | #include "xla/stream_executor/gpu/gpu_types.h" 22 | 23 | namespace gpu { 24 | namespace xetla { 25 | 26 | template 27 | class GemmRowMajorDispatcher { 28 | public: 29 | GemmRowMajorDispatcher() = default; 30 | 31 | GemmRowMajorDispatcher( 32 | DispatchParams* params, 33 | std::tuple selected_policy_id) 34 | : params_(params), selected_policy_id_(selected_policy_id) {} 35 | 36 | template 37 | bool dispatch(se::gpu::GpuStreamHandle handle); 38 | 39 | bool run(se::gpu::GpuStreamHandle handle); 40 | 41 | private: 42 | DispatchParams* params_; 43 | std::tuple selected_policy_id_; 44 | }; 45 | 46 | } // namespace xetla 47 | } // namespace gpu 48 | 49 | #endif // XLA_SERVICE_GPU_XETLA_GEMM_DISPATCH_ROW_MAJOR_H_ -------------------------------------------------------------------------------- /xla/service/gpu/xetla/gemm/gemm_common.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef XLA_SERVICE_GPU_XETLA_GEMM_GEMM_COMMON_H_ 17 | #define XLA_SERVICE_GPU_XETLA_GEMM_GEMM_COMMON_H_ 18 | 19 | #include "xla/service/gpu/matrix_descriptor.h" 20 | 21 | namespace se = ::stream_executor; 22 | const int kMaxNumEpilogues = 4; 23 | 24 | namespace gpu { 25 | namespace xetla { 26 | 27 | enum EpilogueType { 28 | BIAS = 0, 29 | RES_ADD, 30 | GELU, 31 | RES_MUL, 32 | SILU, 33 | }; 34 | 35 | } // namespace xetla 36 | } // namespace gpu 37 | 38 | #endif // XLA_SERVICE_GPU_XETLA_GEMM_GEMM_COMMON_H_ -------------------------------------------------------------------------------- /xla/service/gpu/xetla/sdp/BUILD: -------------------------------------------------------------------------------- 1 | load("//xla:xla.bzl", "xetla_library") 2 | 3 | # List all kernels here. 4 | xetla_library( 5 | name = "sdp_forward_kernel", 6 | srcs = [ 7 | "sdp_forward.cc", 8 | ], 9 | hdrs = [ 10 | "sdp_forward.h", 11 | "fmha_forward.h", 12 | "fmha_policy.h", 13 | "fmha_utils.h", 14 | ], 15 | visibility = ["//visibility:public"], 16 | deps = [ 17 | "@xetla//:xetla_header", 18 | "@tsl//tsl/platform:logging", 19 | ], 20 | ) 21 | 22 | xetla_library( 23 | name = "sdp_backward_kernel", 24 | srcs = [ 25 | "sdp_backward.cc", 26 | ], 27 | hdrs = [ 28 | "sdp_backward.h", 29 | "fmha_backward.h", 30 | "fmha_policy.h", 31 | "fmha_utils.h", 32 | ], 33 | visibility = ["//visibility:public"], 34 | deps = [ 35 | "@xetla//:xetla_header", 36 | "@tsl//tsl/platform:logging", 37 | ], 38 | ) 39 | 40 | -------------------------------------------------------------------------------- /xla/service/gpu/xetla/sdp/sdp_backward.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "sdp_backward.h" 17 | 18 | #include "fmha_backward.h" 19 | #include "xetla.hpp" 20 | 21 | namespace gpu::xetla { 22 | 23 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 24 | [&] { \ 25 | if (COND) { \ 26 | constexpr static bool CONST_NAME = true; \ 27 | return __VA_ARGS__(); \ 28 | } else { \ 29 | constexpr static bool CONST_NAME = false; \ 30 | return __VA_ARGS__(); \ 31 | } \ 32 | }() 33 | 34 | void fmha_backward_kernel_fp16( 35 | sycl::queue& q, void* query, void* key, void* value, void* out, void* bias, 36 | void* grad_out, void* dp_sum, void* activation_ptr, void* grad_query, 37 | void* grad_query_accum, void* grad_key, void* grad_value, 38 | uint32_t num_batches, uint32_t num_heads, uint32_t head_size, 39 | uint32_t num_queries, uint32_t num_keys, float head_scale) { 40 | const bool use_dropout = false; 41 | bool use_bias = bias == nullptr ? false : true; 42 | BOOL_SWITCH(use_bias, kUseBias, [&] { 43 | BOOL_SWITCH(use_dropout, kIsDropout, [&] { 44 | fmha_backward( 45 | q, static_cast(query), static_cast(key), 46 | static_cast(value), static_cast(out), 47 | static_cast(bias), static_cast(grad_out), 48 | static_cast(dp_sum), static_cast(activation_ptr), 49 | static_cast(grad_query), static_cast(grad_query_accum), 50 | static_cast(grad_key), static_cast(grad_value), 51 | num_batches, num_heads, head_size, num_queries, num_keys, head_scale); 52 | }); 53 | }); 54 | } 55 | 56 | void fmha_backward_kernel_bf16( 57 | sycl::queue& q, void* query, void* key, void* value, void* out, void* bias, 58 | void* grad_out, void* dp_sum, void* activation_ptr, void* grad_query, 59 | void* grad_query_accum, void* grad_key, void* grad_value, 60 | uint32_t num_batches, uint32_t num_heads, uint32_t head_size, 61 | uint32_t num_queries, uint32_t num_keys, float head_scale) { 62 | const bool use_dropout = false; 63 | bool use_bias = bias == nullptr ? false : true; 64 | BOOL_SWITCH(use_bias, kUseBias, [&] { 65 | BOOL_SWITCH(use_dropout, kIsDropout, [&] { 66 | fmha_backward( 67 | q, static_cast(query), static_cast(key), 68 | static_cast(value), static_cast(out), 69 | static_cast(bias), static_cast(grad_out), 70 | static_cast(dp_sum), static_cast(activation_ptr), 71 | static_cast(grad_query), static_cast(grad_query_accum), 72 | static_cast(grad_key), static_cast(grad_value), 73 | num_batches, num_heads, head_size, num_queries, num_keys, head_scale); 74 | }); 75 | }); 76 | } 77 | 78 | #undef BOOL_SWITCH 79 | } // namespace gpu::xetla -------------------------------------------------------------------------------- /xla/service/gpu/xetla/sdp/sdp_backward.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include 19 | 20 | namespace gpu::xetla { 21 | 22 | void fmha_backward_kernel_fp16( 23 | sycl::queue& q, void* query, void* key, void* value, void* out, void* bias, 24 | void* grad_out, void* dp_sum, void* activation_ptr, void* grad_query, 25 | void* grad_query_accum, void* grad_key, void* grad_value, 26 | uint32_t num_batches, uint32_t num_heads, uint32_t head_size, 27 | uint32_t num_queries, uint32_t num_keys, float head_scale); 28 | 29 | void fmha_backward_kernel_bf16( 30 | sycl::queue& q, void* query, void* key, void* value, void* out, void* bias, 31 | void* grad_out, void* dp_sum, void* activation_ptr, void* grad_query, 32 | void* grad_query_accum, void* grad_key, void* grad_value, 33 | uint32_t num_batches, uint32_t num_heads, uint32_t head_size, 34 | uint32_t num_queries, uint32_t num_keys, float head_scale); 35 | 36 | } // namespace gpu::xetla -------------------------------------------------------------------------------- /xla/service/gpu/xetla/sdp/sdp_forward.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "sdp_forward.h" 17 | 18 | #include "fmha_forward.h" 19 | #include "xetla.hpp" 20 | 21 | namespace gpu::xetla { 22 | 23 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 24 | [&] { \ 25 | if (COND) { \ 26 | constexpr static bool CONST_NAME = true; \ 27 | return __VA_ARGS__(); \ 28 | } else { \ 29 | constexpr static bool CONST_NAME = false; \ 30 | return __VA_ARGS__(); \ 31 | } \ 32 | }() 33 | 34 | void fmha_forward_kernel_fp16(sycl::queue& q, void* query, void* key, 35 | void* value, void* bias, uint8_t* dropout, 36 | float dropout_prob, void* out, 37 | void* activation_ptr, uint32_t num_batches, 38 | uint32_t num_heads, uint32_t head_size, 39 | uint32_t num_queries, uint32_t num_keys, 40 | float head_scale, bool is_training) { 41 | const bool use_causal = false; 42 | const bool use_dropout = false; 43 | bool use_bias = bias == nullptr ? false : true; 44 | BOOL_SWITCH(use_causal, kIsCausal, [&] { 45 | BOOL_SWITCH(use_bias, kUseBias, [&] { 46 | BOOL_SWITCH(use_dropout, kIsDropout, [&] { 47 | BOOL_SWITCH(is_training, kIsTraining, [&] { 48 | fmha_forward( 49 | q, static_cast(query), static_cast(key), 50 | static_cast(value), static_cast(bias), dropout, 51 | dropout_prob, static_cast(out), 52 | static_cast(activation_ptr), num_batches, num_heads, 53 | head_size, num_queries, num_keys, head_scale); 54 | }); 55 | }); 56 | }); 57 | }); 58 | } 59 | 60 | void fmha_forward_kernel_bf16(sycl::queue& q, void* query, void* key, 61 | void* value, void* bias, uint8_t* dropout, 62 | float dropout_prob, void* out, 63 | void* activation_ptr, uint32_t num_batches, 64 | uint32_t num_heads, uint32_t head_size, 65 | uint32_t num_queries, uint32_t num_keys, 66 | float head_scale, bool is_training) { 67 | const bool use_causal = false; 68 | const bool use_dropout = false; 69 | bool use_bias = bias == nullptr ? false : true; 70 | BOOL_SWITCH(use_causal, kIsCausal, [&] { 71 | BOOL_SWITCH(use_bias, kUseBias, [&] { 72 | BOOL_SWITCH(use_dropout, kIsDropout, [&] { 73 | BOOL_SWITCH(is_training, kIsTraining, [&] { 74 | fmha_forward( 75 | q, static_cast(query), static_cast(key), 76 | static_cast(value), static_cast(bias), dropout, 77 | dropout_prob, static_cast(out), 78 | static_cast(activation_ptr), num_batches, num_heads, 79 | head_size, num_queries, num_keys, head_scale); 80 | }); 81 | }); 82 | }); 83 | }); 84 | } 85 | 86 | #undef BOOL_SWITCH 87 | } // namespace gpu::xetla -------------------------------------------------------------------------------- /xla/service/gpu/xetla/sdp/sdp_forward.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include 19 | 20 | namespace gpu::xetla { 21 | 22 | void fmha_forward_kernel_fp16(sycl::queue& q, void* query, void* key, 23 | void* value, void* bias, uint8_t* dropout, 24 | float dropout_prob, void* out, 25 | void* activation_ptr, uint32_t num_batches, 26 | uint32_t num_heads, uint32_t head_size, 27 | uint32_t num_queries, uint32_t num_keys, 28 | float head_scale, bool is_training); 29 | 30 | void fmha_forward_kernel_bf16(sycl::queue& q, void* query, void* key, 31 | void* value, void* bias, uint8_t* dropout, 32 | float dropout_prob, void* out, 33 | void* activation_ptr, uint32_t num_batches, 34 | uint32_t num_heads, uint32_t head_size, 35 | uint32_t num_queries, uint32_t num_keys, 36 | float head_scale, bool is_training); 37 | 38 | } // namespace gpu::xetla -------------------------------------------------------------------------------- /xla/service/gpu/xetla_gpu_fused_mha_runner.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Copyright 2023 The TensorFlow Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | #ifndef XLA_SERVICE_GPU_XETLA_GPU_FUSED_MHA_RUNNER_H_ 19 | #define XLA_SERVICE_GPU_XETLA_GPU_FUSED_MHA_RUNNER_H_ 20 | 21 | #include 22 | 23 | #include "xla/hlo/ir/hlo_instruction.h" 24 | #include "xla/hlo/ir/hlo_instructions.h" 25 | #include "xla/service/gpu/backend_configs.pb.h" 26 | #include "xla/service/gpu/cublas_cudnn.h" 27 | #include "xla/service/gpu/gpu_fused_mha_runner.h" 28 | #include "xla/stream_executor/dnn.h" 29 | #include "xla/stream_executor/stream_executor.h" 30 | #include "xla/types.h" 31 | #include "xla/xla_data.pb.h" 32 | 33 | namespace xla { 34 | namespace gpu { 35 | 36 | absl::Status RunXetlaGpuFMHA( 37 | const GpufMHAConfig& fmha_config, se::DeviceMemoryBase lhs_bmm1_buffer, 38 | se::DeviceMemoryBase rhs_bmm1_buffer, se::DeviceMemoryBase rhs_bmm2_buffer, 39 | se::DeviceMemoryBase output_buffer, se::DeviceMemoryBase scratch_buffer, 40 | std::optional bias_buffer, 41 | std::optional activation_buffer, se::Stream* stream); 42 | 43 | absl::Status RunXetlaGpuFMHABackward( 44 | const GpufMHABackwardConfig& fmha_config, 45 | se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, 46 | se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, 47 | se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, 48 | se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, 49 | se::DeviceMemoryBase d_output_buffer, se::DeviceMemoryBase scratch_buffer, 50 | se::DeviceMemoryBase d_bmm1_lhs_buffer, 51 | se::DeviceMemoryBase d_bmm1_rhs_buffer, 52 | se::DeviceMemoryBase d_bmm2_rhs_buffer, 53 | std::optional d_s_buffer, 54 | std::optional d_bias_buffer, 55 | std::optional fwd_output_buffer, 56 | std::optional bias_buffer, 57 | std::optional softmax_buffer, 58 | std::optional accum_buffer, se::Stream* stream); 59 | 60 | } // namespace gpu 61 | } // namespace xla 62 | #endif // XLA_SERVICE_GPU_XETLA_GPU_FUSED_MHA_RUNNER_H_ 63 | -------------------------------------------------------------------------------- /xla/service/onednn_util.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef XLA_SERVICE_ONEDNN_UTIL_H_ 17 | #define XLA_SERVICE_ONEDNN_UTIL_H_ 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | #include "dnnl.hpp" // NOLINT(build/include_subdir) 25 | #include "dnnl_sycl.hpp" // NOLINT(build/include_subdir) 26 | #include "xla/tsl/util/env_var.h" 27 | #include "xla/stream_executor/gpu/gpu_types.h" 28 | 29 | namespace xla { 30 | inline dnnl::memory::dims CalculateTFStrides( 31 | const dnnl::memory::dims& dims_tf_order) { 32 | CHECK_GT(dims_tf_order.size(), 0); 33 | dnnl::memory::dims strides(dims_tf_order.size(), 1); 34 | for (int d = strides.size() - 2; d >= 0; d--) { 35 | strides[d] = strides[d + 1] * dims_tf_order[d + 1]; 36 | } 37 | return strides; 38 | } 39 | 40 | static dnnl::engine& FindOrCreateEngine(se::gpu::GpuStreamHandle stream) { 41 | static std::map stream_engine_map; 42 | auto iter = stream_engine_map.find(stream); 43 | if (iter != stream_engine_map.end()) return iter->second; 44 | 45 | dnnl::engine engine; 46 | engine = dnnl::sycl_interop::make_engine(stream->get_device(), 47 | stream->get_context()); 48 | return stream_engine_map 49 | .insert(std::pair(stream, engine)) 50 | .first->second; 51 | } 52 | 53 | inline dnnl::fpmath_mode GetFP32MathMode() { 54 | std::string fp32_math_mode = "fp32"; 55 | TF_CHECK_OK( 56 | tsl::ReadStringFromEnvVar("XLA_FP32_MATH_MODE", "fp32", &fp32_math_mode)); 57 | fp32_math_mode = tsl::str_util::Lowercase(fp32_math_mode); 58 | if (fp32_math_mode == "fp32") { 59 | return dnnl::fpmath_mode::strict; 60 | } 61 | if (fp32_math_mode == "tf32") { 62 | return dnnl::fpmath_mode::tf32; 63 | } 64 | if (fp32_math_mode == "bf32") { 65 | LOG(FATAL) << "Did not support BF32 math mode on GPU "; 66 | } 67 | LOG(FATAL) 68 | << "Invalid XLA_FP32_MATH_MODE, should be FP32, TF32 or BF32, but got " 69 | << fp32_math_mode; 70 | } 71 | 72 | inline dnnl::memory CreateDnnlMemory(const dnnl::memory::desc& md, 73 | const dnnl::engine& engine, 74 | void* data_handle = nullptr) { 75 | if (engine.get_kind() == dnnl::engine::kind::gpu) { 76 | auto kind = dnnl::sycl_interop::memory_kind::usm; 77 | if (data_handle == nullptr) 78 | return dnnl::sycl_interop::make_memory(md, engine, kind, 79 | DNNL_MEMORY_ALLOCATE); 80 | else 81 | return dnnl::sycl_interop::make_memory(md, engine, kind, data_handle); 82 | } 83 | 84 | // Default path, always assume it's CPU engine. 85 | CHECK(engine.get_kind() == dnnl::engine::kind::cpu) 86 | << "Create oneDNN memory for unsupported engine."; 87 | if (data_handle == nullptr) 88 | return dnnl::memory(md, engine); 89 | else 90 | return dnnl::memory(md, engine, data_handle); 91 | } 92 | } // namespace xla 93 | #endif // XLA_SERVICE_ONEDNN_UTIL_H_ -------------------------------------------------------------------------------- /xla/stream_executor/BUILD: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = [ 3 | "//visibility:public", 4 | ], 5 | licenses = ["notice"], 6 | ) 7 | 8 | alias( 9 | name = "sycl_platform", 10 | actual = "//xla/stream_executor/sycl:all_runtime", 11 | ) 12 | -------------------------------------------------------------------------------- /xla/stream_executor/sycl/hw_info.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef XLA_STREAM_EXECUTOR_SYCL_HW_INFO_H_ 17 | #define XLA_STREAM_EXECUTOR_SYCL_HW_INFO_H_ 18 | 19 | #include "xla/stream_executor/sycl/sycl_gpu_runtime.h" 20 | 21 | bool IsXeHPC(const sycl::device* device_ptr = nullptr); 22 | 23 | bool IsXeHPG(const sycl::device* device_ptr = nullptr); 24 | 25 | bool HasXMX(const sycl::device* device_ptr = nullptr); 26 | 27 | bool IsXetlaHardwareSupport(); 28 | 29 | bool IsARC(const sycl::device* device_ptr = nullptr); 30 | 31 | uint64_t GetMaxAllocateLimitByte(sycl::device* device_ptr = nullptr); 32 | #endif // XLA_STREAM_EXECUTOR_SYCL_HW_INFO_H_ 33 | -------------------------------------------------------------------------------- /xla/stream_executor/sycl/sycl_blas.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Copyright 2015 The TensorFlow Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | #ifndef XLA_STREAM_EXECUTOR_SYCL_BLAS_H_ 19 | #define XLA_STREAM_EXECUTOR_SYCL_BLAS_H_ 20 | 21 | #if __has_include() 22 | #include 23 | #elif __has_include() 24 | #include 25 | #else 26 | #error "Unsupported compiler" 27 | #endif 28 | 29 | #include "absl/base/thread_annotations.h" 30 | #include "absl/synchronization/mutex.h" 31 | #include "oneapi/mkl/blas.hpp" 32 | #include "oneapi/mkl/dfti.hpp" 33 | #include "oneapi/mkl/exceptions.hpp" 34 | #include "oneapi/mkl/lapack.hpp" 35 | #include "xla/stream_executor/blas.h" 36 | 37 | namespace stream_executor { 38 | 39 | class Stream; 40 | 41 | namespace gpu { 42 | class GpuExecutor; 43 | } // namespace gpu 44 | 45 | using syclStream_t = ::sycl::queue *; 46 | 47 | namespace sycl { 48 | // Thread-safe post-initialization. 49 | class SYCLBlas : public blas::BlasSupport { 50 | public: 51 | explicit SYCLBlas(gpu::GpuExecutor *parent); 52 | 53 | bool Init(); 54 | 55 | ~SYCLBlas() override; 56 | 57 | TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES 58 | 59 | gpu::BlasLt *GetBlasLt() override { return nullptr; } 60 | 61 | private: 62 | bool SetStream(Stream *stream) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 63 | 64 | syclStream_t SYCLStream(Stream *stream); 65 | 66 | absl::Mutex mu_; 67 | 68 | gpu::GpuExecutor *parent_; 69 | 70 | void *blas_ ABSL_GUARDED_BY(mu_); 71 | 72 | void *blas_it_; 73 | 74 | SYCLBlas(const SYCLBlas &) = delete; 75 | void operator=(const SYCLBlas &) = delete; 76 | }; 77 | 78 | } // namespace sycl 79 | } // namespace stream_executor 80 | 81 | #endif // XLA_STREAM_EXECUTOR_SYCL_BLAS_H_ 82 | -------------------------------------------------------------------------------- /xla/stream_executor/sycl/sycl_collectives.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Copyright 2024 The TensorFlow Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | #include 19 | 20 | #include "absl/status/status.h" 21 | #include "absl/status/statusor.h" 22 | #include "xla/stream_executor/gpu/gpu_driver.h" 23 | #include "xla/stream_executor/gpu/gpu_collectives.h" 24 | 25 | namespace stream_executor::gpu { 26 | 27 | absl::StatusOr GpuCollectives::CollectiveMemoryAllocate( 28 | GpuContext* context, uint64_t bytes) { 29 | return absl::UnimplementedError( 30 | "Feature not supported on SYCL platform (CollectiveMemoryAllocate)"); 31 | } 32 | 33 | absl::Status GpuCollectives::CollectiveMemoryDeallocate(GpuContext* context, 34 | void* location) { 35 | return absl::UnimplementedError( 36 | "Feature not supported on SYCL platform (CollectiveMemoryDeallocate)"); 37 | } 38 | 39 | } // namespace stream_executor::gpu 40 | -------------------------------------------------------------------------------- /xla/stream_executor/sycl/sycl_conditional_kernels.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Copyright 2023 The OpenXLA Authors. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | namespace stream_executor { 19 | namespace sycl { 20 | namespace { 21 | 22 | void SetCondition() {} 23 | 24 | } // namespace 25 | } // namespace sycl 26 | 27 | namespace gpu { 28 | void* GetSetIfConditionKernel() { 29 | return reinterpret_cast(&sycl::SetCondition); 30 | } 31 | void* GetSetIfElseConditionKernel() { 32 | return reinterpret_cast(&sycl::SetCondition); 33 | } 34 | void* GetSetCaseConditionKernel() { 35 | return reinterpret_cast(&sycl::SetCondition); 36 | } 37 | void* GetSetForConditionKernel() { 38 | return reinterpret_cast(&sycl::SetCondition); 39 | } 40 | void* GetSetWhileConditionKernel() { 41 | return reinterpret_cast(&sycl::SetCondition); 42 | } 43 | } // namespace gpu 44 | 45 | } // namespace stream_executor 46 | -------------------------------------------------------------------------------- /xla/stream_executor/sycl/sycl_dnn.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Copyright 2024 The TensorFlow Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | #include "xla/stream_executor/sycl/sycl_dnn.h" 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | 34 | #include "absl/base/optimization.h" 35 | #include "absl/base/thread_annotations.h" 36 | #include "absl/container/flat_hash_map.h" 37 | #include "absl/container/flat_hash_set.h" 38 | #include "absl/memory/memory.h" 39 | #include "absl/status/status.h" 40 | #include "absl/strings/str_cat.h" 41 | #include "absl/strings/str_format.h" 42 | #include "xla/stream_executor/sycl/sycl_platform_id.h" 43 | #include "xla/stream_executor/dnn.h" 44 | #include "xla/stream_executor/gpu/gpu_executor.h" 45 | #include "xla/stream_executor/gpu/gpu_timer.h" 46 | #include "xla/stream_executor/numeric_options.h" 47 | #include "xla/stream_executor/platform/initialize.h" 48 | #include "xla/stream_executor/plugin_registry.h" 49 | #include "xla/stream_executor/scratch_allocator.h" 50 | #include "xla/stream_executor/stream.h" 51 | #include "xla/stream_executor/stream_executor.h" 52 | #include "absl/strings/string_view.h" 53 | 54 | 55 | namespace stream_executor { 56 | namespace gpu { 57 | 58 | OnednnSupport::OnednnSupport(GpuExecutor* parent) : parent_(parent) {} 59 | 60 | absl::Status OnednnSupport::Init() { 61 | return absl::OkStatus(); 62 | } 63 | 64 | absl::StatusOr OnednnSupport::GetVersion() { 65 | return dnn::VersionInfo(0, 0, 0); 66 | } 67 | 68 | } // namespace gpu 69 | 70 | void initialize_onednn() { 71 | absl::Status status = 72 | PluginRegistry::Instance()->RegisterFactory( 73 | sycl::kSyclPlatformId, "oneDNN", 74 | [](StreamExecutor* parent) -> dnn::DnnSupport* { 75 | gpu::GpuExecutor* sycl_executor = 76 | dynamic_cast(parent); 77 | if (sycl_executor == nullptr) { 78 | LOG(ERROR) << "Attempting to initialize an instance of the oneDNN " 79 | << "support library with a non-SYCL StreamExecutor"; 80 | return nullptr; 81 | } 82 | 83 | gpu::OnednnSupport* dnn = new gpu::OnednnSupport(sycl_executor); 84 | if (!dnn->Init().ok()) { 85 | // Note: Init() will log a more specific error. 86 | delete dnn; 87 | return nullptr; 88 | } 89 | return dnn; 90 | }); 91 | 92 | if (!status.ok()) { 93 | LOG(ERROR) << "Unable to register oneDNN factory: " << status.message(); 94 | } 95 | } 96 | 97 | } // namespace stream_executor 98 | 99 | STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(register_onednn, { 100 | stream_executor::initialize_onednn(); 101 | }); 102 | 103 | -------------------------------------------------------------------------------- /xla/stream_executor/sycl/sycl_dnn.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2024 Intel Corporation 2 | 3 | Copyright 2024 The TensorFlow Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | // oneDNN library support, implementing the general DnnSupport interface. 19 | 20 | #ifndef XLA_STREAM_EXECUTOR_SYCL_SYCL_DNN_H_ 21 | #define XLA_STREAM_EXECUTOR_SYCL_SYCL_DNN_H_ 22 | 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | #include "absl/base/thread_annotations.h" 29 | #include "absl/status/status.h" 30 | #include "absl/types/span.h" 31 | #include "xla/stream_executor/dnn.h" 32 | #include "xla/stream_executor/plugin_registry.h" 33 | 34 | namespace stream_executor { 35 | namespace gpu { 36 | 37 | class GpuExecutor; 38 | 39 | // onednn-library based DNN support. For details on overridden interface 40 | // functions, see dnn.h. 41 | class OnednnSupport : public dnn::DnnSupport { 42 | public: 43 | explicit OnednnSupport(GpuExecutor* parent); 44 | 45 | absl::Status Init() override; 46 | absl::StatusOr GetVersion() override; 47 | 48 | absl::Status DoConvolve( 49 | dnn::ConvolutionKind kind, dnn::DataType element_type, 50 | dnn::DataType output_type, Stream* stream, 51 | const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, 52 | const dnn::FilterDescriptor& filter_descriptor, 53 | DeviceMemoryBase filter_data, 54 | const dnn::BatchDescriptor& output_descriptor, 55 | DeviceMemoryBase output_data, 56 | const dnn::ConvolutionDescriptor& convolution_descriptor, 57 | dnn::AlgorithmDesc algorithm_desc, DeviceMemory scratch_memory, 58 | dnn::ProfileResult* output_profile_result) override { 59 | return absl::UnimplementedError( 60 | "DnnSupport::DoConvolve not implemented on this platform."); 61 | } 62 | 63 | absl::Status DoPoolForward(dnn::DataType element_type, Stream* stream, 64 | const dnn::PoolingDescriptor& pooling_dimensions, 65 | const dnn::BatchDescriptor& input_dimensions, 66 | DeviceMemoryBase input_data, 67 | const dnn::BatchDescriptor& output_dimensions, 68 | DeviceMemoryBase output_data, 69 | ScratchAllocator* workspace_allocator) override { 70 | return absl::UnimplementedError( 71 | "DnnSupport::DoPoolForward not implemented on this platform."); 72 | } 73 | 74 | absl::Status DoPoolBackward(dnn::DataType element_type, Stream* stream, 75 | const dnn::PoolingDescriptor& pooling_dimensions, 76 | const dnn::BatchDescriptor& input_dimensions, 77 | DeviceMemoryBase input_data, 78 | const dnn::BatchDescriptor& output_dimensions, 79 | DeviceMemoryBase output_data, 80 | DeviceMemoryBase input_diff_data, 81 | DeviceMemoryBase output_diff_data, 82 | ScratchAllocator* workspace_allocator) override { 83 | return absl::UnimplementedError( 84 | "DnnSupport::DoPoolBackward not implemented on this platform."); 85 | } 86 | 87 | private: 88 | GpuExecutor* parent_; // Parent executor object. Not owned. 89 | 90 | OnednnSupport(const OnednnSupport&) = delete; 91 | void operator=(const OnednnSupport&) = delete; 92 | }; 93 | 94 | } // namespace gpu 95 | } // namespace stream_executor 96 | 97 | #endif // XLA_STREAM_EXECUTOR_SYCL_SYCL_DNN_H_ 98 | -------------------------------------------------------------------------------- /xla/stream_executor/sycl/sycl_event.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Copyright 2015 The TensorFlow Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | #include "xla/stream_executor/sycl/sycl_event.h" 19 | 20 | #include "tsl/platform/statusor.h" 21 | // #include "xla/stream_executor/sycl/sycl_executor.h" 22 | #include "xla/stream_executor/sycl/sycl_gpu_runtime.h" 23 | #include "xla/stream_executor/sycl/sycl_stream.h" 24 | 25 | namespace stream_executor { 26 | namespace gpu { 27 | 28 | namespace sycl = ::sycl; 29 | 30 | Event::Status SYCLEvent::PollForStatus() { 31 | auto* event = gpu_event()->event; 32 | auto event_status = 33 | event->get_info(); 34 | 35 | switch (event_status) { 36 | case sycl::info::event_command_status::submitted: 37 | case sycl::info::event_command_status::running: 38 | return Event::Status::kPending; 39 | case sycl::info::event_command_status::complete: 40 | return Event::Status::kComplete; 41 | default: 42 | return Event::Status::kUnknown; 43 | } 44 | } 45 | 46 | } // namespace gpu 47 | } // namespace stream_executor 48 | -------------------------------------------------------------------------------- /xla/stream_executor/sycl/sycl_event.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Copyright 2019 The TensorFlow Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | #ifndef XLA_STREAM_EXECUTOR_SYCL_SYCL_EVENT_H_ 19 | #define XLA_STREAM_EXECUTOR_SYCL_SYCL_EVENT_H_ 20 | 21 | #include "xla/stream_executor/gpu/gpu_event.h" 22 | 23 | namespace stream_executor::gpu { 24 | 25 | // This class implements Event::PollForStatus for CUDA devices. 26 | class SYCLEvent : public GpuEvent { 27 | public: 28 | explicit SYCLEvent(GpuExecutor *executor) : GpuEvent(executor) {} 29 | 30 | Event::Status PollForStatus() override; 31 | }; 32 | 33 | } // namespace stream_executor::gpu 34 | 35 | #endif // XLA_STREAM_EXECUTOR_SYCL_SYCL_EVENT_H_ 36 | -------------------------------------------------------------------------------- /xla/stream_executor/sycl/sycl_executor.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Copyright 2019 The TensorFlow Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | // The SYCL implementation of the StreamExecutorInterface functionality. 19 | // SYCL inclusions are ideally confined to this implementation file. 20 | // 21 | // The notions from the StreamExecutor basically correspond to the SYCL streams 22 | // programming model provided by the libcuda.so driver APIs, so we don't have 23 | // to do much more than wrap the calls to the libraries appropriately. 24 | 25 | #ifndef XLA_STREAM_EXECUTOR_SYCL_SYCL_EXECUTOR_H_ 26 | #define XLA_STREAM_EXECUTOR_SYCL_SYCL_EXECUTOR_H_ 27 | 28 | #include "xla/stream_executor/gpu/gpu_executor.h" 29 | 30 | namespace stream_executor { 31 | namespace sycl { 32 | 33 | using SYCLExecutor = gpu::GpuExecutor; 34 | 35 | } // namespace sycl 36 | } // namespace stream_executor 37 | #endif // XLA_STREAM_EXECUTOR_SYCL_SYCL_EXECUTOR_H_ 38 | -------------------------------------------------------------------------------- /xla/stream_executor/sycl/sycl_kernel.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Copyright 2019 The TensorFlow Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | #include "xla/stream_executor/sycl/sycl_kernel.h" 19 | 20 | namespace stream_executor { 21 | namespace gpu { 22 | 23 | absl::StatusOr GpuKernel::GetMaxOccupiedBlocksPerCore( 24 | ThreadDim threads, size_t dynamic_shared_memory_bytes) const { 25 | int32_t threads_per_block = threads.x * threads.y * threads.z; 26 | VLOG(3) << "Get kernel block occupancy: " << name_ 27 | << "; threads_per_block: " << threads_per_block 28 | << "; dynamic_shared_memory_bytes: " << dynamic_shared_memory_bytes; 29 | 30 | return GpuDriver::GetMaxOccupiedBlocksPerCore(gpu_context_, gpu_function_, 31 | threads_per_block, 32 | dynamic_shared_memory_bytes); 33 | } 34 | 35 | } // namespace gpu 36 | } // namespace stream_executor 37 | -------------------------------------------------------------------------------- /xla/stream_executor/sycl/sycl_kernel.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Copyright 2019 The TensorFlow Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | // The CUDA implementation of the StreamExecutorInterface functionality. 19 | // CUDA inclusions are ideally confined to this implementation file. 20 | // 21 | // The notions from the StreamExecutor basically correspond to the CUDA streams 22 | // programming model provided by the libcuda.so driver APIs, so we don't have 23 | // to do much more than wrap the calls to the libraries appropriately. 24 | #ifndef XLA_STREAM_EXECUTOR_SYCL_SYCL_KERNEL_H_ 25 | #define XLA_STREAM_EXECUTOR_SYCL_SYCL_KERNEL_H_ 26 | 27 | #include "xla/stream_executor/gpu/gpu_kernel.h" 28 | 29 | namespace stream_executor { 30 | namespace sycl { 31 | 32 | using SYCLKernel = gpu::GpuKernel; 33 | 34 | } // namespace sycl 35 | } // namespace stream_executor 36 | 37 | #endif // XLA_STREAM_EXECUTOR_SYCL_SYCL_KERNEL_H_ 38 | -------------------------------------------------------------------------------- /xla/stream_executor/sycl/sycl_platform.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Copyright 2015 The TensorFlow Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | #ifndef XLA_STREAM_EXECUTOR_SYCL_SYCL_PLATFORM_H_ 19 | #define XLA_STREAM_EXECUTOR_SYCL_SYCL_PLATFORM_H_ 20 | 21 | #include 22 | #include 23 | 24 | #include "absl/base/thread_annotations.h" 25 | #include "xla/stream_executor/executor_cache.h" 26 | #include "xla/stream_executor/platform_manager.h" 27 | #include "xla/stream_executor/platform.h" 28 | #include "xla/stream_executor/platform/port.h" 29 | #include "xla/stream_executor/stream_executor.h" 30 | #include "tsl/platform/statusor.h" 31 | 32 | namespace stream_executor { 33 | namespace sycl { 34 | // Opaque and unique identifier for the SYCL platform plugin. 35 | // This is needed so that plugins can refer to/identify this platform without 36 | // instantiating a SyclPlatform object. 37 | extern const Platform::Id kSyclPlatformId; 38 | } // namespace sycl 39 | 40 | namespace gpu { 41 | // Cuda-specific platform plugin, registered as a singleton value via module 42 | // initializer. 43 | class SyclPlatform : public Platform { 44 | public: 45 | SyclPlatform(); 46 | ~SyclPlatform() override; 47 | 48 | // SyclPlatform-specific functionality 49 | // Returns the number of distinct buses / NUMA nodes on the machine. 50 | int BusCount(); 51 | 52 | // Returns the bus/NUMA node for the specified device ordinal. 53 | int DeviceToBus(int device_ordinal); 54 | 55 | // Returns the lowest-ordinal-number StreamExecutor on the specified bus. 56 | absl::StatusOr FirstExecutorForBus(int bus_ordinal); 57 | 58 | // Platform interface implementation: 59 | // Returns the same value as kCudaPlatform above. 60 | Platform::Id id() const override; 61 | 62 | // Returns -1 as a sentinel on internal failure (and logs the error). 63 | int VisibleDeviceCount() const override; 64 | 65 | const std::string& Name() const override; 66 | 67 | absl::StatusOr> DescriptionForDevice( 68 | int ordinal) const override; 69 | 70 | absl::StatusOr ExecutorForDevice(int ordinal) override; 71 | 72 | absl::StatusOr GetExecutor( 73 | const StreamExecutorConfig& config) override; 74 | 75 | absl::StatusOr> GetUncachedExecutor( 76 | const StreamExecutorConfig& config) override; 77 | 78 | private: 79 | // Determines the number of NUMA nodes and the assignment of executor to each. 80 | void InspectNumaNodes(); 81 | 82 | // This platform's name. 83 | std::string name_; 84 | 85 | // Cache of created executors. 86 | ExecutorCache executor_cache_; 87 | 88 | // The smallest NUMA node value for any device managed by this machine 89 | // manager. Used, along with limit_numa_node_, to convert NUMA nodes into bus 90 | // ordinals. The NUMA node space occupied by GPUs is assumed to be dense./ 91 | int min_numa_node_; 92 | 93 | // Larger than the NUMA node value for any device managed by this machine 94 | // manager. 95 | int limit_numa_node_; 96 | 97 | SyclPlatform(const SyclPlatform&) = delete; 98 | void operator=(const SyclPlatform&) = delete; 99 | }; 100 | 101 | } // namespace gpu 102 | 103 | } // namespace stream_executor 104 | 105 | #endif // XLA_STREAM_EXECUTOR_SYCL_SYCL_PLATFORM_H_ 106 | -------------------------------------------------------------------------------- /xla/stream_executor/sycl/sycl_platform_id.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Copyright 2015 The TensorFlow Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | #include "xla/stream_executor/sycl/sycl_platform_id.h" 19 | 20 | namespace stream_executor { 21 | namespace sycl { 22 | 23 | PLATFORM_DEFINE_ID(kSyclPlatformId); 24 | 25 | } // namespace sycl 26 | } // namespace stream_executor 27 | -------------------------------------------------------------------------------- /xla/stream_executor/sycl/sycl_platform_id.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Copyright 2015 The TensorFlow Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | #ifndef XLA_STREAM_EXECUTOR_SYCL_SYCL_PLATFORM_ID_H_ 19 | #define XLA_STREAM_EXECUTOR_SYCL_SYCL_PLATFORM_ID_H_ 20 | 21 | #include "xla/stream_executor/platform.h" 22 | 23 | namespace stream_executor { 24 | namespace sycl { 25 | 26 | // Opaque and unique identifier for the cuda platform. 27 | // This is needed so that plugins can refer to/identify this platform without 28 | // instantiating a CudaPlatform object. 29 | // This is broken out here to avoid a circular dependency between CudaPlatform 30 | // and CudaExecutor. 31 | extern const Platform::Id kSyclPlatformId; 32 | 33 | } // namespace sycl 34 | } // namespace stream_executor 35 | 36 | #endif // XLA_STREAM_EXECUTOR_SYCL_SYCL_PLATFORM_ID_H_ 37 | -------------------------------------------------------------------------------- /xla/stream_executor/sycl/sycl_stream.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Intel Corporation 2 | 3 | Copyright 2019 The TensorFlow Authors. All Rights Reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | ==============================================================================*/ 17 | 18 | // Defines the GpuStream type - the CUDA-specific implementation of the generic 19 | // StreamExecutor Stream interface. 20 | 21 | #ifndef XLA_STREAM_EXECUTOR_SYCL_SYCL_STREAM_H_ 22 | #define XLA_STREAM_EXECUTOR_SYCL_SYCL_STREAM_H_ 23 | 24 | #include "xla/stream_executor/gpu/gpu_stream.h" 25 | 26 | namespace stream_executor { 27 | namespace sycl { 28 | 29 | using SYCLStream = gpu::GpuStream; 30 | 31 | inline SYCLStream* AsSYCLStream(Stream* stream) { 32 | return gpu::AsGpuStream(stream); 33 | } 34 | 35 | } // namespace sycl 36 | } // namespace stream_executor 37 | 38 | #endif // XLA_STREAM_EXECUTOR_SYCL_SYCL_STREAM_H_ 39 | -------------------------------------------------------------------------------- /xla/tools/pip_package/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # Tools for building the TensorFlow pip package. 3 | 4 | package(default_visibility = ["//visibility:private"]) 5 | 6 | load("//xla:xla.bzl", "transitive_hdrs") 7 | 8 | COMMON_PIP_DEPS = [ 9 | "MANIFEST.in", 10 | "README.md", 11 | "xla_setup.py", 12 | "//xla:pjrt_plugin_xpu.so", 13 | "//xla/service/gpu:sycl_onednn.so", 14 | "//xla/python:xpu_plugin_extension.so" 15 | ] 16 | 17 | py_binary( 18 | name = "simple_console", 19 | srcs = ["simple_console.py"], 20 | srcs_version = "PY2AND3", 21 | deps = [], 22 | ) 23 | 24 | sh_binary( 25 | name = "build_pip_package", 26 | srcs = ["build_pip_package.sh"], 27 | data = ["simple_console"] + COMMON_PIP_DEPS, 28 | ) 29 | -------------------------------------------------------------------------------- /xla/tools/pip_package/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README 2 | recursive-include * *.py 3 | recursive-include * *.pyd 4 | recursive-include * *.pd 5 | recursive-include * *.so 6 | recursive-include * *.so.[0-9] 7 | recursive-include * *.dylib 8 | recursive-include * *.dll 9 | recursive-include * *.lib 10 | recursive-include * *.csv 11 | recursive-include * *.h 12 | recursive-include * *.hpp -------------------------------------------------------------------------------- /xla/tools/pip_package/README.md: -------------------------------------------------------------------------------- 1 | # Intel® Extension for OpenXLA* 2 | 3 | [![Python](https://img.shields.io/pypi/pyversions/intel_extension_for_openxla)](https://badge.fury.io/py/intel-extension-for-openxla) 4 | [![PyPI version](https://badge.fury.io/py/intel-extension-for-openxla.svg)](https://badge.fury.io/py/intel-extension-for-openxla) 5 | [![version](https://img.shields.io/github/v/release/intel/intel-extension-for-openxla?color=brightgreen)](https://github.com/intel/intel-extension-for-openxla/releases) 6 | 7 | The [OpenXLA](https://github.com/openxla/xla) Project brings together a community of developers and leading AI/ML teams to accelerate ML and address infrastructure fragmentation across ML frameworks and hardware. 8 | 9 | Intel® Extension for OpenXLA* includes PJRT plugin implementation, which seamlessly runs JAX models on Intel GPU. The PJRT API simplified the integration, which allowed the Intel GPU plugin to be developed separately and quickly integrated into JAX. 10 | 11 | ## Installation 12 | 13 | The following table tracks intel-extension-for-openxla versions and compatible versions of jax, jaxlib. 14 | | **intel-extension-for-openxla** | **jaxlib** | **jax** | 15 | |:-:|:-:|:-:| 16 | | 0.6.0 | 0.4.38 | 0.4.38 | 17 | | 0.5.0 | 0.4.30 | >= 0.4.30, <= 0.4.31| 18 | | 0.4.0 | 0.4.26 | >= 0.4.26, <= 0.4.27| 19 | | 0.3.0 | 0.4.24 | >= 0.4.24, <= 0.4.27| 20 | | 0.2.1 | 0.4.20 | >= 0.4.20, <= 0.4.26| 21 | | 0.2.0 | 0.4.20 | >= 0.4.20, <= 0.4.26| 22 | | 0.1.0 | 0.4.13 | >= 0.4.13, <= 0.4.14| 23 | 24 | 25 | ``` 26 | pip install --upgrade intel-extension-for-openxla 27 | ``` 28 | 29 | ## Security 30 | See Intel's [Security Center](https://www.intel.com/content/www/us/en/security-center/default.html) for information on how to report a potential security issue or vulnerability. 31 | 32 | See also: [Security Policy](https://github.com/intel/intel-extension-for-openxla/blob/main/security.md) 33 | -------------------------------------------------------------------------------- /xla/tools/pip_package/simple_console.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Intel Corporation 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Start a simple interactive console with TensorFlow available.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import code 23 | import sys 24 | 25 | 26 | def main(_): 27 | """Run an interactive console.""" 28 | code.interact() 29 | return 0 30 | 31 | 32 | if __name__ == '__main__': 33 | sys.exit(main(sys.argv)) 34 | -------------------------------------------------------------------------------- /xla/workspace.bzl: -------------------------------------------------------------------------------- 1 | load("//third_party/gpus:sycl_configure.bzl", "sycl_configure") 2 | load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository", "new_git_repository") 3 | 4 | def workspace(path_prefix = "", tf_repo_name = ""): 5 | sycl_configure(name = "local_config_sycl") 6 | 7 | new_git_repository( 8 | name = "onednn_gpu", 9 | # v3.7 10 | commit = "5e9224036021433d2577548ed0539fe9a53256bc", 11 | remote = "https://github.com/oneapi-src/oneDNN.git", 12 | build_file = "//third_party/onednn:onednn_gpu.BUILD", 13 | verbose = True, 14 | patch_cmds = [ 15 | "git log -1 --format=%H > COMMIT", 16 | ], 17 | ) 18 | 19 | git_repository( 20 | name = "spir_headers", 21 | commit = "4f7b471f1a66b6d06462cd4ba57628cc0cd087d7", 22 | remote = "https://github.com/KhronosGroup/SPIRV-Headers.git", 23 | verbose = True, 24 | patch_cmds = [ 25 | "git log -1 --format=%H > COMMIT", 26 | ], 27 | ) 28 | 29 | new_git_repository( 30 | name = "llvm_spir", 31 | commit = "98729ef13ca970b98a149c18318746df8d921f20", 32 | remote = "https://github.com/KhronosGroup/SPIRV-LLVM-Translator.git", 33 | build_file = "//third_party/llvm_spir:llvm_spir.BUILD", 34 | verbose = True, 35 | patches = ["//third_party/llvm_spir:llvm_spir.patch"], 36 | patch_args = ["-p1"], 37 | ) 38 | 39 | new_git_repository( 40 | name = "xetla", 41 | patch_args = ["-p1"], 42 | patches = ["//third_party/xetla:xetla.patch"], 43 | # v0.3.7.2 44 | commit = "ae46a690bac364a93437e248418636c2a8423134", 45 | remote = "https://github.com/intel/xetla.git", 46 | verbose = True, 47 | build_file = "//third_party/xetla:BUILD", 48 | patch_cmds = [ 49 | "git log -1 --format=%H > COMMIT", 50 | ], 51 | ) 52 | -------------------------------------------------------------------------------- /xla/xla.bzl: -------------------------------------------------------------------------------- 1 | # Return the options to use for a C++ library or binary build. 2 | # Uses the ":optmode" config_setting to pick the options. 3 | load("@local_config_sycl//sycl:build_defs.bzl", "if_sycl") 4 | 5 | def if_linux_x86_64(a, otherwise = []): 6 | return select({ 7 | "//conditons:default": otherwise, 8 | }) 9 | 10 | def tf_copts(android_optimization_level_override = "-O2", is_external = False): 11 | # For compatibility reasons, android_optimization_level_override 12 | # is currently only being set for Android. 13 | # To clear this value, and allow the CROSSTOOL default 14 | # to be used, pass android_optimization_level_override=None 15 | return ( 16 | [ 17 | "-Wno-sign-compare", 18 | "-Wno-unknown-pragmas", 19 | "-ftemplate-depth=900", 20 | "-msse3", 21 | "-pthread", 22 | ] 23 | ) 24 | 25 | def xetla_library(name, srcs = [], hdrs = [], deps = [], *argc, **kwargs): 26 | kwargs["copts"] = kwargs.get("copts", []) + if_sycl(["--xetla", "-sycl_compile"]) 27 | kwargs["linkopts"] = kwargs.get("linkopts", []) + if_sycl(["--xetla", "-link_stage"]) 28 | kwargs["alwayslink"] = True 29 | native.cc_library( 30 | name = name, 31 | srcs = srcs, 32 | hdrs = hdrs, 33 | deps = deps, 34 | **kwargs 35 | ) 36 | 37 | def xpu_library(name, srcs = [], hdrs = [], deps = [], *argc, **kwargs): 38 | kwargs["copts"] = kwargs.get("copts", []) + if_sycl(["-sycl_compile"]) 39 | kwargs["linkopts"] = kwargs.get("linkopts", []) + if_sycl(["-link_stage"]) 40 | kwargs["alwayslink"] = True 41 | native.cc_library( 42 | name = name, 43 | srcs = srcs, 44 | hdrs = hdrs, 45 | deps = deps, 46 | **kwargs 47 | ) 48 | 49 | def _get_transitive_headers(hdrs, deps): 50 | return depset( 51 | hdrs, 52 | transitive = [dep[CcInfo].compilation_context.headers for dep in deps], 53 | ) 54 | 55 | def _transitive_hdrs_impl(ctx): 56 | outputs = _get_transitive_headers([], ctx.attr.deps) 57 | return struct(files = outputs) 58 | 59 | _transitive_hdrs = rule( 60 | attrs = { 61 | "deps": attr.label_list( 62 | allow_files = True, 63 | providers = [CcInfo], 64 | ), 65 | }, 66 | implementation = _transitive_hdrs_impl, 67 | ) 68 | 69 | def transitive_hdrs(name, deps = [], **kwargs): 70 | _transitive_hdrs(name = name + "_gather", deps = deps) 71 | native.filegroup(name = name, srcs = [":" + name + "_gather"]) 72 | 73 | def cc_header_only_library(name, deps = [], includes = [], extra_deps = [], **kwargs): 74 | _transitive_hdrs(name = name + "_gather", deps = deps) 75 | native.cc_library( 76 | name = name, 77 | srcs = [":" + name + "_gather"], 78 | hdrs = includes, 79 | deps = extra_deps, 80 | **kwargs 81 | ) 82 | --------------------------------------------------------------------------------