├── config ├── dev.exs ├── prod.exs ├── config.exs └── test.exs ├── test ├── test_helper.exs └── beaver │ ├── defn │ ├── vulkan_expr_test.exs │ ├── attention_test.exs │ ├── loc_emb_test.exs │ ├── attn_vulkan_test.exs │ └── expr_test.exs │ └── nx_test.exs ├── lib └── manx │ ├── nx │ ├── defn_env.ex │ ├── type.ex │ ├── linalg.ex │ ├── slice.ex │ ├── interoperability.ex │ ├── batcher.ex │ ├── compiler.ex │ └── defn.ex │ ├── flag.ex │ ├── pass.ex │ ├── application.ex │ ├── memref_allocator.ex │ ├── lowering │ ├── vulkan.ex │ └── cpu.ex │ ├── assert.ex │ └── manx.ex ├── .formatter.exs ├── .gitignore ├── .github └── workflows │ └── elixir.yml ├── mix.exs ├── examples └── attn.exs ├── README.md └── mix.lock /config/dev.exs: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /config/prod.exs: -------------------------------------------------------------------------------- 1 | import Config 2 | -------------------------------------------------------------------------------- /test/test_helper.exs: -------------------------------------------------------------------------------- 1 | ExUnit.start() 2 | -------------------------------------------------------------------------------- /config/config.exs: -------------------------------------------------------------------------------- 1 | import Config 2 | 3 | import_config "#{config_env()}.exs" 4 | -------------------------------------------------------------------------------- /lib/manx/nx/defn_env.ex: -------------------------------------------------------------------------------- 1 | defmodule Manx.Defn.Env do 2 | defstruct block: nil, ctx: nil, gen_op: nil, gen_type: nil 3 | end 4 | -------------------------------------------------------------------------------- /lib/manx/flag.ex: -------------------------------------------------------------------------------- 1 | defmodule Manx.Flags do 2 | def print_ir?() do 3 | System.get_env("MANX_PRINT_IR") == "1" 4 | end 5 | end 6 | -------------------------------------------------------------------------------- /.formatter.exs: -------------------------------------------------------------------------------- 1 | # Used by "mix format" 2 | [ 3 | import_deps: [:nx], 4 | inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"] 5 | ] 6 | -------------------------------------------------------------------------------- /config/test.exs: -------------------------------------------------------------------------------- 1 | import Config 2 | 3 | config :beaver, 4 | skip_dialects: ~w{nvgpu 5 | x86vector 6 | vector 7 | omp 8 | emitc 9 | sparse_tensor 10 | amdgpu 11 | async 12 | llvm 13 | transform 14 | ml_program 15 | amx 16 | arm_neon 17 | spv 18 | quant 19 | arm_sve 20 | rocdl 21 | acc 22 | shape 23 | nvvm} 24 | -------------------------------------------------------------------------------- /lib/manx/pass.ex: -------------------------------------------------------------------------------- 1 | defmodule Manx.Lowering.Vulkan.PutSPVAttrPass do 2 | alias Beaver.MLIR 3 | import MLIR.Sigils 4 | 5 | use Beaver.MLIR.Pass, on: "gpu.func" 6 | 7 | @impl true 8 | def run(op) do 9 | [ 10 | "gpu.kernel": Beaver.MLIR.Attribute.unit(), 11 | "spirv.entry_point_abi": ~a{#spirv.entry_point_abi} 12 | ] 13 | |> Enum.each(fn {name, attr} -> 14 | ctx = MLIR.CAPI.mlirOperationGetContext(op) 15 | attr = Beaver.Deferred.create(attr, ctx) 16 | MLIR.CAPI.mlirOperationSetAttributeByName(op, MLIR.StringRef.create(name), attr) 17 | end) 18 | 19 | :ok 20 | end 21 | end 22 | -------------------------------------------------------------------------------- /lib/manx/application.ex: -------------------------------------------------------------------------------- 1 | defmodule Manx.Application do 2 | # See https://hexdocs.pm/elixir/Application.html 3 | # for more information on OTP Applications 4 | @moduledoc false 5 | 6 | use Application 7 | 8 | @impl true 9 | def start(_type, _args) do 10 | children = [ 11 | # Starts a worker by calling: Manx.Worker.start_link(arg) 12 | # {Manx.Worker, arg} 13 | Manx.MemrefAllocator 14 | ] 15 | 16 | # See https://hexdocs.pm/elixir/Supervisor.html 17 | # for other strategies and supported options 18 | opts = [strategy: :one_for_one, name: Manx.Supervisor] 19 | Supervisor.start_link(children, opts) 20 | end 21 | end 22 | -------------------------------------------------------------------------------- /lib/manx/memref_allocator.ex: -------------------------------------------------------------------------------- 1 | defmodule Manx.MemrefAllocator do 2 | @moduledoc """ 3 | MemrefAllocator is an Agent managing memrefs. 4 | """ 5 | use Agent 6 | 7 | def start_link(_) do 8 | Agent.start_link( 9 | fn -> :ets.new(__MODULE__, [:named_table, :public, read_concurrency: true]) end, 10 | name: __MODULE__ 11 | ) 12 | end 13 | 14 | def add(memref) do 15 | :ets.insert(__MODULE__, {memref}) 16 | end 17 | 18 | def delete(memref) do 19 | found = :ets.lookup(__MODULE__, memref) 20 | 21 | if length(found) >= 1 do 22 | :ets.delete(__MODULE__, memref) 23 | :ok 24 | else 25 | :already_deallocated 26 | end 27 | end 28 | end 29 | -------------------------------------------------------------------------------- /lib/manx/nx/type.ex: -------------------------------------------------------------------------------- 1 | defmodule Manx.Type do 2 | require Beaver.MLIR 3 | alias Beaver.MLIR.Type 4 | 5 | @moduledoc """ 6 | Helper functions for defining functions and operations. 7 | """ 8 | 9 | def gen_type({:u, size}), do: Type.i(size) 10 | def gen_type({:s, size}), do: Type.i(size) 11 | def gen_type({:f, size}), do: Type.f(size) 12 | def gen_type({:c, size}), do: Type.complex(Type.f(div(size, 2))) 13 | 14 | def gen_type(%Nx.Tensor{shape: shape, type: type}) do 15 | Tuple.to_list(shape) 16 | |> Type.ranked_tensor(gen_type(type)) 17 | end 18 | 19 | def gen_type(tuple) when is_tuple(tuple) do 20 | Tuple.to_list(tuple) 21 | |> Enum.map(&gen_type/1) 22 | |> Type.tuple() 23 | end 24 | end 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # The directory Mix will write compiled artifacts to. 2 | /_build/ 3 | 4 | # If you run "mix test --cover", coverage assets end up here. 5 | /cover/ 6 | 7 | # The directory Mix downloads your dependencies sources to. 8 | /deps/ 9 | 10 | # Where third-party dependencies like ExDoc output generated docs. 11 | /doc/ 12 | 13 | # Ignore .fetch files in case you like to edit your project deps locally. 14 | /.fetch 15 | 16 | # If the VM crashes, it generates a dump, let's ignore it too. 17 | erl_crash.dump 18 | 19 | # Also ignore archive artifacts (built via "mix archive.build"). 20 | *.ez 21 | 22 | # Ignore package tarball (built via "mix hex.build"). 23 | manx-*.tar 24 | 25 | # Temporary files, for example, from tests. 26 | /tmp/ 27 | -------------------------------------------------------------------------------- /.github/workflows/elixir.yml: -------------------------------------------------------------------------------- 1 | name: Elixir CI 2 | 3 | on: 4 | pull_request: 5 | branches: ["main"] 6 | 7 | permissions: 8 | contents: read 9 | 10 | concurrency: 11 | group: manx-build-and-test-${{ github.ref }} 12 | cancel-in-progress: true 13 | 14 | jobs: 15 | build: 16 | name: Build and test 17 | runs-on: ubuntu-latest 18 | strategy: 19 | matrix: 20 | otp: ["25.0"] 21 | elixir: ["1.14.0"] 22 | steps: 23 | - uses: actions/checkout@v2 24 | - uses: erlef/setup-beam@v1 25 | with: 26 | otp-version: ${{matrix.otp}} 27 | elixir-version: ${{matrix.elixir}} 28 | - name: Restore dependencies cache 29 | uses: actions/cache@v3 30 | with: 31 | path: deps 32 | key: ${{ runner.os }}-mix-${{ hashFiles('**/mix.lock') }} 33 | restore-keys: ${{ runner.os }}-mix- 34 | - name: Install dependencies 35 | run: mix deps.get 36 | - name: Run tests 37 | run: | 38 | mix test --exclude vulkan --exclude todo --exclude runtime 39 | -------------------------------------------------------------------------------- /mix.exs: -------------------------------------------------------------------------------- 1 | defmodule Manx.MixProject do 2 | use Mix.Project 3 | 4 | def project do 5 | [ 6 | app: :manx, 7 | version: "0.1.5-dev", 8 | elixir: "~> 1.13", 9 | start_permanent: Mix.env() == :prod, 10 | deps: deps(), 11 | description: description(), 12 | package: package() 13 | ] 14 | end 15 | 16 | # Run "mix help compile.app" to learn about applications. 17 | def application do 18 | [ 19 | extra_applications: [:logger], 20 | mod: {Manx.Application, []} 21 | ] 22 | end 23 | 24 | # Run "mix help deps" to learn about dependencies. 25 | defp deps do 26 | [ 27 | {:nx, "~> 0.7"}, 28 | {:beaver, "~> 0.3.5"}, 29 | {:ex_doc, ">= 0.0.0", only: :dev, runtime: false} 30 | ] 31 | end 32 | 33 | defp description() do 34 | "MLIR backend for Nx" 35 | end 36 | 37 | defp package() do 38 | [ 39 | licenses: ["Apache-2.0", "MIT"], 40 | links: %{"GitHub" => "https://github.com/beaver-project/beaver"}, 41 | files: ~w{ 42 | lib .formatter.exs mix.exs README* 43 | } 44 | ] 45 | end 46 | end 47 | -------------------------------------------------------------------------------- /examples/attn.exs: -------------------------------------------------------------------------------- 1 | Nx.Defn.default_options(compiler: Manx.Compiler, default_backends: {Manx, device: :vulkan}) 2 | defmodule ManxVulkanAttention do 3 | import Nx.Defn 4 | 5 | 6 | defn softmax(t) do 7 | Nx.exp(t) / Nx.sum(Nx.exp(t), axes: [-1], keep_axes: true) 8 | end 9 | 10 | defn batched_dot(t1, t2) do 11 | Nx.dot(t1, [2], [0], t2, [1], [0]) 12 | end 13 | 14 | @doc """ 15 | dim is the dimension of each head 16 | """ 17 | defn scaled_dot_product_attention(query, key, value, dim) do 18 | score = Nx.dot(query, [2], [0], key, [2], [0]) / Nx.sqrt(dim) 19 | attn = softmax(score) 20 | Nx.dot(attn, [2], [0], value, [1], [0]) 21 | end 22 | end 23 | 24 | 25 | query = Nx.iota({4, 3, 2}, type: {:f, 32}) |> Nx.divide(10.0) 26 | query = Nx.backend_transfer(query, {Manx, device: :vulkan}) 27 | key = Nx.iota({4, 3, 2}, type: {:f, 32}) |> Nx.divide(10.0) 28 | key = Nx.backend_transfer(key, {Manx, device: :vulkan}) 29 | value = Nx.iota({4, 3, 2}, type: {:f, 32}) |> Nx.divide(10.0) 30 | value = Nx.backend_transfer(value, {Manx, device: :vulkan}) 31 | 32 | for i <- 1..100 do 33 | r = try do 34 | ManxVulkanAttention.scaled_dot_product_attention(query, key, value, 12) 35 | :ok 36 | rescue e -> 37 | :error 38 | end 39 | case r do 40 | :ok -> 41 | nil 42 | raise "ok" 43 | :error -> 44 | nil 45 | end 46 | end 47 | -------------------------------------------------------------------------------- /test/beaver/defn/vulkan_expr_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Manx.VulkanExprTest do 2 | use ExUnit.Case, async: true 3 | import Nx.Defn 4 | import Manx.Assert 5 | 6 | @moduletag :nx 7 | @moduletag :vulkan 8 | setup do 9 | Nx.Defn.default_options(compiler: Manx.Compiler, default_backends: {Manx, device: :vulkan}) 10 | :ok 11 | end 12 | 13 | describe "unary float ops" do 14 | @float_tensor Nx.tensor([1.0, 2.0, 3.0]) 15 | defn unary_sin(t) do 16 | Nx.sin(t) 17 | end 18 | 19 | test "sin" do 20 | t_vulkan = Nx.backend_transfer(@float_tensor, {Manx, device: :vulkan}) 21 | assert_all_close(unary_sin(t_vulkan), evaluate(&unary_sin/1, [@float_tensor])) 22 | end 23 | 24 | @int_tensor_a Nx.tensor([1, 2, 3], type: {:u, 32}) 25 | @int_tensor_b Nx.tensor([4, 5, 6], type: {:u, 32}) 26 | defn binary_add(a, b) do 27 | Nx.add(a, b) 28 | end 29 | 30 | test "add" do 31 | a = Nx.backend_transfer(@int_tensor_a, {Manx, device: :vulkan}) 32 | b = Nx.backend_transfer(@int_tensor_b, {Manx, device: :vulkan}) 33 | assert [5, 7, 9] == binary_add(a, b) |> Nx.to_flat_list() 34 | end 35 | 36 | defn unary_add(a) do 37 | Nx.add(a, @int_tensor_b) 38 | end 39 | 40 | test "add a constant" do 41 | a = Nx.backend_transfer(@int_tensor_a, {Manx, device: :vulkan}) 42 | assert [5, 7, 9] == unary_add(a) |> Nx.to_flat_list() 43 | end 44 | end 45 | end 46 | -------------------------------------------------------------------------------- /test/beaver/defn/attention_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Manx.AttentionTest do 2 | use ExUnit.Case, async: true 3 | import Nx.Defn 4 | import Manx.Assert 5 | 6 | @moduletag :nx 7 | @moduletag :attention 8 | setup do 9 | Nx.Defn.default_options(compiler: Manx.Compiler) 10 | :ok 11 | end 12 | 13 | # original implementation from: https://github.com/sooftware/attentions/blob/master/attentions.py 14 | describe "attention" do 15 | defn(softmax(t), do: Nx.exp(t) / Nx.sum(Nx.exp(t), axes: [-1], keep_axes: true)) 16 | 17 | defn(batched_dot(t1, t2), do: Nx.dot(t1, [2], [0], t2, [1], [0])) 18 | 19 | @doc """ 20 | dim is the dimension of each head 21 | """ 22 | defn scaled_dot_product_attention(dim, query, key, value) do 23 | score = Nx.dot(query, [2], [0], key, [2], [0]) / Nx.sqrt(dim) 24 | attn = softmax(score) 25 | Nx.dot(attn, [2], [0], value, [1], [0]) 26 | end 27 | 28 | test "dot product attention" do 29 | # do a divide to prevent overflow 30 | query = Nx.iota({4, 3, 2}, type: {:f, 32}) |> Nx.divide(10.0) 31 | key = Nx.iota({4, 3, 2}, type: {:f, 32}) |> Nx.divide(10.0) 32 | value = Nx.iota({4, 3, 2}, type: {:f, 32}) |> Nx.divide(10.0) 33 | 34 | assert_all_close( 35 | scaled_dot_product_attention(12, query, key, value), 36 | Nx.tensor([ 37 | [[0.2008, 0.3008], [0.2038, 0.3038], [0.2069, 0.3069]], 38 | [[0.8100, 0.9100], [0.8131, 0.9131], [0.8161, 0.9161]], 39 | [[1.4192, 1.5192], [1.4222, 1.5222], [1.4253, 1.5253]], 40 | [[2.0283, 2.1283], [2.0313, 2.1313], [2.0343, 2.1343]] 41 | ]) 42 | ) 43 | end 44 | end 45 | end 46 | -------------------------------------------------------------------------------- /test/beaver/defn/loc_emb_test.exs: -------------------------------------------------------------------------------- 1 | defmodule TestEmbedding do 2 | @moduledoc """ 3 | embedding implementation from bumblebee 4 | """ 5 | import Nx.Defn 6 | 7 | defn timestep_sinusoidal_embedding_impl(timestep, opts \\ []) do 8 | opts = 9 | keyword!(opts, [ 10 | :embedding_size, 11 | flip_sin_to_cos: false, 12 | frequency_correction_term: 1, 13 | scale: 1, 14 | max_period: 10_000, 15 | mode: :train 16 | ]) 17 | 18 | embedding_size = opts[:embedding_size] 19 | max_period = opts[:max_period] 20 | frequency_correction_term = opts[:frequency_correction_term] 21 | 22 | if rem(embedding_size, 2) != 0 do 23 | raise ArgumentError, 24 | "expected embedding size to an even number, but got: #{inspect(embedding_size)}" 25 | end 26 | 27 | half_size = div(embedding_size, 2) 28 | 29 | frequency = 30 | Nx.exp(-Nx.log(max_period) * Nx.iota({half_size}) / (half_size - frequency_correction_term)) 31 | 32 | angle = Nx.new_axis(timestep, -1) * Nx.new_axis(frequency, 0) 33 | angle = opts[:scale] * angle 34 | 35 | if opts[:flip_sin_to_cos] do 36 | Nx.concatenate([Nx.cos(angle), Nx.sin(angle)], axis: -1) 37 | else 38 | Nx.concatenate([Nx.sin(angle), Nx.cos(angle)], axis: -1) 39 | end 40 | end 41 | end 42 | 43 | defmodule Beaver.Defn.LocEmbTest do 44 | use ExUnit.Case, async: true 45 | 46 | @moduletag :nx 47 | @moduletag :runtime 48 | setup do 49 | Nx.Defn.default_options(compiler: Manx.Compiler) 50 | :ok 51 | end 52 | 53 | test "time emb" do 54 | TestEmbedding.timestep_sinusoidal_embedding_impl(100, embedding_size: 10) 55 | end 56 | end 57 | -------------------------------------------------------------------------------- /test/beaver/defn/attn_vulkan_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Manx.VulkanAttentionTest do 2 | use ExUnit.Case, async: true 3 | import Nx.Defn 4 | import Manx.Assert 5 | 6 | @moduletag :nx 7 | @moduletag :attention 8 | @moduletag :vulkan 9 | setup do 10 | Nx.Defn.default_options(compiler: Manx.Compiler, default_backends: {Manx, device: :vulkan}) 11 | :ok 12 | end 13 | 14 | # original implementation from: https://github.com/sooftware/attentions/blob/master/attentions.py 15 | describe "attention" do 16 | defn softmax(t) do 17 | Nx.exp(t) / Nx.sum(Nx.exp(t), axes: [-1], keep_axes: true) 18 | end 19 | 20 | defn batched_dot(t1, t2) do 21 | Nx.dot(t1, [2], [0], t2, [1], [0]) 22 | end 23 | 24 | @doc """ 25 | dim is the dimension of each head 26 | """ 27 | defn scaled_dot_product_attention(query, key, value, dim) do 28 | score = Nx.dot(query, [2], [0], key, [2], [0]) / Nx.sqrt(dim) 29 | attn = softmax(score) 30 | Nx.dot(attn, [2], [0], value, [1], [0]) 31 | end 32 | 33 | test "dot product attention" do 34 | # do a divide to prevent overflow 35 | query = Nx.iota({4, 3, 2}, type: {:f, 32}) |> Nx.divide(10.0) 36 | query = Nx.backend_transfer(query, {Manx, device: :vulkan}) 37 | key = Nx.iota({4, 3, 2}, type: {:f, 32}) |> Nx.divide(10.0) 38 | key = Nx.backend_transfer(key, {Manx, device: :vulkan}) 39 | value = Nx.iota({4, 3, 2}, type: {:f, 32}) |> Nx.divide(10.0) 40 | value = Nx.backend_transfer(value, {Manx, device: :vulkan}) 41 | 42 | assert_raise RuntimeError, ~r"Unexpected failure running passes", fn -> 43 | assert_all_close( 44 | scaled_dot_product_attention(query, key, value, 12), 45 | Nx.tensor([ 46 | [[0.2008, 0.3008], [0.2038, 0.3038], [0.2069, 0.3069]], 47 | [[0.8100, 0.9100], [0.8131, 0.9131], [0.8161, 0.9161]], 48 | [[1.4192, 1.5192], [1.4222, 1.5222], [1.4253, 1.5253]], 49 | [[2.0283, 2.1283], [2.0313, 2.1313], [2.0343, 2.1343]] 50 | ]) 51 | ) 52 | end 53 | end 54 | end 55 | end 56 | -------------------------------------------------------------------------------- /test/beaver/nx_test.exs: -------------------------------------------------------------------------------- 1 | defmodule BeaverNxTest do 2 | @moduledoc """ 3 | Tests for compliance with the Nx backend behavior. Many of these tests are adapted from EXLA 4 | """ 5 | use ExUnit.Case, async: true 6 | doctest Manx 7 | 8 | @moduletag :nx 9 | 10 | setup do 11 | Nx.default_backend(Manx) 12 | :ok 13 | end 14 | 15 | test "Nx.to_binary/1" do 16 | t = Nx.tensor([1, 2, 3, 4], backend: Manx) 17 | assert Nx.to_binary(t) == <<1::64-native, 2::64-native, 3::64-native, 4::64-native>> 18 | assert Nx.to_binary(t, limit: 2) == <<1::64-native, 2::64-native>> 19 | assert Nx.to_binary(t, limit: 6) == <<1::64-native, 2::64-native, 3::64-native, 4::64-native>> 20 | end 21 | 22 | test "Nx.backend_transfer/1" do 23 | t = Nx.tensor([1, 2, 3, 4]) 24 | 25 | et = Nx.backend_transfer(t, {Manx, device_id: 0}) 26 | assert %Manx{memory: %{}} = et.data 27 | 28 | nt = Nx.backend_transfer(et) 29 | assert Nx.to_binary(nt) == <<1::64-native, 2::64-native, 3::64-native, 4::64-native>> 30 | 31 | assert_raise RuntimeError, ~r"called on deleted or donated buffer", fn -> 32 | Nx.backend_transfer(et) 33 | end 34 | end 35 | 36 | test "Nx.backend_copy/1" do 37 | t = Nx.tensor([1, 2, 3, 4]) 38 | 39 | et = Nx.backend_transfer(t, Manx) 40 | assert %Manx{memory: %{} = old_buffer} = et.data 41 | 42 | # Copy to the same client/device_id still makes a copy 43 | et = Nx.backend_copy(t, Manx) 44 | assert %Manx{memory: %{} = new_buffer} = et.data 45 | assert old_buffer != new_buffer 46 | 47 | nt = Nx.backend_copy(et) 48 | assert Nx.to_binary(nt) == <<1::64-native, 2::64-native, 3::64-native, 4::64-native>> 49 | 50 | nt = Nx.backend_copy(et) 51 | assert Nx.to_binary(nt) == <<1::64-native, 2::64-native, 3::64-native, 4::64-native>> 52 | end 53 | 54 | test "Kernel.inspect/2" do 55 | t = Nx.tensor([1, 2, 3, 4], backend: Manx) 56 | 57 | assert inspect(t) == 58 | """ 59 | #Nx.Tensor< 60 | s64[4] 61 | [1, 2, 3, 4] 62 | >\ 63 | """ 64 | end 65 | end 66 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🐱 Manx 🐈 2 | 3 | **M**LIR-**A**ccelerated-**Nx**. MLIR compiler/backend for the [Nx](https://github.com/elixir-nx/nx/tree/main/nx#readme). 4 | 5 | ## What does this library do? 6 | 7 | You can think of Manx as IREE implemented in Elixir and unlike IREE with new dedicated runtime Manx uses BEAM as the runtime. Nx's expressions are very close to XLA's MHLO so Manx would borrow a lot of conversion/lowering implementations from [XLA](https://github.com/openxla/xla) and [IREE](https://github.com/iree-org/iree). 8 | 9 | ## Why do we need it? 10 | 11 | - Instead of repurposing compilers built for Python, Manx is about building a Nx compiler in Elixir and tailored for Elixir. 12 | - With Manx, "Tensor compiler" is no longer a giant black box for Erlang world anymore. A non-python programming language should have its full-stack data/ML solution so that it could be truly maintainable. 13 | - Tighter integration with BEAM. We can build passes and optimizations for Elixir and BEAM and even generate LLVM instructions to send messages or allocate memory with Erlang's allocator. 14 | - There is a great gap between the understanding "distributed system" in ML and non-ML applications (MPI vs. fault-tolerance). With Manx we could narrow the gap by implementing a ML compiler with a programming language with strong fault-tolerance capability. 15 | 16 | ## Compared to EXLA 17 | 18 | - [EXLA](https://github.com/elixir-nx/nx/tree/main/exla) is the Nx backend for XLA. 19 | - In the short run, Manx's performance won't be on-per with XLA/EXLA's. 20 | 21 | - EXLA's lowering: 22 | 23 | ``` 24 | Nx |> EXLA |> XLA |> MLIR |> LLVM |> hardware 25 | ``` 26 | 27 | - Manx's lowering 28 | 29 | ``` 30 | Nx |> Manx |> MLIR |> LLVM |> hardware 31 | ``` 32 | 33 | ## Installation 34 | 35 | If [available in Hex](https://hex.pm/docs/publish), the package can be installed 36 | by adding `manx` to your list of dependencies in `mix.exs`: 37 | 38 | ```elixir 39 | def deps do 40 | [ 41 | {:manx, "~> 0.1.0"} 42 | ] 43 | end 44 | ``` 45 | 46 | Documentation can be generated with [ExDoc](https://github.com/elixir-lang/ex_doc) 47 | and published on [HexDocs](https://hexdocs.pm). Once published, the docs can 48 | be found at . 49 | -------------------------------------------------------------------------------- /lib/manx/lowering/vulkan.ex: -------------------------------------------------------------------------------- 1 | defmodule Manx.Lowering.Vulkan do 2 | alias Beaver.MLIR 3 | import MLIR.{Transforms, Conversion} 4 | 5 | def lower(op) do 6 | op 7 | |> MLIR.Operation.verify!(dump_if_fail: true) 8 | |> canonicalize 9 | |> MLIR.Pass.Composer.nested( 10 | "func.func", 11 | ~w{tosa-make-broadcastable llvm-request-c-wrappers tosa-layerwise-constant-fold} 12 | ) 13 | |> cse 14 | |> tosa_to_arith 15 | |> tosa_to_tensor() 16 | |> convert_tensor_to_linalg() 17 | |> MLIR.Pass.Composer.nested("func.func", [ 18 | tosa_to_linalg_named(), 19 | tosa_to_linalg(), 20 | linalg_generalize_named_ops(), 21 | linalg_fuse_elementwise_ops(), 22 | linalg_bufferize(), 23 | convert_linalg_to_parallel_loops(), 24 | gpu_map_parallel_loops() 25 | ]) 26 | |> MLIR.Pass.Composer.append("arith-bufferize,func-bufferize") 27 | |> convert_parallel_loops_to_gpu() 28 | |> gpu_launch_sink_index_computations() 29 | |> gpu_kernel_outlining() 30 | |> MLIR.Pass.Composer.nested("gpu.module", [ 31 | { 32 | :nested, 33 | "gpu.func", 34 | [ 35 | lower_affine(), 36 | MLIR.ExternalPass.create(__MODULE__.PutSPVAttrPass) 37 | ] 38 | } 39 | ]) 40 | |> MLIR.Pass.Composer.nested("func.func", "tensor-bufferize") 41 | |> MLIR.Pass.Composer.nested("gpu.module", [ 42 | { 43 | :nested, 44 | "gpu.func", 45 | [ 46 | convert_memref_to_spirv(), 47 | convert_math_to_spirv(), 48 | convert_arith_to_spirv(), 49 | convert_cf_to_spirv(), 50 | convert_tensor_to_spirv(), 51 | convert_vector_to_spirv(), 52 | convert_func_to_spirv(), 53 | convert_scf_to_spirv() 54 | ] 55 | } 56 | ]) 57 | |> convert_gpu_to_spirv() 58 | |> MLIR.Pass.Composer.nested( 59 | "spirv.module", 60 | ~w{spirv-lower-abi-attrs spirv-update-vce} 61 | ) 62 | |> convert_gpu_launch_to_vulkan_launch 63 | |> MLIR.Pass.Composer.append("expand-strided-metadata") 64 | |> MLIR.Pass.Composer.append("finalize-memref-to-llvm") 65 | |> MLIR.Pass.Composer.nested("func.func", "llvm-request-c-wrappers") 66 | |> convert_complex_to_standard() 67 | |> convert_vector_to_llvm 68 | |> convert_complex_to_llvm() 69 | |> convert_func_to_llvm 70 | |> reconcile_unrealized_casts 71 | |> launch_func_to_vulkan 72 | |> MLIR.Pass.Composer.run(dump_if_fail: false, print: Manx.Flags.print_ir?()) 73 | end 74 | end 75 | -------------------------------------------------------------------------------- /lib/manx/assert.ex: -------------------------------------------------------------------------------- 1 | defmodule Manx.Assert do 2 | import Nx.Defn 3 | 4 | @moduledoc """ 5 | Tensor assertions. Original implementation is from EXLA. 6 | """ 7 | 8 | defmacro assert_equal(left, right) do 9 | # Assert against binary backend tensors to show diff on failure 10 | quote do 11 | assert unquote(left) |> Manx.Assert.to_binary_backend() == 12 | unquote(right) |> Manx.Assert.to_binary_backend() 13 | end 14 | end 15 | 16 | defmacro assert_not_equal(left, right) do 17 | # Assert against binary backend tensors to show diff on failure 18 | quote do 19 | assert unquote(left) |> Manx.Assert.to_binary_backend() != 20 | unquote(right) |> Manx.Assert.to_binary_backend() 21 | end 22 | end 23 | 24 | def to_binary_backend(tensor) do 25 | Nx.backend_copy(tensor, Nx.BinaryBackend) 26 | end 27 | 28 | defn all_close_jit(a, b, opts \\ []) do 29 | import Nx 30 | opts = keyword!(opts, equal_nan: false, rtol: 1.0e-5, atol: 1.0e-8, both_integer: false) 31 | both_integer = opts[:both_integer] 32 | rtol = opts[:rtol] 33 | atol = opts[:atol] 34 | 35 | a = to_tensor(a) 36 | b = to_tensor(b) 37 | 38 | finite_entries = less_equal(Nx.abs(subtract(a, b)), add(atol, multiply(rtol, Nx.abs(b)))) 39 | 40 | if both_integer do 41 | all(finite_entries) 42 | else 43 | # inf - inf is a nan, however, they are equal, 44 | # so we explicitly check for equal entries. 45 | inf_a = is_infinity(a) 46 | inf_b = is_infinity(b) 47 | inf_entries = select(logical_or(inf_a, inf_b), equal(a, b), finite_entries) 48 | 49 | if opts[:equal_nan] do 50 | nan_a = is_nan(a) 51 | nan_b = is_nan(b) 52 | nan_entries = logical_and(nan_a, nan_b) 53 | all(select(nan_entries, 1, inf_entries)) 54 | else 55 | all(inf_entries) 56 | end 57 | end 58 | end 59 | 60 | # def all_close(left, right, opts \\ []) do 61 | # true 62 | # end 63 | 64 | def all_close(left, right, opts \\ []) do 65 | atol = Keyword.get(opts, :atol, 1.0e-4) 66 | rtol = Keyword.get(opts, :rtol, 1.0e-4) 67 | 68 | equals = 69 | left 70 | |> all_close_jit(right, 71 | atol: atol, 72 | rtol: rtol, 73 | both_integer: Nx.Type.integer?(left.type) and Nx.Type.integer?(right.type) 74 | ) 75 | |> Nx.backend_transfer(Nx.BinaryBackend) 76 | 77 | if equals != Nx.tensor(1, type: {:u, 8}, backend: Nx.BinaryBackend) do 78 | raise(""" 79 | expected 80 | 81 | #{inspect(left)} 82 | 83 | to be within tolerance of 84 | 85 | #{inspect(right)} 86 | """) 87 | end 88 | end 89 | 90 | defmacro assert_all_close(left, right, opts \\ []) do 91 | quote bind_quoted: [ 92 | left: left, 93 | right: right, 94 | opts: opts 95 | ] do 96 | all_close(left, right, opts) 97 | end 98 | end 99 | 100 | def evaluate(fun, args) do 101 | fun |> Nx.Defn.jit(compiler: Nx.Defn.Evaluator) |> apply(args) 102 | end 103 | end 104 | -------------------------------------------------------------------------------- /lib/manx/nx/linalg.ex: -------------------------------------------------------------------------------- 1 | defmodule Manx.Linalg do 2 | import Beaver.MLIR.Sigils 3 | alias Beaver.MLIR 4 | alias Beaver.MLIR.Attribute 5 | @moduledoc false 6 | 7 | def expand_for_output(input_shape, output_shape) 8 | when tuple_size(output_shape) >= tuple_size(input_shape) do 9 | output_rank = tuple_size(output_shape) 10 | rank = tuple_size(input_shape) 11 | expanded = List.duplicate(1, output_rank - rank) ++ Tuple.to_list(input_shape) 12 | List.to_tuple(expanded) 13 | end 14 | 15 | defp gen_identity(shape), do: &MLIR.CAPI.mlirAffineMapMultiDimIdentityGet(&1, tuple_size(shape)) 16 | 17 | defp gen_broadcast_minor_identity(in_shape, out_shape) do 18 | rank = tuple_size(out_shape) 19 | rank_diff = rank - tuple_size(in_shape) 20 | 21 | zipped = 22 | in_shape 23 | |> expand_for_output(out_shape) 24 | |> Tuple.to_list() 25 | |> Enum.zip(Tuple.to_list(out_shape)) 26 | 27 | exprs = 28 | for {{in_dim, out_dim}, index} <- zipped |> Enum.with_index(), index >= rank_diff do 29 | case {in_dim, out_dim} do 30 | {1, out_dim} when out_dim != 1 -> 31 | 0 32 | 33 | _ -> 34 | MLIR.AffineMap.dim(index) 35 | end 36 | end 37 | 38 | MLIR.AffineMap.create(rank, 0, exprs) 39 | end 40 | 41 | defp maps_to_attr(maps) do 42 | maps 43 | |> Enum.map(&MLIR.Attribute.affine_map/1) 44 | |> Attribute.array() 45 | end 46 | 47 | # unary, always identity 48 | defp do_gen_indexing_maps(shape, shape) do 49 | do_gen_indexing_maps([shape], shape) 50 | end 51 | 52 | defp do_gen_indexing_maps([shape], shape) do 53 | gen_identity(shape) 54 | |> List.duplicate(2) 55 | end 56 | 57 | # binary+, might broadcast 58 | defp do_gen_indexing_maps(input_shapes, out_shape) 59 | when is_list(input_shapes) and length(input_shapes) > 1 do 60 | Enum.map(input_shapes, &gen_broadcast_minor_identity(&1, out_shape)) ++ 61 | [gen_identity(out_shape)] 62 | end 63 | 64 | def gen_indexing_maps(input_shapes, out_shape) do 65 | do_gen_indexing_maps(input_shapes, out_shape) |> maps_to_attr 66 | end 67 | 68 | def gen_indexing_maps(out_shape) do 69 | [gen_identity(out_shape)] |> maps_to_attr 70 | end 71 | 72 | def gen_iterator_types({}, {}) do 73 | ~a{[]} 74 | end 75 | 76 | def gen_iterator_types({_}, {_}) do 77 | ~a{[#linalg.iterator_type]} 78 | end 79 | 80 | def gen_iterator_types(input, output) when input == output do 81 | case tuple_size(input) do 82 | 1 -> 83 | ~a{[#linalg.iterator_type]} 84 | 85 | 2 -> 86 | ~a{[#linalg.iterator_type, #linalg.iterator_type]} 87 | end 88 | end 89 | 90 | def gen_iterator_types({}, {}, _output) do 91 | ~a{[]} 92 | end 93 | 94 | def gen_iterator_types(input1, _input2, output) do 95 | input1 = expand_for_output(input1, output) 96 | 97 | case tuple_size(input1) do 98 | 1 -> 99 | ~a{[#linalg.iterator_type]} 100 | 101 | 2 -> 102 | ~a{[#linalg.iterator_type, #linalg.iterator_type]} 103 | end 104 | end 105 | 106 | def gen_iterator_types(output) do 107 | for _ <- output |> Tuple.to_list() do 108 | ~a{#linalg.iterator_type} 109 | end 110 | |> Attribute.array() 111 | end 112 | end 113 | -------------------------------------------------------------------------------- /mix.lock: -------------------------------------------------------------------------------- 1 | %{ 2 | "beaver": {:hex, :beaver, "0.3.5", "c5847ca24a56f2ac0bbd311f3b4b219f26b462de791fe6eb75f3b44a9b9ce029", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kinda, "~> 0.8.1", [hex: :kinda, repo: "hexpm", optional: false]}, {:llvm_config, "~> 0.1.0", [hex: :llvm_config, repo: "hexpm", optional: false]}], "hexpm", "a189c7fd23370f86ba5c7bb1c75e53662ef517f87ef7066d92f55b8a9bca80b0"}, 3 | "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, 4 | "earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"}, 5 | "elixir_make": {:hex, :elixir_make, "0.8.3", "d38d7ee1578d722d89b4d452a3e36bcfdc644c618f0d063b874661876e708683", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "5c99a18571a756d4af7a4d89ca75c28ac899e6103af6f223982f09ce44942cc9"}, 6 | "ex_doc": {:hex, :ex_doc, "0.31.1", "8a2355ac42b1cc7b2379da9e40243f2670143721dd50748bf6c3b1184dae2089", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.1", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "3178c3a407c557d8343479e1ff117a96fd31bafe52a039079593fb0524ef61b0"}, 7 | "kinda": {:hex, :kinda, "0.8.1", "89be40b4d07369db18f19f076d10082344e75566b217caae0cf15f9aad148e1d", [:mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "d33ac401b3044f82af49d8e0d2dbf40235de224c842298a2a535e1878a344221"}, 8 | "llvm_config": {:hex, :llvm_config, "0.1.1", "2c4a1e16c51d18528014c35783bc7779cb8b8b0d0ed2a5f3b4d35709a819e492", [:mix], [], "hexpm", "026376818206f1ff91f48946aa0272cac2bae2d00e81916d1275b3f9ca72a7ea"}, 9 | "makeup": {:hex, :makeup, "1.1.1", "fa0bc768698053b2b3869fa8a62616501ff9d11a562f3ce39580d60860c3a55e", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "5dc62fbdd0de44de194898b6710692490be74baa02d9d108bc29f007783b0b48"}, 10 | "makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"}, 11 | "makeup_erlang": {:hex, :makeup_erlang, "0.1.5", "e0ff5a7c708dda34311f7522a8758e23bfcd7d8d8068dc312b5eb41c6fd76eba", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "94d2e986428585a21516d7d7149781480013c56e30c6a233534bedf38867a59a"}, 12 | "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, 13 | "nx": {:hex, :nx, "0.7.1", "5f6376e3d18408116e8a84b8f4ac851fb07dfe61764a5410ebf0b5dcb69c1b7e", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "e3ddd6a3f2a9bac79c67b3933368c25bb5ec814a883fc68aba8fd8a236751777"}, 14 | "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, 15 | } 16 | -------------------------------------------------------------------------------- /lib/manx/nx/slice.ex: -------------------------------------------------------------------------------- 1 | defmodule Manx.Slice do 2 | @moduledoc false 3 | alias Manx.Defn.Env 4 | use Beaver 5 | alias Beaver.MLIR 6 | import Beaver, only: :macros 7 | require Beaver.MLIR 8 | alias MLIR.{Type, Attribute} 9 | alias MLIR.Dialect.{TOSA, Arith, Tensor} 10 | 11 | def static_slice( 12 | %Env{block: block, ctx: ctx, gen_op: gen_op, gen_type: gen_type} = env, 13 | %Nx.Tensor{ 14 | data: %Nx.Defn.Expr{ 15 | op: :slice, 16 | args: [tensor, start_indices, lengths, strides] 17 | } 18 | } = t 19 | ) do 20 | mlir block: block, ctx: ctx do 21 | input_value = gen_op.(env, tensor) 22 | 23 | sizes = 24 | for {start, length, stride} <- Enum.zip([start_indices, lengths, strides]) do 25 | limit = start + length 26 | Integer.floor_div(limit - 1 - start + stride, stride) 27 | end 28 | |> Attribute.dense_array(Beaver.Native.I64) 29 | 30 | offsets = Attribute.dense_array(start_indices, Beaver.Native.I64) 31 | 32 | if Enum.all?(strides, &Kernel.==(&1, 1)) do 33 | TOSA.slice(input_value, start: offsets, size: sizes) >>> gen_type.(t) 34 | else 35 | Tensor.extract_slice(input_value, 36 | static_offsets: offsets, 37 | static_sizes: sizes, 38 | static_strides: Attribute.dense_array(strides, Beaver.Native.I64), 39 | operand_segment_sizes: ODS.operand_segment_sizes([1, 0, 0, 0]) 40 | ) >>> gen_type.(t) 41 | end 42 | end 43 | end 44 | 45 | def dynamic_slice( 46 | %Env{block: block, ctx: ctx, gen_op: gen_op, gen_type: gen_type} = env, 47 | %Nx.Tensor{ 48 | data: %Nx.Defn.Expr{ 49 | op: :slice, 50 | args: [tensor, start_indices, lengths, strides] 51 | } 52 | } = t 53 | ) do 54 | mlir block: block, ctx: ctx do 55 | input_value = gen_op.(env, tensor) 56 | 57 | start_indices = 58 | for {{start, length}, index} <- Enum.zip([start_indices, lengths]) |> Enum.with_index() do 59 | start_value = gen_op.(env, start) 60 | extracted = Tensor.extract(start_value) >>> gen_type.(start.type) 61 | 62 | start_index = 63 | case start.type do 64 | {:s, _} -> 65 | Arith.index_castui(extracted) >>> Type.index() 66 | 67 | {:f, _} -> 68 | Arith.index_cast(extracted) >>> Type.index() 69 | end 70 | 71 | mn = Arith.constant(value: Attribute.index(0)) >>> Type.index() 72 | dim = Arith.constant(value: Attribute.index(index)) >>> Type.index() 73 | mx = Tensor.dim(input_value, dim) >>> Type.index() 74 | size = Arith.constant(value: Attribute.index(length)) >>> Type.index() 75 | mx = Arith.subi(mx, size) >>> Type.index() 76 | start_index = Arith.maxsi(start_index, mn) >>> Type.index() 77 | Arith.minsi(start_index, mx) >>> Type.index() 78 | end 79 | 80 | sizes = lengths |> Attribute.dense_array(Beaver.Native.I64) 81 | strides = strides |> Attribute.dense_array(Beaver.Native.I64) 82 | 83 | offsets = 84 | Attribute.dense_array( 85 | List.duplicate( 86 | Beaver.MLIR.CAPI.mlirShapedTypeGetDynamicStrideOrOffset(), 87 | length(lengths) 88 | ), 89 | Beaver.Native.I64 90 | ) 91 | 92 | Tensor.extract_slice( 93 | input_value, 94 | start_indices, 95 | static_offsets: offsets, 96 | static_sizes: sizes, 97 | static_strides: strides, 98 | operand_segment_sizes: ODS.operand_segment_sizes([1, length(lengths), 0, 0]) 99 | ) >>> 100 | gen_type.(t) 101 | end 102 | end 103 | end 104 | -------------------------------------------------------------------------------- /lib/manx/nx/interoperability.ex: -------------------------------------------------------------------------------- 1 | defmodule Manx.Nx.Interoperability do 2 | @moduledoc """ 3 | Functions for interoperability between Elixir/NX and LLVM/MLIR. For instance the data transfer between NX tensor and MemRef. 4 | """ 5 | 6 | alias Beaver.MLIR 7 | 8 | @doc """ 9 | - If it is a tensor, return a memref 10 | - If it is a tuple, recursively pack them into one struct. 11 | """ 12 | def memref_from_tensor(f) when is_function(f), do: f.() |> memref_from_tensor 13 | def memref_from_tensor(%Nx.Tensor{data: %Manx{memory: memory}}), do: memory 14 | 15 | def memref_from_tensor( 16 | %Nx.Tensor{ 17 | data: %Nx.BinaryBackend{state: binary} 18 | } = tensor 19 | ) do 20 | Manx.from_binary(tensor, binary, []) |> memref_from_tensor 21 | end 22 | 23 | def memref_from_tensor(%Nx.Tensor{shape: shape, data: %Nx.TemplateBackend{}}) do 24 | # TODO: generate a magical deadbeef pointer for this 25 | Beaver.Native.Memory.new(nil, sizes: shape |> Tuple.to_list(), type: Beaver.Native.F32) 26 | end 27 | 28 | def memref_from_tensor({}) do 29 | raise "can't extract memref from an empty tuple" 30 | end 31 | 32 | def memref_from_tensor(tuple) when is_tuple(tuple) do 33 | mems = 34 | Tuple.to_list(tuple) 35 | |> Enum.map(&memref_from_tensor/1) 36 | 37 | # TODO: support array of memref descriptor of different kinds 38 | first = mems |> List.first() 39 | kind = first.descriptor.descriptor_kind 40 | 41 | refs = 42 | mems 43 | |> Enum.map(fn %Beaver.Native.Memory{descriptor: %Beaver.Native.Memory.Descriptor{ref: ref}} -> 44 | ref 45 | end) 46 | 47 | # TODO: add a raw NIF beaver_raw_create_heterogeneous_array, using union maybe 48 | mut_array = Beaver.Native.forward(kind, :mut_array, [refs]) 49 | 50 | struct!(Beaver.Native.Array, 51 | element_kind: kind, 52 | ref: mut_array 53 | ) 54 | end 55 | 56 | @doc """ 57 | - If it is a tensor, return a memref 58 | - If it is a tuple, recursively unpack each member from the nested struct. 59 | """ 60 | def populate_tensor_from_memref(%Nx.Tensor{data: %Manx{}} = tensor, memory) do 61 | %{tensor | data: %Manx{memory: memory}} 62 | end 63 | 64 | def populate_tensor_from_memref( 65 | tuple, 66 | %Beaver.Native.Array{element_kind: element_kind} = nested_struct 67 | ) 68 | when is_tuple(tuple) do 69 | nested_struct_ptr = nested_struct |> Beaver.Native.Memory.descriptor_ptr() 70 | 71 | {tensors, _offset} = 72 | Enum.reduce(tuple |> Tuple.to_list(), {[], 0}, fn x, {acc, offset} -> 73 | {ref, size} = 74 | Beaver.Native.OpaquePtr.to_resource( 75 | element_kind, 76 | nested_struct_ptr, 77 | offset 78 | ) 79 | 80 | mem = %Beaver.Native.Memory{ 81 | descriptor: %Beaver.Native.Memory.Descriptor{ 82 | ref: ref, 83 | descriptor_kind: element_kind 84 | } 85 | } 86 | 87 | {acc ++ [populate_tensor_from_memref(x, mem)], offset + size} 88 | end) 89 | 90 | tensors |> List.to_tuple() 91 | end 92 | 93 | def loc_from_stack_trace({:current_stacktrace, frames}, ctx) do 94 | loc_from_stack_trace(frames, ctx) 95 | end 96 | 97 | def loc_from_stack_trace(frames, ctx) do 98 | stacktrace_locs = 99 | for {_, _, _, f} <- frames do 100 | f 101 | end 102 | |> Stream.map(&[name: to_string(&1[:file]), line: &1[:line], ctx: ctx]) 103 | |> Stream.reject(&String.starts_with?(&1[:name], "lib/process.ex")) 104 | |> Stream.map(&MLIR.Location.file(&1)) 105 | |> Enum.to_list() 106 | 107 | MLIR.CAPI.mlirLocationFusedGet( 108 | ctx, 109 | length(stacktrace_locs), 110 | Beaver.Native.array(stacktrace_locs, MLIR.Location), 111 | MLIR.Attribute.null() 112 | ) 113 | end 114 | end 115 | -------------------------------------------------------------------------------- /lib/manx/lowering/cpu.ex: -------------------------------------------------------------------------------- 1 | defmodule Manx.Lowering.CPU do 2 | alias Beaver.MLIR 3 | import MLIR.{Transforms, Conversion} 4 | 5 | defp one_shot(op) do 6 | op 7 | |> MLIR.Operation.verify!() 8 | |> canonicalize 9 | |> MLIR.Pass.Composer.nested("func.func", "tosa-make-broadcastable") 10 | |> MLIR.Pass.Composer.nested("func.func", "tosa-layerwise-constant-fold") 11 | |> cse 12 | |> tosa_to_scf 13 | |> tosa_to_arith 14 | |> tosa_to_tensor() 15 | |> convert_tensor_to_linalg() 16 | |> MLIR.Pass.Composer.nested("func.func", [ 17 | tosa_to_linalg_named(), 18 | tosa_to_linalg(), 19 | "empty-tensor-to-alloc-tensor", 20 | linalg_fuse_elementwise_ops(), 21 | "tosa-layerwise-constant-fold", 22 | lower_affine() 23 | ]) 24 | |> MLIR.Pass.Composer.nested("func.func", "empty-tensor-to-alloc-tensor") 25 | |> MLIR.Pass.Composer.append("one-shot-bufferize{allow-return-allocs create-deallocs=false}") 26 | |> MLIR.Pass.Composer.append("func-bufferize,arith-bufferize") 27 | |> MLIR.Pass.Composer.nested("func.func", [ 28 | convert_linalg_to_loops(), 29 | convert_scf_to_cf(), 30 | "arith-expand", 31 | convert_arith_to_llvm(), 32 | convert_math_to_llvm() 33 | ]) 34 | |> MLIR.Pass.Composer.nested("func.func", "llvm-request-c-wrappers") 35 | |> convert_math_to_libm 36 | |> convert_complex_to_standard() 37 | |> convert_vector_to_llvm 38 | |> MLIR.Pass.Composer.nested("func.func", "expand-strided-metadata,memref-expand") 39 | |> MLIR.Pass.Composer.append("expand-strided-metadata") 40 | |> MLIR.Pass.Composer.append("finalize-memref-to-llvm") 41 | |> convert_complex_to_llvm() 42 | |> convert_func_to_llvm 43 | |> reconcile_unrealized_casts 44 | |> MLIR.Pass.Composer.run(print: Manx.Flags.print_ir?(), debug: false) 45 | end 46 | 47 | defp do_lower(op) do 48 | op 49 | |> MLIR.Operation.verify!() 50 | |> MLIR.Pass.Composer.nested("func.func", "tosa-make-broadcastable") 51 | |> MLIR.Pass.Composer.nested("func.func", "tosa-layerwise-constant-fold") 52 | |> cse 53 | |> tosa_to_scf 54 | |> tosa_to_arith 55 | |> tosa_to_tensor() 56 | |> convert_tensor_to_linalg() 57 | |> MLIR.Pass.Composer.nested("func.func", [ 58 | tosa_to_linalg_named(), 59 | tosa_to_linalg(), 60 | linalg_fuse_elementwise_ops(), 61 | "tosa-layerwise-constant-fold", 62 | linalg_bufferize(), 63 | convert_linalg_to_loops(), 64 | "affine-expand-index-ops", 65 | lower_affine(), 66 | convert_math_to_llvm(), 67 | convert_arith_to_llvm(), 68 | convert_scf_to_cf(), 69 | "arith-expand" 70 | ]) 71 | |> MLIR.Pass.Composer.nested("func.func", "empty-tensor-to-alloc-tensor") 72 | |> MLIR.Pass.Composer.append("arith-bufferize,func-bufferize") 73 | |> MLIR.Pass.Composer.nested("func.func", "tensor-bufferize") 74 | |> MLIR.Pass.Composer.nested("func.func", "llvm-request-c-wrappers") 75 | |> MLIR.Pass.Composer.nested("func.func", "expand-strided-metadata,memref-expand") 76 | |> convert_math_to_libm 77 | |> convert_complex_to_standard() 78 | |> convert_vector_to_llvm 79 | |> MLIR.Pass.Composer.append("expand-strided-metadata") 80 | |> MLIR.Pass.Composer.append("finalize-memref-to-llvm") 81 | |> convert_complex_to_llvm() 82 | |> convert_func_to_llvm 83 | |> reconcile_unrealized_casts 84 | |> MLIR.Pass.Composer.run(print: Manx.Flags.print_ir?()) 85 | end 86 | 87 | @doc """ 88 | Run passes to compile IR generated from Nx expressions, mostly in TOSA and some LinAlg. The results should be in LLVM. 89 | """ 90 | def lower(op, opts \\ []) do 91 | one_shot = opts[:one_shot] || false 92 | # canonicalize it first to fold operations of index but result type fixed in Nx expression 93 | case op 94 | |> canonicalize 95 | |> MLIR.Pass.Composer.run(print: Manx.Flags.print_ir?()) do 96 | {:ok, op} -> 97 | if one_shot do 98 | one_shot(op) 99 | else 100 | do_lower(op) 101 | end 102 | 103 | result -> 104 | result 105 | end 106 | end 107 | end 108 | -------------------------------------------------------------------------------- /lib/manx/nx/batcher.ex: -------------------------------------------------------------------------------- 1 | defmodule Manx.Nx.Batcher do 2 | alias Beaver.MLIR 3 | require Beaver.MLIR 4 | alias MLIR.Attribute 5 | import MLIR.Sigils 6 | defstruct [:tensor, :contract_axes, :batch_axes] 7 | 8 | def from_args([ 9 | %Nx.Tensor{} = a, 10 | contract_axes1, 11 | batch_axes1, 12 | %Nx.Tensor{} = b, 13 | contract_axes2, 14 | batch_axes2 15 | ]) do 16 | batched_a = %Manx.Nx.Batcher{ 17 | tensor: a, 18 | contract_axes: contract_axes1, 19 | batch_axes: batch_axes1 20 | } 21 | 22 | batched_b = %Manx.Nx.Batcher{ 23 | tensor: b, 24 | contract_axes: contract_axes2, 25 | batch_axes: batch_axes2 26 | } 27 | 28 | {batched_a, batched_b} 29 | end 30 | 31 | # [[CONTRACT DIMS]...[BATCH DIMS]...[OUTER DIMS]...] 32 | defmodule AffineMapAcc do 33 | defstruct exprs: [], contract_index: 0, batch_index: nil, outer_index: nil 34 | end 35 | 36 | defp gen_input_affine_map( 37 | %__MODULE__{ 38 | tensor: %{shape: shape}, 39 | batch_axes: batch_axes, 40 | contract_axes: contract_axes 41 | }, 42 | {maps, outer_index}, 43 | output_rank: output_rank 44 | ) do 45 | rank = tuple_size(shape) 46 | 47 | acc = 48 | Enum.reduce( 49 | Range.new(0, rank - 1, 1), 50 | %AffineMapAcc{outer_index: outer_index, batch_index: length(contract_axes)}, 51 | fn dim, acc -> 52 | case {dim in contract_axes, dim in batch_axes} do 53 | {true, false} -> 54 | %{ 55 | acc 56 | | exprs: acc.exprs ++ [MLIR.AffineMap.dim(acc.contract_index)], 57 | contract_index: acc.contract_index + 1 58 | } 59 | 60 | {false, true} -> 61 | %{ 62 | acc 63 | | exprs: acc.exprs ++ [MLIR.AffineMap.dim(acc.batch_index)], 64 | batch_index: acc.batch_index + 1 65 | } 66 | 67 | {false, false} -> 68 | %{ 69 | acc 70 | | exprs: acc.exprs ++ [MLIR.AffineMap.dim(acc.outer_index)], 71 | outer_index: acc.outer_index + 1 72 | } 73 | end 74 | end 75 | ) 76 | 77 | {maps ++ [MLIR.AffineMap.create(output_rank, 0, acc.exprs)], acc.outer_index} 78 | end 79 | 80 | defp gen_output_affine_map(contract_axes_length, rank) when is_integer(rank) do 81 | out_rank = contract_axes_length + rank 82 | 83 | exprs = 84 | for dim <- Range.new(0, out_rank - 1, 1), dim >= contract_axes_length do 85 | MLIR.AffineMap.dim(dim) 86 | end 87 | 88 | MLIR.AffineMap.create(out_rank, 0, exprs) 89 | end 90 | 91 | def gen_indexing_maps( 92 | %__MODULE__{batch_axes: batch_axes_a, contract_axes: contract_axes_a} = a, 93 | %__MODULE__{batch_axes: batch_axes_b, contract_axes: contract_axes_b} = b, 94 | c 95 | ) 96 | when length(contract_axes_a) == length(contract_axes_b) and 97 | length(batch_axes_a) == length(batch_axes_b) do 98 | contract_axes_length = length(contract_axes_a) 99 | batch_axes_length = length(batch_axes_b) 100 | 101 | output_rank = tuple_size(c.shape) + contract_axes_length 102 | 103 | {input_maps, _} = 104 | Enum.reduce( 105 | [a, b], 106 | {[], contract_axes_length + batch_axes_length}, 107 | &gen_input_affine_map(&1, &2, output_rank: output_rank) 108 | ) 109 | 110 | Enum.concat(input_maps, [ 111 | gen_output_affine_map(contract_axes_length, tuple_size(c.shape)) 112 | ]) 113 | |> Enum.map(&MLIR.Attribute.affine_map/1) 114 | |> Attribute.array() 115 | end 116 | 117 | def gen_iterator_types( 118 | %__MODULE__{contract_axes: contract_axes_a}, 119 | %__MODULE__{contract_axes: contract_axes_b}, 120 | c 121 | ) 122 | when length(contract_axes_a) == length(contract_axes_b) do 123 | contract_iterator_types = 124 | ~a{#linalg.iterator_type} 125 | |> List.duplicate(length(contract_axes_a)) 126 | 127 | outer_iterator_types = 128 | ~a{#linalg.iterator_type} 129 | |> List.duplicate(tuple_size(c.shape)) 130 | 131 | Enum.concat(contract_iterator_types, outer_iterator_types) 132 | |> Attribute.array() 133 | end 134 | end 135 | -------------------------------------------------------------------------------- /lib/manx/nx/compiler.ex: -------------------------------------------------------------------------------- 1 | defmodule Manx.Compiler do 2 | use Beaver 3 | alias Beaver.MLIR 4 | import MLIR.Sigils 5 | import Beaver, only: :macros 6 | require Beaver.MLIR 7 | alias Beaver.MLIR.Dialect.{Func} 8 | require Func 9 | @behaviour Nx.Defn.Compiler 10 | 11 | defp eval_arg(f) when is_function(f), do: f.() 12 | defp eval_arg(list) when is_list(list), do: Enum.map(list, &eval_arg/1) 13 | defp eval_arg(a), do: a 14 | 15 | defp runtime_libs() do 16 | case LLVMConfig.lib_dir() do 17 | {:ok, llvm_lib_dir} -> 18 | [ 19 | llvm_lib_dir |> Path.join("libmlir_c_runner_utils.dylib") 20 | ] 21 | 22 | _ -> 23 | [] 24 | end 25 | end 26 | 27 | defp vulkan_runtime_libs() do 28 | case LLVMConfig.lib_dir() do 29 | {:ok, llvm_lib_dir} -> 30 | [ 31 | llvm_lib_dir |> Path.join("libvulkan-runtime-wrappers.dylib") 32 | ] 33 | 34 | _ -> 35 | [] 36 | end 37 | end 38 | 39 | # Invoke MLIR JIT with Nx tensors. If there are tuples their memrefs will be packed into a single C struct. 40 | defp invoke(return, args, jit, symbol) do 41 | import Manx.Nx.Interoperability 42 | # pack the tensor tuples into a C struct 43 | jit_args = 44 | [return_struct | _] = 45 | [return | args] 46 | |> Enum.map(&memref_from_tensor/1) 47 | 48 | if List.improper?(jit_args), do: raise("jit arguments is not a proper list") 49 | 50 | MLIR.ExecutionEngine.invoke!( 51 | jit, 52 | symbol, 53 | jit_args |> Enum.map(&Beaver.Native.Memory.descriptor_ptr/1) 54 | ) 55 | 56 | # unpack the C struct into tensor tuples 57 | populate_tensor_from_memref(return, return_struct) 58 | |> Manx.add_allocated_memory() 59 | end 60 | 61 | defp module_attrs([tensor | _]), do: module_attrs(tensor) 62 | 63 | defp module_attrs(%Nx.Tensor{data: %Manx{device: :vulkan}}) do 64 | [ 65 | "spirv.target_env": 66 | ~a"#spirv.target_env<#spirv.vce, #spirv.resource_limits<>>" 67 | ] 68 | end 69 | 70 | defp module_attrs(_), do: [] 71 | 72 | defp lower(ir, []), do: {Manx.Lowering.CPU.lower(ir), runtime_libs()} 73 | defp lower(ir, [tensor | _]), do: lower(ir, tensor) 74 | 75 | defp lower(ir, %Nx.Tensor{data: %Nx.BinaryBackend{}}) do 76 | {Manx.Lowering.CPU.lower(ir), runtime_libs()} 77 | end 78 | 79 | defp lower(ir, %Nx.Tensor{data: %Manx{device: :host}}) do 80 | {Manx.Lowering.CPU.lower(ir), runtime_libs()} 81 | end 82 | 83 | defp lower(ir, %Nx.Tensor{data: %Manx{device: :vulkan}}) do 84 | {Manx.Lowering.Vulkan.lower(ir), vulkan_runtime_libs()} 85 | end 86 | 87 | @doc false 88 | @impl Nx.Defn.Compiler 89 | def __jit__(key, vars, fun, args_list, options) do 90 | __compile__(key, vars, fun, options).(args_list) 91 | end 92 | 93 | @doc false 94 | @impl Nx.Defn.Compiler 95 | def __compile__(key, vars, fun, _options) do 96 | # call fun to generate expression tree 97 | tree = fun.(vars) 98 | info = Function.info(key) 99 | uniq = info |> Keyword.get(:uniq) 100 | module = info |> Keyword.get(:module) 101 | name = info |> Keyword.get(:name) 102 | symbol = Module.concat([module, name, "#{uniq}"]) |> Atom.to_string() 103 | 104 | # generate ir 105 | entry_types = 106 | Enum.reduce(vars, [], fn 107 | tuple, acc when is_tuple(tuple) -> 108 | acc ++ Enum.map(Tuple.to_list(tuple), &Manx.Defn.gen_type/1) 109 | 110 | t, acc -> 111 | acc ++ [Manx.Defn.gen_type(t)] 112 | end) 113 | 114 | fn args_list -> 115 | args_list = args_list |> Enum.map(&eval_arg/1) 116 | 117 | for args <- args_list do 118 | ctx = MLIR.Context.create() 119 | Beaver.Diagnostic.attach(ctx) 120 | 121 | ir = 122 | mlir ctx: ctx do 123 | module(module_attrs(args)) do 124 | function_type = 125 | Type.function( 126 | entry_types, 127 | Manx.Defn.gen_root_types(tree) 128 | ) 129 | 130 | stacktrace_loc = 131 | Process.info(self(), :current_stacktrace) 132 | |> Manx.Nx.Interoperability.loc_from_stack_trace(ctx) 133 | 134 | Func.func manx_main( 135 | sym_name: "\"#{symbol}\"", 136 | function_type: function_type, 137 | loc: stacktrace_loc 138 | ) do 139 | region do 140 | locs = List.duplicate(stacktrace_loc, length(entry_types)) 141 | 142 | entry = 143 | MLIR.Block.create( 144 | entry_types |> Enum.map(&Beaver.Deferred.create(&1, Beaver.Env.context())), 145 | locs |> Enum.map(&Beaver.Deferred.create(&1, Beaver.Env.context())) 146 | ) 147 | 148 | mlir block: entry do 149 | case Manx.Defn.gen_op(%Manx.Defn.Env{block: entry, ctx: ctx}, tree) do 150 | ret = %Beaver.MLIR.Value{} -> 151 | Func.return(ret, loc: stacktrace_loc) >>> [] 152 | 153 | tuple_ret when is_tuple(tuple_ret) -> 154 | Func.return(Tuple.to_list(tuple_ret), loc: stacktrace_loc) >>> [] 155 | end 156 | end 157 | 158 | Beaver.Env.region() 159 | |> Beaver.MLIR.CAPI.mlirRegionAppendOwnedBlock(entry) 160 | end 161 | end 162 | end 163 | end 164 | 165 | case lower(ir, args) do 166 | {{:ok, mod}, libs} -> 167 | jit = 168 | mod 169 | |> MLIR.ExecutionEngine.create!(shared_lib_paths: libs) 170 | 171 | # invoke jit and setting return for tree 172 | tree_return = 173 | tree 174 | |> Manx.tensor_of_null_memref() 175 | |> invoke(args, jit, symbol) 176 | 177 | MLIR.Module.destroy(mod) 178 | MLIR.Context.destroy(ctx) 179 | tree_return 180 | 181 | {{:error, msg}, _} -> 182 | MLIR.Context.destroy(ctx) 183 | raise msg 184 | end 185 | end 186 | end 187 | end 188 | 189 | @doc false 190 | @impl Nx.Defn.Compiler 191 | def __stream__( 192 | _key, 193 | _input, 194 | _acc, 195 | _vars, 196 | _fun, 197 | _args_list, 198 | _opts 199 | ), 200 | do: raise("not implemented") 201 | 202 | @doc false 203 | @impl Nx.Defn.Compiler 204 | def __to_backend__(_keyword), do: raise("not implemented") 205 | 206 | @doc false 207 | @impl Nx.Defn.Compiler 208 | def __partitions_options__(_keyword), do: raise("not implemented") 209 | end 210 | -------------------------------------------------------------------------------- /lib/manx/manx.ex: -------------------------------------------------------------------------------- 1 | defmodule Manx do 2 | @moduledoc """ 3 | `Manx` is a MLIR backend for the `Nx`. It mainly targets TOSA/Linalg dialect and will generate LLVM/CUDA/Vulkan code for different configurations. 4 | """ 5 | 6 | @enforce_keys [:memory] 7 | defstruct memory: nil, device: :host 8 | 9 | @behaviour Nx.Backend 10 | 11 | alias Nx.Tensor, as: T 12 | alias __MODULE__, as: B 13 | 14 | @impl Nx.Backend 15 | def init(keyword) do 16 | ctx = Beaver.MLIR.Context.create() 17 | Beaver.Diagnostic.attach(ctx) 18 | keyword ++ [ctx: ctx] 19 | end 20 | 21 | @impl Nx.Backend 22 | def constant(out, constant, backend_options) do 23 | binary_tensor = Nx.BinaryBackend.constant(out, constant, []) 24 | Nx.BinaryBackend.backend_transfer(binary_tensor, __MODULE__, backend_options) 25 | end 26 | 27 | @impl Nx.Backend 28 | def from_binary(%T{shape: shape, type: type} = tensor, binary, backend_options) do 29 | shape = Tuple.to_list(shape) 30 | device = Keyword.get(backend_options, :device, :host) 31 | 32 | memory = 33 | Beaver.Native.Memory.new( 34 | binary, 35 | sizes: shape, 36 | type: type 37 | ) 38 | 39 | memory |> Manx.MemrefAllocator.add() 40 | put_in(tensor.data, %B{memory: memory, device: device}) 41 | end 42 | 43 | @impl Nx.Backend 44 | def to_binary(%T{shape: _shape, data: %B{memory: memory}} = tensor, limit) do 45 | Beaver.Native.Memory.aligned(memory) 46 | |> Beaver.Native.OpaquePtr.to_binary(limit * div(element_size(tensor), 8)) 47 | end 48 | 49 | @impl Nx.Backend 50 | def inspect(%T{} = tensor, inspect_opts) do 51 | limit = if inspect_opts.limit == :infinity, do: :infinity, else: inspect_opts.limit + 1 52 | 53 | tensor 54 | |> to_binary(min(limit, Nx.size(tensor))) 55 | |> then(&Nx.Backend.inspect(tensor, &1, inspect_opts)) 56 | end 57 | 58 | defp element_size(%T{type: {_, size}}), do: size 59 | 60 | @impl Nx.Backend 61 | def backend_copy(tensor, Nx.Tensor, backend_options) do 62 | backend_copy(tensor, Nx.BinaryBackend, backend_options) 63 | end 64 | 65 | # TODO: Support direct transfers without going through Elixir 66 | def backend_copy( 67 | %T{shape: shape, data: %B{memory: memory}} = tensor, 68 | backend, 69 | backend_options 70 | ) do 71 | binary_len = Enum.reduce(Tuple.to_list(shape), 1, &*/2) * div(element_size(tensor), 8) 72 | 73 | backend.from_binary( 74 | tensor, 75 | Beaver.Native.Memory.aligned(memory) 76 | |> Beaver.Native.OpaquePtr.to_binary(binary_len), 77 | backend_options 78 | ) 79 | end 80 | 81 | @impl Nx.Backend 82 | def backend_transfer( 83 | %T{data: %B{memory: memory}} = tensor, 84 | backend, 85 | backend_options 86 | ) do 87 | if backend == __MODULE__ do 88 | # TODO: support tensor on device memory like CUDA 89 | tensor 90 | else 91 | tensor = backend_copy(tensor, backend, backend_options) 92 | 93 | with :ok <- Manx.MemrefAllocator.delete(memory) do 94 | tensor 95 | else 96 | :already_deallocated -> raise "called on deleted or donated buffer" 97 | end 98 | end 99 | end 100 | 101 | @impl Nx.Backend 102 | def backend_deallocate(%T{data: %B{memory: memory}}) do 103 | memory |> Manx.MemrefAllocator.delete() 104 | end 105 | 106 | @doc """ 107 | Create a new tensor of null ptr memref. This should be used as as the return tensor of JIT function. 108 | """ 109 | def tensor_of_null_memref(%T{shape: shape, type: _type} = tensor) do 110 | shape = Tuple.to_list(shape) 111 | 112 | memory = Beaver.Native.Memory.new(nil, sizes: shape, type: Beaver.Native.F32) 113 | 114 | put_in(tensor.data, %B{memory: memory}) 115 | end 116 | 117 | def tensor_of_null_memref(tuple) when is_tuple(tuple) do 118 | for t <- tuple |> Tuple.to_list() do 119 | tensor_of_null_memref(t) 120 | end 121 | |> List.to_tuple() 122 | end 123 | 124 | # TODO: check if argument is returned by JIT function 125 | @doc """ 126 | Add returned memref to the allocator. 127 | """ 128 | def add_allocated_memory(%T{data: %B{memory: memory}} = tensor) do 129 | memory |> Manx.MemrefAllocator.add() 130 | put_in(tensor.data.memory, memory) 131 | end 132 | 133 | def add_allocated_memory(tuple) when is_tuple(tuple) do 134 | for t <- Tuple.to_list(tuple) do 135 | add_allocated_memory(t) 136 | end 137 | |> List.to_tuple() 138 | end 139 | 140 | require Nx.Defn.Expr 141 | ## JIT callbacks 142 | 143 | @impl Nx.Backend 144 | def concatenate(out, tensors, axis) do 145 | out = Nx.to_template(out) 146 | 147 | expr_fun = fn tensors -> 148 | Nx.Defn.Expr.concatenate(out, Tuple.to_list(tensors), axis) 149 | end 150 | 151 | jit(expr_fun, [List.to_tuple(tensors)]) 152 | end 153 | 154 | @impl Nx.Backend 155 | def slice(out, tensor, start_indices, lengths, strides) do 156 | out = Nx.to_template(out) 157 | 158 | if Enum.all?(start_indices, &is_integer/1) do 159 | expr_fun = fn tensor -> 160 | Nx.Defn.Expr.slice(out, tensor, start_indices, lengths, strides) 161 | end 162 | 163 | jit(expr_fun, [tensor]) 164 | else 165 | expr_fun = fn tensor, start_indices -> 166 | Nx.Defn.Expr.slice(out, tensor, Tuple.to_list(start_indices), lengths, strides) 167 | end 168 | 169 | jit(expr_fun, [tensor, List.to_tuple(start_indices)]) 170 | end 171 | end 172 | 173 | @impl Nx.Backend 174 | def put_slice(out, tensor, start_indices, slice) do 175 | out = Nx.to_template(out) 176 | 177 | if Enum.all?(start_indices, &is_integer/1) do 178 | expr_fun = fn tensor, slice -> 179 | Nx.Defn.Expr.put_slice(out, tensor, start_indices, slice) 180 | end 181 | 182 | jit(expr_fun, [tensor, slice]) 183 | else 184 | expr_fun = fn tensor, start_indices, slice -> 185 | Nx.Defn.Expr.put_slice(out, tensor, Tuple.to_list(start_indices), slice) 186 | end 187 | 188 | jit(expr_fun, [tensor, List.to_tuple(start_indices), slice]) 189 | end 190 | end 191 | 192 | @impl Nx.Backend 193 | def optional(_name, args, fun) do 194 | # Here we take the leading tensor arguments and pass them as JIT arguments 195 | {tensors, rest} = Enum.split_while(args, &is_struct(&1, Nx.Tensor)) 196 | 197 | wrapper_fun = fn tensors -> 198 | tensors = Tuple.to_list(tensors) 199 | apply(fun, tensors ++ rest) 200 | end 201 | 202 | jit(wrapper_fun, [List.to_tuple(tensors)]) 203 | end 204 | 205 | binary_ops = 206 | [:add, :subtract, :multiply, :pow, :remainder, :divide, :atan2, :min, :max, :quotient] ++ 207 | [:bitwise_and, :bitwise_or, :bitwise_xor, :left_shift, :right_shift] ++ 208 | [:equal, :not_equal, :greater, :less, :greater_equal, :less_equal] ++ 209 | [:logical_and, :logical_or, :logical_xor] 210 | 211 | unary_ops = 212 | [:exp, :expm1, :log, :log1p, :sigmoid, :cos, :sin, :tan] ++ 213 | [:cosh, :sinh, :tanh, :acos, :asin, :atan, :acosh, :asinh, :atanh] ++ 214 | [:sqrt, :rsqrt, :cbrt, :is_nan, :is_infinity, :erf, :erfc, :erf_inv] ++ 215 | [:abs, :bitwise_not, :ceil, :conjugate, :floor, :negate, :round, :sign] ++ 216 | [:count_leading_zeros, :population_count, :real, :imag] 217 | 218 | callbacks = 219 | [ 220 | {:eye, [:backend_options], []}, 221 | {:iota, [:axis, :backend_options], []}, 222 | {:as_type, [:tensor], [:tensor]}, 223 | {:bitcast, [:tensor], [:tensor]}, 224 | {:reshape, [:tensor], [:tensor]}, 225 | {:squeeze, [:tensor, :axes], [:tensor]}, 226 | {:broadcast, [:tensor, :shape, :axes], [:tensor]}, 227 | {:transpose, [:tensor, :axes], [:tensor]}, 228 | {:pad, [:tensor, :pad_value, :padding_config], [:tensor, :pad_value]}, 229 | {:reverse, [:tensor, :axes], [:tensor]}, 230 | {:dot, [:left, :c1, :b1, :right, :c2, :b2], [:left, :right]}, 231 | {:clip, [:tensor, :min, :max], [:tensor, :min, :max]}, 232 | {:take, [:tensor, :indices, :axis], [:tensor, :indices]}, 233 | {:take_along_axis, [:tensor, :indices, :axis], [:tensor, :indices]}, 234 | {:gather, [:input, :indices, :opts], [:input, :indices]}, 235 | {:select, [:pred, :on_true, :on_false], [:pred, :on_true, :on_false]}, 236 | {:conv, [:tensor, :kernel, :opts], [:tensor, :kernel]}, 237 | {:all, [:tensor, :opts], [:tensor]}, 238 | {:any, [:tensor, :opts], [:tensor]}, 239 | {:sum, [:tensor, :opts], [:tensor]}, 240 | {:product, [:tensor, :opts], [:tensor]}, 241 | {:reduce_max, [:tensor, :opts], [:tensor]}, 242 | {:reduce_min, [:tensor, :opts], [:tensor]}, 243 | {:argmax, [:tensor, :opts], [:tensor]}, 244 | {:argmin, [:tensor, :opts], [:tensor]}, 245 | {:reduce, [:tensor, :acc, :opts, :fun], [:tensor, :acc]}, 246 | {:window_reduce, [:tensor, :acc, :shape, :opts, :fun], [:tensor, :acc]}, 247 | {:window_sum, [:tensor, :shape, :opts], [:tensor]}, 248 | {:window_product, [:tensor, :shape, :opts], [:tensor]}, 249 | {:window_max, [:tensor, :shape, :opts], [:tensor]}, 250 | {:window_min, [:tensor, :shape, :opts], [:tensor]}, 251 | {:map, [:tensor, :opts, :fun], [:tensor]}, 252 | {:sort, [:tensor, :opts], [:tensor]}, 253 | {:argsort, [:tensor, :opts], [:tensor]}, 254 | {:window_scatter_max, [:tensor, :source, :init_value, :window_dims, :opts], 255 | [:tensor, :source, :init_value]}, 256 | {:window_scatter_min, [:tensor, :source, :init_value, :window_dims, :opts], 257 | [:tensor, :source, :init_value]}, 258 | {:indexed_add, [:tensor, :indices, :updates, :opts], [:tensor, :indices, :updates]}, 259 | {:indexed_put, [:tensor, :indices, :updates, :opts], [:tensor, :indices, :updates]}, 260 | {:lu, [:tensor, :opts], [:tensor]}, 261 | {:triangular_solve, [:a, :b, :opts], [:a, :b]}, 262 | {:fft, [:tensor, :opts], [:tensor]}, 263 | {:ifft, [:tensor, :opts], [:tensor]} 264 | ] ++ 265 | for(op <- binary_ops, do: {op, [:left, :right], [:left, :right]}) ++ 266 | for(op <- unary_ops, do: {op, [:tensor], [:tensor]}) 267 | 268 | for {name, args, tensor_args} <- callbacks do 269 | args = Enum.map(args, &Macro.var(&1, __MODULE__)) 270 | tensor_args = Enum.map(tensor_args, &Macro.var(&1, __MODULE__)) 271 | 272 | @impl Nx.Backend 273 | def unquote(name)(out, unquote_splicing(args)) do 274 | out = Nx.to_template(out) 275 | 276 | expr_fun = fn unquote_splicing(tensor_args) -> 277 | Nx.Defn.Expr.unquote(name)(out, unquote_splicing(args)) 278 | end 279 | 280 | jit(expr_fun, [unquote_splicing(tensor_args)]) 281 | end 282 | end 283 | 284 | defp jit(function, args, options \\ []) do 285 | Nx.Defn.jit_apply(function, args, Keyword.put(options, :compiler, __MODULE__)) 286 | end 287 | 288 | @impl Nx.Backend 289 | def to_batched(_out, _tensor, _keyword), do: raise("not implemented") 290 | end 291 | -------------------------------------------------------------------------------- /test/beaver/defn/expr_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Manx.ExprTest do 2 | use ExUnit.Case, async: true 3 | import Nx.Defn 4 | import Manx.Assert 5 | 6 | @moduletag :nx 7 | setup do 8 | Nx.Defn.default_options(compiler: Manx.Compiler) 9 | :ok 10 | end 11 | 12 | describe "tuples" do 13 | defn add_subtract_tuple(a, b), do: {a + b, a - b} 14 | 15 | test "on results" do 16 | assert_equal(add_subtract_tuple(2, 3), {Nx.tensor(5), Nx.tensor(-1)}) 17 | 18 | assert_equal( 19 | add_subtract_tuple(Nx.tensor([-1, 0, 1]), Nx.tensor([10, 10, 10])), 20 | {Nx.tensor([9, 10, 11]), Nx.tensor([-11, -10, -9])} 21 | ) 22 | 23 | assert_equal( 24 | add_subtract_tuple(Nx.tensor([-1, 0, 1]), 10), 25 | {Nx.tensor([9, 10, 11]), Nx.tensor([-11, -10, -9])} 26 | ) 27 | end 28 | 29 | defn pattern_tuple({a, b}), do: a + b 30 | 31 | test "on patterns" do 32 | assert_equal(pattern_tuple({2, 3}), Nx.tensor(5)) 33 | 34 | assert_equal( 35 | pattern_tuple({Nx.tensor([1, 2]), Nx.tensor([[3], [4]])}), 36 | Nx.tensor([[4, 5], [5, 6]]) 37 | ) 38 | end 39 | 40 | defn calls_pattern_tuple(a, b), do: pattern_tuple({a, b}) 41 | 42 | test "on inlined tuples" do 43 | assert_equal(calls_pattern_tuple(2, 3), Nx.tensor(5)) 44 | 45 | assert_equal( 46 | calls_pattern_tuple(Nx.tensor([1, 2]), Nx.tensor([[3], [4]])), 47 | Nx.tensor([[4, 5], [5, 6]]) 48 | ) 49 | end 50 | end 51 | 52 | describe "tensor constants" do 53 | @two 2 54 | defn constants, do: @two 55 | defn add_two_attribute(t), do: t + @two 56 | 57 | @two_per_two Nx.tensor([[1, 2], [3, 4]]) 58 | defn add_2x2_attribute(t), do: t + @two_per_two 59 | defn add_2x2_constant(), do: @two_per_two + @two_per_two 60 | defn add_2x2_constant(_), do: @two_per_two + @two_per_two 61 | 62 | test "handles tensors as constants" do 63 | assert_equal(constants(), Nx.tensor(2)) 64 | end 65 | 66 | test "expands module attributes to scalars" do 67 | assert_equal(add_two_attribute(1), Nx.tensor(3)) 68 | assert_equal(add_two_attribute(Nx.tensor([1, 2, 3])), Nx.tensor([3, 4, 5])) 69 | end 70 | 71 | test "expands module attributes to tensors" do 72 | assert_equal(add_2x2_attribute(1), Nx.tensor([[2, 3], [4, 5]])) 73 | assert_equal(add_2x2_attribute(Nx.tensor([1, 2])), Nx.tensor([[2, 4], [4, 6]])) 74 | end 75 | 76 | test "constants should be folded" do 77 | assert_equal(add_2x2_constant(), Nx.tensor([[2, 4], [6, 8]])) 78 | assert_equal(add_2x2_constant(1), Nx.tensor([[2, 4], [6, 8]])) 79 | end 80 | end 81 | 82 | describe "non finite" do 83 | defn infinity, do: Nx.Constants.infinity() 84 | defn neg_infinity, do: Nx.Constants.neg_infinity() 85 | defn nan, do: Nx.Constants.nan() 86 | 87 | test "handles non-finite constants correctly" do 88 | assert_equal(infinity(), Nx.Constants.infinity()) 89 | assert_equal(neg_infinity(), Nx.Constants.neg_infinity()) 90 | # TODO: fix this 91 | {nan(), Nx.Constants.nan()} 92 | end 93 | 94 | defn negate_infinity, do: Nx.negate(Nx.Constants.infinity()) 95 | defn negate_neg_infinity, do: Nx.negate(Nx.Constants.infinity()) 96 | 97 | test "sanity check constants" do 98 | assert_equal(negate_infinity(), Nx.Constants.neg_infinity()) 99 | assert_equal(infinity(), Nx.Constants.infinity()) 100 | end 101 | end 102 | 103 | describe "float16" do 104 | defn return_float, do: Nx.tensor(1, type: {:f, 16}) 105 | 106 | test "supports float16 return types" do 107 | assert_equal(return_float(), Nx.tensor(1, type: {:f, 16})) 108 | end 109 | end 110 | 111 | describe "complex" do 112 | defn return_complex, do: Nx.complex(1, 2) 113 | defn return_complex_tensor, do: Nx.broadcast(Nx.complex(1, 2), {3, 3, 3}) 114 | 115 | test "supports complex return types" do 116 | assert_equal(return_complex(), Nx.tensor(Complex.new(1, 2))) 117 | assert_equal(return_complex_tensor(), Nx.broadcast(Complex.new(1, 2), {3, 3, 3})) 118 | end 119 | end 120 | 121 | describe "conjugate" do 122 | defn conjugate(x), do: Nx.conjugate(x) 123 | 124 | test "correctly returns complex conjugate" do 125 | assert_equal(conjugate(Nx.tensor(Complex.new(1, 2))), Nx.tensor(Complex.new(1, -2))) 126 | # This differs from the Nx doctest, which I believe should also return -0 127 | assert_equal(conjugate(Nx.tensor(1)), Nx.tensor(Complex.new(1, -0.0))) 128 | 129 | assert_equal( 130 | conjugate(Nx.tensor([Complex.new(1, 2), Complex.new(2, -4)])), 131 | Nx.tensor([Complex.new(1, -2), Complex.new(2, 4)]) 132 | ) 133 | end 134 | end 135 | 136 | describe "imag" do 137 | defn imag(x), do: Nx.imag(x) 138 | 139 | test "correctly returns imaginary part of complex" do 140 | assert_equal(imag(Nx.tensor(Complex.new(1, 2))), Nx.tensor(2.0)) 141 | assert_equal(imag(Nx.tensor(1)), Nx.tensor(0.0)) 142 | 143 | assert_equal( 144 | imag(Nx.tensor([Complex.new(1, 2), Complex.new(2, -4)])), 145 | Nx.tensor([2.0, -4.0]) 146 | ) 147 | end 148 | end 149 | 150 | describe "+/2" do 151 | @describetag :plus 152 | defn add_two(a, b), do: a + b 153 | 154 | test "same shape and type" do 155 | assert_equal(add_two(1.0, 2.0), Nx.tensor(3.0)) 156 | assert_equal(add_two(1, 2), Nx.tensor(3)) 157 | 158 | assert_equal(add_two(Nx.tensor([1, 2]), Nx.tensor([3, 4])), Nx.tensor([4, 6])) 159 | assert_equal(add_two(Nx.tensor([1.0, 2.0]), Nx.tensor([3.0, 4.0])), Nx.tensor([4.0, 6.0])) 160 | end 161 | 162 | test "different types" do 163 | tensors = [ 164 | {1, 2}, 165 | {1.0, 2}, 166 | {1.0, 3.0}, 167 | {Nx.tensor([1, 2], type: {:u, 8}), 3}, 168 | {Nx.tensor([1, 2], type: {:u, 8}), -3}, 169 | {Nx.tensor([1, 2], type: {:u, 8}), 3.0}, 170 | {Nx.tensor([1, 2], type: {:s, 8}), 3}, 171 | {Nx.tensor([1, 2], type: {:s, 8}), 3.0}, 172 | {Nx.tensor([1, 2], type: {:f, 32}), 3}, 173 | {Nx.tensor([1, 2], type: {:f, 32}), 3.0}, 174 | {Nx.tensor([1, 2], type: {:u, 8}), Nx.tensor(3, type: {:u, 16})}, 175 | {Nx.tensor([1, 2], type: {:u, 8}), Nx.tensor(-3, type: {:s, 16})}, 176 | {Nx.tensor([1, 2], type: {:u, 8}), Nx.tensor(3.0, type: {:f, 32})}, 177 | {Nx.tensor([1, 2], type: {:s, 8}), Nx.tensor(3, type: {:s, 16})}, 178 | {Nx.tensor([1, 2], type: {:s, 8}), Nx.tensor(3.0, type: {:f, 32})}, 179 | {Nx.tensor([1, 2], type: {:f, 32}), Nx.tensor(3, type: {:u, 16})}, 180 | {Nx.tensor([1, 2], type: {:f, 32}), Nx.tensor(3, type: {:s, 16})} 181 | # {Nx.tensor([1, 2], type: {:f, 32}), Nx.tensor(3.0, type: {:f, 64})} 182 | ] 183 | 184 | for {left, right} <- tensors do 185 | assert_all_close(add_two(left, right), evaluate(&add_two/2, [left, right])) 186 | assert_all_close(add_two(right, left), evaluate(&add_two/2, [right, left])) 187 | end 188 | end 189 | 190 | defn add_two_int(t), do: t + 2 191 | defn add_two_float(t), do: t + 2.0 192 | 193 | test "constants" do 194 | tensors = [ 195 | Nx.tensor([1, 2], type: {:u, 8}), 196 | Nx.tensor([1, 2], type: {:u, 16}), 197 | Nx.tensor([1, 2], type: {:u, 32}), 198 | Nx.tensor([1, 2], type: {:s, 8}), 199 | Nx.tensor([1, 2], type: {:s, 32}), 200 | Nx.tensor([1, 2], type: {:f, 32}) 201 | # Nx.tensor([1, 2], type: {:f, 64}) 202 | ] 203 | 204 | for t <- tensors do 205 | assert_equal(add_two_int(t), Nx.add(t, 2)) 206 | assert_equal(add_two_float(t), Nx.add(t, 2.0)) 207 | end 208 | end 209 | 210 | test "broadcast" do 211 | tensors = [ 212 | {Nx.tensor([1, 2]), Nx.tensor([[1, 2], [3, 4]])}, 213 | {Nx.tensor([1, 2]), Nx.tensor([[[1, 2], [3, 4]], [[4, 5], [6, 7]]])}, 214 | {Nx.tensor([[1], [2]]), Nx.tensor([[10, 20]])}, 215 | {Nx.tensor([[10, 20]]), Nx.tensor([[1], [2]])}, 216 | {Nx.tensor([[[10], [20]]]), Nx.tensor([[[1, 2]], [[3, 4]]])}, 217 | {Nx.tensor([[[100], [200], [300]]]), 218 | Nx.tensor([[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]])}, 219 | {Nx.tensor([[[[1]]]]), Nx.tensor([[1, 2], [3, 4]])}, 220 | {Nx.tensor([[[[1]]]]), Nx.tensor([1, 2])}, 221 | {Nx.tensor([[[10], [20]], [[30], [40]]]), Nx.tensor([[1, 2]])}, 222 | {Nx.tensor([[[[10], [20]], [[30], [40]]]]), Nx.tensor([[[1, 2]], [[3, 4]]])}, 223 | {Nx.tensor([[[[10], [20]], [[30], [40]]]]), Nx.tensor([[[[1, 2]]], [[[3, 4]]]])}, 224 | {Nx.tensor([[[10], [20]], [[30], [40]]]), Nx.tensor([[[1, 2]], [[3, 4]]])} 225 | ] 226 | 227 | for {left, right} <- tensors do 228 | assert_all_close(add_two(left, right), evaluate(&add_two/2, [left, right])) 229 | assert_all_close(add_two(right, left), evaluate(&add_two/2, [right, left])) 230 | end 231 | end 232 | 233 | test "names" do 234 | left = Nx.tensor([[10, 20]], names: [nil, :tens]) 235 | right = Nx.tensor([[1], [2]], names: [:ones, nil]) 236 | assert add_two(left, right).names == [:ones, :tens] 237 | end 238 | end 239 | 240 | describe "//2" do 241 | defn divide_two(a, b), do: a / b 242 | 243 | test "parameters" do 244 | tensors = [ 245 | {1, 2}, 246 | {1, Nx.tensor([1.0, 2.0, 3.0])}, 247 | {Nx.tensor([1, 2, 3]), 1}, 248 | {Nx.tensor([[1], [2]]), Nx.tensor([[10, 20]])}, 249 | {Nx.tensor([[1], [2]], type: {:s, 8}), Nx.tensor([[10, 20]], type: {:s, 8})}, 250 | {Nx.tensor([[1], [2]], type: {:f, 32}), Nx.tensor([[10, 20]], type: {:f, 32})} 251 | ] 252 | 253 | for {left, right} <- tensors do 254 | assert_all_close(divide_two(left, right), Nx.divide(left, right)) 255 | assert_all_close(divide_two(right, left), Nx.divide(right, left)) 256 | end 257 | end 258 | 259 | defn divide_two_int(t), do: t / 2 260 | defn divide_two_float(t), do: t / 2.0 261 | 262 | test "constants" do 263 | tensors = [ 264 | Nx.tensor([1, 2], type: {:u, 8}), 265 | Nx.tensor([1, 2], type: {:u, 16}), 266 | Nx.tensor([1, 2], type: {:u, 32}), 267 | Nx.tensor([1, 2], type: {:s, 8}), 268 | Nx.tensor([1, 2], type: {:s, 32}), 269 | Nx.tensor([1, 2], type: {:f, 32}) 270 | # Nx.tensor([1, 2], type: {:f, 64}) 271 | ] 272 | 273 | for t <- tensors do 274 | assert_all_close(divide_two_int(t), Nx.divide(t, 2)) 275 | assert_all_close(divide_two_float(t), Nx.divide(t, 2.0)) 276 | end 277 | end 278 | end 279 | 280 | describe "remainder" do 281 | defn remainder(a, b), do: Nx.remainder(a, b) 282 | 283 | test "integers" do 284 | left = Nx.tensor([-1023, 1023]) 285 | right = Nx.tensor([[-4], [4]]) 286 | assert Nx.shape(remainder(left, right)) == {2, 2} 287 | assert_all_close(remainder(left, right), Nx.remainder(left, right)) 288 | end 289 | 290 | test "floats" do 291 | left = Nx.tensor([-8.3, -8.4, -8.5, 8.3, 8.4, 8.5]) 292 | right = Nx.tensor([[-4.2], [-4.1], [-4.0], [4.0], [4.1], [4.2]]) 293 | assert Nx.shape(remainder(left, right)) == {6, 6} 294 | assert_all_close(remainder(left, right), Nx.remainder(left, right)) 295 | end 296 | end 297 | 298 | describe "element-wise arith operators" do 299 | @tensors [ 300 | {1, 2}, 301 | {1, Nx.tensor([1.0, 2.0, 3.0])}, 302 | {Nx.tensor([1, 2, 3]), 1}, 303 | {Nx.tensor([[1], [2]]), Nx.tensor([[10, 20]])}, 304 | {Nx.tensor([[1], [2]], type: {:s, 8}), Nx.tensor([[10, 20]], type: {:s, 8})}, 305 | {Nx.tensor([[1], [2]], type: {:f, 32}), Nx.tensor([[10, 20]], type: {:f, 32})} 306 | ] 307 | 308 | defn subtract_two(a, b), do: a - b 309 | 310 | test "-" do 311 | for {left, right} <- @tensors do 312 | assert_all_close(subtract_two(left, right), Nx.subtract(left, right)) 313 | assert_all_close(subtract_two(right, left), Nx.subtract(right, left)) 314 | end 315 | end 316 | 317 | defn multiply_two(a, b), do: a * b 318 | 319 | test "*" do 320 | for {left, right} <- @tensors do 321 | assert_all_close(multiply_two(left, right), Nx.multiply(left, right)) 322 | assert_all_close(multiply_two(right, left), Nx.multiply(right, left)) 323 | end 324 | end 325 | 326 | defn unary_minus(a), do: -a 327 | 328 | test "negate" do 329 | for t <- [ 330 | Nx.tensor([-1, 0, 1], type: {:u, 8}), 331 | Nx.tensor([-1, 0, 1]), 332 | Nx.tensor([-1.0, 1.0]) 333 | ] do 334 | assert_equal(unary_minus(t), Nx.negate(t)) 335 | end 336 | end 337 | 338 | defn max_two(a, b), do: max(a, b) 339 | 340 | test "max" do 341 | for {left, right} <- @tensors do 342 | assert_all_close(max_two(left, right), Nx.max(left, right)) 343 | assert_all_close(max_two(right, left), Nx.max(right, left)) 344 | end 345 | end 346 | 347 | defn min_two(a, b), do: min(a, b) 348 | 349 | test "min" do 350 | for {left, right} <- @tensors do 351 | assert_all_close(min_two(left, right), Nx.min(left, right)) 352 | assert_all_close(min_two(right, left), Nx.min(right, left)) 353 | end 354 | end 355 | 356 | defn power_two(a, b), do: Nx.pow(a, b) 357 | 358 | test "power" do 359 | for {left, right} <- @tensors do 360 | case left do 361 | %{type: {_, 8}} -> 362 | nil 363 | 364 | i when is_integer(i) -> 365 | nil 366 | 367 | %{type: {:s, _}} -> 368 | nil 369 | 370 | _ -> 371 | assert_all_close(power_two(left, right), Nx.pow(left, right)) 372 | assert_all_close(power_two(right, left), Nx.pow(right, left)) 373 | end 374 | end 375 | end 376 | 377 | defn atan2_two(a, b), do: Nx.atan2(a, b) 378 | 379 | test "atan2" do 380 | <> = <<0x8000000000000000::64>> 381 | left = Nx.tensor([-1.0, neg_zero, 0.0, 1.0]) 382 | right = Nx.tensor([[-1.0], [neg_zero], [0.0], [1.0]]) 383 | 384 | assert_all_close(atan2_two(left, right), Nx.atan2(left, right)) 385 | assert_all_close(atan2_two(right, left), Nx.atan2(right, left)) 386 | end 387 | 388 | defn quotient_two(a, b), do: Nx.quotient(a, b) 389 | 390 | test "quotient" do 391 | int_tensors = [ 392 | {1, 2}, 393 | {1, Nx.tensor([1, 2, 3])}, 394 | {Nx.tensor([1, 2, 3]), 1}, 395 | {Nx.tensor([[1], [2]]), Nx.tensor([[10, 20]])}, 396 | {Nx.tensor([[1], [2]], type: {:s, 8}), Nx.tensor([[10, 20]], type: {:s, 8})}, 397 | {Nx.tensor([[1], [2]], type: {:s, 8}), Nx.tensor([[10, 20]], type: {:s, 32})} 398 | ] 399 | 400 | for {left, right} <- int_tensors do 401 | assert_all_close(quotient_two(left, right), Nx.quotient(left, right)) 402 | assert_all_close(quotient_two(right, left), Nx.quotient(right, left)) 403 | end 404 | end 405 | end 406 | 407 | describe "element-wise bitwise operators" do 408 | @left Nx.tensor([-2, -1, 0, 1, 2]) 409 | @right Nx.tensor([[-2], [-1], [0], [1], [2]]) 410 | 411 | defn bitwise_and(a, b), do: a &&& b 412 | 413 | test "bitwise_and" do 414 | assert Nx.shape(bitwise_and(@left, @right)) == {5, 5} 415 | assert_equal(bitwise_and(@left, @right), Nx.bitwise_and(@left, @right)) 416 | end 417 | 418 | defn bitwise_or(a, b), do: a ||| b 419 | 420 | test "bitwise_or" do 421 | assert Nx.shape(bitwise_or(@left, @right)) == {5, 5} 422 | assert_equal(bitwise_or(@left, @right), Nx.bitwise_or(@left, @right)) 423 | end 424 | 425 | defn bitwise_not(a), do: ~~~a 426 | 427 | test "bitwise_not" do 428 | assert Nx.shape(bitwise_not(@left)) == {5} 429 | assert_equal(bitwise_not(@left), Nx.bitwise_not(@left)) 430 | end 431 | 432 | defn bitwise_pc(a), do: Nx.population_count(a) 433 | 434 | test "population_count" do 435 | assert Nx.shape(bitwise_pc(@left)) == {5} 436 | assert_equal(bitwise_pc(@left), Nx.population_count(@left)) 437 | end 438 | 439 | defn bitwise_clz(a), do: Nx.count_leading_zeros(a) 440 | 441 | test "count_leading_zeros" do 442 | assert Nx.shape(bitwise_clz(@left)) == {5} 443 | assert_equal(bitwise_clz(@left), Nx.count_leading_zeros(@left)) 444 | end 445 | 446 | @left Nx.tensor([-2, -1, 0, 1, 2]) 447 | @right Nx.tensor([[0], [1], [2], [3], [4]]) 448 | 449 | defn left_shift(a, b), do: a <<< b 450 | 451 | test "left_shift" do 452 | assert Nx.shape(left_shift(@left, @right)) == {5, 5} 453 | assert_equal(left_shift(@left, @right), Nx.left_shift(@left, @right)) 454 | end 455 | 456 | @left_signed Nx.tensor([-128, -127, -2, -1, 0, 1, 2, 126, 127], type: {:s, 8}) 457 | @right_signed Nx.tensor([[0], [1], [2], [3], [4], [5], [6], [7], [8]], type: {:s, 8}) 458 | 459 | @left_unsigned Nx.tensor([0, 1, 2, 253, 254, 255], type: {:u, 8}) 460 | @right_unsigned Nx.tensor([[0], [1], [2], [3], [4], [5]], type: {:u, 8}) 461 | 462 | defn right_shift(a, b), do: a >>> b 463 | 464 | test "right_shift" do 465 | assert Nx.shape(right_shift(@left_signed, @right_signed)) == {9, 9} 466 | 467 | assert_equal( 468 | right_shift(@left_signed, @right_signed), 469 | Nx.right_shift(@left_signed, @right_signed) 470 | ) 471 | 472 | assert Nx.shape(right_shift(@left_unsigned, @right_unsigned)) == {6, 6} 473 | 474 | assert_equal( 475 | right_shift(@left_unsigned, @right_unsigned), 476 | Nx.right_shift(@left_unsigned, @right_unsigned) 477 | ) 478 | end 479 | end 480 | 481 | describe "exp" do 482 | defn exp(t), do: Nx.exp(t) 483 | 484 | test "computes the exp across types" do 485 | assert_all_close( 486 | Nx.tensor([1, 2, 3]) |> exp(), 487 | Nx.tensor([2.718281828459045, 7.38905609893065, 20.085536923187668]) 488 | ) 489 | 490 | assert_all_close( 491 | Nx.tensor([1, 2, 3], type: {:s, 8}) |> exp(), 492 | Nx.tensor([2.718281828459045, 7.38905609893065, 20.085536923187668], type: {:f, 32}) 493 | ) 494 | 495 | assert_all_close( 496 | Nx.tensor([1, 2, 3], type: {:u, 8}) |> exp(), 497 | Nx.tensor([2.718281828459045, 7.38905609893065, 20.085536923187668], type: {:f, 32}) 498 | ) 499 | 500 | assert_all_close( 501 | Nx.tensor([1.0, 2.0, 3.0]) |> exp(), 502 | Nx.tensor([2.718281828459045, 7.38905609893065, 20.085536923187668]) 503 | ) 504 | 505 | assert_all_close( 506 | Nx.tensor([1.0, 2.0, 3.0], type: {:f, 32}) |> exp(), 507 | Nx.tensor([2.718281828459045, 7.38905609893065, 20.085536923187668], type: {:f, 32}) 508 | ) 509 | end 510 | end 511 | 512 | describe "equal" do 513 | defn equal(a, b), do: Nx.equal(a, b) 514 | 515 | test "computes equality of scalars" do 516 | assert_equal(equal(Nx.tensor(1), Nx.tensor(2)), Nx.tensor(0, type: {:u, 8})) 517 | end 518 | 519 | test "computes equality with broadcasting" do 520 | assert_equal( 521 | equal(Nx.tensor(1), Nx.tensor([1, 2, 3])), 522 | Nx.tensor([1, 0, 0], type: {:u, 8}) 523 | ) 524 | end 525 | 526 | test "computes equality with mixed types" do 527 | assert_equal( 528 | equal(Nx.tensor([1, 2, 3]), Nx.tensor([1.0, 2.0, 3.0])), 529 | Nx.tensor([1, 1, 1], type: {:u, 8}) 530 | ) 531 | end 532 | 533 | defn successive_compare(y_true, y_pred) do 534 | y_pred 535 | |> Nx.equal(y_pred) 536 | |> Nx.equal(y_true) 537 | end 538 | 539 | @tag :todo 540 | # TODO: track https://github.com/llvm/llvm-project/issues/57951 541 | test "computes successive comparisons" do 542 | {successive_compare(Nx.tensor(1), Nx.tensor(1)), Nx.tensor(1, type: {:u, 8})} 543 | end 544 | end 545 | 546 | describe "not equal" do 547 | defn not_equal(a, b), do: Nx.not_equal(a, b) 548 | 549 | test "computes equality of scalars" do 550 | assert_equal(not_equal(Nx.tensor(1), Nx.tensor(2)), Nx.tensor(1, type: {:u, 8})) 551 | end 552 | 553 | test "computes equality with broadcasting" do 554 | assert_equal( 555 | not_equal(Nx.tensor(1), Nx.tensor([1, 2, 3])), 556 | Nx.tensor([0, 1, 1], type: {:u, 8}) 557 | ) 558 | end 559 | 560 | test "computes equality with mixed types" do 561 | assert_equal( 562 | not_equal(Nx.tensor([1, 2, 3]), Nx.tensor([1.0, 2.0, 3.0])), 563 | Nx.tensor([0, 0, 0], type: {:u, 8}) 564 | ) 565 | end 566 | end 567 | 568 | describe "less" do 569 | defn less(a, b), do: Nx.less(a, b) 570 | 571 | test "compares scalars" do 572 | assert_equal(less(Nx.tensor(1), Nx.tensor(2)), Nx.tensor(1, type: {:u, 8})) 573 | end 574 | 575 | test "compares with broadcasting" do 576 | assert_equal(less(Nx.tensor(1), Nx.tensor([1, 2, 3])), Nx.tensor([0, 1, 1], type: {:u, 8})) 577 | end 578 | 579 | test "compares with mixed types" do 580 | assert_equal( 581 | less(Nx.tensor([1, 2, 3]), Nx.tensor([1.0, 2.0, 3.0])), 582 | Nx.tensor([0, 0, 0], type: {:u, 8}) 583 | ) 584 | end 585 | end 586 | 587 | describe "greater" do 588 | defn greater(a, b), do: Nx.greater(a, b) 589 | 590 | test "compares scalars" do 591 | assert_equal(greater(Nx.tensor(1), Nx.tensor(2)), Nx.tensor(0, type: {:u, 8})) 592 | end 593 | 594 | test "compares with broadcasting" do 595 | assert_equal( 596 | greater(Nx.tensor(1), Nx.tensor([1, 2, 3])), 597 | Nx.tensor([0, 0, 0], type: {:u, 8}) 598 | ) 599 | end 600 | 601 | test "compares with mixed types" do 602 | assert_equal( 603 | greater(Nx.tensor([1, 2, 3]), Nx.tensor([1.0, 2.0, 3.0])), 604 | Nx.tensor([0, 0, 0], type: {:u, 8}) 605 | ) 606 | end 607 | end 608 | 609 | describe "less equal" do 610 | defn less_equal(a, b), do: Nx.less_equal(a, b) 611 | 612 | test "compares scalars" do 613 | assert_equal(less_equal(Nx.tensor(1), Nx.tensor(2)), Nx.tensor(1, type: {:u, 8})) 614 | end 615 | 616 | test "compares with broadcasting" do 617 | assert_equal( 618 | less_equal(Nx.tensor(1), Nx.tensor([1, 2, 3])), 619 | Nx.tensor([1, 1, 1], type: {:u, 8}) 620 | ) 621 | end 622 | 623 | test "compares with mixed types" do 624 | assert_equal( 625 | less_equal(Nx.tensor([1, 2, 3]), Nx.tensor([1.0, 2.0, 3.0])), 626 | Nx.tensor([1, 1, 1], type: {:u, 8}) 627 | ) 628 | end 629 | end 630 | 631 | describe "greater equal" do 632 | defn greater_equal(a, b), do: Nx.greater_equal(a, b) 633 | 634 | test "compares scalars" do 635 | assert_equal(greater_equal(Nx.tensor(1), Nx.tensor(2)), Nx.tensor(0, type: {:u, 8})) 636 | end 637 | 638 | test "compares with broadcasting" do 639 | assert_equal( 640 | greater_equal(Nx.tensor(1), Nx.tensor([1, 2, 3])), 641 | Nx.tensor([1, 0, 0], type: {:u, 8}) 642 | ) 643 | end 644 | 645 | test "compares with mixed types" do 646 | assert_equal( 647 | greater_equal(Nx.tensor([1, 2, 3]), Nx.tensor([1.0, 2.0, 3.0])), 648 | Nx.tensor([1, 1, 1], type: {:u, 8}) 649 | ) 650 | end 651 | end 652 | 653 | describe "logical" do 654 | defn logical_and(a, b), do: Nx.logical_and(a, b) 655 | 656 | test "and" do 657 | assert_equal( 658 | logical_and(Nx.tensor([-1, 0, 1]), Nx.tensor([[-1], [0], [1]])), 659 | Nx.tensor( 660 | [ 661 | [1, 0, 1], 662 | [0, 0, 0], 663 | [1, 0, 1] 664 | ], 665 | type: {:u, 8} 666 | ) 667 | ) 668 | 669 | assert_equal( 670 | logical_and(Nx.tensor([-1.0, 0.0, 1.0]), Nx.tensor([[-1], [0], [1]])), 671 | Nx.tensor( 672 | [ 673 | [1, 0, 1], 674 | [0, 0, 0], 675 | [1, 0, 1] 676 | ], 677 | type: {:u, 8} 678 | ) 679 | ) 680 | end 681 | 682 | defn logical_or(a, b), do: Nx.logical_or(a, b) 683 | 684 | test "or" do 685 | assert_equal( 686 | logical_or(Nx.tensor([-1, 0, 1]), Nx.tensor([[-1], [0], [1]])), 687 | Nx.tensor( 688 | [ 689 | [1, 1, 1], 690 | [1, 0, 1], 691 | [1, 1, 1] 692 | ], 693 | type: {:u, 8} 694 | ) 695 | ) 696 | 697 | assert_equal( 698 | logical_or(Nx.tensor([-1.0, 0.0, 1.0]), Nx.tensor([[-1], [0], [1]])), 699 | Nx.tensor( 700 | [ 701 | [1, 1, 1], 702 | [1, 0, 1], 703 | [1, 1, 1] 704 | ], 705 | type: {:u, 8} 706 | ) 707 | ) 708 | end 709 | 710 | defn logical_xor(a, b), do: Nx.logical_xor(a, b) 711 | 712 | test "xor" do 713 | assert_equal( 714 | logical_xor(Nx.tensor([-1, 0, 1]), Nx.tensor([[-1], [0], [1]])), 715 | Nx.tensor( 716 | [ 717 | [0, 1, 0], 718 | [1, 0, 1], 719 | [0, 1, 0] 720 | ], 721 | type: {:u, 8} 722 | ) 723 | ) 724 | 725 | assert_equal( 726 | logical_xor(Nx.tensor([-1.0, 0.0, 1.0]), Nx.tensor([[-1], [0], [1]])), 727 | Nx.tensor( 728 | [ 729 | [0, 1, 0], 730 | [1, 0, 1], 731 | [0, 1, 0] 732 | ], 733 | type: {:u, 8} 734 | ) 735 | ) 736 | end 737 | 738 | defn logical_not(a), do: Nx.logical_not(a) 739 | 740 | test "not" do 741 | assert_equal( 742 | logical_not(Nx.tensor([-2, -1, 0, 1, 2])), 743 | Nx.tensor([0, 0, 1, 0, 0], type: {:u, 8}) 744 | ) 745 | end 746 | end 747 | 748 | describe "select" do 749 | defn select(pred, x, y), do: Nx.select(pred, x, y) 750 | 751 | test "selects one or the other with a scalar" do 752 | assert_equal( 753 | select(Nx.tensor(1), Nx.tensor([1, 2, 3]), Nx.tensor([4, 5, 6])), 754 | Nx.tensor([1, 2, 3]) 755 | ) 756 | end 757 | 758 | test "selects with type" do 759 | assert_equal( 760 | select( 761 | Nx.tensor(1), 762 | Nx.tensor([1, 2, 3], type: {:u, 8}), 763 | Nx.tensor([4, 5, 6], type: {:u, 8}) 764 | ), 765 | Nx.tensor([1, 2, 3], type: {:u, 8}) 766 | ) 767 | 768 | assert_equal( 769 | select( 770 | Nx.tensor(1), 771 | Nx.tensor([1, 2, 3], type: {:u, 8}), 772 | Nx.tensor([4, 5, 6], type: {:f, 32}) 773 | ), 774 | Nx.tensor([1, 2, 3], type: {:f, 32}) 775 | ) 776 | end 777 | 778 | test "selects with broadcasting" do 779 | assert_equal( 780 | select(Nx.tensor([1, 0, 1, 0, 1]), Nx.tensor([10]), Nx.tensor([1, 2, 3, 4, 5])), 781 | Nx.tensor([10, 2, 10, 4, 10]) 782 | ) 783 | 784 | assert_equal( 785 | select(Nx.tensor([-2, -1, 0, 1, 2]), Nx.tensor([10]), Nx.tensor([1, 2, 3, 4, 5])), 786 | Nx.tensor([10, 10, 3, 10, 10]) 787 | ) 788 | end 789 | end 790 | 791 | describe "unary float ops" do 792 | @int_tensor Nx.tensor([1, 2, 3]) 793 | @float_tensor Nx.tensor([1.0, 2.0, 3.0]) 794 | float_ops = 795 | ([ 796 | :exp, 797 | :expm1, 798 | :log, 799 | :log1p, 800 | :sigmoid, 801 | :cos, 802 | :sin, 803 | :tanh, 804 | :sqrt, 805 | :rsqrt, 806 | :cbrt, 807 | :is_nan 808 | ] ++ 809 | [:is_infinity, :tan, :acosh, :asinh, :cosh, :sinh, :erf, :erfc]) 810 | |> Enum.reject(fn x -> x in [:erfc, :asinh, :sinh, :acosh, :cosh] end) 811 | 812 | for fun <- float_ops do 813 | defn_fun = :"unary_#{fun}" 814 | defn_var = Macro.var(defn_fun, __MODULE__) 815 | defn unquote(defn_fun)(t), do: Nx.unquote(fun)(t) 816 | 817 | test "#{fun}" do 818 | assert_all_close( 819 | unquote(defn_fun)(@float_tensor), 820 | evaluate(&(unquote(defn_var) / 1), [@float_tensor]) 821 | ) 822 | 823 | assert_all_close( 824 | unquote(defn_fun)(@int_tensor), 825 | evaluate(&(unquote(defn_var) / 1), [@int_tensor]) 826 | ) 827 | end 828 | end 829 | end 830 | 831 | describe "softmax" do 832 | defn softmax(t), do: Nx.exp(t) / Nx.sum(Nx.exp(t)) 833 | 834 | test "computes softmax" do 835 | assert_all_close( 836 | softmax(Nx.tensor([1.0, 2.0, 3.0, 4.0])), 837 | Nx.tensor([ 838 | 0.03205860328008499, 839 | 0.08714431874203257, 840 | 0.23688281808991013, 841 | 0.6439142598879722 842 | ]) 843 | ) 844 | end 845 | end 846 | 847 | describe "dot product" do 848 | defn dot(a, b), do: Nx.dot(a, b) 849 | 850 | test "computes the dot product of scalars" do 851 | assert_equal(dot(Nx.tensor(2), Nx.tensor(2)), Nx.tensor(4)) 852 | assert_equal(dot(Nx.tensor(2.0), Nx.tensor(2.0)), Nx.tensor(4.0)) 853 | assert_equal(dot(Nx.tensor(-2.0), Nx.tensor(-2)), Nx.tensor(4.0)) 854 | end 855 | 856 | test "computes the dot product of vectors" do 857 | assert_equal( 858 | dot(Nx.tensor([1, 2, 3], type: {:s, 32}), Nx.tensor([4, 5, 6], type: {:s, 32})), 859 | Nx.tensor(32, type: {:s, 32}) 860 | ) 861 | 862 | assert_equal( 863 | dot(Nx.tensor([1.0, 2.0, 3.0], type: {:f, 32}), Nx.tensor([4, 5, 6])), 864 | Nx.tensor(32.0) 865 | ) 866 | 867 | assert_equal(dot(Nx.tensor([1.0, 2.0, 3.0]), Nx.tensor([4.0, 5.0, 6.0])), Nx.tensor(32.0)) 868 | end 869 | 870 | test "computes the dot product of matrices" do 871 | assert_equal( 872 | dot( 873 | Nx.tensor([[1, 2, 3], [4, 5, 6]], type: {:s, 32}), 874 | Nx.tensor([[7, 8], [9, 10], [11, 12]], type: {:s, 32}) 875 | ), 876 | Nx.tensor([[58, 64], [139, 154]], type: {:s, 32}) 877 | ) 878 | 879 | assert_equal( 880 | dot( 881 | Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 882 | Nx.tensor([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]) 883 | ), 884 | Nx.tensor([[58.0, 64.0], [139.0, 154.0]]) 885 | ) 886 | 887 | assert_equal( 888 | dot( 889 | Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 890 | Nx.tensor([[7, 8], [9, 10], [11, 12]]) 891 | ), 892 | Nx.tensor([[58.0, 64.0], [139.0, 154.0]]) 893 | ) 894 | end 895 | 896 | test "computes the dot product of tensors" do 897 | assert_equal( 898 | dot( 899 | Nx.tensor( 900 | [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]], 901 | type: {:s, 32}, 902 | names: [:a, :b, :c] 903 | ), 904 | Nx.tensor( 905 | [[[1, 2, 3], [3, 4, 5], [5, 6, 7]]], 906 | type: {:s, 32}, 907 | names: [:e, :f, :g] 908 | ) 909 | ), 910 | Nx.tensor( 911 | [ 912 | [[[22, 28, 34]], [[49, 64, 79]], [[76, 100, 124]]], 913 | [[[22, 28, 34]], [[49, 64, 79]], [[76, 100, 124]]] 914 | ], 915 | type: {:s, 32}, 916 | names: [:a, :b, :e, :g] 917 | ) 918 | ) 919 | end 920 | 921 | defn batched_dot(t1, t2), do: Nx.dot(t1, [1], [0], t2, [1], [0]) 922 | 923 | test "computes a batched dot product" do 924 | assert_equal( 925 | batched_dot(Nx.iota({3, 2, 3}, type: {:f, 32}), Nx.iota({3, 2, 2}, type: {:f, 32})), 926 | Nx.tensor([ 927 | [[6.0, 9.0], [8.0, 13.0], [10.0, 17.0]], 928 | [[78.0, 93.0], [88.0, 105.0], [98.0, 117.0]], 929 | [[246.0, 273.0], [264.0, 293.0], [282.0, 313.0]] 930 | ]) 931 | ) 932 | end 933 | 934 | defn general_dot(t1, t2), do: Nx.dot(t1, [0, 1], [], t2, [1, 2], []) 935 | 936 | test "computes a general dot product" do 937 | assert_equal( 938 | general_dot(Nx.iota({4, 5, 2}, type: {:f, 32}), Nx.iota({2, 4, 5}, type: {:f, 32})), 939 | Nx.tensor([[4940.0, 12540.0], [5130.0, 13130.0]]) 940 | ) 941 | end 942 | end 943 | 944 | describe "creation" do 945 | defn iota, do: Nx.iota({10, 10}) 946 | 947 | test "iota" do 948 | %{data: %Manx{}} = tensor = iota() 949 | assert Nx.backend_transfer(tensor) == Nx.iota({10, 10}, backend: Nx.BinaryBackend) 950 | end 951 | 952 | defn iota0, do: Nx.iota({10, 10}, axis: 0) 953 | 954 | test "iota0" do 955 | %{data: %Manx{}} = tensor = iota0() 956 | assert Nx.backend_transfer(tensor) == Nx.iota({10, 10}, axis: 0, backend: Nx.BinaryBackend) 957 | end 958 | 959 | defn iota1, do: Nx.iota({10, 10}, axis: 1) 960 | 961 | test "iota1" do 962 | %{data: %Manx{}} = tensor = iota1() 963 | assert Nx.backend_transfer(tensor) == Nx.iota({10, 10}, axis: 1, backend: Nx.BinaryBackend) 964 | end 965 | end 966 | 967 | describe "as_type" do 968 | defn to_float(t), do: Nx.as_type(t, {:f, 32}) 969 | 970 | test "converts tensor type" do 971 | assert_equal(to_float(Nx.tensor([1, 2, 3])), Nx.tensor([1.0, 2.0, 3.0], type: {:f, 32})) 972 | end 973 | 974 | defn generic_as_type(t, template), do: Nx.as_type(t, template.type) 975 | 976 | test "converts non-finite types" do 977 | non_finite = 978 | Nx.stack([Nx.Constants.infinity(), Nx.Constants.nan(), Nx.Constants.neg_infinity()]) 979 | 980 | # TODO: fix this, they should be equal to each other 981 | assert_not_equal( 982 | generic_as_type(non_finite, Nx.template({}, {:u, 8})), 983 | Nx.tensor([255, 0, 0], type: {:u, 8}) 984 | ) 985 | 986 | assert_equal( 987 | generic_as_type(non_finite, Nx.template({}, {:s, 16})), 988 | Nx.tensor([32767, 0, -32768], type: {:s, 16}) 989 | ) 990 | end 991 | end 992 | 993 | describe "bitcast" do 994 | defn bitcast_to_float(t), do: Nx.bitcast(t, {:f, 32}) 995 | 996 | test "converts tensor type" do 997 | assert_equal( 998 | bitcast_to_float(Nx.tensor([0, 0, 0], type: {:s, 32})), 999 | Nx.tensor([0.0, 0.0, 0.0]) 1000 | ) 1001 | end 1002 | end 1003 | 1004 | describe "squeeze" do 1005 | defn squeeze(t), do: Nx.squeeze(t) 1006 | defn squeeze2(t), do: Nx.squeeze(t, axes: [0, 1]) 1007 | 1008 | test "with scalar" do 1009 | assert_equal(squeeze(Nx.tensor(1)), Nx.tensor(1)) 1010 | end 1011 | 1012 | test "with tensors" do 1013 | assert_equal(squeeze(Nx.tensor([[1, 2, 3]])), Nx.tensor([1, 2, 3])) 1014 | assert_equal(squeeze(Nx.tensor([[[[[1]]]]])), Nx.tensor(1)) 1015 | assert_equal(squeeze2(Nx.tensor([[[[[1]]]]])), Nx.tensor([[[1]]])) 1016 | end 1017 | end 1018 | 1019 | describe "slicing" do 1020 | @describetag :runtime 1021 | defn slice1(t), do: Nx.slice(t, [0, 6, 2], [2, 1, 3]) 1022 | 1023 | defn slice1_dynamic(t), do: Nx.slice(t, [Nx.tensor(0), Nx.tensor(6), Nx.tensor(2)], [2, 1, 3]) 1024 | 1025 | defn slice2(t), do: Nx.slice(t, [1, 4, 10], [1, 1, 10], strides: [1, 2, 3]) 1026 | 1027 | defn slice2_dynamic(t), 1028 | do: Nx.slice(t, [Nx.tensor(1), Nx.tensor(4), Nx.tensor(10)], [1, 1, 10], strides: [1, 2, 3]) 1029 | 1030 | defn slice3(t), do: Nx.slice(t, [0, 4, 11], [2, 3, 9], strides: [2, 1, 3]) 1031 | 1032 | defn slice3_dynamic(t), 1033 | do: Nx.slice(t, [Nx.tensor(0), Nx.tensor(4), Nx.tensor(11)], [2, 3, 9], strides: [2, 1, 3]) 1034 | 1035 | test "works without stride" do 1036 | t = Nx.iota({900}) 1037 | t = Nx.reshape(t, {2, 15, 30}) 1038 | assert_equal(slice1(t), Nx.tensor([[[182, 183, 184]], [[632, 633, 634]]])) 1039 | assert_equal(slice1_dynamic(t), Nx.tensor([[[182, 183, 184]], [[632, 633, 634]]])) 1040 | end 1041 | 1042 | test "works with stride" do 1043 | t = Nx.iota({900}) 1044 | t = Nx.reshape(t, {2, 15, 30}) 1045 | assert_equal(slice2(t), Nx.tensor([[[580, 583, 586, 589]]])) 1046 | assert_equal(slice2_dynamic(t), Nx.tensor([[[580, 583, 586, 589]]])) 1047 | 1048 | assert_equal( 1049 | slice3(t), 1050 | Nx.tensor([ 1051 | [ 1052 | [131, 134, 137], 1053 | [161, 164, 167], 1054 | [191, 194, 197] 1055 | ] 1056 | ]) 1057 | ) 1058 | 1059 | assert_equal( 1060 | slice3_dynamic(t), 1061 | Nx.tensor([ 1062 | [ 1063 | [131, 134, 137], 1064 | [161, 164, 167], 1065 | [191, 194, 197] 1066 | ] 1067 | ]) 1068 | ) 1069 | end 1070 | end 1071 | end 1072 | -------------------------------------------------------------------------------- /lib/manx/nx/defn.ex: -------------------------------------------------------------------------------- 1 | defmodule Manx.Defn do 2 | alias __MODULE__.Env 3 | use Beaver 4 | alias Beaver.MLIR 5 | import MLIR.Sigils 6 | import Beaver, only: :macros 7 | require Beaver.MLIR 8 | alias MLIR.{Type, Attribute} 9 | alias MLIR.Dialect.{TOSA, Linalg, Arith, Tensor, Bufferization, Math, SCF, MemRef} 10 | 11 | defdelegate gen_type(tensor), to: Manx.Type 12 | 13 | @doc """ 14 | In upstream MLIR, there is no lower-able Op packing multiple values into a tuple. 15 | If the Nx root type is a tuple, it should be converted to multi-results. 16 | This function should always return a list of types 17 | """ 18 | def gen_root_types(tuple) when is_tuple(tuple) do 19 | Tuple.to_list(tuple) 20 | |> Enum.map(&gen_type/1) 21 | end 22 | 23 | def gen_root_types(type), do: [gen_type(type)] 24 | 25 | def gen_op( 26 | %Env{block: block}, 27 | %Nx.Tensor{ 28 | data: %Nx.Defn.Expr{op: :parameter, args: [pos]} 29 | } 30 | ) 31 | when is_integer(pos) do 32 | arg_cnt = Beaver.Walker.arguments(block) |> Enum.count() 33 | 34 | if pos >= arg_cnt do 35 | raise "argument ##{pos} out of bound, argument count: #{arg_cnt}" 36 | end 37 | 38 | arg = block |> Beaver.MLIR.CAPI.mlirBlockGetArgument(pos) 39 | 40 | if MLIR.is_null(arg) do 41 | raise "argument ##{pos} not found" 42 | end 43 | 44 | arg 45 | end 46 | 47 | def gen_op( 48 | %Env{block: block, ctx: ctx}, 49 | %Nx.Tensor{ 50 | data: %Nx.Defn.Expr{op: :constant, args: [:nan]}, 51 | shape: {}, 52 | type: {:f, 32} 53 | } = t 54 | ) do 55 | mlir block: block, ctx: ctx do 56 | TOSA.const({:value, ~a{dense<0x7F800001> : tensor}}) >>> gen_type(t) 57 | end 58 | end 59 | 60 | def gen_op( 61 | %Env{block: block, ctx: ctx}, 62 | %Nx.Tensor{ 63 | data: %Nx.Defn.Expr{op: :constant, args: [:infinity]}, 64 | shape: {}, 65 | type: {:f, 32} 66 | } = t 67 | ) do 68 | mlir block: block, ctx: ctx do 69 | TOSA.const({:value, ~a{dense<0x7F800000> : tensor}}) >>> 70 | gen_type(t) 71 | end 72 | end 73 | 74 | def gen_op( 75 | %Env{block: block, ctx: ctx}, 76 | %Nx.Tensor{ 77 | data: %Nx.Defn.Expr{op: :constant, args: [:neg_infinity]}, 78 | shape: {}, 79 | type: {:f, 32} 80 | } = t 81 | ) do 82 | mlir block: block, ctx: ctx do 83 | _r = 84 | TOSA.const({:value, ~a{dense<0xFF800000> : tensor}}) >>> 85 | gen_type(t) 86 | end 87 | end 88 | 89 | def gen_op( 90 | %Env{block: block, ctx: ctx}, 91 | %Nx.Tensor{ 92 | data: %Nx.Defn.Expr{op: :constant, args: [value]}, 93 | shape: {} 94 | } = t 95 | ) 96 | when is_integer(value) or is_float(value) do 97 | mlir block: block, ctx: ctx do 98 | t_str = gen_type(t) |> Beaver.Deferred.create(ctx) |> MLIR.to_string() 99 | 100 | TOSA.const({:value, ~a{dense<#{value}> : #{t_str}}}) >>> 101 | gen_type(t) 102 | end 103 | end 104 | 105 | def gen_op( 106 | %Env{block: block, ctx: ctx}, 107 | %Nx.Tensor{ 108 | data: %Nx.Defn.Expr{op: :constant, args: [%Complex{im: im, re: re}]}, 109 | type: {:c, 64} 110 | } = t 111 | ) do 112 | mlir block: block, ctx: ctx do 113 | t_str = gen_type(t) |> Beaver.Deferred.create(ctx) |> MLIR.to_string() 114 | 115 | Arith.constant({:value, ~a[dense<(#{re}, #{im})> : #{t_str}]}) >>> 116 | gen_type(t) 117 | end 118 | end 119 | 120 | def gen_op( 121 | %Env{block: block, ctx: ctx}, 122 | %Nx.Tensor{ 123 | data: %Nx.Defn.Expr{ 124 | args: [%Nx.Tensor{data: %Nx.BinaryBackend{state: binary}}], 125 | op: :tensor 126 | } 127 | } = t 128 | ) do 129 | mlir block: block, ctx: ctx do 130 | tensor_attr = 131 | MLIR.CAPI.mlirDenseElementsAttrRawBufferGet( 132 | gen_type(t) |> Beaver.Deferred.create(ctx), 133 | byte_size(binary), 134 | MLIR.StringRef.create(binary) 135 | |> then(fn s -> 136 | %{ref: MLIR.CAPI.beaverStringRefGetData(s), element_kind: Beaver.Native.U8} 137 | |> Beaver.Native.Array.as_opaque() 138 | |> Map.get(:ref) 139 | end) 140 | ) 141 | 142 | if MLIR.Attribute.is_null(tensor_attr), do: raise("fail to parse tensor dense elements") 143 | 144 | TOSA.const({:value, tensor_attr}) >>> gen_type(t) 145 | end 146 | end 147 | 148 | # unary tosa 149 | def gen_op( 150 | %Env{block: block, ctx: ctx} = env, 151 | %Nx.Tensor{data: %Nx.Defn.Expr{op: op, args: [input1]}} = t 152 | ) 153 | when op in [ 154 | :negate, 155 | :abs, 156 | :bitwise_not, 157 | :exp, 158 | :log, 159 | :tanh, 160 | :rsqrt, 161 | :is_nan, 162 | :is_infinity, 163 | :sigmoid 164 | ] do 165 | mlir block: block, ctx: ctx do 166 | input1_value = gen_op(env, input1) 167 | input1_value = TOSA.cast(input1_value) >>> gen_type(%{input1 | type: t.type}) 168 | 169 | case op do 170 | :negate -> 171 | TOSA.negate(input1_value) >>> gen_type(t) 172 | 173 | :abs -> 174 | TOSA.abs(input1_value) >>> gen_type(t) 175 | 176 | :bitwise_not -> 177 | TOSA.bitwise_not(input1_value) >>> gen_type(t) 178 | 179 | :exp -> 180 | TOSA.exp(input1_value) >>> gen_type(t) 181 | 182 | :log -> 183 | TOSA.log(input1_value) >>> gen_type(t) 184 | 185 | :tanh -> 186 | TOSA.tanh(input1_value) >>> gen_type(t) 187 | 188 | :rsqrt -> 189 | TOSA.rsqrt(input1_value) >>> gen_type(t) 190 | 191 | :sigmoid -> 192 | TOSA.sigmoid(input1_value) >>> gen_type(t) 193 | 194 | :is_nan -> 195 | c = TOSA.equal(input1_value, input1_value) >>> gen_type(%{t | type: {:u, 1}}) 196 | c = TOSA.logical_not(c) >>> gen_type(%{t | type: {:u, 1}}) 197 | TOSA.cast(c) >>> gen_type(t) 198 | 199 | :is_infinity -> 200 | input1_value = gen_op(env, input1) 201 | input1_type_str = gen_type(input1) |> Beaver.Deferred.create(ctx) |> MLIR.to_string() 202 | 203 | inf = 204 | TOSA.const({:value, ~a{dense<0x7F800000> : #{input1_type_str}}}) >>> gen_type(input1) 205 | 206 | abs = TOSA.abs(input1_value) >>> gen_type(input1) 207 | equal = TOSA.equal(inf, abs) >>> gen_type(%{t | type: {:u, 1}}) 208 | TOSA.cast(equal) >>> gen_type(t) 209 | end 210 | end 211 | end 212 | 213 | def gen_op( 214 | env, 215 | %Nx.Tensor{shape: {}, data: %Nx.Defn.Expr{op: :all, args: [%{shape: {}} = input1, _]}} 216 | ) do 217 | gen_op(env, input1) 218 | end 219 | 220 | def gen_op( 221 | %Env{block: block, ctx: ctx} = env, 222 | %Nx.Tensor{ 223 | data: %Nx.Defn.Expr{ 224 | op: :squeeze, 225 | args: [input, _axes] 226 | } 227 | } = t 228 | ) do 229 | mlir block: block, ctx: ctx do 230 | input_value = gen_op(env, input) 231 | source_type = gen_type(input) |> Beaver.Deferred.create(ctx) 232 | target_type = gen_type(t) |> Beaver.Deferred.create(ctx) 233 | reassociation = Tensor.reassociation_for_reshape(source_type, target_type) 234 | 235 | if MLIR.is_null(reassociation) do 236 | raise "fail to create reassociation" 237 | end 238 | 239 | Tensor.collapse_shape(input_value, reassociation: reassociation) >>> target_type 240 | end 241 | end 242 | 243 | def gen_op( 244 | %Env{} = env, 245 | %Nx.Tensor{ 246 | data: %Nx.Defn.Expr{ 247 | op: :slice, 248 | args: [_tensor, start_indices, _lengths, _strides] 249 | } 250 | } = t 251 | ) do 252 | env = %Env{env | gen_op: &gen_op/2, gen_type: &gen_type/1} 253 | 254 | if Enum.all?(start_indices, &is_integer/1) do 255 | Manx.Slice.static_slice(env, t) 256 | else 257 | Manx.Slice.dynamic_slice(env, t) 258 | end 259 | end 260 | 261 | def gen_op( 262 | %Env{block: block, ctx: ctx} = env, 263 | %Nx.Tensor{ 264 | data: %Nx.Defn.Expr{ 265 | op: op, 266 | args: [%{shape: in_shape} = input1, [axes: axes, keep_axes: keep_axes]] 267 | } 268 | } = t 269 | ) 270 | when is_list(axes) and op in [:all, :sum] do 271 | mlir block: block, ctx: ctx do 272 | input1 = gen_op(env, input1) 273 | 274 | input1 = 275 | case op do 276 | :all -> 277 | TOSA.cast(input1) >>> gen_type(%{t | shape: in_shape, type: {:u, 1}}) 278 | 279 | :sum -> 280 | input1 281 | end 282 | 283 | {in_shape, mlir_value} = 284 | Enum.reduce( 285 | axes, 286 | {Tuple.to_list(in_shape), input1}, 287 | fn axis, {in_shape, mlir_value} -> 288 | out_shape = List.replace_at(in_shape, axis, 1) 289 | 290 | reduce_attr = [axis: Attribute.integer(Type.i32(), axis)] 291 | 292 | reduced = 293 | case op do 294 | :all -> 295 | TOSA.reduce_all(mlir_value, reduce_attr) >>> 296 | gen_type(%{t | shape: List.to_tuple(out_shape), type: {:u, 1}}) 297 | 298 | :sum -> 299 | TOSA.reduce_sum(mlir_value, reduce_attr, 300 | loc: 301 | Manx.Nx.Interoperability.loc_from_stack_trace( 302 | Process.info(self(), :current_stacktrace), 303 | ctx 304 | ) 305 | ) >>> 306 | gen_type(%{t | shape: List.to_tuple(out_shape)}) 307 | end 308 | 309 | {out_shape, reduced} 310 | end 311 | ) 312 | 313 | mlir_value = TOSA.cast(mlir_value) >>> gen_type(%{t | shape: List.to_tuple(in_shape)}) 314 | 315 | if keep_axes do 316 | mlir_value 317 | else 318 | Tensor.collapse_shape(mlir_value, reassociation: Tensor.reassociation([])) >>> gen_type(t) 319 | end 320 | end 321 | end 322 | 323 | def gen_op( 324 | %Env{block: block, ctx: ctx} = env, 325 | %Nx.Tensor{ 326 | data: 327 | %Nx.Defn.Expr{ 328 | op: op, 329 | args: [%{shape: in_shape} = input1, [axes: nil, keep_axes: keep_axes]] 330 | } = expr 331 | } = t 332 | ) 333 | when op in [:sum, :all] do 334 | # if axes is nil, replace it with a list of every axis 335 | mlir block: block, ctx: ctx do 336 | rank = tuple_size(in_shape) 337 | axes = Range.new(0, rank - 1, 1) |> Enum.to_list() 338 | 339 | expr = %{ 340 | expr 341 | | args: [input1, [axes: axes, keep_axes: keep_axes]] 342 | } 343 | 344 | gen_op(env, %{t | data: expr}) 345 | end 346 | end 347 | 348 | def gen_op( 349 | %Env{block: block, ctx: ctx} = env, 350 | %Nx.Tensor{ 351 | data: %Nx.Defn.Expr{ 352 | op: :conjugate, 353 | args: [%Nx.Tensor{type: {:c, 64}} = complex_tensor] 354 | }, 355 | shape: {} 356 | } = t 357 | ) do 358 | alias MLIR.Dialect.Complex 359 | 360 | mlir block: block, ctx: ctx do 361 | complex_tensor = gen_op(env, complex_tensor) 362 | complex_element = Tensor.extract(complex_tensor) >>> Type.complex(Type.f32()) 363 | conjugate_element = Complex.conj(complex_element) >>> Type.complex(Type.f32()) 364 | 365 | conjugate_tensor = 366 | Bufferization.alloc_tensor(operand_segment_sizes: ODS.operand_segment_sizes([0, 0, 0])) >>> 367 | gen_type(t) 368 | 369 | Tensor.insert(conjugate_element, conjugate_tensor) >>> 370 | gen_type(t) 371 | end 372 | end 373 | 374 | def gen_op( 375 | %Env{block: block, ctx: ctx} = env, 376 | %Nx.Tensor{ 377 | data: %Nx.Defn.Expr{op: :conjugate, args: [%Nx.Tensor{} = real_tensor]}, 378 | shape: {}, 379 | type: complex_type = {:c, 64} 380 | } = t 381 | ) do 382 | alias MLIR.Dialect.Complex 383 | 384 | mlir block: block, ctx: ctx do 385 | real_tensor = gen_op(env, real_tensor) 386 | real_tensor = TOSA.cast(real_tensor) >>> Type.ranked_tensor([], Type.f32()) 387 | real = Tensor.extract(real_tensor) >>> Type.f32() 388 | 389 | conjugate_tensor = 390 | Bufferization.alloc_tensor(operand_segment_sizes: ODS.operand_segment_sizes([0, 0, 0])) >>> 391 | gen_type(t) 392 | 393 | imaginary = Arith.constant(value: Attribute.float(Type.f32(), 0.0)) >>> Type.f32() 394 | 395 | complex_element_t = gen_type(complex_type) 396 | complex_element = Complex.create(real, imaginary) >>> complex_element_t 397 | conjugate_element = Complex.conj(complex_element) >>> complex_element_t 398 | 399 | _ = Tensor.insert(conjugate_element, conjugate_tensor) >>> gen_type(t) 400 | end 401 | end 402 | 403 | def gen_op( 404 | %Env{block: block, ctx: ctx} = env, 405 | %Nx.Tensor{ 406 | data: %Nx.Defn.Expr{op: :conjugate, args: [complex_tensor]}, 407 | shape: shape 408 | } = t 409 | ) do 410 | alias MLIR.Dialect.Complex 411 | 412 | mlir block: block, ctx: ctx do 413 | element_cnt = Enum.reduce(Tuple.to_list(shape), 1, &*/2) 414 | complex_tensor = gen_op(env, complex_tensor) 415 | lower = Arith.constant(value: Attribute.integer(Type.index(), 0)) >>> Type.index() 416 | upper = Arith.constant(value: Attribute.integer(Type.index(), element_cnt)) >>> Type.index() 417 | step = Arith.constant(value: Attribute.integer(Type.index(), 1)) >>> Type.index() 418 | 419 | conjugate_tensor = 420 | Bufferization.alloc_tensor(operand_segment_sizes: ODS.operand_segment_sizes([0, 0, 0])) >>> 421 | gen_type(t) 422 | 423 | conjugate_memref = 424 | Bufferization.to_memref(conjugate_tensor) >>> 425 | Type.memref([2], Type.complex(Type.f32())) 426 | 427 | SCF.for [lower, upper, step] do 428 | region do 429 | block _(index >>> Type.index()) do 430 | complex_element = Tensor.extract(complex_tensor, index) >>> Type.complex(Type.f32()) 431 | conjugate_element = Complex.conj(complex_element) >>> Type.complex(Type.f32()) 432 | MemRef.store(conjugate_element, conjugate_memref, index) >>> [] 433 | SCF.yield() >>> [] 434 | end 435 | end 436 | end >>> [] 437 | 438 | conjugate_tensor 439 | end 440 | end 441 | 442 | def gen_op( 443 | %Env{block: block, ctx: ctx} = env, 444 | %Nx.Tensor{ 445 | data: %Nx.Defn.Expr{ 446 | op: :imag, 447 | args: [%Nx.Tensor{type: {:c, 64}, shape: in_shape} = in_tensor] 448 | }, 449 | shape: out_shape 450 | } = t 451 | ) do 452 | alias MLIR.Dialect.Complex 453 | 454 | mlir block: block, ctx: ctx do 455 | in_tensor = gen_op(env, in_tensor) 456 | 457 | out_tensor = 458 | Bufferization.alloc_tensor(operand_segment_sizes: ODS.operand_segment_sizes([0, 0, 0])) >>> 459 | gen_type(t) 460 | 461 | Linalg.generic [ 462 | in_tensor, 463 | out_tensor, 464 | operand_segment_sizes: ODS.operand_segment_sizes([1, 1]), 465 | indexing_maps: Manx.Linalg.gen_indexing_maps(in_shape, out_shape), 466 | iterator_types: Manx.Linalg.gen_iterator_types(in_shape, out_shape) 467 | ] do 468 | region do 469 | block _(arg0 >>> Type.complex(Type.f32()), arg1 >>> Type.f(32)) do 470 | %MLIR.Value{} = arg1 471 | im = Complex.im(arg0) >>> Type.f32() 472 | Linalg.yield([im]) >>> [] 473 | end 474 | end 475 | end >>> gen_type(t) 476 | end 477 | end 478 | 479 | # unary linalg 480 | def gen_op( 481 | %Env{block: block, ctx: ctx} = env, 482 | %Nx.Tensor{type: type, data: %Nx.Defn.Expr{op: op, args: [input]}} = t 483 | ) 484 | when op in [ 485 | :population_count, 486 | :count_leading_zeros, 487 | :cos, 488 | :sin, 489 | :sqrt, 490 | :tan, 491 | :erf, 492 | :cbrt, 493 | :expm1, 494 | :log1p, 495 | :bitcast, 496 | :atan 497 | ] do 498 | mlir block: block, ctx: ctx do 499 | input_value = gen_op(env, input) 500 | input_value = TOSA.cast(input_value) >>> gen_type(t) 501 | 502 | out_tensor = 503 | Bufferization.alloc_tensor(operand_segment_sizes: ODS.operand_segment_sizes([0, 0, 0])) >>> 504 | gen_type(t) 505 | 506 | Linalg.generic [ 507 | input_value, 508 | out_tensor, 509 | operand_segment_sizes: ODS.operand_segment_sizes([1, 1]), 510 | indexing_maps: Manx.Linalg.gen_indexing_maps(input.shape, t.shape), 511 | iterator_types: Manx.Linalg.gen_iterator_types(input.shape, t.shape) 512 | ] do 513 | region do 514 | block _(arg0 >>> gen_type(type), out >>> gen_type(type)) do 515 | %MLIR.Value{} = out 516 | 517 | result = 518 | case op do 519 | :population_count -> 520 | Math.ctpop(arg0) >>> gen_type(type) 521 | 522 | :count_leading_zeros -> 523 | Math.ctlz(arg0) >>> gen_type(type) 524 | 525 | :cos -> 526 | Math.cos(arg0) >>> gen_type(type) 527 | 528 | :sin -> 529 | Math.sin(arg0) >>> gen_type(type) 530 | 531 | :sqrt -> 532 | Math.sqrt(arg0) >>> gen_type(type) 533 | 534 | :tan -> 535 | Math.tan(arg0) >>> gen_type(type) 536 | 537 | :erf -> 538 | Math.erf(arg0) >>> gen_type(type) 539 | 540 | :bitcast -> 541 | Arith.bitcast(arg0) >>> gen_type(type) 542 | 543 | :cbrt -> 544 | abs = 545 | case type do 546 | {i_type, _} when i_type in [:i, :s] -> 547 | Math.absi(arg0) >>> gen_type(type) 548 | 549 | {f_type, _} when f_type in [:f] -> 550 | Math.absf(arg0) >>> gen_type(type) 551 | end 552 | 553 | third = 554 | Arith.constant(value: Attribute.float(gen_type(type), 0.333333343)) >>> 555 | gen_type(type) 556 | 557 | pow = Math.powf(abs, third) >>> gen_type(type) 558 | Math.copysign(pow, arg0) >>> gen_type(type) 559 | 560 | :expm1 -> 561 | Math.expm1(arg0) >>> gen_type(type) 562 | 563 | :log1p -> 564 | Math.log1p(arg0) >>> gen_type(type) 565 | 566 | :atan -> 567 | Math.atan(arg0) >>> gen_type(type) 568 | end 569 | 570 | Linalg.yield(result) >>> [] 571 | end 572 | end 573 | end >>> gen_type(t) 574 | end 575 | end 576 | 577 | # binary linalg 578 | def gen_op( 579 | %Env{block: block, ctx: ctx} = env, 580 | %Nx.Tensor{type: type, data: %Nx.Defn.Expr{op: op, args: [a, b]}} = t 581 | ) 582 | when op in [:remainder, :atan2, :pow] do 583 | mlir block: block, ctx: ctx do 584 | a_value = gen_op(env, a) 585 | b_value = gen_op(env, b) 586 | 587 | out_tensor = 588 | Bufferization.alloc_tensor(operand_segment_sizes: ODS.operand_segment_sizes([0, 0, 0])) >>> 589 | gen_type(t) 590 | 591 | Linalg.generic [ 592 | a_value, 593 | b_value, 594 | out_tensor, 595 | operand_segment_sizes: ODS.operand_segment_sizes([2, 1]), 596 | indexing_maps: Manx.Linalg.gen_indexing_maps([a.shape, b.shape], t.shape), 597 | iterator_types: Manx.Linalg.gen_iterator_types(a.shape, b.shape, t.shape) 598 | ] do 599 | region do 600 | block _(arg0 >>> gen_type(type), arg1 >>> gen_type(type), out >>> gen_type(type)) do 601 | %MLIR.Value{} = out 602 | 603 | result = 604 | case op do 605 | :remainder -> 606 | case type do 607 | {:f, _} -> 608 | Arith.remf(arg0, arg1) >>> gen_type(type) 609 | 610 | {:i, _} -> 611 | Arith.remui(arg0, arg1) >>> gen_type(type) 612 | 613 | {:s, _} -> 614 | Arith.remsi(arg0, arg1) >>> gen_type(type) 615 | end 616 | 617 | :pow -> 618 | case type do 619 | {:f, _} -> 620 | Math.powf(arg0, arg1) >>> gen_type(type) 621 | 622 | {inter_type, _} when inter_type in [:i, :s] -> 623 | Math.ipowi(arg0, arg1) >>> gen_type(type) 624 | end 625 | 626 | :atan2 -> 627 | Math.atan2(arg0, arg1) >>> gen_type(type) 628 | end 629 | 630 | Linalg.yield(result) >>> [] 631 | end 632 | end 633 | end >>> gen_type(t) 634 | end 635 | end 636 | 637 | def gen_op(env, %Nx.Tensor{ 638 | data: %Nx.Defn.Expr{ 639 | op: :optional, 640 | args: alternatives 641 | } 642 | }) do 643 | tensor = 644 | alternatives 645 | |> Enum.find(fn 646 | %Nx.Tensor{data: %{op: :equal}} -> true 647 | %Nx.Tensor{data: %{op: :logical_not}} -> false 648 | _ -> true 649 | end) 650 | 651 | gen_op(env, tensor) 652 | end 653 | 654 | # dot product 655 | def gen_op( 656 | %Env{block: block, ctx: ctx} = env, 657 | %Nx.Tensor{ 658 | data: %Nx.Defn.Expr{ 659 | op: :dot, 660 | args: [ 661 | %Nx.Tensor{shape: {n}} = a, 662 | _, 663 | _, 664 | %Nx.Tensor{shape: {n}} = b, 665 | _, 666 | _ 667 | ] 668 | } 669 | } = t 670 | ) do 671 | mlir block: block, ctx: ctx do 672 | a_value = gen_op(env, a) 673 | b_value = gen_op(env, b) 674 | a_value = TOSA.cast(a_value) >>> gen_type(%{a | type: t.type}) 675 | b_value = TOSA.cast(b_value) >>> gen_type(%{b | type: t.type}) 676 | c = TOSA.mul(a_value, b_value, shift: Attribute.integer(Type.i8(), 0)) >>> gen_type(a) 677 | 678 | c = 679 | TOSA.reduce_sum(c, axis: Attribute.integer(Type.i32(), 0)) >>> gen_type(%{t | shape: {1}}) 680 | 681 | Tensor.collapse_shape(c, reassociation: Tensor.reassociation([])) >>> gen_type(t) 682 | end 683 | end 684 | 685 | # standard batch matmul 686 | def gen_op( 687 | %Env{block: block, ctx: ctx} = env, 688 | %Nx.Tensor{ 689 | data: %Nx.Defn.Expr{ 690 | op: :dot, 691 | args: [ 692 | %Nx.Tensor{shape: a_shape} = a, 693 | [2], 694 | [0], 695 | %Nx.Tensor{shape: b_shape} = b, 696 | [1], 697 | [0] 698 | ] 699 | } 700 | } = t 701 | ) 702 | when tuple_size(a_shape) == 3 and tuple_size(b_shape) == 3 do 703 | mlir block: block, ctx: ctx do 704 | a_value = gen_op(env, a) 705 | b_value = gen_op(env, b) 706 | 707 | TOSA.matmul(a_value, b_value, 708 | loc: 709 | Manx.Nx.Interoperability.loc_from_stack_trace( 710 | Process.info(self(), :current_stacktrace), 711 | ctx 712 | ) 713 | ) >>> gen_type(t) 714 | end 715 | end 716 | 717 | # generic dot product 718 | def gen_op( 719 | %Env{block: block, ctx: ctx} = env, 720 | %Nx.Tensor{ 721 | data: %Nx.Defn.Expr{ 722 | op: :dot, 723 | args: 724 | [ 725 | %Nx.Tensor{shape: a_shape} = a, 726 | _contract_axes1, 727 | _batch_axes1, 728 | %Nx.Tensor{shape: b_shape} = b, 729 | _contract_axes2, 730 | _batch_axes2 731 | ] = args 732 | } 733 | } = t 734 | ) 735 | when tuple_size(a_shape) in [2, 3] or tuple_size(b_shape) in [2, 3] do 736 | mlir block: block, ctx: ctx do 737 | a_value = gen_op(env, a) 738 | b_value = gen_op(env, b) 739 | a_value = TOSA.cast(a_value) >>> gen_type(%{a | type: t.type}) 740 | b_value = TOSA.cast(b_value) >>> gen_type(%{b | type: t.type}) 741 | {batched_a, batched_b} = Manx.Nx.Batcher.from_args(args) 742 | 743 | output_type = gen_type(t) 744 | 745 | out_tensor = 746 | Bufferization.alloc_tensor(operand_segment_sizes: ODS.operand_segment_sizes([0, 0, 0])) >>> 747 | output_type 748 | 749 | zero = 750 | case t.type do 751 | {:f, _} -> 752 | Arith.constant(value: Attribute.float(gen_type(t.type), 0.0)) >>> gen_type(t.type) 753 | 754 | _ -> 755 | Arith.constant(value: Attribute.integer(gen_type(t.type), 0)) >>> gen_type(t.type) 756 | end 757 | 758 | out_tensor = 759 | Linalg.fill [zero, out_tensor, operand_segment_sizes: ODS.operand_segment_sizes([1, 1])] do 760 | region do 761 | block _( 762 | arg >>> gen_type(t.type), 763 | res >>> gen_type(t.type) 764 | ) do 765 | %MLIR.Value{} = res 766 | Linalg.yield(arg) >>> [] 767 | end 768 | end 769 | end >>> output_type 770 | 771 | Linalg.generic [ 772 | a_value, 773 | b_value, 774 | out_tensor, 775 | operand_segment_sizes: ODS.operand_segment_sizes([2, 1]), 776 | indexing_maps: Manx.Nx.Batcher.gen_indexing_maps(batched_a, batched_b, t), 777 | iterator_types: Manx.Nx.Batcher.gen_iterator_types(batched_a, batched_b, t) 778 | ] do 779 | region do 780 | block _( 781 | left >>> gen_type(t.type), 782 | right >>> gen_type(t.type), 783 | sum >>> gen_type(t.type) 784 | ) do 785 | sum = 786 | case t.type do 787 | {:f, _} -> 788 | mul = Arith.mulf(left, right) >>> gen_type(t.type) 789 | Arith.addf(sum, mul) >>> gen_type(t.type) 790 | 791 | _ -> 792 | mul = Arith.muli(left, right) >>> gen_type(t.type) 793 | Arith.addi(sum, mul) >>> gen_type(t.type) 794 | end 795 | 796 | Linalg.yield(sum) >>> [] 797 | end 798 | end 799 | end >>> output_type 800 | end 801 | end 802 | 803 | def gen_op( 804 | %Env{block: block, ctx: ctx} = env, 805 | %Nx.Tensor{ 806 | data: %Nx.Defn.Expr{ 807 | op: :concatenate, 808 | args: [ 809 | inputs, 810 | axis 811 | ] 812 | } 813 | } = t 814 | ) 815 | when is_list(inputs) do 816 | mlir block: block, ctx: ctx do 817 | inputs = inputs |> Enum.map(&gen_op(env, &1)) 818 | TOSA.concat(inputs, axis: Attribute.integer(Type.i32(), axis)) >>> gen_type(t) 819 | end 820 | end 821 | 822 | # binary tosa 823 | def gen_op( 824 | %Env{block: block, ctx: ctx} = env, 825 | %Nx.Tensor{data: %Nx.Defn.Expr{op: op, args: [%Nx.Tensor{} = a, %Nx.Tensor{} = b]}} = t 826 | ) do 827 | mlir block: block, ctx: ctx do 828 | a_t = %{a | type: t.type} |> gen_type 829 | b_t = %{b | type: t.type} |> gen_type 830 | a_value = gen_op(env, a) 831 | b_value = gen_op(env, b) 832 | 833 | {a_value, b_value} = 834 | case op do 835 | _ when op in [:equal, :greater_equal, :less_equal, :less, :greater, :not_equal] -> 836 | case {a.type, b.type} do 837 | {{int_type, _}, {:f, _}} when int_type in [:s, :u] -> 838 | a_value = TOSA.cast(a_value) >>> gen_type(%{a | type: b.type}) 839 | {a_value, b_value} 840 | 841 | {{:f, _}, {int_type, _}} when int_type in [:s, :u] -> 842 | b_value = TOSA.cast(b_value) >>> gen_type(%{b | type: a.type}) 843 | {a_value, b_value} 844 | 845 | {{_, width_a}, {_, width_b}} 846 | when width_a > width_b -> 847 | b_value = TOSA.cast(b_value) >>> gen_type(%{b | type: a.type}) 848 | {a_value, b_value} 849 | 850 | {{_, width_a}, {_, width_b}} 851 | when width_a < width_b -> 852 | a_value = TOSA.cast(a_value) >>> gen_type(%{a | type: b.type}) 853 | {a_value, b_value} 854 | 855 | _ -> 856 | b_value = TOSA.cast(b_value) >>> gen_type(%{b | type: a.type}) 857 | {a_value, b_value} 858 | end 859 | 860 | _ when op in [:logical_or, :logical_xor, :logical_and] -> 861 | a_value = TOSA.cast(a_value) >>> gen_type(%{a | type: {:u, 1}}) 862 | b_value = TOSA.cast(b_value) >>> gen_type(%{b | type: {:u, 1}}) 863 | {a_value, b_value} 864 | 865 | _ -> 866 | a_value = TOSA.cast(a_value) >>> a_t 867 | b_value = TOSA.cast(b_value) >>> b_t 868 | {a_value, b_value} 869 | end 870 | 871 | case op do 872 | :subtract -> 873 | TOSA.sub(a_value, b_value) >>> gen_type(t) 874 | 875 | :less_equal -> 876 | c = TOSA.greater_equal(b_value, a_value) >>> gen_type(%{t | type: {:u, 1}}) 877 | TOSA.cast(c) >>> gen_type(t) 878 | 879 | :greater_equal -> 880 | c = TOSA.greater_equal(a_value, b_value) >>> gen_type(%{t | type: {:u, 1}}) 881 | TOSA.cast(c) >>> gen_type(t) 882 | 883 | :less -> 884 | c = TOSA.greater(b_value, a_value) >>> gen_type(%{t | type: {:u, 1}}) 885 | TOSA.cast(c) >>> gen_type(t) 886 | 887 | :greater -> 888 | c = TOSA.greater(a_value, b_value) >>> gen_type(%{t | type: {:u, 1}}) 889 | TOSA.cast(c) >>> gen_type(t) 890 | 891 | :equal -> 892 | c = TOSA.equal(b_value, a_value) >>> gen_type(%{t | type: {:u, 1}}) 893 | TOSA.cast(c) >>> gen_type(t) 894 | 895 | :not_equal -> 896 | c = TOSA.equal(b_value, a_value) >>> gen_type(%{t | type: {:u, 1}}) 897 | c = TOSA.logical_not(c) >>> gen_type(%{t | type: {:u, 1}}) 898 | TOSA.cast(c) >>> gen_type(t) 899 | 900 | :logical_and -> 901 | c = TOSA.logical_and(a_value, b_value) >>> gen_type(%{t | type: {:u, 1}}) 902 | TOSA.cast(c) >>> gen_type(t) 903 | 904 | :logical_or -> 905 | c = TOSA.logical_or(a_value, b_value) >>> gen_type(%{t | type: {:u, 1}}) 906 | TOSA.cast(c) >>> gen_type(t) 907 | 908 | :logical_xor -> 909 | c = TOSA.logical_xor(a_value, b_value) >>> gen_type(%{t | type: {:u, 1}}) 910 | TOSA.cast(c) >>> gen_type(t) 911 | 912 | :add -> 913 | TOSA.add(a_value, b_value) >>> gen_type(t) 914 | 915 | :max -> 916 | TOSA.maximum(a_value, b_value) >>> gen_type(t) 917 | 918 | :min -> 919 | TOSA.minimum(a_value, b_value) >>> gen_type(t) 920 | 921 | :bitwise_and -> 922 | TOSA.bitwise_and(a_value, b_value) >>> gen_type(t) 923 | 924 | :bitwise_or -> 925 | TOSA.bitwise_or(a_value, b_value) >>> gen_type(t) 926 | 927 | :bitwise_xor -> 928 | TOSA.bitwise_xor(a_value, b_value) >>> gen_type(t) 929 | 930 | :left_shift -> 931 | TOSA.logical_left_shift(a_value, b_value) >>> gen_type(t) 932 | 933 | :right_shift -> 934 | case t.type do 935 | {:u, _} -> 936 | TOSA.logical_right_shift(a_value, b_value) >>> gen_type(t) 937 | 938 | {:s, _} -> 939 | TOSA.arithmetic_right_shift(a_value, b_value, round: Attribute.bool(false)) >>> 940 | gen_type(t) 941 | end 942 | 943 | :multiply -> 944 | TOSA.mul(a_value, b_value, shift: Attribute.integer(Type.i8(), 0)) >>> gen_type(t) 945 | 946 | :divide -> 947 | b_r = TOSA.reciprocal(b_value) >>> b_t 948 | TOSA.mul(a_value, b_r, shift: Attribute.integer(Type.i8(), 0)) >>> gen_type(t) 949 | 950 | :quotient -> 951 | a_value = TOSA.cast(a_value) >>> gen_type(%{a | type: {:u, 32}}) 952 | b_value = TOSA.cast(b_value) >>> gen_type(%{b | type: {:u, 32}}) 953 | result = TOSA.int_div(a_value, b_value) >>> gen_type(%{t | type: {:u, 32}}) 954 | TOSA.cast(result) >>> gen_type(t) 955 | 956 | _ -> 957 | raise "Unsupported binary op: #{inspect(t, structs: false, pretty: true)}" 958 | end 959 | end 960 | end 961 | 962 | def gen_op( 963 | %Env{block: block, ctx: ctx} = env, 964 | %Nx.Tensor{data: %Nx.Defn.Expr{op: :select, args: [pred, on_true, on_false]}} = t 965 | ) do 966 | mlir block: block, ctx: ctx do 967 | pred_value = gen_op(env, pred) 968 | pred_t = %{pred | type: {:u, 1}} 969 | pred_value = TOSA.cast(pred_value) >>> gen_type(pred_t) 970 | on_true_value = gen_op(env, on_true) 971 | on_false_value = gen_op(env, on_false) 972 | on_true_value = TOSA.cast(on_true_value) >>> gen_type(%{on_true | type: t.type}) 973 | on_false_value = TOSA.cast(on_false_value) >>> gen_type(%{on_false | type: t.type}) 974 | TOSA.select(pred_value, on_true_value, on_false_value) >>> gen_type(t) 975 | end 976 | end 977 | 978 | def gen_op( 979 | %Env{block: block, ctx: ctx} = env, 980 | %Nx.Tensor{data: %Nx.Defn.Expr{op: :reshape, args: [input]}} = t 981 | ) do 982 | mlir block: block, ctx: ctx do 983 | input = gen_op(env, input) 984 | 985 | new_shape = 986 | t.shape 987 | |> Tuple.to_list() 988 | |> Attribute.dense_array(Beaver.Native.I64) 989 | 990 | TOSA.reshape(input, new_shape: new_shape) >>> gen_type(t) 991 | end 992 | end 993 | 994 | def gen_op( 995 | %Env{block: block, ctx: ctx} = env, 996 | %Nx.Tensor{data: %Nx.Defn.Expr{op: :as_type, args: [input1]}} = t 997 | ) do 998 | mlir block: block, ctx: ctx do 999 | input1_value = gen_op(env, input1) 1000 | TOSA.cast(input1_value) >>> gen_type(t) 1001 | end 1002 | end 1003 | 1004 | def gen_op( 1005 | %Env{block: block, ctx: ctx} = env, 1006 | %Nx.Tensor{data: %Nx.Defn.Expr{op: :iota, args: [axis]} = expr, shape: out_shape} = t 1007 | ) do 1008 | if axis do 1009 | mlir block: block, ctx: ctx do 1010 | out_tensor = 1011 | Bufferization.alloc_tensor(operand_segment_sizes: ODS.operand_segment_sizes([0, 0, 0])) >>> 1012 | gen_type(t) 1013 | 1014 | Linalg.generic [ 1015 | out_tensor, 1016 | operand_segment_sizes: ODS.operand_segment_sizes([0, 1]), 1017 | indexing_maps: Manx.Linalg.gen_indexing_maps(out_shape), 1018 | iterator_types: Manx.Linalg.gen_iterator_types(out_shape) 1019 | ] do 1020 | region do 1021 | block _(arg1 >>> gen_type(t.type)) do 1022 | %MLIR.Value{} = arg1 1023 | index = Linalg.index(dim: Attribute.integer(Type.i64(), axis)) >>> Type.index() 1024 | cast = Arith.index_cast(index) >>> gen_type(t.type) 1025 | Linalg.yield(cast) >>> [] 1026 | end 1027 | end 1028 | end >>> gen_type(t) 1029 | end 1030 | else 1031 | mlir block: block, ctx: ctx do 1032 | dim = t.shape |> Tuple.to_list() |> Enum.reduce(1, &Kernel.*/2) 1033 | 1034 | permutation_1d = 1035 | gen_op( 1036 | env, 1037 | %{t | data: %{expr | args: [0]}, shape: {dim}} 1038 | ) 1039 | 1040 | new_shape = 1041 | t.shape 1042 | |> Tuple.to_list() 1043 | |> Attribute.dense_array(Beaver.Native.I64) 1044 | 1045 | # this should generate affine.apply for index 1046 | TOSA.reshape(permutation_1d, new_shape: new_shape) >>> gen_type(t) 1047 | end 1048 | end 1049 | end 1050 | 1051 | def gen_op(%Env{} = env, tuple) when is_tuple(tuple) do 1052 | tuple 1053 | |> Tuple.to_list() 1054 | |> Enum.map(&gen_op(env, &1)) 1055 | |> List.to_tuple() 1056 | end 1057 | 1058 | def gen_op(_, tensor) do 1059 | raise "op not supported: " <> inspect(tensor) 1060 | end 1061 | end 1062 | --------------------------------------------------------------------------------