├── .tool-versions ├── .formatter.exs ├── .dockerignore ├── RELEASE.md ├── lib ├── xla │ ├── checksumer.ex │ └── utils.ex ├── mix │ └── tasks │ │ ├── xla.info.ex │ │ └── xla.checksum.ex └── xla.ex ├── .github ├── scripts │ └── upload_artifact.sh └── workflows │ └── release.yml ├── extension ├── patches │ ├── cuda_ncrtc_builtins.patch │ └── apply.sh └── BUILD ├── .gitignore ├── mix.exs ├── Makefile ├── mix.lock ├── CHANGELOG.md ├── README.md └── LICENSE /.tool-versions: -------------------------------------------------------------------------------- 1 | bazel 7.4.1 2 | python 3.11.11 3 | -------------------------------------------------------------------------------- /.formatter.exs: -------------------------------------------------------------------------------- 1 | # Used by "mix format" 2 | [ 3 | inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"] 4 | ] 5 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # Mirrors .gitignore 2 | 3 | /_build/ 4 | /cover/ 5 | /deps/ 6 | /doc/ 7 | /.fetch 8 | erl_crash.dump 9 | *.ez 10 | xla-*.tar 11 | /tmp/ 12 | /cache/ 13 | /builds/output/ 14 | -------------------------------------------------------------------------------- /RELEASE.md: -------------------------------------------------------------------------------- 1 | # Releasing XLA 2 | 3 | 1. Update version in `mix.exs` and update CHANGELOG. 4 | 2. Run `git tag x.y.z` and `git push --tags`. 5 | 1. Wait for CI to precompile all artifacts. 6 | 2. Build the remaining artifacts off-CI and upload to the draft GH release. 7 | 3. Publish GH release with copied changelog notes (CI creates a draft, we need to publish it to compute the checksum). 8 | 4. Run `mix xla.checksum`. 9 | 5. Run `mix hex.publish`. 10 | -------------------------------------------------------------------------------- /lib/xla/checksumer.ex: -------------------------------------------------------------------------------- 1 | defmodule XLA.Checksumer do 2 | @moduledoc false 3 | 4 | defstruct algorithm: :sha256 5 | 6 | defimpl Collectable do 7 | def into(checksumer) do 8 | state = :crypto.hash_init(checksumer.algorithm) 9 | 10 | collector = fn 11 | state, {:cont, chunk} when is_binary(chunk) -> 12 | :crypto.hash_update(state, chunk) 13 | 14 | state, :done -> 15 | hash = :crypto.hash_final(state) 16 | Base.encode16(hash, case: :lower) 17 | 18 | _state, :halt -> 19 | :ok 20 | end 21 | 22 | {state, collector} 23 | end 24 | end 25 | end 26 | -------------------------------------------------------------------------------- /.github/scripts/upload_artifact.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -ex 4 | 5 | cd "$(dirname "$0")/../.." 6 | 7 | sha="$1" 8 | release_name="$2" 9 | files="${@:3}" 10 | 11 | # Create the release, if not present. 12 | if ! gh release view $release_name; then 13 | gh release create $release_name --title $release_name --draft --target $sha 14 | fi 15 | 16 | # Uploading is the final action after several hour long build, so in 17 | # case of any temporary network failures we want to retry a number 18 | # of times. 19 | for i in {1..10}; do 20 | gh release upload $release_name $files --clobber && break 21 | echo "Upload failed, retrying in 30s" 22 | sleep 30 23 | done 24 | -------------------------------------------------------------------------------- /lib/mix/tasks/xla.info.ex: -------------------------------------------------------------------------------- 1 | defmodule Mix.Tasks.Xla.Info do 2 | @moduledoc false 3 | # Returns relevant information about the XLA archive. 4 | 5 | use Mix.Task 6 | 7 | @impl true 8 | def run(["archive_filename"]) do 9 | Mix.shell().info(XLA.archive_filename_with_target()) 10 | end 11 | 12 | def run(["release_tag"]) do 13 | Mix.shell().info("v" <> XLA.version()) 14 | end 15 | 16 | def run(["build_archive_dir"]) do 17 | Mix.shell().info(XLA.build_archive_dir()) 18 | end 19 | 20 | def run(_args) do 21 | Mix.shell().error(""" 22 | Usage: 23 | mix xla.info archive_filename 24 | mix xla.info release_tag\ 25 | mix xla.info build_archive_dir\ 26 | """) 27 | end 28 | end 29 | -------------------------------------------------------------------------------- /extension/patches/cuda_ncrtc_builtins.patch: -------------------------------------------------------------------------------- 1 | diff --git a/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl 2 | index 7c0399a..17b1d78 100644 3 | --- a/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl 4 | +++ b/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl 5 | @@ -10,16 +10,11 @@ cc_import( 6 | shared_library = "lib/libnvrtc.so.%{libnvrtc_version}", 7 | ) 8 | 9 | -cc_import( 10 | - name = "nvrtc_builtins", 11 | - shared_library = "lib/libnvrtc-builtins.so.%{libnvrtc-builtins_version}", 12 | -) 13 | %{multiline_comment} 14 | cc_library( 15 | name = "nvrtc", 16 | %{comment}deps = [ 17 | %{comment}":nvrtc_main", 18 | - %{comment}":nvrtc_builtins", 19 | %{comment}], 20 | %{comment}linkopts = cuda_rpath_flags("nvidia/cuda_nvrtc/lib"), 21 | visibility = ["//visibility:public"], 22 | -------------------------------------------------------------------------------- /extension/patches/apply.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -ex 4 | 5 | dir="$(cd "$(dirname "$0")"; pwd)" 6 | arch="$(uname -m)" 7 | 8 | # if [[ $arch == 'aarch64' ]]; then 9 | # # ... 10 | # fi 11 | 12 | # XLA build links againast a major version of CUDA libraries, so the 13 | # build should be compatible with CUDA installations across all minor 14 | # versions. However, currently it also links againast a specific minor 15 | # version nvrtc-builtins. That library is for debugging, it does not 16 | # maintain compatibility across minor versions, and libraries should 17 | # not link againast it. Looks like they only use symbols from that 18 | # library for tests. The below patch changes the Bazel XLA build 19 | # definitions to not link against nvrtc-builtins. 20 | # 21 | # See https://github.com/tensorflow/tensorflow/pull/86413 and the 22 | # referenced threads. 23 | git apply $dir/cuda_ncrtc_builtins.patch 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # The directory Mix will write compiled artifacts to. 2 | /_build/ 3 | 4 | # If you run "mix test --cover", coverage assets end up here. 5 | /cover/ 6 | 7 | # The directory Mix downloads your dependencies sources to. 8 | /deps/ 9 | 10 | # Where third-party dependencies like ExDoc output generated docs. 11 | /doc/ 12 | 13 | # Ignore .fetch files in case you like to edit your project deps locally. 14 | /.fetch 15 | 16 | # If the VM crashes, it generates a dump, let's ignore it too. 17 | erl_crash.dump 18 | 19 | # Also ignore archive artifacts (built via "mix archive.build"). 20 | *.ez 21 | 22 | # Ignore package tarball (built via "mix hex.build"). 23 | xla-*.tar 24 | 25 | # Temporary files, for example, from tests. 26 | /tmp/ 27 | 28 | # Ignore all the archives 29 | /cache/ 30 | 31 | # Output from compilations within Docker 32 | /builds/output/ 33 | 34 | # Archive checksums file 35 | /checksum.txt 36 | -------------------------------------------------------------------------------- /lib/mix/tasks/xla.checksum.ex: -------------------------------------------------------------------------------- 1 | defmodule Mix.Tasks.Xla.Checksum do 2 | @moduledoc false 3 | # Generates a checksum file for all precompiled artifacts. 4 | 5 | use Mix.Task 6 | 7 | @impl true 8 | def run(_args) do 9 | XLA.Utils.start_inets_profile() 10 | 11 | Mix.shell().info("Downloading and computing checksums...") 12 | 13 | checksums = 14 | XLA.precompiled_files() 15 | |> Task.async_stream( 16 | fn {filename, url} -> 17 | {filename, download_checksum!(url)} 18 | end, 19 | timeout: :infinity, 20 | ordered: false 21 | ) 22 | |> Map.new(fn {:ok, {filename, checksum}} -> {filename, checksum} end) 23 | 24 | XLA.write_checksums!(checksums) 25 | 26 | Mix.shell().info("Checksums written") 27 | after 28 | XLA.Utils.stop_inets_profile() 29 | end 30 | 31 | defp download_checksum!(url) do 32 | case with_retry(fn -> XLA.Utils.download(url, %XLA.Checksumer{}) end, 3) do 33 | {:ok, checksum} -> 34 | checksum 35 | 36 | {:error, message} -> 37 | Mix.raise("failed to download archive from #{url}, reason: #{message}") 38 | end 39 | end 40 | 41 | defp with_retry(fun, retries) when retries > 0 do 42 | first_try = fun.() 43 | 44 | Enum.reduce_while(1..retries//1, first_try, fn n, result -> 45 | case result do 46 | {:ok, _} -> 47 | {:halt, result} 48 | 49 | {:error, message} -> 50 | Mix.shell().info("Retrying request, attempt #{n} failed with reason: #{message}") 51 | 52 | wait_in_ms = :rand.uniform(n * 2_000) 53 | Process.sleep(wait_in_ms) 54 | 55 | {:cont, fun.()} 56 | end 57 | end) 58 | end 59 | end 60 | -------------------------------------------------------------------------------- /mix.exs: -------------------------------------------------------------------------------- 1 | defmodule XLA.MixProject do 2 | use Mix.Project 3 | 4 | @version "0.9.1" 5 | 6 | def project do 7 | [ 8 | app: :xla, 9 | version: @version, 10 | description: "Precompiled XLA binaries", 11 | elixir: "~> 1.12", 12 | aliases: aliases(), 13 | deps: deps(), 14 | compilers: Mix.compilers() ++ if(build?(), do: [:elixir_make], else: []), 15 | make_env: &XLA.make_env/0, 16 | package: package(), 17 | docs: docs() 18 | ] 19 | end 20 | 21 | def application do 22 | [extra_applications: [:logger, :inets, :ssl, :public_key, :crypto]] 23 | end 24 | 25 | def aliases do 26 | [ 27 | "hex.publish": [&ensure_checksum_file/1, "hex.publish"] 28 | ] 29 | end 30 | 31 | defp ensure_checksum_file(_) do 32 | valid? = 33 | case File.read("checksum.txt") do 34 | {:ok, content} -> content =~ @version 35 | {:error, :enoent} -> false 36 | end 37 | 38 | if not valid? do 39 | raise "run mix xla.checksum before releasing" 40 | end 41 | end 42 | 43 | defp deps do 44 | [ 45 | {:elixir_make, "~> 0.4", runtime: false}, 46 | {:ex_doc, "~> 0.25", only: :dev, runtime: false} 47 | ] 48 | end 49 | 50 | def package do 51 | [ 52 | licenses: ["Apache-2.0"], 53 | links: %{ 54 | "GitHub" => "https://github.com/elixir-nx/xla" 55 | }, 56 | files: ~w(extension lib Makefile mix.exs README.md LICENSE CHANGELOG.md checksum.txt) 57 | ] 58 | end 59 | 60 | def docs do 61 | [ 62 | main: "XLA", 63 | source_url: "https://github.com/elixir-nx/xla", 64 | source_ref: "v#{@version}" 65 | ] 66 | end 67 | 68 | defp build?() do 69 | System.get_env("XLA_BUILD") in ~w(1 true) 70 | end 71 | end 72 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Environment variables passed via elixir_make 2 | # ROOT_DIR 3 | # BUILD_INTERNAL_FLAGS 4 | # BUILD_ARCHIVE 5 | # BUILD_ARCHIVE_DIR 6 | # BUILD_CACHE_DIR 7 | 8 | # Public configuration 9 | BUILD_MODE ?= opt # can also be dbg 10 | OPENXLA_GIT_REPO ?= https://github.com/openxla/xla.git 11 | 12 | OPENXLA_GIT_REV ?= 870d90fd098c480fb8a426126bd02047adb2bc20 13 | 14 | # Private configuration 15 | BAZEL_FLAGS = --define "framework_shared_object=false" -c $(BUILD_MODE) 16 | 17 | OPENXLA_NS = xla-$(OPENXLA_GIT_REV) 18 | OPENXLA_DIR = $(BUILD_CACHE_DIR)/$(OPENXLA_NS) 19 | OPENXLA_XLA_EXTENSION_NS = xla/extension 20 | OPENXLA_XLA_EXTENSION_DIR = $(OPENXLA_DIR)/$(OPENXLA_XLA_EXTENSION_NS) 21 | OPENXLA_XLA_BUILD_ARCHIVE = $(OPENXLA_DIR)/bazel-bin/$(OPENXLA_XLA_EXTENSION_NS)/xla_extension.tar.gz 22 | 23 | $(BUILD_ARCHIVE): $(OPENXLA_DIR) extension/BUILD 24 | rm -f $(OPENXLA_XLA_EXTENSION_DIR) && \ 25 | ln -s "$(ROOT_DIR)/extension" $(OPENXLA_XLA_EXTENSION_DIR) && \ 26 | cd $(OPENXLA_DIR) && \ 27 | bazel build $(BAZEL_FLAGS) $(BUILD_FLAGS) $(BUILD_INTERNAL_FLAGS) //$(OPENXLA_XLA_EXTENSION_NS):xla_extension && \ 28 | mkdir -p $(dir $(BUILD_ARCHIVE)) && \ 29 | cp -f $(OPENXLA_XLA_BUILD_ARCHIVE) $(BUILD_ARCHIVE) 30 | 31 | # Clones OPENXLA 32 | $(OPENXLA_DIR): 33 | mkdir -p $(OPENXLA_DIR) && \ 34 | cp -r extension/patches $(OPENXLA_DIR) && \ 35 | cd $(OPENXLA_DIR) && \ 36 | git init && \ 37 | git remote add origin $(OPENXLA_GIT_REPO) && \ 38 | git fetch --depth 1 origin $(OPENXLA_GIT_REV) && \ 39 | git checkout FETCH_HEAD && \ 40 | bash patches/apply.sh && \ 41 | rm $(OPENXLA_DIR)/.bazelversion 42 | 43 | # Print OPENXLA Dir 44 | PTD: 45 | @ echo $(OPENXLA_DIR) 46 | 47 | clean: 48 | cd $(OPENXLA_DIR) && bazel clean --expunge 49 | rm -f $(OPENXLA_XLA_EXTENSION_DIR) 50 | rm -rf $(OPENXLA_DIR) 51 | rm -rf $(TARGET_DIR) 52 | -------------------------------------------------------------------------------- /mix.lock: -------------------------------------------------------------------------------- 1 | %{ 2 | "earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"}, 3 | "elixir_make": {:hex, :elixir_make, "0.6.2", "7dffacd77dec4c37b39af867cedaabb0b59f6a871f89722c25b28fcd4bd70530", [:mix], [], "hexpm", "03e49eadda22526a7e5279d53321d1cced6552f344ba4e03e619063de75348d9"}, 4 | "ex_doc": {:hex, :ex_doc, "0.34.2", "13eedf3844ccdce25cfd837b99bea9ad92c4e511233199440488d217c92571e8", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "5ce5f16b41208a50106afed3de6a2ed34f4acfd65715b82a0b84b49d995f95c1"}, 5 | "makeup": {:hex, :makeup, "1.2.1", "e90ac1c65589ef354378def3ba19d401e739ee7ee06fb47f94c687016e3713d1", [:mix], [{:nimble_parsec, "~> 1.4", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "d36484867b0bae0fea568d10131197a4c2e47056a6fbe84922bf6ba71c8d17ce"}, 6 | "makeup_elixir": {:hex, :makeup_elixir, "1.0.0", "74bb8348c9b3a51d5c589bf5aebb0466a84b33274150e3b6ece1da45584afc82", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "49159b7d7d999e836bedaf09dcf35ca18b312230cf901b725a64f3f42e407983"}, 7 | "makeup_erlang": {:hex, :makeup_erlang, "1.0.1", "c7f58c120b2b5aa5fd80d540a89fdf866ed42f1f3994e4fe189abebeab610839", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "8a89a1eeccc2d798d6ea15496a6e4870b75e014d1af514b1b71fa33134f57814"}, 8 | "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, 9 | } 10 | -------------------------------------------------------------------------------- /lib/xla/utils.ex: -------------------------------------------------------------------------------- 1 | defmodule XLA.Utils do 2 | @moduledoc false 3 | 4 | @doc """ 5 | Downloads resource at the given URL into `collectable`. 6 | 7 | If collectable raises an error, it is rescued and an error tuple 8 | is returned. 9 | 10 | ## Options 11 | 12 | * `:headers` - request headers 13 | 14 | """ 15 | @spec download(String.t(), Collectable.t(), keyword()) :: 16 | {:ok, Collectable.t()} | {:error, String.t()} 17 | def download(url, collectable, opts \\ []) do 18 | headers = build_headers(opts[:headers] || []) 19 | 20 | request = {url, headers} 21 | http_opts = [ssl: http_ssl_opts()] 22 | 23 | caller = self() 24 | 25 | receiver = fn reply_info -> 26 | request_id = elem(reply_info, 0) 27 | 28 | # Cancel the request if the caller terminates 29 | if Process.alive?(caller) do 30 | send(caller, {:http, reply_info}) 31 | else 32 | :httpc.cancel_request(request_id, :xla) 33 | end 34 | end 35 | 36 | opts = [stream: :self, sync: false, receiver: receiver] 37 | 38 | {:ok, request_id} = :httpc.request(:get, request, http_opts, opts, :xla) 39 | 40 | try do 41 | {acc, collector} = Collectable.into(collectable) 42 | 43 | try do 44 | download_loop(%{request_id: request_id, acc: acc, collector: collector}) 45 | catch 46 | kind, reason -> 47 | collector.(acc, :halt) 48 | :httpc.cancel_request(request_id, :xla) 49 | exception = Exception.normalize(kind, reason, __STACKTRACE__) 50 | {:error, Exception.message(exception)} 51 | else 52 | {:ok, state} -> 53 | acc = state.collector.(state.acc, :done) 54 | {:ok, acc} 55 | 56 | {:error, message} -> 57 | collector.(acc, :halt) 58 | :httpc.cancel_request(request_id, :xla) 59 | {:error, message} 60 | end 61 | catch 62 | kind, reason -> 63 | :httpc.cancel_request(request_id, :xla) 64 | exception = Exception.normalize(kind, reason, __STACKTRACE__) 65 | {:error, Exception.message(exception)} 66 | end 67 | end 68 | 69 | defp build_headers(entries) do 70 | headers = 71 | Enum.map(entries, fn {key, value} -> 72 | {to_charlist(key), to_charlist(value)} 73 | end) 74 | 75 | [{~c"user-agent", ~c"elixir-nx/xla"} | headers] 76 | end 77 | 78 | defp download_loop(state) do 79 | receive do 80 | {:http, reply_info} when elem(reply_info, 0) == state.request_id -> 81 | download_receive(state, reply_info) 82 | end 83 | end 84 | 85 | defp download_receive(_state, {_, {:error, error}}) do 86 | {:error, "reason: #{inspect(error)}"} 87 | end 88 | 89 | defp download_receive(state, {_, {{_, 200, _}, _headers, body}}) do 90 | acc = state.collector.(state.acc, {:cont, body}) 91 | {:ok, %{state | acc: acc}} 92 | end 93 | 94 | defp download_receive(_state, {_, {{_, status, _}, _headers, _body}}) do 95 | {:error, "got HTTP status #{status}"} 96 | end 97 | 98 | defp download_receive(state, {_, :stream_start, _headers}) do 99 | download_loop(state) 100 | end 101 | 102 | defp download_receive(state, {_, :stream, body_part}) do 103 | acc = state.collector.(state.acc, {:cont, body_part}) 104 | download_loop(%{state | acc: acc}) 105 | end 106 | 107 | defp download_receive(state, {_, :stream_end, _headers}) do 108 | {:ok, state} 109 | end 110 | 111 | defp http_ssl_opts() do 112 | # Use secure options, see https://gist.github.com/jonatanklosko/5e20ca84127f6b31bbe3906498e1a1d7 113 | [ 114 | cacerts: :public_key.cacerts_get(), 115 | verify: :verify_peer, 116 | customize_hostname_check: [ 117 | match_fun: :public_key.pkix_verify_hostname_match_fun(:https) 118 | ] 119 | ] 120 | end 121 | 122 | @doc false 123 | def start_inets_profile() do 124 | # Starting an HTTP client profile allows us to scope the httpc 125 | # configuration options, such as proxy options 126 | {:ok, _pid} = :inets.start(:httpc, profile: :xla) 127 | set_proxy_options() 128 | end 129 | 130 | @doc false 131 | def stop_inets_profile() do 132 | :inets.stop(:httpc, :xla) 133 | end 134 | 135 | defp set_proxy_options() do 136 | http_proxy = System.get_env("HTTP_PROXY") || System.get_env("http_proxy") 137 | https_proxy = System.get_env("HTTPS_PROXY") || System.get_env("https_proxy") 138 | 139 | no_proxy = 140 | if no_proxy = System.get_env("NO_PROXY") || System.get_env("no_proxy") do 141 | no_proxy 142 | |> String.split(",") 143 | |> Enum.map(&String.to_charlist/1) 144 | else 145 | [] 146 | end 147 | 148 | set_proxy_option(:proxy, http_proxy, no_proxy) 149 | set_proxy_option(:https_proxy, https_proxy, no_proxy) 150 | end 151 | 152 | defp set_proxy_option(proxy_scheme, proxy, no_proxy) do 153 | uri = URI.parse(proxy || "") 154 | 155 | if uri.host && uri.port do 156 | host = String.to_charlist(uri.host) 157 | :httpc.set_options([{proxy_scheme, {{host, uri.port}, no_proxy}}], :xla) 158 | end 159 | end 160 | end 161 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). 6 | 7 | ## [v0.9.1](https://github.com/elixir-nx/xla/tree/v0.9.1) (2025-07-03) 8 | 9 | ### Changed 10 | 11 | * Lowered the requirement from glibc 2.35+ back to 2.31+ ([#116](https://github.com/elixir-nx/xla/pull/116)) 12 | 13 | ## [v0.9.0](https://github.com/elixir-nx/xla/tree/v0.9.0) (2025-06-16) 14 | 15 | ### Changed 16 | 17 | * Bumped XLA version ([#111](https://github.com/elixir-nx/xla/pull/111)) 18 | * CUDA build now requires Nvidia NCCL (`libnccl2`) to be installed 19 | * Compiling XLA from source now requires Clang instead of GCC 20 | 21 | ## [v0.8.0](https://github.com/elixir-nx/xla/tree/v0.8.0) (2024-08-17) 22 | 23 | ### Added 24 | 25 | * Integrity verification when downloading the precompiled binaries ([#94](https://github.com/elixir-nx/xla/pull/94)) 26 | 27 | ### Changed 28 | 29 | * Bumped the version requirement for CUDA 12 to cuDNN 9.1+ ([#93](https://github.com/elixir-nx/xla/pull/93)) 30 | * Archive file names to include the release version 31 | * Dropped the requirement for either `wget` or `curl` to be installed ([#94](https://github.com/elixir-nx/xla/pull/94)) 32 | 33 | ### Removed 34 | 35 | * Removed the `XLA_HTTP_HEADERS` environment variable ([#94](https://github.com/elixir-nx/xla/pull/94)) 36 | 37 | ### Fixed 38 | 39 | * Download failures due to GitHub API rate limiting on CI ([#94](https://github.com/elixir-nx/xla/pull/94)) 40 | 41 | ## [v0.7.1](https://github.com/elixir-nx/xla/tree/v0.7.1) (2024-07-01) 42 | 43 | ### Changed 44 | 45 | * `XLA_TARGET` to default to a matching target when CUDA installation is detected ([#88](https://github.com/elixir-nx/xla/pull/88)) 46 | 47 | ## [v0.7.0](https://github.com/elixir-nx/xla/tree/v0.7.0) (2024-05-21) 48 | 49 | ### Changed 50 | 51 | * Bumped XLA version ([#83](https://github.com/elixir-nx/xla/pull/83)) 52 | * Renamed the recognised XLA_TARGET "cuda120" to "cuda12" ([#84](https://github.com/elixir-nx/xla/pull/84)) 53 | 54 | ### Removed 55 | 56 | * Dropped support for CUDA 11.8+, now 12.1+ is required ([#84](https://github.com/elixir-nx/xla/pull/84)) 57 | 58 | ## [v0.6.0](https://github.com/elixir-nx/xla/tree/v0.6.0) (2023-11-10) 59 | 60 | ### Changed 61 | 62 | * Bumped XLA version ([#62](https://github.com/elixir-nx/xla/pull/62)) 63 | 64 | ## [v0.5.1](https://github.com/elixir-nx/xla/tree/v0.5.1) (2023-09-14) 65 | 66 | ### Changed 67 | 68 | * Bumped the version requirement for CUDA 12 to CUDA 12.1 and cuDNN 8.9 ([#54](https://github.com/elixir-nx/xla/pull/54)) 69 | 70 | ## [v0.5.0](https://github.com/elixir-nx/xla/tree/v0.5.0) (2023-08-13) 71 | 72 | ### Added 73 | 74 | * Support for custom http headers ([#44](https://github.com/elixir-nx/xla/pull/44)) 75 | * Support for CUDA 12 76 | 77 | ### Changed 78 | 79 | * Migrated to OpenXLA source code ([#45](https://github.com/elixir-nx/xla/pull/45)) 80 | 81 | ### Removed 82 | 83 | * Dropped precompiled binary for CUDA 11.1 and CUDA 11.4 84 | * Dropped precompiled binary for Linux musl 85 | 86 | ## [v0.4.4](https://github.com/elixir-nx/xla/tree/v0.4.4) (2023-02-17) 87 | 88 | ### Added 89 | 90 | * Sorting library ([#39](https://github.com/elixir-nx/xla/pull/39)) 91 | 92 | ## [v0.4.3](https://github.com/elixir-nx/xla/tree/v0.4.3) (2022-12-15) 93 | 94 | ### Fixed 95 | 96 | * Building with `XLA_BUILD` (regression from v0.4.2) ([#33](https://github.com/elixir-nx/xla/pull/33)) 97 | 98 | ## [v0.4.2](https://github.com/elixir-nx/xla/tree/v0.4.2) (2022-12-15) 99 | 100 | ### Added 101 | 102 | * Precompiled binaries for Linux musl ([#31](https://github.com/elixir-nx/xla/pull/31)) 103 | 104 | ### Fixed 105 | 106 | * Partially fixed building for ROCm, see [notes](https://github.com/elixir-nx/xla/blob/e0352a1769ecdb93f7c829f7f184fd2b81d6ad3f/README.md#notes-for-rocm) ([#30](https://github.com/elixir-nx/xla/pull/30)) 107 | 108 | ## [v0.4.1](https://github.com/elixir-nx/xla/tree/v0.4.1) (2022-12-08) 109 | 110 | ### Added 111 | 112 | * Precompiled binaries for CUDA 11.4+ (cuDNN 8.2+) and CUDA 11.8+ (cuDNN 8.6+) ([#27](https://github.com/elixir-nx/xla/pull/27)) 113 | 114 | ### Changed 115 | 116 | * Precompiled binaries to assume glibc 2.31+ ([#27](https://github.com/elixir-nx/xla/pull/27)) 117 | 118 | ## [v0.4.0](https://github.com/elixir-nx/xla/tree/v0.4.0) (2022-11-20) 119 | 120 | ### Changed 121 | 122 | * Bumped XLA (Tensorflow) version to 2.11.0 ([#25](https://github.com/elixir-nx/xla/pull/25)) 123 | 124 | ## [v0.3.0](https://github.com/elixir-nx/xla/tree/v0.3.0) (2022-02-17) 125 | 126 | ### Changed 127 | 128 | * Bumped XLA (Tensorflow) version to 2.8.0 ([#15](https://github.com/elixir-nx/xla/pull/15)) 129 | 130 | ### Removed 131 | 132 | * Dropped support for CUDA 10.2 and 11.0, now 11.1+ is required ([#17](https://github.com/elixir-nx/xla/pull/17)) 133 | 134 | ## [v0.2.0](https://github.com/elixir-nx/xla/tree/v0.2.0) (2021-09-23) 135 | 136 | ### Added 137 | 138 | * Added support for Apple Silicon ([#9](https://github.com/elixir-nx/xla/pull/9)) 139 | 140 | ### Changed 141 | 142 | * Bumped XLA (Tensorflow) version to 2.6.0 ([#9](https://github.com/elixir-nx/xla/pull/9)) 143 | 144 | ## [v0.1.1](https://github.com/elixir-nx/xla/tree/v0.1.1) (2021-09-16) 145 | 146 | ### Changed 147 | 148 | * Build for older glibc versions ([#3](https://github.com/elixir-nx/xla/pull/3)) 149 | 150 | ## [v0.1.0](https://github.com/elixir-nx/xla/tree/v0.1.0) (2021-09-16) 151 | 152 | Initial release. 153 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | on: 3 | push: 4 | tags: 5 | - "v*.*.*" 6 | 7 | env: 8 | NX_XLA_SHA: ${{ github.sha }} 9 | NX_XLA_RELEASE_NAME: ${{ github.ref_name }} 10 | 11 | # Build envs for non-Docker jobs. 12 | USE_BAZEL_VERSION: 7.4.1 13 | XLA_BUILD: true 14 | XLA_CACHE_DIR: tmp/cache 15 | 16 | jobs: 17 | linux-cpu: 18 | name: "x86_64-linux-gnu-cpu" 19 | runs-on: ubuntu-24.04 20 | steps: 21 | # Free up space, see https://github.com/orgs/community/discussions/25678#discussioncomment-5242449 22 | - run: rm -rf /opt/hostedtoolcache 23 | - uses: actions/checkout@v4 24 | with: 25 | ref: ${{ env.NX_XLA_SHA }} 26 | - run: builds/build.sh cpu 27 | - run: .github/scripts/upload_artifact.sh ${{ env.NX_XLA_SHA }} ${{ env.NX_XLA_RELEASE_NAME }} builds/output/*/cache/*/build/* 28 | env: 29 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 30 | 31 | linux-cuda: 32 | name: "x86_64-linux-gnu-cuda12" 33 | runs-on: ubuntu-24.04 34 | steps: 35 | # Free up space, see https://github.com/orgs/community/discussions/25678#discussioncomment-5242449 36 | - run: rm -rf /opt/hostedtoolcache 37 | - uses: actions/checkout@v4 38 | with: 39 | ref: ${{ env.NX_XLA_SHA }} 40 | - run: builds/build.sh cuda12 41 | - run: .github/scripts/upload_artifact.sh ${{ env.NX_XLA_SHA }} ${{ env.NX_XLA_RELEASE_NAME }} builds/output/*/cache/*/build/* 42 | env: 43 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 44 | 45 | linux-tpu: 46 | name: "x86_64-linux-gnu-tpu" 47 | runs-on: ubuntu-24.04 48 | steps: 49 | # Free up space, see https://github.com/orgs/community/discussions/25678#discussioncomment-5242449 50 | - run: rm -rf /opt/hostedtoolcache 51 | - uses: actions/checkout@v4 52 | with: 53 | ref: ${{ env.NX_XLA_SHA }} 54 | - run: builds/build.sh tpu 55 | - run: .github/scripts/upload_artifact.sh ${{ env.NX_XLA_SHA }} ${{ env.NX_XLA_RELEASE_NAME }} builds/output/*/cache/*/build/* 56 | env: 57 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 58 | 59 | linux-arm-cpu: 60 | name: "aarch64-linux-gnu-cpu" 61 | runs-on: ubuntu-24.04-arm 62 | steps: 63 | # Free up space, see https://github.com/orgs/community/discussions/25678#discussioncomment-5242449 64 | - run: rm -rf /opt/hostedtoolcache 65 | - uses: actions/checkout@v4 66 | with: 67 | ref: ${{ env.NX_XLA_SHA }} 68 | - run: builds/build.sh cpu 69 | - run: .github/scripts/upload_artifact.sh ${{ env.NX_XLA_SHA }} ${{ env.NX_XLA_RELEASE_NAME }} builds/output/*/cache/*/build/* 70 | env: 71 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 72 | 73 | linux-arm-cuda: 74 | name: "aarch64-linux-gnu-cuda12" 75 | runs-on: ubuntu-24.04-arm 76 | steps: 77 | # Free up space, see https://github.com/orgs/community/discussions/25678#discussioncomment-5242449 78 | - run: rm -rf /opt/hostedtoolcache 79 | - uses: actions/checkout@v4 80 | with: 81 | ref: ${{ env.NX_XLA_SHA }} 82 | - run: builds/build.sh cuda12 83 | - run: .github/scripts/upload_artifact.sh ${{ env.NX_XLA_SHA }} ${{ env.NX_XLA_RELEASE_NAME }} builds/output/*/cache/*/build/* 84 | env: 85 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 86 | 87 | macos: 88 | name: "x86_64-darwin-cpu" 89 | runs-on: macos-13 90 | steps: 91 | - uses: actions/checkout@v4 92 | with: 93 | ref: ${{ env.NX_XLA_SHA }} 94 | - run: brew install elixir 95 | - run: mix local.hex --force 96 | # Setup the compilation environment 97 | - uses: bazel-contrib/setup-bazel@0.14.0 98 | - uses: actions/setup-python@v5 99 | with: 100 | python-version: "3.11" 101 | - run: python -m pip install --upgrade pip numpy 102 | # Build and upload the archive 103 | - run: mix deps.get 104 | - run: mix compile 105 | env: 106 | XLA_TARGET: cpu 107 | # This runner comes with Clang 14, which does not support the -mavxvnniint8 108 | # CLI flag. We can install newer Clang, however at some point Bazel toolchains 109 | # invoke xcrun clang, which always uses the system version from Xcode, ignoring 110 | # whichever version we installed ourselves. With the flag below, we make sure 111 | # this flag is not passed in the first place. 112 | # See https://github.com/tensorflow/tensorflow/pull/87514 113 | BUILD_FLAGS: "--define=xnn_enable_avxvnniint8=false" 114 | - run: .github/scripts/upload_artifact.sh ${{ env.NX_XLA_SHA }} ${{ env.NX_XLA_RELEASE_NAME }} tmp/cache/*/build/* 115 | env: 116 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 117 | 118 | macos-arm: 119 | name: "aarch64-darwin-cpu" 120 | runs-on: macos-14 121 | steps: 122 | - uses: actions/checkout@v4 123 | with: 124 | ref: ${{ env.NX_XLA_SHA }} 125 | - run: brew install elixir 126 | - run: mix local.hex --force 127 | # Setup the compilation environment 128 | - uses: bazel-contrib/setup-bazel@0.14.0 129 | - uses: actions/setup-python@v5 130 | with: 131 | python-version: "3.11" 132 | - run: python -m pip install --upgrade pip numpy 133 | # Build and upload the archive 134 | - run: mix deps.get 135 | - run: mix compile 136 | env: 137 | XLA_TARGET: cpu 138 | - run: .github/scripts/upload_artifact.sh ${{ env.NX_XLA_SHA }} ${{ env.NX_XLA_RELEASE_NAME }} tmp/cache/*/build/* 139 | env: 140 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 141 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # XLA 2 | 3 | 4 | 5 | Precompiled [XLA](https://github.com/openxla/xla) binaries for [EXLA](https://github.com/elixir-nx/nx/tree/main/exla). 6 | 7 | Currently supports UNIX systems, including macOS (although no built-in support for Apple Metal). 8 | Windows platforms are only supported upstream via [WSL](https://en.wikipedia.org/wiki/Windows_Subsystem_for_Linux). 9 | 10 | ## Usage 11 | 12 | EXLA already depends on this package, so you generally don't need to install it yourself. 13 | There is however a number of environment variables that you may want to use in order to 14 | customize the variant of XLA binary. 15 | 16 | The binaries are always built/downloaded to match the current configuration, so you should 17 | set the environment variables in `.bash_profile` or a similar configuration file so you don't 18 | need to export it in every shell session. 19 | 20 | #### `XLA_TARGET` 21 | 22 | The default value is usually `cpu`, which implies the final the binary supports targeting 23 | only the host CPU. If a matching CUDA version is detected, the target is set to CUDA accordingly. 24 | 25 | | Value | Target environment | 26 | | --- | --- | 27 | | cpu | | 28 | | tpu | libtpu | 29 | | cuda12 | CUDA >= 12.1, cuDNN >= 9.1 and < 10.0 | 30 | | cuda | CUDA x.y, cuDNN (building from source only) | 31 | | rocm | ROCm (building from source only) | 32 | 33 | To use XLA with NVidia GPU you need [CUDA](https://developer.nvidia.com/cuda-downloads) 34 | and [cuDNN](https://developer.nvidia.com/cudnn) compatible with your GPU drivers. 35 | See [the installation instructions](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html) 36 | and [the cuDNN support matrix](https://docs.nvidia.com/deeplearning/cudnn/support-matrix/index.html) 37 | for version compatibility. To use precompiled XLA binaries specify a target matching 38 | your CUDA version (like `cuda12`). You can find your CUDA version by running `nvcc --version` 39 | (note that `nvidia-smi` shows the highest supported CUDA version, not the installed one). 40 | When building from source it's enough to specify `cuda` as the target. 41 | 42 | Note that all precompiled Linux binaries assume glibc 2.31 or newer. 43 | 44 | ##### Notes for ROCm 45 | 46 | For GPU support, we primarily rely on CUDA, because of the popularity and availability 47 | in the cloud. In case you use ROCm and it does not work, please open up an issue and 48 | we will be happy to help. 49 | 50 | In addition to building in a local environment, you can build the ROCm binary using 51 | the Docker-based scripts in [`builds/`](https://github.com/elixir-nx/xla/tree/main/builds). You may want to adjust the ROCm 52 | version in `rocm.Dockerfile` accordingly. 53 | 54 | When you encounter errors at runtime, you may want to set `ROCM_PATH=/opt/rocm-6.0.0` 55 | and `LD_LIBRARY_PATH="/opt/rocm-6.0.0/lib"` (with your respective version). For further 56 | issues, feel free to open an issue. 57 | 58 | #### `XLA_BUILD` 59 | 60 | Defaults to `false`. If `true` the binary is built locally, which may be intended 61 | if no precompiled binary is available for your target environment. Once set, you 62 | must run `mix deps.clean xla --build` explicitly to force XLA to recompile. 63 | Building has a number of dependencies, see *Building from source* below. 64 | 65 | #### `XLA_ARCHIVE_URL` 66 | 67 | A URL pointing to a specific build of the `.tar.gz` archive. When using this option 68 | you need to make sure the build matches your OS, CPU architecture and the XLA target. 69 | 70 | #### `XLA_ARCHIVE_PATH` 71 | 72 | Just like `XLA_ARCHIVE_URL`, but pointing to a local `.tar.gz` archive file. 73 | 74 | #### `XLA_CACHE_DIR` 75 | 76 | The directory to store the downloaded and built archives in. Defaults to the standard 77 | cache location for the given operating system. 78 | 79 | #### `XLA_TARGET_PLATFORM` 80 | 81 | The target triplet describing the target platform, such as `aarch64-linux-gnu`. By default 82 | this target is inferred for the host, however you may want to override this when cross-compiling 83 | the project using Nerves. 84 | 85 | ## Building from source 86 | 87 | > Note: currently only macOS and Linux is supported. When on Windows, the best option 88 | > to use XLA and EXLA is by running inside WSL. 89 | 90 | To build the XLA binaries locally you need to set `XLA_BUILD=true` and possibly `XLA_TARGET`. 91 | Keep in mind that the compilation usually takes a very long time. 92 | 93 | You will need the following installed in your system for the compilation: 94 | 95 | * [Git](https://git-scm.com/) for fetching XLA source 96 | * [Bazel v7.4.1](https://bazel.build/) for compiling XLA 97 | * [Clang 18](https://clang.llvm.org/) for compiling XLA 98 | * [Python3](https://python.org) with NumPy installed for compiling XLA 99 | 100 | ### Common issues 101 | 102 | #### Bazel version 103 | 104 | Use `bazel --version` to check your Bazel version, make sure you are using v7.4.1. 105 | Most binaries are available on [Github](https://github.com/bazelbuild/bazel/releases), 106 | but it can also be installed with `asdf`: 107 | 108 | ```shell 109 | asdf plugin add bazel 110 | asdf install bazel 7.4.1 111 | asdf set -u bazel 7.4.1 112 | ``` 113 | 114 | #### Clang 115 | 116 | XLA builds are known to work with Clang 18. On macOS clang comes as part of Xcode SDK 117 | and the version may be older, though for macOS we have precompiled archives, so you 118 | most likely don't need to worry about it. 119 | 120 | #### Python and asdf 121 | 122 | `Bazel` cannot find `python` installed via the `asdf` version manager by default. `asdf` uses a 123 | function to lookup the specified version of a given binary, this approach prevents `Bazel` from 124 | being able to correctly build XLA. The error is `unknown command: python. Perhaps you have to reshim?`. 125 | There are two known workarounds: 126 | 127 | 1. Explicitly change your `$PATH` to point to a Python installation (note the build process 128 | looks for `python`, not `python3`). For example: 129 | 130 | ```shell 131 | # Point directly to a specific Python version 132 | export PATH=$HOME/.asdf/installs/python/3.10.8/bin:$PATH 133 | ``` 134 | 135 | 2. Use the [`asdf direnv`](https://github.com/asdf-community/asdf-direnv) plugin to install [`direnv 2.20.0`](https://direnv.net). 136 | `direnv` along with the `asdf-direnv` plugin will explicitly set the paths for any binary specified 137 | in your project's `.tool-versions` files. 138 | 139 | If you still get the error, you can also try setting `PYTHON_BIN_PATH`, like `export PYTHON_BIN_PATH=/usr/bin/python3.11`. 140 | 141 | After doing any of the steps above, it may be necessary to clear the build cache by removing ` ~/.cache/xla_build` 142 | (or the corresponding OS-specific cache location). 143 | 144 | ### GPU support 145 | 146 | To build binaries with GPU support, you need all the GPU-specific dependencies (CUDA, ROCm), 147 | then you can build with either `XLA_TARGET=cuda` or `XLA_TARGET=rocm`. See the `XLA_TARGET` 148 | for more details. 149 | 150 | ### TPU support 151 | 152 | All you need is setting `XLA_TARGET=tpu`. 153 | 154 | ### Compilation-specific environment variables 155 | 156 | You can use the following env vars to customize your build: 157 | 158 | * `BUILD_CACHE` - controls where to store XLA source and builds 159 | 160 | * `BUILD_FLAGS` - additional flags passed to Bazel 161 | 162 | * `BUILD_MODE` - controls to compile `opt` (default) artifacts or `dbg`, example: `BUILD_MODE=dbg` 163 | 164 | ## Runtime flags 165 | 166 | You can further configure XLA runtime options with `XLA_FLAGS`, 167 | see: [xla/debug_options_flags.cc](https://github.com/openxla/xla/blob/main/xla/debug_options_flags.cc) 168 | for the list of available flags. 169 | 170 | 171 | 172 | ## Release process 173 | 174 | To publish a new version of this package: 175 | 176 | 1. Update version in `mix.exs`. 177 | 2. Create and push a new tag. 178 | 3. Wait for the release workflow to build all the binaries. 179 | 4. Publish the release from draft. 180 | 5. Publish the package to Hex. 181 | 182 | ## License 183 | 184 | Note that the build artifacts are a result of compiling XLA, hence are under 185 | the respective license. See [XLA](https://github.com/openxla/xla). 186 | 187 | ```text 188 | Copyright (c) 2020 Sean Moriarity 189 | 190 | Licensed under the Apache License, Version 2.0 (the "License"); 191 | you may not use this file except in compliance with the License. 192 | You may obtain a copy of the License at [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0) 193 | 194 | Unless required by applicable law or agreed to in writing, software 195 | distributed under the License is distributed on an "AS IS" BASIS, 196 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 197 | See the License for the specific language governing permissions and 198 | limitations under the License. 199 | ``` 200 | -------------------------------------------------------------------------------- /extension/BUILD: -------------------------------------------------------------------------------- 1 | load("//xla/stream_executor:build_defs.bzl", "if_cuda_or_rocm",) 2 | load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda",) 3 | load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm",) 4 | load("//xla/tsl:tsl.bzl", "if_with_tpu_support") 5 | load("//xla/tsl:tsl.bzl", "tsl_grpc_cc_dependencies",) 6 | load("//xla/tsl:tsl.bzl", "transitive_hdrs",) 7 | load("@rules_pkg//pkg:tar.bzl", "pkg_tar") 8 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_file") 9 | 10 | package(default_visibility=["//visibility:private"]) 11 | 12 | # Shared library which contains the subset of XLA required for EXLA 13 | cc_binary( 14 | name = "libxla_extension.so", 15 | deps = [ 16 | "//xla:xla_proto_cc_impl", 17 | "//xla:xla_data_proto_cc_impl", 18 | "//xla:autotune_results_proto_cc_impl", 19 | "//xla:autotuning_proto_cc_impl", 20 | "//xla/service:hlo_proto_cc_impl", 21 | "//xla/service/memory_space_assignment:memory_space_assignment_proto_cc_impl", 22 | "//xla/service:buffer_assignment_proto_cc_impl", 23 | "//xla/service/gpu:backend_configs_cc_impl", 24 | "//xla/service/gpu/model:hlo_op_profile_proto_cc_impl", 25 | "//xla/stream_executor:device_description_proto_cc_impl", 26 | "//xla/stream_executor:stream_executor_impl", 27 | "//xla/stream_executor/gpu:gpu_init_impl", 28 | "//xla/stream_executor/host:host_platform", 29 | "//xla:literal", 30 | "//xla:shape_util", 31 | "//xla/tsl/platform:status", 32 | "//xla/tsl/platform:statusor", 33 | "//xla/tsl/concurrency:async_value", 34 | "@tsl//tsl/platform:platform_port", 35 | "//xla/service:custom_call_target_registry", 36 | "@llvm-project//llvm:config", 37 | "//xla:types", 38 | "//xla:util", 39 | "//xla/mlir/utils:error_util", 40 | "//xla/mlir_hlo", 41 | "//xla/mlir_hlo:all_passes", 42 | "//xla/pjrt:mlir_to_hlo", 43 | "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", 44 | "//xla/hlo/builder:xla_computation", 45 | "//xla/hlo/builder/lib:lu_decomposition", 46 | "//xla/hlo/builder/lib:math", 47 | "//xla/hlo/builder/lib:qr", 48 | "//xla/hlo/builder/lib:svd", 49 | "//xla/hlo/builder/lib:self_adjoint_eig", 50 | "//xla/hlo/builder/lib:sorting", 51 | "//xla/pjrt:pjrt_client", 52 | "//xla/pjrt:pjrt_compiler", 53 | "//xla/pjrt:tfrt_cpu_pjrt_client", 54 | "//xla/pjrt:pjrt_c_api_client", 55 | "//xla/pjrt/distributed", 56 | "//xla/pjrt/gpu:se_gpu_pjrt_client", 57 | "//xla/pjrt/distributed:client", 58 | "//xla/pjrt/distributed:service", 59 | "//xla/service:metrics_proto_cc_impl", 60 | "//xla/stream_executor/cuda:cuda_compute_capability_proto_cc_impl", 61 | "@com_google_absl//absl/types:span", 62 | "@com_google_absl//absl/types:optional", 63 | "@com_google_absl//absl/base:log_severity", 64 | "@com_google_protobuf//:protobuf", 65 | "@llvm-project//llvm:Support", 66 | "@llvm-project//mlir:FuncDialect", 67 | "@llvm-project//mlir:IR", 68 | "@llvm-project//mlir:Parser", 69 | "@llvm-project//mlir:Pass", 70 | "@llvm-project//mlir:ReconcileUnrealizedCasts", 71 | "@llvm-project//mlir:SparseTensorDialect", 72 | "@tsl//tsl/platform:errors", 73 | "@tsl//tsl/platform:fingerprint", 74 | "@ml_dtypes_py//ml_dtypes:float8", 75 | "@ml_dtypes_py//ml_dtypes:intn", 76 | "@ml_dtypes_py//ml_dtypes:mxfloat", 77 | "@tsl//tsl/platform:statusor", 78 | "@tsl//tsl/platform:env_impl", 79 | "@tsl//tsl/platform:tensor_float_32_utils", 80 | "//xla/tsl/profiler/utils:time_utils_impl", 81 | "//xla/tsl/profiler/backends/cpu:annotation_stack_impl", 82 | "//xla/tsl/profiler/backends/cpu:traceme_recorder_impl", 83 | "//xla/tsl/protobuf:protos_all_cc_impl", 84 | "//xla/tsl/protobuf:dnn_proto_cc_impl", 85 | "//xla/tsl/framework:allocator", 86 | "//xla/tsl/framework:allocator_registry_impl", 87 | "//xla/tsl/util:determinism", 88 | ] 89 | # GRPC Dependencies (needed for PjRt distributed) 90 | + tsl_grpc_cc_dependencies() 91 | + if_cuda_or_rocm([ 92 | "//xla/service:gpu_plugin", 93 | ]) 94 | + if_cuda([ 95 | "//xla/stream_executor:cuda_platform" 96 | ]) 97 | + if_rocm([ 98 | "//xla/stream_executor:rocm_platform" 99 | ]), 100 | copts = ["-fvisibility=default"], 101 | linkopts= select({ 102 | "//xla/tsl:macos": [ 103 | # We set the install_name, such that the library is looked up 104 | # in the RPATH at runtime, otherwise the install_name is an 105 | # arbitrary path within bazel workspace 106 | "-Wl,-install_name,@rpath/libxla_extension.so", 107 | # We set RPATH to the same dir as libxla_extension.so, so that 108 | # loading PjRt plugins in the same directory works out of the box 109 | "-Wl,-rpath,@loader_path/", 110 | ], 111 | "//conditions:default": [ 112 | "-Wl,-soname,libxla_extension.so", 113 | "-Wl,-rpath='$$ORIGIN'", 114 | ], 115 | }), 116 | features = ["-use_header_modules"], 117 | linkshared = 1, 118 | ) 119 | 120 | # Transitive hdrs gets all headers required by deps, including 121 | # transitive dependencies, it seems though it generates a lot 122 | # of unused headers as well 123 | transitive_hdrs( 124 | name = "xla_extension_dep_headers", 125 | deps = [ 126 | ":libxla_extension.so", 127 | ] 128 | ) 129 | 130 | # This is the genrule used by TF install headers to correctly 131 | # map headers into a directory structure 132 | genrule( 133 | name = "xla_extension_headers", 134 | srcs = [ 135 | ":xla_extension_dep_headers", 136 | ], 137 | outs = ["include.tar.gz"], 138 | cmd = """ 139 | HEADERS_DIR=$$(mktemp -d) 140 | for f in $(SRCS); do 141 | d="$${f%/*}" 142 | d="$${d#bazel-out/*/genfiles/}" 143 | d="$${d#bazel-out/*/bin/}" 144 | if [[ $${d} == *local_config_* ]]; then 145 | continue 146 | fi 147 | if [[ $${d} == external* ]]; then 148 | extname="$${d#*external/}" 149 | extname="$${extname%%/*}" 150 | if [[ $${TF_SYSTEM_LIBS:-} == *$${extname}* ]]; then 151 | continue 152 | fi 153 | d="$${d#*external/farmhash_archive/src}" 154 | d="$${d#*external/$${extname}/}" 155 | fi 156 | # Remap third party paths 157 | d="$${d/third_party\\/llvm_derived\\/include\\/llvm_derived/llvm_derived}" 158 | # Remap llvm paths 159 | d="$${d/llvm\\/include\\/llvm/llvm}" 160 | d="$${d/llvm\\/include\\/llvm-c/llvm-c}" 161 | # Remap mlir paths 162 | d="$${d/mlir\\/include\\/mlir/mlir}" 163 | # Remap google path 164 | d="$${d/src\\/google/google}" 165 | # Remap grpc paths 166 | d="$${d/include\\/grpc/grpc}" 167 | # Remap tfrt paths 168 | d="$${d/include\\/tfrt/tfrt}" 169 | # Remap ml_dtypes paths 170 | d="$${d/_virtual_includes\\/intn\\/ml_dtypes/ml_dtypes}" 171 | d="$${d/_virtual_includes\\/float8\\/ml_dtypes/ml_dtypes}" 172 | mkdir -p "$${HEADERS_DIR}/$${d}" 173 | cp "$${f}" "$${HEADERS_DIR}/$${d}/" 174 | done 175 | # Files in xla/mlir_hlo include sibling headers from mhlo, so we 176 | # need to mirror them in includes 177 | cp -r $${HEADERS_DIR}/xla/mlir_hlo/mhlo $${HEADERS_DIR} 178 | # Create the final tarball 179 | tar czf "$@" -C $${HEADERS_DIR} . 180 | rm -rf $${HEADERS_DIR} 181 | """, 182 | ) 183 | 184 | genrule( 185 | name = "libtpu_whl", 186 | outs = ["libtpu.whl"], 187 | cmd = """ 188 | libtpu_version="0.1.dev20231102" 189 | libtpu_storage_path="https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-$${libtpu_version}-py3-none-any.whl" 190 | wget -O "$@" "$$libtpu_storage_path" 191 | """ 192 | ) 193 | 194 | genrule( 195 | name = "libtpu_so", 196 | srcs = [ 197 | ":libtpu_whl" 198 | ], 199 | outs = ["libtpu.so"], 200 | cmd = """ 201 | unzip -p "$(SRCS)" libtpu/libtpu.so > "$@" 202 | """ 203 | ) 204 | 205 | # This genrule remaps libxla_extension.so to lib/libxla_extension.so 206 | genrule( 207 | name = "xla_extension_lib", 208 | srcs = [ 209 | ":libxla_extension.so", 210 | ] 211 | + if_with_tpu_support([ 212 | ":libtpu_so" 213 | ]), 214 | outs = ["lib.tar.gz"], 215 | cmd = """ 216 | LIB_DIR=$$(mktemp -d) 217 | mv $(SRCS) $${LIB_DIR}/ 218 | tar czf "$@" -C $${LIB_DIR} . 219 | rm -rf $${LIB_DIR} 220 | """ 221 | ) 222 | 223 | # See https://github.com/bazelbuild/rules_pkg/issues/517#issuecomment-1492917994 224 | genrule( 225 | name = "xla_extension", 226 | outs = ["xla_extension.tar.gz"], 227 | srcs = [ 228 | ":xla_extension_lib", 229 | ":xla_extension_headers", 230 | ], 231 | cmd = """ 232 | mkdir xla_extension 233 | # Extract the lib tarball 234 | mkdir xla_extension/lib 235 | tar xzf $(location :xla_extension_lib) -C xla_extension/lib 236 | # Extract the headers tarball 237 | mkdir xla_extension/include 238 | tar xzf $(location :xla_extension_headers) -C xla_extension/include 239 | # Create the final package tarball 240 | tar czf "$@" xla_extension 241 | """ 242 | ) 243 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | -------------------------------------------------------------------------------- /lib/xla.ex: -------------------------------------------------------------------------------- 1 | defmodule XLA do 2 | @external_resource "README.md" 3 | 4 | [_, readme_docs, _] = 5 | "README.md" 6 | |> File.read!() 7 | |> String.split("") 8 | 9 | @moduledoc readme_docs 10 | 11 | require Logger 12 | 13 | @version Mix.Project.config()[:version] 14 | 15 | @base_url "https://github.com/elixir-nx/xla/releases/download/v#{@version}" 16 | 17 | @precompiled_targets [ 18 | "x86_64-darwin-cpu", 19 | "aarch64-darwin-cpu", 20 | "x86_64-linux-gnu-cpu", 21 | "aarch64-linux-gnu-cpu", 22 | "x86_64-linux-gnu-cuda12", 23 | "aarch64-linux-gnu-cuda12", 24 | "x86_64-linux-gnu-tpu" 25 | ] 26 | 27 | @supported_xla_targets ["cpu", "cuda", "rocm", "tpu", "cuda12"] 28 | 29 | @doc """ 30 | Returns path to the precompiled XLA archive. 31 | 32 | Depending on the environment variables configuration, 33 | the path will point to either built or downloaded file. 34 | If not found locally, the file is downloaded when calling 35 | this function. 36 | """ 37 | @spec archive_path!() :: Path.t() 38 | def archive_path!() do 39 | XLA.Utils.start_inets_profile() 40 | 41 | cond do 42 | build?() -> 43 | # The archive should have already been built by this point 44 | archive_path_for_build() 45 | 46 | path = xla_archive_path() -> 47 | path 48 | 49 | url = xla_archive_url() -> 50 | path = archive_path_for_external_download(url) 51 | unless File.exists?(path), do: download_external!(url, path) 52 | path 53 | 54 | true -> 55 | path = archive_path_for_precompiled_download() 56 | unless File.exists?(path), do: download_precompiled!(path) 57 | path 58 | end 59 | after 60 | XLA.Utils.stop_inets_profile() 61 | end 62 | 63 | defp build?() do 64 | System.get_env("XLA_BUILD") in ~w(1 true) 65 | end 66 | 67 | defp xla_archive_path() do 68 | System.get_env("XLA_ARCHIVE_PATH") 69 | end 70 | 71 | defp xla_archive_url() do 72 | System.get_env("XLA_ARCHIVE_URL") 73 | end 74 | 75 | defp xla_target() do 76 | target = System.get_env("XLA_TARGET") || infer_xla_target() || "cpu" 77 | 78 | supported_xla_targets = @supported_xla_targets 79 | 80 | unless target in supported_xla_targets do 81 | listing = supported_xla_targets |> Enum.map(&inspect/1) |> Enum.join(", ") 82 | raise "expected XLA_TARGET to be one of #{listing}, but got: #{inspect(target)}" 83 | end 84 | 85 | target 86 | end 87 | 88 | defp infer_xla_target() do 89 | with nvcc when nvcc != nil <- System.find_executable("nvcc"), 90 | {output, 0} <- System.cmd(nvcc, ["--version"]) do 91 | if output =~ "release 12.", do: "cuda12" 92 | else 93 | _ -> nil 94 | end 95 | end 96 | 97 | defp xla_cache_dir() do 98 | # The directory where we store all the archives 99 | if dir = System.get_env("XLA_CACHE_DIR") do 100 | Path.expand(dir) 101 | else 102 | :filename.basedir(:user_cache, "xla") 103 | end 104 | end 105 | 106 | defp target() do 107 | case target_triplet() do 108 | {arch, os, nil} -> "#{arch}-#{os}-#{xla_target()}" 109 | {arch, os, abi} -> "#{arch}-#{os}-#{abi}-#{xla_target()}" 110 | end 111 | end 112 | 113 | defp target_triplet() do 114 | if target = System.get_env("XLA_TARGET_PLATFORM") do 115 | case String.split(target, "-") do 116 | [arch, os, abi] -> 117 | {arch, os, abi} 118 | 119 | [arch, os] -> 120 | {arch, os, nil} 121 | 122 | _other -> 123 | raise "expected XLA_TARGET_PLATFORM to be either ARCHITECTURE-OS-ABI or ARCHITECTURE-OS, got: #{target}" 124 | end 125 | else 126 | :erlang.system_info(:system_architecture) 127 | |> List.to_string() 128 | |> String.split("-") 129 | |> case do 130 | ["arm" <> _, _vendor, "darwin" <> _ | _] -> {"aarch64", "darwin", nil} 131 | [arch, _vendor, "darwin" <> _ | _] -> {arch, "darwin", nil} 132 | [arch, _vendor, os, abi] -> {arch, os, abi} 133 | [arch, _vendor, os] -> {arch, os, nil} 134 | ["win32"] -> {"x86_64", "windows", nil} 135 | end 136 | end 137 | end 138 | 139 | defp archive_path_for_build() do 140 | filename = archive_filename(target()) 141 | cache_path(["build", filename]) 142 | end 143 | 144 | defp archive_path_for_external_download(url) do 145 | hash = url |> :erlang.md5() |> Base.encode32(case: :lower, padding: false) 146 | filename = "xla_extension-#{hash}.tar.gz" 147 | cache_path(["external", filename]) 148 | end 149 | 150 | defp archive_path_for_precompiled_download() do 151 | filename = archive_filename(target()) 152 | cache_path(["download", filename]) 153 | end 154 | 155 | defp archive_filename(target) do 156 | "xla_extension-#{@version}-#{target}.tar.gz" 157 | end 158 | 159 | defp cache_path(parts) do 160 | base_dir = xla_cache_dir() 161 | Path.join([base_dir, @version | parts]) 162 | end 163 | 164 | defp download_external!(url, archive_path) do 165 | Logger.info("Downloading XLA archive from #{url}") 166 | 167 | case download_archive(url, archive_path) do 168 | :ok -> 169 | Logger.info("Successfully downloaded the XLA archive") 170 | 171 | {:error, message} -> 172 | File.rm(archive_path) 173 | raise message 174 | end 175 | end 176 | 177 | defp download_precompiled!(archive_path) do 178 | expected_filename = Path.basename(archive_path) 179 | 180 | target = target() 181 | precompiled_targets = precompiled_targets() 182 | 183 | if target not in precompiled_targets do 184 | listing = Enum.map_join(precompiled_targets, "\n", &(" * " <> &1)) 185 | 186 | raise """ 187 | no precompiled XLA archive available for this target: #{target}. 188 | 189 | The available targets are: 190 | 191 | #{listing} 192 | 193 | You can compile XLA locally by setting an environment variable: XLA_BUILD=true\ 194 | """ 195 | end 196 | 197 | Logger.info("Downloading a precompiled XLA archive for target #{target}") 198 | 199 | url = release_file_url(expected_filename) 200 | 201 | with :ok <- download_archive(url, archive_path), 202 | :ok <- verify_integrity(archive_path) do 203 | Logger.info("Successfully downloaded the XLA archive") 204 | else 205 | {:error, message} -> 206 | File.rm(archive_path) 207 | raise message 208 | end 209 | end 210 | 211 | defp release_file_url(filename) do 212 | @base_url <> "/" <> filename 213 | end 214 | 215 | defp download_archive(url, archive_path) do 216 | File.mkdir_p!(Path.dirname(archive_path)) 217 | 218 | file = File.stream!(archive_path) 219 | 220 | case XLA.Utils.download(url, file) do 221 | {:ok, _file} -> 222 | :ok 223 | 224 | {:error, message} -> 225 | {:error, "failed to download the XLA archive from #{url}, reason: #{message}"} 226 | end 227 | end 228 | 229 | defp verify_integrity(path) do 230 | filename = Path.basename(path) 231 | checksum = compute_file_checksum!(path) 232 | 233 | case read_checksums!() do 234 | %{^filename => ^checksum} -> 235 | :ok 236 | 237 | %{^filename => _} -> 238 | {:error, "the integrity check failed for file #{filename}, the checksum does not match"} 239 | 240 | %{} -> 241 | {:error, "no entry for file #{filename} in the checksum file"} 242 | end 243 | end 244 | 245 | @doc false 246 | def write_checksums!(%{} = checksums) do 247 | content = 248 | checksums 249 | |> Enum.sort() 250 | |> Enum.map_join("", fn {filename, checksum} -> 251 | checksum <> " " <> filename <> "\n" 252 | end) 253 | 254 | File.write!(checksum_path(), content) 255 | end 256 | 257 | defp read_checksums!() do 258 | content = File.read!(checksum_path()) 259 | 260 | for line <- String.split(content, "\n", trim: true), into: %{} do 261 | [checksum, filename] = String.split(line, " ") 262 | {filename, checksum} 263 | end 264 | end 265 | 266 | defp compute_file_checksum!(path) do 267 | path 268 | |> File.stream!([], 64_000) 269 | |> Enum.into(%XLA.Checksumer{}) 270 | end 271 | 272 | defp checksum_path() do 273 | # Note that this path points to the project source, which normally 274 | # may not be available at runtime (in releases). However, we expect 275 | # XLA to be called only during compilation, in which case this path 276 | # is still available 277 | Path.expand("../checksum.txt", __DIR__) 278 | end 279 | 280 | defp precompiled_targets(), do: @precompiled_targets 281 | 282 | # Used by tasks 283 | 284 | @doc false 285 | def build_archive_dir() do 286 | Path.dirname(archive_path_for_build()) 287 | end 288 | 289 | @doc false 290 | def version(), do: @version 291 | 292 | @doc false 293 | def archive_filename_with_target() do 294 | archive_filename(target()) 295 | end 296 | 297 | @doc false 298 | def precompiled_files() do 299 | for target <- @precompiled_targets do 300 | filename = archive_filename(target) 301 | url = release_file_url(filename) 302 | {filename, url} 303 | end 304 | end 305 | 306 | # Configuration for elixir_make 307 | 308 | @doc false 309 | def make_env() do 310 | bazel_build_flags_accelerator = 311 | case xla_target() do 312 | "cuda" <> _ -> 313 | [ 314 | # See https://github.com/google/jax/blob/66a92c41f6bac74960159645158e8d932ca56613/.bazelrc#L68 315 | "--config=cuda", 316 | # XLA downloads and uses the configured hermetic versions. 317 | ~s/--repo_env=HERMETIC_CUDA_VERSION="12.8.0"/, 318 | ~s/--repo_env=HERMETIC_CUDNN_VERSION="9.8.0"/, 319 | ~s/--action_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,sm_90,sm_100,compute_120"/, 320 | # See https://github.com/jax-ml/jax/blob/f2188786c225c7d16d8a7effd852470b2ad1b229/.bazelrc#L174-L176 321 | # (by default Jax compiles CUDA code is compiled with NVCC, so we do the same) 322 | ~s/--action_env=TF_NVCC_CLANG="1"/, 323 | "--@local_config_cuda//:cuda_compiler=nvcc" 324 | ] 325 | 326 | "rocm" <> _ -> 327 | [ 328 | "--config=rocm", 329 | "--action_env=HIP_PLATFORM=hcc", 330 | # See https://github.com/google/jax/blob/66a92c41f6bac74960159645158e8d932ca56613/.bazelrc#L128 331 | ~s/--action_env=TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100,gfx1200,gfx1201"/ 332 | ] 333 | 334 | "tpu" <> _ -> 335 | ["--define=with_tpu_support=true"] 336 | 337 | _ -> 338 | [] 339 | end 340 | 341 | bazel_build_flags_shared = [ 342 | # Always use Clang 343 | "--repo_env=CC=clang", 344 | "--repo_env=CXX=clang++", 345 | # See https://github.com/tensorflow/tensorflow/issues/62459#issuecomment-2043942557 346 | "--copt=-Wno-error=unused-command-line-argument", 347 | # See https://github.com/jax-ml/jax/blob/0842cc6f386a20aa20ed20691fb78a43f6c4a307/.bazelrc#L127-L138 348 | "--copt=-Wno-gnu-offsetof-extensions", 349 | "--copt=-Qunused-arguments", 350 | "--copt=-Wno-error=c23-extensions" 351 | ] 352 | 353 | bazel_build_flags = Enum.join(bazel_build_flags_accelerator ++ bazel_build_flags_shared, " ") 354 | 355 | # Additional environment variables passed to make 356 | %{ 357 | "BUILD_INTERNAL_FLAGS" => bazel_build_flags, 358 | "ROOT_DIR" => Path.expand("..", __DIR__), 359 | "BUILD_ARCHIVE" => archive_path_for_build(), 360 | "BUILD_ARCHIVE_DIR" => build_archive_dir(), 361 | "BUILD_CACHE_DIR" => :filename.basedir(:user_cache, "xla_build") 362 | } 363 | end 364 | end 365 | --------------------------------------------------------------------------------