├── .formatter.exs ├── .github └── workflows │ └── ci.yml ├── .gitignore ├── LICENSE ├── README.md ├── config ├── config.exs └── runtime.exs ├── examples ├── distilbert │ ├── README.md │ ├── distilbert_classification.exs │ └── export.py └── stablelm │ ├── README.md │ ├── export.py │ └── stablelm.exs ├── lib ├── ortex.ex └── ortex │ ├── backend.ex │ ├── model.ex │ ├── native.ex │ ├── serving.ex │ └── util.ex ├── mix.exs ├── mix.lock ├── models └── tinymodel.onnx ├── native └── ortex │ ├── .cargo │ └── config.toml │ ├── .gitignore │ ├── Cargo.lock │ ├── Cargo.toml │ ├── README.md │ └── src │ ├── constants.rs │ ├── lib.rs │ ├── model.rs │ ├── tensor.rs │ └── utils.rs ├── python ├── export_resnet.py └── multi_input.py └── test ├── dtype └── dtype_test.exs ├── ortex_test.exs ├── shape ├── concat_test.exs ├── reshape_test.exs ├── slice_test.exs └── squeeze_test.exs └── test_helper.exs /.formatter.exs: -------------------------------------------------------------------------------- 1 | # Used by "mix format" 2 | [ 3 | inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"] 4 | ] 5 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | on: 3 | push: 4 | branches: [main] 5 | pull_request: 6 | jobs: 7 | linux: 8 | runs-on: ubuntu-latest 9 | name: Linux ${{ matrix.elixir }}, ${{ matrix.otp }} 10 | strategy: 11 | matrix: 12 | elixir: ["1.17.3"] 13 | otp: ["27.1.2"] 14 | env: 15 | MIX_ENV: test 16 | steps: 17 | - uses: actions/checkout@v2 18 | - uses: erlef/setup-beam@v1 19 | with: 20 | otp-version: ${{ matrix.otp }} 21 | elixir-version: ${{ matrix.elixir }} 22 | - name: Install dependencies 23 | run: mix deps.get 24 | - name: Compile and check warnings 25 | run: mix compile --warnings-as-errors 26 | - name: Check formatting 27 | run: mix format --check-formatted 28 | - name: Run tests 29 | run: mix test 30 | 31 | macos: 32 | runs-on: macos-latest 33 | name: macOS 34 | env: 35 | MIX_ENV: test 36 | steps: 37 | - uses: actions/checkout@v2 38 | - name: Install 39 | run: | 40 | brew update 41 | brew install erlang@27 elixir 42 | - name: Install dependencies 43 | run: mix deps.get 44 | - name: Compile and check warnings 45 | run: mix compile --warnings-as-errors 46 | - name: Check formatting 47 | run: mix format --check-formatted 48 | - name: Run tests 49 | run: mix test 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # The directory Mix will write compiled artifacts to. 2 | /_build/ 3 | 4 | # If you run "mix test --cover", coverage assets end up here. 5 | /cover/ 6 | 7 | # The directory Mix downloads your dependencies sources to. 8 | /deps/ 9 | 10 | # Where third-party dependencies like ExDoc output generated docs. 11 | /doc/ 12 | 13 | # Ignore .fetch files in case you like to edit your project deps locally. 14 | /.fetch 15 | 16 | # If the VM crashes, it generates a dump, let's ignore it too. 17 | erl_crash.dump 18 | 19 | # Also ignore archive artifacts (built via "mix archive.build"). 20 | *.ez 21 | 22 | # Ignore package tarball (built via "mix hex.build"). 23 | ortex-*.tar 24 | 25 | # Temporary files, for example, from tests. 26 | /tmp/ 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Relay Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ortex 2 | 3 | `Ortex` is a wrapper around [ONNX Runtime](https://onnxruntime.ai/) (implemented as 4 | bindings to [`ort`](https://github.com/pykeio/ort)). Ortex leverages 5 | [`Nx.Serving`](https://hexdocs.pm/nx/Nx.Serving.html) to easily deploy ONNX models 6 | that run concurrently and distributed in a cluster. Ortex also provides a storage-only 7 | tensor implementation for ease of use. 8 | 9 | ONNX models are a standard machine learning model format that can be exported from most ML 10 | libraries like PyTorch and TensorFlow. Ortex allows for easy loading and fast inference of 11 | ONNX models using different backends available to ONNX Runtime such as CUDA, TensorRT, Core 12 | ML, and ARM Compute Library. 13 | 14 | ## Examples 15 | 16 | TL;DR: 17 | 18 | ```elixir 19 | iex> model = Ortex.load("./models/resnet50.onnx") 20 | #Ortex.Model< 21 | inputs: [{"input", "Float32", [nil, 3, 224, 224]}] 22 | outputs: [{"output", "Float32", [nil, 1000]}]> 23 | iex> {output} = Ortex.run(model, Nx.broadcast(0.0, {1, 3, 224, 224})) 24 | iex> output |> Nx.backend_transfer() |> Nx.argmax 25 | #Nx.Tensor< 26 | s64 27 | 499 28 | > 29 | ``` 30 | 31 | Inspecting a model shows the expected inputs, outputs, data types, and shapes. Axes with 32 | `nil` represent a dynamic size. 33 | 34 | To see more real world examples see the `examples` folder. 35 | 36 | ### Serving 37 | 38 | `Ortex` also implements `Nx.Serving` behaviour. To use it in your application's 39 | supervision tree consult the `Nx.Serving` docs. 40 | 41 | ```elixir 42 | iex> serving = Nx.Serving.new(Ortex.Serving, model) 43 | iex> batch = Nx.Batch.stack([{Nx.broadcast(0.0, {3, 224, 224})}]) 44 | iex> {result} = Nx.Serving.run(serving, batch) 45 | iex> result |> Nx.backend_transfer() |> Nx.argmax(axis: 1) 46 | #Nx.Tensor< 47 | s64[1] 48 | [499] 49 | > 50 | ``` 51 | 52 | ## Installation 53 | 54 | `Ortex` can be installed by adding `ortex` to your list of dependencies in `mix.exs`: 55 | 56 | ```elixir 57 | def deps do 58 | [ 59 | {:ortex, "~> 0.1.10"} 60 | ] 61 | end 62 | ``` 63 | 64 | You will need [Rust](https://www.rust-lang.org/tools/install) for compilation to succeed. 65 | -------------------------------------------------------------------------------- /config/config.exs: -------------------------------------------------------------------------------- 1 | import Config 2 | # Something is setting this to IEx.Pry so we're overriding it for now. Remove 3 | # if you need to do real debugging 4 | config :elixir, :dbg_callback, {Macro, :dbg, []} 5 | 6 | config :ortex, 7 | add_backend_on_inspect: config_env() != :test 8 | 9 | # Set the cargo feature flags required to use the matching execution provider 10 | # based on the OS we're running on 11 | ortex_features = 12 | case :os.type() do 13 | {:win32, _} -> ["directml"] 14 | {:unix, :darwin} -> ["coreml"] 15 | {:unix, _} -> ["cuda", "tensorrt"] 16 | end 17 | 18 | config :ortex, Ortex.Native, features: ortex_features 19 | -------------------------------------------------------------------------------- /config/runtime.exs: -------------------------------------------------------------------------------- 1 | # import Config 2 | # config :elixir, :dbg_callback, {Macro, :dbg, []} 3 | -------------------------------------------------------------------------------- /examples/distilbert/README.md: -------------------------------------------------------------------------------- 1 | # DistilBert exported to ONNX with HuggingFace transformers 2 | 3 | ### Running 4 | 5 | Run `python export.py` to create the ONNX model for distilbert/distilbert-base-uncased-finetuned-sst-2-english, then `mix run` the `distilbert_classification.exs` script. 6 | 7 | ### Labels 8 | 9 | When exporting the model from huggingface transformers to ONNX, a `config.json` file is added to the chosen directory. This file has the id to label mappings and you can extract them directly to give a label to the input, as shwon in `distilbert_classification.exs`. 10 | -------------------------------------------------------------------------------- /examples/distilbert/distilbert_classification.exs: -------------------------------------------------------------------------------- 1 | defmodule Inference do 2 | def id_to_label(id) do 3 | {:ok, config_json} = File.read("./models/distilbert-onnx/config.json") 4 | {:ok, %{"id2label" => id2label}} = Jason.decode(config_json) 5 | Map.get(id2label, to_string(id)) 6 | end 7 | 8 | def run() do 9 | model = Ortex.load("./models/distilbert-onnx/model.onnx") 10 | 11 | text = 12 | "the movie had a lot of nuance and interesting artistic choices, would like to see more support in the industry for these types of productions" 13 | 14 | {:ok, tokenizer} = Tokenizers.Tokenizer.from_file("./models/distilbert-onnx/tokenizer.json") 15 | {:ok, encoding} = Tokenizers.Tokenizer.encode(tokenizer, text) 16 | 17 | input = Nx.tensor([Tokenizers.Encoding.get_ids(encoding)]) 18 | mask = Nx.tensor([Tokenizers.Encoding.get_attention_mask(encoding)]) 19 | 20 | {output} = Ortex.run(model, {input, mask}) 21 | 22 | IO.inspect(output) 23 | 24 | IO.inspect( 25 | output 26 | |> Nx.backend_transfer() 27 | |> Nx.argmax() 28 | |> Nx.to_number() 29 | |> id_to_label() 30 | ) 31 | end 32 | end 33 | 34 | Inference.run() 35 | -------------------------------------------------------------------------------- /examples/distilbert/export.py: -------------------------------------------------------------------------------- 1 | """ 2 | ### Install dependencies: 3 | 4 | $ pip install transformers 5 | $ pip install optimum 6 | $ pip install "transformers[onnx]" 7 | 8 | """ 9 | 10 | from transformers import DistilBertTokenizer 11 | from optimum.onnxruntime import ORTModelForSequenceClassification 12 | 13 | save_directory = "./models/distilbert-onnx/" 14 | 15 | tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") 16 | model = ORTModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english", export=True) 17 | print(model) 18 | 19 | model.save_pretrained(save_directory) 20 | tokenizer.save_pretrained(save_directory) 21 | -------------------------------------------------------------------------------- /examples/stablelm/README.md: -------------------------------------------------------------------------------- 1 | # StableLM 2 | 3 | Run `python export.py` to create the ONNX model for stablelm-3b, copy the model to the 4 | models directory (or change where `stablelm.exs` loads the model from), then `mix run` 5 | the `stablelm.exs` script. 6 | -------------------------------------------------------------------------------- /examples/stablelm/export.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AutoModelForCausalLM, 3 | AutoTokenizer, 4 | ) 5 | import torch 6 | 7 | tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-tuned-alpha-3b") 8 | model = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-tuned-alpha-3b") 9 | print(model) 10 | 11 | prompt = "<|ASSISTANT|>" 12 | 13 | inputs = tokenizer(prompt, return_tensors="pt") 14 | torch.onnx.export( 15 | model, 16 | (inputs["input_ids"].cpu(), inputs["attention_mask"].cpu()), 17 | "output/stability-lm-tuned-3b.onnx", 18 | input_names=["input_ids", "attention_mask"], 19 | dynamic_axes={ 20 | "input_ids": {0: "batch_size", 1: "sequence_length"}, 21 | "attention_mask": {0: "batch_size", 1: "sequence_length"}, 22 | }, 23 | ) 24 | -------------------------------------------------------------------------------- /examples/stablelm/stablelm.exs: -------------------------------------------------------------------------------- 1 | model = Ortex.load("./models/stability-lm-3b/stability-lm-tuned-3b.onnx") 2 | 3 | prompt = "<|SYSTEM|># StableLM Tuned (Alpha version) 4 | - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. 5 | - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. 6 | - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. 7 | - StableLM will refuse to participate in anything that could harm a human. 8 | <|USER|>How are you feeling? <|ASSISTANT|> 9 | " 10 | 11 | {:ok, tokenizer} = Tokenizers.Tokenizer.from_pretrained("stabilityai/stablelm-tuned-alpha-3b") 12 | {:ok, encoding} = Tokenizers.Tokenizer.encode(tokenizer, prompt) 13 | 14 | input = Nx.tensor([Tokenizers.Encoding.get_ids(encoding)]) 15 | mask = Nx.tensor([Tokenizers.Encoding.get_attention_mask(encoding)]) 16 | 17 | defmodule M do 18 | def generate(_model, input, _mask, 500) do 19 | input 20 | end 21 | 22 | def generate(model, input, mask, iter) do 23 | [output | _] = 24 | Ortex.run(model, { 25 | input, 26 | mask 27 | }) 28 | |> Tuple.to_list() 29 | 30 | x = output |> Nx.backend_transfer() |> Nx.argmax(axis: 2) 31 | last = x[[.., -1]] |> Nx.new_axis(0) 32 | IO.inspect(last[0][0] |> Nx.to_number) 33 | 34 | case Enum.member?([50278, 50279, 50277, 1, 0], last[0][0] |> Nx.to_number) do 35 | true -> 36 | input 37 | 38 | false -> 39 | generate( 40 | model, 41 | Nx.concatenate([input, last], axis: 1), 42 | Nx.concatenate([mask, Nx.tensor([[1]])], axis: 1), 43 | iter + 1 44 | ) 45 | end 46 | end 47 | end 48 | 49 | result = M.generate(model, input, mask, 0) 50 | IO.inspect(result) 51 | 52 | IO.inspect( 53 | Tokenizers.Tokenizer.decode( 54 | tokenizer, 55 | result 56 | |> Nx.backend_transfer() 57 | |> Nx.to_batched(1) 58 | |> Enum.map(&Nx.to_flat_list/1) 59 | ) 60 | ) 61 | -------------------------------------------------------------------------------- /lib/ortex.ex: -------------------------------------------------------------------------------- 1 | defmodule Ortex do 2 | @moduledoc """ 3 | Documentation for `Ortex`. 4 | 5 | `Ortex` is an Elixir wrapper around [ONNX Runtime](https://onnxruntime.ai/) using 6 | [Rustler](https://hexdocs.pm/rustler) and [ORT](https://github.com/pykeio/ort). 7 | """ 8 | 9 | @doc """ 10 | Load an `Ortex.Model` from disk. Optionally pass the execution providers as a list 11 | of descending priority and graph optimization level 1-3. Any graph optimization level 12 | beyond the range of 1-3 will disable graph optimization. 13 | 14 | By default, `Ortex` only includes some of the supported execution providers of ONNX Runtime. 15 | To enable others, first ensure you have downloaded or compiled a version of 16 | `libonnxruntime` that includes them, then set the environment variable `ORT_LIB_LOCATION` 17 | to its location. Then add `config :ortex, Ortex.Native, features: [EXECUTION_PROVIDERS]` to your 18 | `config.exs` where `EXECUTION_PROVIDERS` is a list of strings of which execution providers 19 | to enable. 20 | 21 | ## Examples 22 | 23 | iex> Ortex.load("./models/tinymodel.onnx") 24 | iex> Ortex.load("./models/tinymodel.onnx", [:cuda, :cpu]) 25 | iex> Ortex.load("./models/tinymodel.onnx", [:cpu], 0) 26 | 27 | """ 28 | defdelegate load(path, eps \\ [:cpu], opt \\ 3), to: Ortex.Model 29 | 30 | @doc """ 31 | Run a forward pass through a model. 32 | 33 | This takes a model and tuple of `Nx.Tensors`, 34 | optionally transfers them to the `Ortex.Backend` if they aren't there already, 35 | and runs a forward pass through the model. This will return a tuple of `Ortex.Backend` 36 | tensors, it's up to the user to transfer these back to another backend if additional 37 | ops are required. 38 | 39 | If there is only one input you can optionally pass a bare tensor rather than a tuple. 40 | 41 | ## Examples 42 | 43 | iex> model = Ortex.load("./models/tinymodel.onnx") 44 | iex> {%Nx.Tensor{shape: {1, 10}}, 45 | ...> %Nx.Tensor{shape: {1, 10}}, 46 | ...> %Nx.Tensor{shape: {1, 10}}} = Ortex.run( 47 | ...> model, { 48 | ...> Nx.broadcast(0, {1, 100}) |> Nx.as_type(:s32), 49 | ...> Nx.broadcast(0, {1, 100}) |> Nx.as_type(:f32) 50 | ...> }) 51 | 52 | """ 53 | defdelegate run(model, tensors), to: Ortex.Model 54 | end 55 | -------------------------------------------------------------------------------- /lib/ortex/backend.ex: -------------------------------------------------------------------------------- 1 | defmodule Ortex.Backend do 2 | @moduledoc """ 3 | Documentation for `Ortex.Backend`. 4 | 5 | This implements the `Nx.Backend` behaviour for `Ortex` tensors. Most `Nx` operations 6 | are not implemented for this (although they may be in the future). This is mainly 7 | for ergonomic tensor construction and deconstruction from Ortex inputs and outputs. 8 | 9 | Since this does not implement most `Nx` operations, it's best *NOT* to set this as 10 | the default backend. 11 | """ 12 | 13 | @behaviour Nx.Backend 14 | @enforce_keys [:ref] 15 | 16 | @derive {Nx.Container, containers: [:ref]} 17 | 18 | defstruct [:ref] 19 | 20 | alias Ortex.Backend, as: B 21 | alias Nx.Tensor, as: T 22 | 23 | @impl true 24 | def init(opts) do 25 | if opts != [] do 26 | raise ArgumentError, "Ortex.Backend accepts no options" 27 | end 28 | 29 | opts 30 | end 31 | 32 | @impl true 33 | def from_binary(%T{shape: shape, type: type} = tensor, binary, _backend_options) do 34 | data = 35 | case Ortex.Native.from_binary(binary, shape, type) do 36 | {:error, msg} -> raise msg 37 | res -> res 38 | end 39 | 40 | put_in(tensor.data, %Ortex.Backend{ref: data}) 41 | end 42 | 43 | @impl true 44 | def to_binary(%T{data: %B{ref: ref}, type: {_, size}}, limit) do 45 | case Ortex.Native.to_binary(ref, size, limit) do 46 | {:error, msg} -> raise msg 47 | res -> res 48 | end 49 | end 50 | 51 | @impl true 52 | def backend_transfer(tensor, Nx.Tensor, _opts) do 53 | tensor 54 | end 55 | 56 | @impl true 57 | def backend_transfer(tensor, Ortex.Backend, _opts) do 58 | tensor 59 | end 60 | 61 | @impl true 62 | def backend_transfer(tensor, backend, opts) do 63 | backend.from_binary(tensor, to_binary(tensor), opts) 64 | end 65 | 66 | defp to_binary(%T{data: %{ref: tensor}}) do 67 | # filling the bits and limits with 0 since we aren't using them right now 68 | Ortex.Native.to_binary(tensor, 0, 0) 69 | end 70 | 71 | @impl true 72 | def inspect(%T{} = tensor, inspect_opts) do 73 | limit = if inspect_opts.limit == :infinity, do: :infinity, else: inspect_opts.limit + 1 74 | 75 | tensor 76 | |> to_binary(min(limit, Nx.size(tensor))) 77 | |> then(&Nx.Backend.inspect(tensor, &1, inspect_opts)) 78 | |> maybe_add_signature(tensor) 79 | end 80 | 81 | @impl true 82 | def slice(out, %T{data: %B{ref: tensor_ref}}, start_indicies, lengths, strides) do 83 | r = Ortex.Native.slice(tensor_ref, start_indicies, lengths, strides) 84 | put_in(out.data, %B{ref: r}) 85 | end 86 | 87 | @impl true 88 | def reshape(out, %T{data: %B{ref: ref}}) do 89 | shape = Nx.shape(out) |> Tuple.to_list() 90 | put_in(out.data, %B{ref: Ortex.Native.reshape(ref, shape)}) 91 | end 92 | 93 | @impl true 94 | def squeeze(out, tensor, axes) do 95 | %T{shape: old_shape, names: names, data: %B{ref: ref}} = tensor 96 | {new_shape, new_names} = Nx.Shape.squeeze(old_shape, axes, names) 97 | 98 | if old_shape == new_shape do 99 | %{out | data: %B{ref: ref}} 100 | else 101 | %{ 102 | out 103 | | shape: new_shape, 104 | names: new_names, 105 | data: %B{ref: Ortex.Native.reshape(ref, new_shape |> Tuple.to_list())} 106 | } 107 | end 108 | end 109 | 110 | @impl true 111 | def concatenate(out, tensors, axis) do 112 | if not Enum.all?(tensors, fn t -> t.type == out.type end) do 113 | raise "Ortex does not currently support concatenation of vectors with differing types." 114 | end 115 | 116 | tensor_refs = 117 | Enum.map(tensors, fn t -> 118 | %T{data: %B{ref: ref}} = t 119 | ref 120 | end) 121 | 122 | type = out.type 123 | 124 | %{out | data: %B{ref: Ortex.Native.concatenate(tensor_refs, type, axis)}} 125 | end 126 | 127 | if Application.compile_env(:ortex, :add_backend_on_inspect, true) do 128 | defp maybe_add_signature(result, %T{data: %B{ref: _mat_ref}}) do 129 | Inspect.Algebra.concat([ 130 | "Ortex.Backend", 131 | Inspect.Algebra.line(), 132 | result 133 | ]) 134 | end 135 | else 136 | defp maybe_add_signature(result, _tensor) do 137 | result 138 | end 139 | end 140 | 141 | funs = Nx.Backend.behaviour_info(:callbacks) -- Module.definitions_in(__MODULE__, :def) 142 | 143 | @doc false 144 | def __unimplemented__, do: unquote(funs) 145 | 146 | for {fun, arity} <- funs do 147 | args = Macro.generate_arguments(arity, __MODULE__) 148 | 149 | @impl true 150 | def unquote(fun)(unquote_splicing(args)) do 151 | raise "operation #{unquote(fun)} is not yet supported on Ortex.Backend." 152 | end 153 | end 154 | end 155 | -------------------------------------------------------------------------------- /lib/ortex/model.ex: -------------------------------------------------------------------------------- 1 | defmodule Ortex.Model do 2 | @moduledoc """ 3 | A model for running Ortex inference with. 4 | 5 | Implements a human-readable representation of a model including the name, dimension, and 6 | type of each input and output 7 | 8 | ``` 9 | #Ortex.Model< 10 | inputs: [{"x", "Int32", [nil, 100]}, {"y", "Float32", [nil, 100]}] 11 | outputs: [ 12 | {"9", "Float32", [nil, 10]}, 13 | {"onnx::Add_7", "Float32", [nil, 10]}, 14 | {"onnx::Add_8", "Float32", [nil, 10]} 15 | ]> 16 | ``` 17 | 18 | `nil` values represent dynamic dimensions 19 | """ 20 | 21 | @enforce_keys [:reference] 22 | defstruct [:reference] 23 | 24 | @doc false 25 | def load(path, eps \\ [:cpu], opt \\ 3) do 26 | case Ortex.Native.init(path, eps, opt) do 27 | {:error, msg} -> 28 | raise msg 29 | 30 | model -> 31 | %Ortex.Model{reference: model} 32 | end 33 | end 34 | 35 | @doc false 36 | def run(%Ortex.Model{} = model, tensor) when not is_tuple(tensor) do 37 | run(model, {tensor}) 38 | end 39 | 40 | @doc false 41 | def run(%Ortex.Model{reference: model}, tensors) do 42 | # Move tensors into Ortex backend and pass the reference to the Ortex NIF 43 | output = 44 | case Ortex.Native.run( 45 | model, 46 | tensors 47 | |> Tuple.to_list() 48 | |> Enum.map(fn x -> x |> Nx.backend_transfer(Ortex.Backend) end) 49 | |> Enum.map(fn %Nx.Tensor{data: %Ortex.Backend{ref: x}} -> x end) 50 | ) do 51 | {:error, msg} -> raise msg 52 | output -> output 53 | end 54 | 55 | # Pack the output into new Ortex.Backend tensor(s) 56 | output 57 | |> Enum.map(fn {ref, shape, dtype_atom, dtype_bits} -> 58 | %Nx.Tensor{ 59 | data: %Ortex.Backend{ref: ref}, 60 | shape: shape |> List.to_tuple(), 61 | type: {dtype_atom, dtype_bits}, 62 | names: List.duplicate(nil, length(shape)) 63 | } 64 | end) 65 | |> List.to_tuple() 66 | end 67 | end 68 | 69 | defimpl Inspect, for: Ortex.Model do 70 | import Inspect.Algebra 71 | 72 | def inspect(%Ortex.Model{reference: model}, inspect_opts) do 73 | case Ortex.Native.show_session(model) do 74 | {:error, msg} -> 75 | raise msg 76 | 77 | {inputs, outputs} -> 78 | force_unfit( 79 | concat([ 80 | color("#Ortex.Model<", :map, inspect_opts), 81 | line(), 82 | nest(concat([" inputs: ", Inspect.List.inspect(inputs, inspect_opts)]), 2), 83 | line(), 84 | nest(concat([" outputs: ", Inspect.List.inspect(outputs, inspect_opts)]), 2), 85 | color(">", :map, inspect_opts) 86 | ]) 87 | ) 88 | end 89 | end 90 | end 91 | -------------------------------------------------------------------------------- /lib/ortex/native.ex: -------------------------------------------------------------------------------- 1 | defmodule Ortex.Native do 2 | @moduledoc false 3 | 4 | @rustler_version Application.spec(:rustler, :vsn) |> to_string() |> Version.parse!() 5 | 6 | # We have to compile the crate before `use Rustler` compiles the crate since 7 | # cargo downloads the onnxruntime shared libraries and they are not available 8 | # to load or copy into Elixir's during the on_load or Elixir compile steps. 9 | # In the future, this may be configurable in Rustler. 10 | if Version.compare(@rustler_version, "0.30.0") in [:gt, :eq] do 11 | Rustler.Compiler.compile_crate(:ortex, Application.compile_env(:ortex, __MODULE__, []), 12 | otp_app: :ortex, 13 | crate: :ortex 14 | ) 15 | else 16 | Rustler.Compiler.compile_crate(__MODULE__, otp_app: :ortex, crate: :ortex) 17 | end 18 | 19 | Ortex.Util.copy_ort_libs() 20 | 21 | use Rustler, 22 | otp_app: :ortex, 23 | crate: :ortex, 24 | skip_compilation?: true 25 | 26 | # When loading a NIF module, dummy clauses for all NIF function are required. 27 | # NIF dummies usually just error out when called when the NIF is not loaded, as that should never normally happen. 28 | def init(_model_path, _execution_providers, _optimization_level), 29 | do: :erlang.nif_error(:nif_not_loaded) 30 | 31 | def run(_model, _inputs), do: :erlang.nif_error(:nif_not_loaded) 32 | def from_binary(_bin, _shape, _type), do: :erlang.nif_error(:nif_not_loaded) 33 | def to_binary(_reference, _bits, _limit), do: :erlang.nif_error(:nif_not_loaded) 34 | def show_session(_model), do: :erlang.nif_error(:nif_not_loaded) 35 | 36 | def slice(_tensor, _start_indicies, _lengths, _strides), 37 | do: :erlang.nif_error(:nif_not_loaded) 38 | 39 | def reshape(_tensor, _shape), do: :erlang.nif_error(:nif_not_loaded) 40 | 41 | def concatenate(_tensors_refs, _type, _axis), do: :erlang.nif_error(:nif_not_loaded) 42 | end 43 | -------------------------------------------------------------------------------- /lib/ortex/serving.ex: -------------------------------------------------------------------------------- 1 | defmodule Ortex.Serving do 2 | @moduledoc """ 3 | `Ortex.Serving` Documentation 4 | 5 | This is a lightweight wrapper for using `Nx.Serving` behaviour with `Ortex`. Using `jit` and 6 | `defn` functions in this are not supported, it is strictly for serving batches to 7 | an `Ortex.Model` for inference. 8 | 9 | ## Examples 10 | 11 | ### Inline/serverless workflow 12 | 13 | To quickly create an `Ortex.Serving` and run it 14 | 15 | ```elixir 16 | iex> model = Ortex.load("./models/resnet50.onnx") 17 | iex> serving = Nx.Serving.new(Ortex.Serving, model) 18 | iex> batch = Nx.Batch.stack([{Nx.broadcast(0.0, {3, 224, 224})}]) 19 | iex> {result} = Nx.Serving.run(serving, batch) 20 | iex> result |> Nx.backend_transfer |> Nx.argmax(axis: 1) 21 | #Nx.Tensor< 22 | s64[1] 23 | [499] 24 | > 25 | ``` 26 | 27 | ### Stateful/process workflow 28 | 29 | An `Ortex.Serving` can also be started in your Application's supervision tree 30 | ```elixir 31 | model = Ortex.load("./models/resnet50.onnx") 32 | children = [ 33 | {Nx.Serving, 34 | serving: Nx.Serving.new(Ortex.Serving, model), 35 | name: MyServing, 36 | batch_size: 10, 37 | batch_timeout: 100} 38 | ] 39 | opts = [strategy: :one_for_one, name: OrtexServing.Supervisor] 40 | Supervisor.start_link(children, opts) 41 | ``` 42 | 43 | With the application started, batches can now be sent to the `Ortex.Serving` process 44 | 45 | ```elixir 46 | iex> Nx.Serving.batched_run(MyServing, Nx.Batch.stack([{Nx.broadcast(0.0, {3, 224, 224})}])) 47 | ...> {#Nx.Tensor< 48 | f32[1][1000] 49 | Ortex.Backend 50 | [ 51 | [...] 52 | ] 53 | >} 54 | 55 | ``` 56 | 57 | """ 58 | 59 | @behaviour Nx.Serving 60 | 61 | @impl true 62 | def init(_inline_or_process, model, [_defn_options]) do 63 | func = fn x -> Ortex.run(model, x) end 64 | {:ok, func} 65 | end 66 | 67 | @impl true 68 | def handle_batch(batch, _partition, function) do 69 | # A hack to move the back into a tensor for Ortex 70 | out = function.(Nx.Defn.jit_apply(&Function.identity/1, [batch])) 71 | {:execute, fn -> {out, :server_info} end, function} 72 | end 73 | end 74 | -------------------------------------------------------------------------------- /lib/ortex/util.ex: -------------------------------------------------------------------------------- 1 | defmodule Ortex.Util do 2 | @moduledoc false 3 | @doc """ 4 | Copies the libraries downloaded during the ORT build into a path that 5 | Elixir can use 6 | """ 7 | def copy_ort_libs() do 8 | build_root = Path.absname(:code.priv_dir(:ortex)) |> Path.dirname() 9 | 10 | rust_env = 11 | case Path.join([build_root, "native/ortex/release"]) |> File.ls() do 12 | {:ok, _} -> "release" 13 | _ -> "debug" 14 | end 15 | 16 | # where the libonnxruntime files are stored 17 | rust_path = Path.join([build_root, "native/ortex", rust_env]) 18 | 19 | onnx_runtime_paths = 20 | case :os.type() do 21 | {:win32, _} -> Path.join([rust_path, "libonnxruntime*.dll*"]) 22 | {:unix, :darwin} -> Path.join([rust_path, "libonnxruntime*.dylib*"]) 23 | {:unix, _} -> Path.join([rust_path, "libonnxruntime*.so*"]) 24 | end 25 | |> Path.wildcard() 26 | 27 | # where we need to copy the paths 28 | destination_dir = Path.join([:code.priv_dir(:ortex), "native"]) 29 | 30 | onnx_runtime_paths 31 | |> Enum.map(fn x -> 32 | File.cp!(x, Path.join([destination_dir, Path.basename(x)])) 33 | end) 34 | end 35 | end 36 | -------------------------------------------------------------------------------- /mix.exs: -------------------------------------------------------------------------------- 1 | defmodule Ortex.MixProject do 2 | use Mix.Project 3 | 4 | def project do 5 | [ 6 | app: :ortex, 7 | version: "0.1.10", 8 | elixir: "~> 1.14", 9 | start_permanent: Mix.env() == :prod, 10 | deps: deps(), 11 | 12 | # Docs 13 | name: "Ortex", 14 | source_url: "https://github.com/elixir-nx/ortex", 15 | homepage_url: "http://github.com/elixir-nx/ortex", 16 | docs: [ 17 | main: "readme", 18 | extras: ["README.md"] 19 | ], 20 | package: package() 21 | ] 22 | end 23 | 24 | # Run "mix help compile.app" to learn about applications. 25 | def application do 26 | [ 27 | extra_applications: [:logger] 28 | ] 29 | end 30 | 31 | # Run "mix help deps" to learn about dependencies. 32 | defp deps do 33 | [ 34 | {:rustler, "~> 0.27"}, 35 | {:nx, "~> 0.6"}, 36 | {:tokenizers, "~> 0.4", only: :dev}, 37 | {:ex_doc, "0.29.4", only: :dev, runtime: false}, 38 | {:exla, "~> 0.6", only: :dev}, 39 | {:torchx, "~> 0.6", only: :dev} 40 | ] 41 | end 42 | 43 | defp package do 44 | [ 45 | files: ~w(lib .formatter.exs mix.exs README* LICENSE* native/ortex/src/ config/config.exs 46 | native/ortex/Cargo.lock native/ortex/Cargo.toml native/ortex/.cargo/config.toml), 47 | licenses: ["MIT"], 48 | links: %{"GitHub" => "https://github.com/elixir-nx/ortex"}, 49 | description: "ONNX Runtime bindings for Elixir" 50 | ] 51 | end 52 | end 53 | -------------------------------------------------------------------------------- /mix.lock: -------------------------------------------------------------------------------- 1 | %{ 2 | "axon": {:hex, :axon, "0.5.1", "1ae3a2193df45e51fca912158320b2ca87cb7fba4df242bd3ebe245504d0ea1a", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:kino_vega_lite, "~> 0.1.7", [hex: :kino_vega_lite, repo: "hexpm", optional: true]}, {:nx, "~> 0.5.0", [hex: :nx, repo: "hexpm", optional: false]}, {:table_rex, "~> 3.1.1", [hex: :table_rex, repo: "hexpm", optional: true]}], "hexpm", "d36f2a11c34c6c2b458f54df5c71ffdb7ed91c6a9ccd908faba909c84cc6a38e"}, 3 | "axon_onnx": {:hex, :axon_onnx, "0.4.0", "7be4b5ac7a44340ec65eb59c24122a8fe2aa8105da33b3321a378b455a6cd9c6", [:mix], [{:axon, "~> 0.5", [hex: :axon, repo: "hexpm", optional: false]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}, {:protox, "~> 1.6.10", [hex: :protox, repo: "hexpm", optional: false]}], "hexpm", "b98c84e5656caf156ef8998296836349a62bc35598f05cc21eececbbef022d09"}, 4 | "castore": {:hex, :castore, "1.0.4", "ff4d0fb2e6411c0479b1d965a814ea6d00e51eb2f58697446e9c41a97d940b28", [:mix], [], "hexpm", "9418c1b8144e11656f0be99943db4caf04612e3eaecefb5dae9a2a87565584f8"}, 5 | "cc_precompiler": {:hex, :cc_precompiler, "0.1.7", "77de20ac77f0e53f20ca82c563520af0237c301a1ec3ab3bc598e8a96c7ee5d9", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "2768b28bf3c2b4f788c995576b39b8cb5d47eb788526d93bd52206c1d8bf4b75"}, 6 | "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, 7 | "decimal": {:hex, :decimal, "2.1.1", "5611dca5d4b2c3dd497dec8f68751f1f1a54755e8ed2a966c2633cf885973ad6", [:mix], [], "hexpm", "53cfe5f497ed0e7771ae1a475575603d77425099ba5faef9394932b35020ffcc"}, 8 | "dll_loader_helper": {:hex, :dll_loader_helper, "1.1.0", "e7d015e980942a0d67e306827ec907e7e853a21186bd92bb968d986698591a0f", [:mix], [{:dll_loader_helper_beam, "~> 1.1", [hex: :dll_loader_helper_beam, repo: "hexpm", optional: false]}], "hexpm", "2b6c11ee7bb48f6a132ce8f872202f9e828c019988da1e2d40ad41496195df0c"}, 9 | "dll_loader_helper_beam": {:hex, :dll_loader_helper_beam, "1.2.0", "557c43befb8e3b119b718da302adccde3bd855acdb999498a14a2a8d2814b8b9", [:rebar3], [], "hexpm", "a2115d4bf1cca488a7b33f3c648847f64019b32c0382d10286d84dd5c3cbc0e5"}, 10 | "earmark_parser": {:hex, :earmark_parser, "1.4.32", "fa739a0ecfa34493de19426681b23f6814573faee95dfd4b4aafe15a7b5b32c6", [:mix], [], "hexpm", "b8b0dd77d60373e77a3d7e8afa598f325e49e8663a51bcc2b88ef41838cca755"}, 11 | "elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"}, 12 | "erlex": {:hex, :erlex, "0.2.6", "c7987d15e899c7a2f34f5420d2a2ea0d659682c06ac607572df55a43753aa12e", [:mix], [], "hexpm", "2ed2e25711feb44d52b17d2780eabf998452f6efda104877a3881c2f8c0c0c75"}, 13 | "ex_doc": {:hex, :ex_doc, "0.29.4", "6257ecbb20c7396b1fe5accd55b7b0d23f44b6aa18017b415cb4c2b91d997729", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "2c6699a737ae46cb61e4ed012af931b57b699643b24dabe2400a8168414bc4f5"}, 14 | "exla": {:hex, :exla, "0.6.1", "a4400933a04d018c5fb508c75a080c73c3c1986f6c16a79bbfee93ba22830d4d", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.6.1", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.5.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "f0e95b0f91a937030cf9fcbe900c9d26933cb31db2a26dfc8569aa239679e6d4"}, 15 | "jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"}, 16 | "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, 17 | "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, 18 | "makeup_erlang": {:hex, :makeup_erlang, "0.1.1", "3fcb7f09eb9d98dc4d208f49cc955a34218fc41ff6b84df7c75b3e6e533cc65f", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "174d0809e98a4ef0b3309256cbf97101c6ec01c4ab0b23e926a9e17df2077cbb"}, 19 | "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, 20 | "nx": {:hex, :nx, "0.6.2", "f1d137f477b1a6f84f8db638f7a6d5a0f8266caea63c9918aa4583db38ebe1d6", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "ac913b68d53f25f6eb39bddcf2d2cd6ea2e9bcb6f25cf86a79e35d0411ba96ad"}, 21 | "protox": {:hex, :protox, "1.6.10", "41d0b0c5b9190e7d5e6a2b1a03a09257ead6f3d95e6a0cf8b81430b526126908", [:mix], [{:decimal, "~> 1.9 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}, {:jason, "~> 1.2", [hex: :jason, repo: "hexpm", optional: true]}], "hexpm", "9769fca26ae7abfc5cc61308a1e8d9e2400ff89a799599cee7930d21132832d9"}, 22 | "rustler": {:hex, :rustler, "0.29.1", "880f20ae3027bd7945def6cea767f5257bc926f33ff50c0d5d5a5315883c084d", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.6", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "109497d701861bfcd26eb8f5801fe327a8eef304f56a5b63ef61151ff44ac9b6"}, 23 | "rustler_precompiled": {:hex, :rustler_precompiled, "0.7.0", "5d0834fc06dbc76dd1034482f17b1797df0dba9b491cef8bb045fcaca94bcade", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "fdf43a6835f4e4de5bfbc4c019bfb8c46d124bd4635fefa3e20d9a2bbbec1512"}, 24 | "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, 25 | "tokenizers": {:hex, :tokenizers, "0.4.0", "140283ca74a971391ddbd83cd8cbdb9bd03736f37a1b6989b82d245a95e1eb97", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, ">= 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.6", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "ef1a9824f5a893cd3b831c0e5b3d72caa250d2ec462035cc6afef6933b13a82e"}, 26 | "toml": {:hex, :toml, "0.7.0", "fbcd773caa937d0c7a02c301a1feea25612720ac3fa1ccb8bfd9d30d822911de", [:mix], [], "hexpm", "0690246a2478c1defd100b0c9b89b4ea280a22be9a7b313a8a058a2408a2fa70"}, 27 | "torchx": {:hex, :torchx, "0.6.1", "2a9862ebc4b397f42c51f0fa3f9f4e3451a83df6fba42882f8523cbc925c8ae1", [:make, :mix], [{:dll_loader_helper, "~> 0.1 or ~> 1.0", [hex: :dll_loader_helper, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.6.1", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "99b3fc73b52d6cfbe5cad8bdd74277ddc99297ce8fc6765b1dabec80681e8d9d"}, 28 | "useful": {:hex, :useful, "1.11.0", "b2d89223563c3354fd56f4da75b63f07f52cb32b243289a7f1fcc37869bcf9c2", [:mix], [], "hexpm", "2e5b2a47acc191bfb38e936f5f1bc57dad3b11133e0defe59a32fda10ebafcff"}, 29 | "xla": {:hex, :xla, "0.5.1", "8ba4c2c51c1a708ff54e9d4f88158c1a75b7f2cb3e5db02bf222b5b3852afffd", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "82a2490f6e9a76c8a29d1aedb47f07c59e3d5081095eac5a74db34d46c8212bc"}, 30 | } 31 | -------------------------------------------------------------------------------- /models/tinymodel.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/ortex/450dbe6ec6cc96e4e5509a746937ccde25de0144/models/tinymodel.onnx -------------------------------------------------------------------------------- /native/ortex/.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [profile.dev] 2 | rpath=true 3 | 4 | [profile.release] 5 | rpath=true 6 | 7 | [target.'cfg(target_os = "macos")'] 8 | rustflags = [ 9 | "-C", "link-arg=-undefined", 10 | "-C", "link-arg=dynamic_lookup", 11 | "-C", "link-arg=-fapple-link-rtlib", 12 | "-C", "link-args=-Wl,-rpath,@loader_path", 13 | ] 14 | [target.x86_64-unknown-linux-gnu] 15 | rustflags = [ "-Clink-args=-Wl,-rpath,$ORIGIN" ] 16 | [target.aarch64-unknown-linux-gnu] 17 | rustflags = [ "-Clink-args=-Wl,-rpath,$ORIGIN" ] 18 | -------------------------------------------------------------------------------- /native/ortex/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | -------------------------------------------------------------------------------- /native/ortex/Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "adler" 7 | version = "1.0.2" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" 10 | 11 | [[package]] 12 | name = "aho-corasick" 13 | version = "0.7.20" 14 | source = "registry+https://github.com/rust-lang/crates.io-index" 15 | checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" 16 | dependencies = [ 17 | "memchr", 18 | ] 19 | 20 | [[package]] 21 | name = "autocfg" 22 | version = "1.1.0" 23 | source = "registry+https://github.com/rust-lang/crates.io-index" 24 | checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" 25 | 26 | [[package]] 27 | name = "base64" 28 | version = "0.21.7" 29 | source = "registry+https://github.com/rust-lang/crates.io-index" 30 | checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" 31 | 32 | [[package]] 33 | name = "bitflags" 34 | version = "1.3.2" 35 | source = "registry+https://github.com/rust-lang/crates.io-index" 36 | checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" 37 | 38 | [[package]] 39 | name = "bitflags" 40 | version = "2.4.2" 41 | source = "registry+https://github.com/rust-lang/crates.io-index" 42 | checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" 43 | 44 | [[package]] 45 | name = "block-buffer" 46 | version = "0.10.4" 47 | source = "registry+https://github.com/rust-lang/crates.io-index" 48 | checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" 49 | dependencies = [ 50 | "generic-array", 51 | ] 52 | 53 | [[package]] 54 | name = "byteorder" 55 | version = "1.5.0" 56 | source = "registry+https://github.com/rust-lang/crates.io-index" 57 | checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" 58 | 59 | [[package]] 60 | name = "cc" 61 | version = "1.0.83" 62 | source = "registry+https://github.com/rust-lang/crates.io-index" 63 | checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" 64 | dependencies = [ 65 | "libc", 66 | ] 67 | 68 | [[package]] 69 | name = "cfg-if" 70 | version = "1.0.0" 71 | source = "registry+https://github.com/rust-lang/crates.io-index" 72 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" 73 | 74 | [[package]] 75 | name = "cpufeatures" 76 | version = "0.2.12" 77 | source = "registry+https://github.com/rust-lang/crates.io-index" 78 | checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" 79 | dependencies = [ 80 | "libc", 81 | ] 82 | 83 | [[package]] 84 | name = "crc32fast" 85 | version = "1.4.0" 86 | source = "registry+https://github.com/rust-lang/crates.io-index" 87 | checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" 88 | dependencies = [ 89 | "cfg-if", 90 | ] 91 | 92 | [[package]] 93 | name = "crunchy" 94 | version = "0.2.2" 95 | source = "registry+https://github.com/rust-lang/crates.io-index" 96 | checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" 97 | 98 | [[package]] 99 | name = "crypto-common" 100 | version = "0.1.6" 101 | source = "registry+https://github.com/rust-lang/crates.io-index" 102 | checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" 103 | dependencies = [ 104 | "generic-array", 105 | "typenum", 106 | ] 107 | 108 | [[package]] 109 | name = "digest" 110 | version = "0.10.7" 111 | source = "registry+https://github.com/rust-lang/crates.io-index" 112 | checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" 113 | dependencies = [ 114 | "block-buffer", 115 | "crypto-common", 116 | ] 117 | 118 | [[package]] 119 | name = "errno" 120 | version = "0.3.8" 121 | source = "registry+https://github.com/rust-lang/crates.io-index" 122 | checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" 123 | dependencies = [ 124 | "libc", 125 | "windows-sys", 126 | ] 127 | 128 | [[package]] 129 | name = "filetime" 130 | version = "0.2.23" 131 | source = "registry+https://github.com/rust-lang/crates.io-index" 132 | checksum = "1ee447700ac8aa0b2f2bd7bc4462ad686ba06baa6727ac149a2d6277f0d240fd" 133 | dependencies = [ 134 | "cfg-if", 135 | "libc", 136 | "redox_syscall", 137 | "windows-sys", 138 | ] 139 | 140 | [[package]] 141 | name = "flate2" 142 | version = "1.0.28" 143 | source = "registry+https://github.com/rust-lang/crates.io-index" 144 | checksum = "46303f565772937ffe1d394a4fac6f411c6013172fadde9dcdb1e147a086940e" 145 | dependencies = [ 146 | "crc32fast", 147 | "miniz_oxide", 148 | ] 149 | 150 | [[package]] 151 | name = "form_urlencoded" 152 | version = "1.2.1" 153 | source = "registry+https://github.com/rust-lang/crates.io-index" 154 | checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" 155 | dependencies = [ 156 | "percent-encoding", 157 | ] 158 | 159 | [[package]] 160 | name = "generic-array" 161 | version = "0.14.7" 162 | source = "registry+https://github.com/rust-lang/crates.io-index" 163 | checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" 164 | dependencies = [ 165 | "typenum", 166 | "version_check", 167 | ] 168 | 169 | [[package]] 170 | name = "getrandom" 171 | version = "0.2.12" 172 | source = "registry+https://github.com/rust-lang/crates.io-index" 173 | checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" 174 | dependencies = [ 175 | "cfg-if", 176 | "libc", 177 | "wasi", 178 | ] 179 | 180 | [[package]] 181 | name = "half" 182 | version = "2.3.1" 183 | source = "registry+https://github.com/rust-lang/crates.io-index" 184 | checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872" 185 | dependencies = [ 186 | "cfg-if", 187 | "crunchy", 188 | ] 189 | 190 | [[package]] 191 | name = "heck" 192 | version = "0.4.1" 193 | source = "registry+https://github.com/rust-lang/crates.io-index" 194 | checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" 195 | 196 | [[package]] 197 | name = "idna" 198 | version = "0.5.0" 199 | source = "registry+https://github.com/rust-lang/crates.io-index" 200 | checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" 201 | dependencies = [ 202 | "unicode-bidi", 203 | "unicode-normalization", 204 | ] 205 | 206 | [[package]] 207 | name = "lazy_static" 208 | version = "1.4.0" 209 | source = "registry+https://github.com/rust-lang/crates.io-index" 210 | checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" 211 | 212 | [[package]] 213 | name = "libc" 214 | version = "0.2.153" 215 | source = "registry+https://github.com/rust-lang/crates.io-index" 216 | checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" 217 | 218 | [[package]] 219 | name = "linux-raw-sys" 220 | version = "0.4.13" 221 | source = "registry+https://github.com/rust-lang/crates.io-index" 222 | checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" 223 | 224 | [[package]] 225 | name = "log" 226 | version = "0.4.17" 227 | source = "registry+https://github.com/rust-lang/crates.io-index" 228 | checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" 229 | dependencies = [ 230 | "cfg-if", 231 | ] 232 | 233 | [[package]] 234 | name = "matchers" 235 | version = "0.1.0" 236 | source = "registry+https://github.com/rust-lang/crates.io-index" 237 | checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" 238 | dependencies = [ 239 | "regex-automata", 240 | ] 241 | 242 | [[package]] 243 | name = "matrixmultiply" 244 | version = "0.3.2" 245 | source = "registry+https://github.com/rust-lang/crates.io-index" 246 | checksum = "add85d4dd35074e6fedc608f8c8f513a3548619a9024b751949ef0e8e45a4d84" 247 | dependencies = [ 248 | "rawpointer", 249 | ] 250 | 251 | [[package]] 252 | name = "memchr" 253 | version = "2.5.0" 254 | source = "registry+https://github.com/rust-lang/crates.io-index" 255 | checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" 256 | 257 | [[package]] 258 | name = "miniz_oxide" 259 | version = "0.7.2" 260 | source = "registry+https://github.com/rust-lang/crates.io-index" 261 | checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" 262 | dependencies = [ 263 | "adler", 264 | ] 265 | 266 | [[package]] 267 | name = "ndarray" 268 | version = "0.16.1" 269 | source = "registry+https://github.com/rust-lang/crates.io-index" 270 | checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" 271 | dependencies = [ 272 | "matrixmultiply", 273 | "num-complex", 274 | "num-integer", 275 | "num-traits", 276 | "portable-atomic", 277 | "portable-atomic-util", 278 | "rawpointer", 279 | ] 280 | 281 | [[package]] 282 | name = "nu-ansi-term" 283 | version = "0.46.0" 284 | source = "registry+https://github.com/rust-lang/crates.io-index" 285 | checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" 286 | dependencies = [ 287 | "overload", 288 | "winapi", 289 | ] 290 | 291 | [[package]] 292 | name = "num-complex" 293 | version = "0.4.3" 294 | source = "registry+https://github.com/rust-lang/crates.io-index" 295 | checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d" 296 | dependencies = [ 297 | "num-traits", 298 | ] 299 | 300 | [[package]] 301 | name = "num-integer" 302 | version = "0.1.45" 303 | source = "registry+https://github.com/rust-lang/crates.io-index" 304 | checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" 305 | dependencies = [ 306 | "autocfg", 307 | "num-traits", 308 | ] 309 | 310 | [[package]] 311 | name = "num-traits" 312 | version = "0.2.15" 313 | source = "registry+https://github.com/rust-lang/crates.io-index" 314 | checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" 315 | dependencies = [ 316 | "autocfg", 317 | ] 318 | 319 | [[package]] 320 | name = "once_cell" 321 | version = "1.19.0" 322 | source = "registry+https://github.com/rust-lang/crates.io-index" 323 | checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" 324 | 325 | [[package]] 326 | name = "ort" 327 | version = "2.0.0-rc.8" 328 | source = "registry+https://github.com/rust-lang/crates.io-index" 329 | checksum = "11826e6118cc42fea0cb2b102f7d006c1bb339cb167f8badb5fb568616438234" 330 | dependencies = [ 331 | "half", 332 | "ndarray", 333 | "ort-sys", 334 | "tracing", 335 | ] 336 | 337 | [[package]] 338 | name = "ort-sys" 339 | version = "2.0.0-rc.8" 340 | source = "registry+https://github.com/rust-lang/crates.io-index" 341 | checksum = "c4780a8b8681e653b2bed85c7f0e2c6e8547224c3e983e5ad27bf0457e012407" 342 | dependencies = [ 343 | "flate2", 344 | "pkg-config", 345 | "sha2", 346 | "tar", 347 | "ureq", 348 | ] 349 | 350 | [[package]] 351 | name = "ortex" 352 | version = "0.1.0" 353 | dependencies = [ 354 | "half", 355 | "ndarray", 356 | "num-traits", 357 | "ort", 358 | "rustler", 359 | "rustls", 360 | "tracing-subscriber", 361 | ] 362 | 363 | [[package]] 364 | name = "overload" 365 | version = "0.1.1" 366 | source = "registry+https://github.com/rust-lang/crates.io-index" 367 | checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" 368 | 369 | [[package]] 370 | name = "percent-encoding" 371 | version = "2.3.1" 372 | source = "registry+https://github.com/rust-lang/crates.io-index" 373 | checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" 374 | 375 | [[package]] 376 | name = "pin-project-lite" 377 | version = "0.2.9" 378 | source = "registry+https://github.com/rust-lang/crates.io-index" 379 | checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" 380 | 381 | [[package]] 382 | name = "pkg-config" 383 | version = "0.3.31" 384 | source = "registry+https://github.com/rust-lang/crates.io-index" 385 | checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" 386 | 387 | [[package]] 388 | name = "portable-atomic" 389 | version = "1.9.0" 390 | source = "registry+https://github.com/rust-lang/crates.io-index" 391 | checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" 392 | 393 | [[package]] 394 | name = "portable-atomic-util" 395 | version = "0.2.3" 396 | source = "registry+https://github.com/rust-lang/crates.io-index" 397 | checksum = "90a7d5beecc52a491b54d6dd05c7a45ba1801666a5baad9fdbfc6fef8d2d206c" 398 | dependencies = [ 399 | "portable-atomic", 400 | ] 401 | 402 | [[package]] 403 | name = "proc-macro2" 404 | version = "1.0.78" 405 | source = "registry+https://github.com/rust-lang/crates.io-index" 406 | checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" 407 | dependencies = [ 408 | "unicode-ident", 409 | ] 410 | 411 | [[package]] 412 | name = "quote" 413 | version = "1.0.27" 414 | source = "registry+https://github.com/rust-lang/crates.io-index" 415 | checksum = "8f4f29d145265ec1c483c7c654450edde0bfe043d3938d6972630663356d9500" 416 | dependencies = [ 417 | "proc-macro2", 418 | ] 419 | 420 | [[package]] 421 | name = "rawpointer" 422 | version = "0.2.1" 423 | source = "registry+https://github.com/rust-lang/crates.io-index" 424 | checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" 425 | 426 | [[package]] 427 | name = "redox_syscall" 428 | version = "0.4.1" 429 | source = "registry+https://github.com/rust-lang/crates.io-index" 430 | checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" 431 | dependencies = [ 432 | "bitflags 1.3.2", 433 | ] 434 | 435 | [[package]] 436 | name = "regex" 437 | version = "1.7.1" 438 | source = "registry+https://github.com/rust-lang/crates.io-index" 439 | checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" 440 | dependencies = [ 441 | "aho-corasick", 442 | "memchr", 443 | "regex-syntax", 444 | ] 445 | 446 | [[package]] 447 | name = "regex-automata" 448 | version = "0.1.10" 449 | source = "registry+https://github.com/rust-lang/crates.io-index" 450 | checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" 451 | dependencies = [ 452 | "regex-syntax", 453 | ] 454 | 455 | [[package]] 456 | name = "regex-syntax" 457 | version = "0.6.28" 458 | source = "registry+https://github.com/rust-lang/crates.io-index" 459 | checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" 460 | 461 | [[package]] 462 | name = "ring" 463 | version = "0.17.8" 464 | source = "registry+https://github.com/rust-lang/crates.io-index" 465 | checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" 466 | dependencies = [ 467 | "cc", 468 | "cfg-if", 469 | "getrandom", 470 | "libc", 471 | "spin", 472 | "untrusted", 473 | "windows-sys", 474 | ] 475 | 476 | [[package]] 477 | name = "rustix" 478 | version = "0.38.31" 479 | source = "registry+https://github.com/rust-lang/crates.io-index" 480 | checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949" 481 | dependencies = [ 482 | "bitflags 2.4.2", 483 | "errno", 484 | "libc", 485 | "linux-raw-sys", 486 | "windows-sys", 487 | ] 488 | 489 | [[package]] 490 | name = "rustler" 491 | version = "0.29.1" 492 | source = "registry+https://github.com/rust-lang/crates.io-index" 493 | checksum = "0884cb623b9f43d3e2c51f9071c5e96a5acf3e6e6007866812884ff0cb983f1e" 494 | dependencies = [ 495 | "lazy_static", 496 | "rustler_codegen", 497 | "rustler_sys", 498 | ] 499 | 500 | [[package]] 501 | name = "rustler_codegen" 502 | version = "0.29.1" 503 | source = "registry+https://github.com/rust-lang/crates.io-index" 504 | checksum = "50e277af754f2560cf4c4ebedb68c1a735292fb354505c6133e47ec406e699cf" 505 | dependencies = [ 506 | "heck", 507 | "proc-macro2", 508 | "quote", 509 | "syn", 510 | ] 511 | 512 | [[package]] 513 | name = "rustler_sys" 514 | version = "2.3.2" 515 | source = "registry+https://github.com/rust-lang/crates.io-index" 516 | checksum = "ff76ba8524729d7c9db2b3e80f2269d1fdef39b5a60624c33fd794797e69b558" 517 | dependencies = [ 518 | "regex", 519 | "unreachable", 520 | ] 521 | 522 | [[package]] 523 | name = "rustls" 524 | version = "0.22.4" 525 | source = "registry+https://github.com/rust-lang/crates.io-index" 526 | checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" 527 | dependencies = [ 528 | "log", 529 | "ring", 530 | "rustls-pki-types", 531 | "rustls-webpki", 532 | "subtle", 533 | "zeroize", 534 | ] 535 | 536 | [[package]] 537 | name = "rustls-pki-types" 538 | version = "1.3.0" 539 | source = "registry+https://github.com/rust-lang/crates.io-index" 540 | checksum = "048a63e5b3ac996d78d402940b5fa47973d2d080c6c6fffa1d0f19c4445310b7" 541 | 542 | [[package]] 543 | name = "rustls-webpki" 544 | version = "0.102.2" 545 | source = "registry+https://github.com/rust-lang/crates.io-index" 546 | checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610" 547 | dependencies = [ 548 | "ring", 549 | "rustls-pki-types", 550 | "untrusted", 551 | ] 552 | 553 | [[package]] 554 | name = "sha2" 555 | version = "0.10.8" 556 | source = "registry+https://github.com/rust-lang/crates.io-index" 557 | checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" 558 | dependencies = [ 559 | "cfg-if", 560 | "cpufeatures", 561 | "digest", 562 | ] 563 | 564 | [[package]] 565 | name = "sharded-slab" 566 | version = "0.1.7" 567 | source = "registry+https://github.com/rust-lang/crates.io-index" 568 | checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" 569 | dependencies = [ 570 | "lazy_static", 571 | ] 572 | 573 | [[package]] 574 | name = "smallvec" 575 | version = "1.13.1" 576 | source = "registry+https://github.com/rust-lang/crates.io-index" 577 | checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" 578 | 579 | [[package]] 580 | name = "socks" 581 | version = "0.3.4" 582 | source = "registry+https://github.com/rust-lang/crates.io-index" 583 | checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" 584 | dependencies = [ 585 | "byteorder", 586 | "libc", 587 | "winapi", 588 | ] 589 | 590 | [[package]] 591 | name = "spin" 592 | version = "0.9.8" 593 | source = "registry+https://github.com/rust-lang/crates.io-index" 594 | checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" 595 | 596 | [[package]] 597 | name = "subtle" 598 | version = "2.5.0" 599 | source = "registry+https://github.com/rust-lang/crates.io-index" 600 | checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" 601 | 602 | [[package]] 603 | name = "syn" 604 | version = "2.0.16" 605 | source = "registry+https://github.com/rust-lang/crates.io-index" 606 | checksum = "a6f671d4b5ffdb8eadec19c0ae67fe2639df8684bd7bc4b83d986b8db549cf01" 607 | dependencies = [ 608 | "proc-macro2", 609 | "quote", 610 | "unicode-ident", 611 | ] 612 | 613 | [[package]] 614 | name = "tar" 615 | version = "0.4.40" 616 | source = "registry+https://github.com/rust-lang/crates.io-index" 617 | checksum = "b16afcea1f22891c49a00c751c7b63b2233284064f11a200fc624137c51e2ddb" 618 | dependencies = [ 619 | "filetime", 620 | "libc", 621 | "xattr", 622 | ] 623 | 624 | [[package]] 625 | name = "thread_local" 626 | version = "1.1.7" 627 | source = "registry+https://github.com/rust-lang/crates.io-index" 628 | checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" 629 | dependencies = [ 630 | "cfg-if", 631 | "once_cell", 632 | ] 633 | 634 | [[package]] 635 | name = "tinyvec" 636 | version = "1.6.0" 637 | source = "registry+https://github.com/rust-lang/crates.io-index" 638 | checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" 639 | dependencies = [ 640 | "tinyvec_macros", 641 | ] 642 | 643 | [[package]] 644 | name = "tinyvec_macros" 645 | version = "0.1.1" 646 | source = "registry+https://github.com/rust-lang/crates.io-index" 647 | checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" 648 | 649 | [[package]] 650 | name = "tracing" 651 | version = "0.1.37" 652 | source = "registry+https://github.com/rust-lang/crates.io-index" 653 | checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" 654 | dependencies = [ 655 | "cfg-if", 656 | "pin-project-lite", 657 | "tracing-core", 658 | ] 659 | 660 | [[package]] 661 | name = "tracing-core" 662 | version = "0.1.30" 663 | source = "registry+https://github.com/rust-lang/crates.io-index" 664 | checksum = "24eb03ba0eab1fd845050058ce5e616558e8f8d8fca633e6b163fe25c797213a" 665 | dependencies = [ 666 | "once_cell", 667 | "valuable", 668 | ] 669 | 670 | [[package]] 671 | name = "tracing-log" 672 | version = "0.2.0" 673 | source = "registry+https://github.com/rust-lang/crates.io-index" 674 | checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" 675 | dependencies = [ 676 | "log", 677 | "once_cell", 678 | "tracing-core", 679 | ] 680 | 681 | [[package]] 682 | name = "tracing-subscriber" 683 | version = "0.3.18" 684 | source = "registry+https://github.com/rust-lang/crates.io-index" 685 | checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" 686 | dependencies = [ 687 | "matchers", 688 | "nu-ansi-term", 689 | "once_cell", 690 | "regex", 691 | "sharded-slab", 692 | "smallvec", 693 | "thread_local", 694 | "tracing", 695 | "tracing-core", 696 | "tracing-log", 697 | ] 698 | 699 | [[package]] 700 | name = "typenum" 701 | version = "1.17.0" 702 | source = "registry+https://github.com/rust-lang/crates.io-index" 703 | checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" 704 | 705 | [[package]] 706 | name = "unicode-bidi" 707 | version = "0.3.15" 708 | source = "registry+https://github.com/rust-lang/crates.io-index" 709 | checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" 710 | 711 | [[package]] 712 | name = "unicode-ident" 713 | version = "1.0.6" 714 | source = "registry+https://github.com/rust-lang/crates.io-index" 715 | checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" 716 | 717 | [[package]] 718 | name = "unicode-normalization" 719 | version = "0.1.23" 720 | source = "registry+https://github.com/rust-lang/crates.io-index" 721 | checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" 722 | dependencies = [ 723 | "tinyvec", 724 | ] 725 | 726 | [[package]] 727 | name = "unreachable" 728 | version = "1.0.0" 729 | source = "registry+https://github.com/rust-lang/crates.io-index" 730 | checksum = "382810877fe448991dfc7f0dd6e3ae5d58088fd0ea5e35189655f84e6814fa56" 731 | dependencies = [ 732 | "void", 733 | ] 734 | 735 | [[package]] 736 | name = "untrusted" 737 | version = "0.9.0" 738 | source = "registry+https://github.com/rust-lang/crates.io-index" 739 | checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" 740 | 741 | [[package]] 742 | name = "ureq" 743 | version = "2.9.6" 744 | source = "registry+https://github.com/rust-lang/crates.io-index" 745 | checksum = "11f214ce18d8b2cbe84ed3aa6486ed3f5b285cf8d8fbdbce9f3f767a724adc35" 746 | dependencies = [ 747 | "base64", 748 | "log", 749 | "once_cell", 750 | "rustls", 751 | "rustls-pki-types", 752 | "rustls-webpki", 753 | "socks", 754 | "url", 755 | "webpki-roots", 756 | ] 757 | 758 | [[package]] 759 | name = "url" 760 | version = "2.5.0" 761 | source = "registry+https://github.com/rust-lang/crates.io-index" 762 | checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" 763 | dependencies = [ 764 | "form_urlencoded", 765 | "idna", 766 | "percent-encoding", 767 | ] 768 | 769 | [[package]] 770 | name = "valuable" 771 | version = "0.1.0" 772 | source = "registry+https://github.com/rust-lang/crates.io-index" 773 | checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" 774 | 775 | [[package]] 776 | name = "version_check" 777 | version = "0.9.4" 778 | source = "registry+https://github.com/rust-lang/crates.io-index" 779 | checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" 780 | 781 | [[package]] 782 | name = "void" 783 | version = "1.0.2" 784 | source = "registry+https://github.com/rust-lang/crates.io-index" 785 | checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" 786 | 787 | [[package]] 788 | name = "wasi" 789 | version = "0.11.0+wasi-snapshot-preview1" 790 | source = "registry+https://github.com/rust-lang/crates.io-index" 791 | checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" 792 | 793 | [[package]] 794 | name = "webpki-roots" 795 | version = "0.26.1" 796 | source = "registry+https://github.com/rust-lang/crates.io-index" 797 | checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009" 798 | dependencies = [ 799 | "rustls-pki-types", 800 | ] 801 | 802 | [[package]] 803 | name = "winapi" 804 | version = "0.3.9" 805 | source = "registry+https://github.com/rust-lang/crates.io-index" 806 | checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" 807 | dependencies = [ 808 | "winapi-i686-pc-windows-gnu", 809 | "winapi-x86_64-pc-windows-gnu", 810 | ] 811 | 812 | [[package]] 813 | name = "winapi-i686-pc-windows-gnu" 814 | version = "0.4.0" 815 | source = "registry+https://github.com/rust-lang/crates.io-index" 816 | checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" 817 | 818 | [[package]] 819 | name = "winapi-x86_64-pc-windows-gnu" 820 | version = "0.4.0" 821 | source = "registry+https://github.com/rust-lang/crates.io-index" 822 | checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" 823 | 824 | [[package]] 825 | name = "windows-sys" 826 | version = "0.52.0" 827 | source = "registry+https://github.com/rust-lang/crates.io-index" 828 | checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" 829 | dependencies = [ 830 | "windows-targets", 831 | ] 832 | 833 | [[package]] 834 | name = "windows-targets" 835 | version = "0.52.0" 836 | source = "registry+https://github.com/rust-lang/crates.io-index" 837 | checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" 838 | dependencies = [ 839 | "windows_aarch64_gnullvm", 840 | "windows_aarch64_msvc", 841 | "windows_i686_gnu", 842 | "windows_i686_msvc", 843 | "windows_x86_64_gnu", 844 | "windows_x86_64_gnullvm", 845 | "windows_x86_64_msvc", 846 | ] 847 | 848 | [[package]] 849 | name = "windows_aarch64_gnullvm" 850 | version = "0.52.0" 851 | source = "registry+https://github.com/rust-lang/crates.io-index" 852 | checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" 853 | 854 | [[package]] 855 | name = "windows_aarch64_msvc" 856 | version = "0.52.0" 857 | source = "registry+https://github.com/rust-lang/crates.io-index" 858 | checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" 859 | 860 | [[package]] 861 | name = "windows_i686_gnu" 862 | version = "0.52.0" 863 | source = "registry+https://github.com/rust-lang/crates.io-index" 864 | checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" 865 | 866 | [[package]] 867 | name = "windows_i686_msvc" 868 | version = "0.52.0" 869 | source = "registry+https://github.com/rust-lang/crates.io-index" 870 | checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" 871 | 872 | [[package]] 873 | name = "windows_x86_64_gnu" 874 | version = "0.52.0" 875 | source = "registry+https://github.com/rust-lang/crates.io-index" 876 | checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" 877 | 878 | [[package]] 879 | name = "windows_x86_64_gnullvm" 880 | version = "0.52.0" 881 | source = "registry+https://github.com/rust-lang/crates.io-index" 882 | checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" 883 | 884 | [[package]] 885 | name = "windows_x86_64_msvc" 886 | version = "0.52.0" 887 | source = "registry+https://github.com/rust-lang/crates.io-index" 888 | checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" 889 | 890 | [[package]] 891 | name = "xattr" 892 | version = "1.3.1" 893 | source = "registry+https://github.com/rust-lang/crates.io-index" 894 | checksum = "8da84f1a25939b27f6820d92aed108f83ff920fdf11a7b19366c27c4cda81d4f" 895 | dependencies = [ 896 | "libc", 897 | "linux-raw-sys", 898 | "rustix", 899 | ] 900 | 901 | [[package]] 902 | name = "zeroize" 903 | version = "1.7.0" 904 | source = "registry+https://github.com/rust-lang/crates.io-index" 905 | checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" 906 | -------------------------------------------------------------------------------- /native/ortex/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "ortex" 3 | version = "0.1.0" 4 | authors = [] 5 | edition = "2018" 6 | 7 | [lib] 8 | name = "ortex" 9 | path = "src/lib.rs" 10 | crate-type = ["cdylib"] 11 | 12 | [dependencies] 13 | rustler = "0.29.0" 14 | ort = { version = "2.0.0-rc.8" } 15 | ndarray = "0.16.1" 16 | half = "2.2.1" 17 | tracing-subscriber = { version = "0.3", features = [ "env-filter", "fmt" ] } 18 | num-traits = "0.2.15" 19 | rustls = "0.22.4" 20 | 21 | [features] 22 | # ONNXRuntime Execution providers 23 | directml = ["ort/directml"] 24 | coreml = ["ort/coreml"] 25 | cuda = ["ort/cuda"] 26 | tensorrt = ["ort/tensorrt"] 27 | -------------------------------------------------------------------------------- /native/ortex/README.md: -------------------------------------------------------------------------------- 1 | # NIF for Elixir.Ortex.Native 2 | 3 | ## To build the NIF module: 4 | 5 | - Your NIF will now build along with your project. 6 | 7 | ## To load the NIF: 8 | 9 | ```elixir 10 | defmodule Ortex do 11 | use Rustler, otp_app: :ortex, crate: "ortex" 12 | 13 | # When your NIF is loaded, it will override this function. 14 | def add(_a, _b), do: :erlang.nif_error(:nif_not_loaded) 15 | end 16 | ``` 17 | 18 | ## Examples 19 | 20 | [This](https://github.com/rusterlium/NifIo) is a complete example of a NIF written in Rust. 21 | 22 | ## Docs 23 | To build the documentation for this locally, run 24 | 25 | ```shell 26 | cargo doc --open --no-deps --document-private-items 27 | ``` 28 | -------------------------------------------------------------------------------- /native/ortex/src/constants.rs: -------------------------------------------------------------------------------- 1 | pub const CUDA: &str = "cuda"; 2 | pub const CPU: &str = "cpu"; 3 | pub const TENSORRT: &str = "tensorrt"; 4 | pub const ACL: &str = "acl"; 5 | pub const ONEDNN: &str = "onednn"; 6 | pub const COREML: &str = "coreml"; 7 | pub const DIRECTML: &str = "directml"; 8 | pub const ROCM: &str = "rocm"; 9 | 10 | pub mod ortex_atoms { 11 | rustler::atoms! { 12 | // Tensor types available 13 | s8, s16, s32, s64, 14 | u8, u16, u32, u64, 15 | f16, f32, f64, 16 | bf16, 17 | c64, c128, 18 | s, u, f, bf, c, 19 | // Execution provider atoms 20 | cpu, cuda, tensorrt, acl, dnnl, 21 | onednn, coreml, directml, rocm 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /native/ortex/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! # Ortex 2 | //! Rust bindings between [ONNX Runtime](https://github.com/microsoft/onnxruntime) and 3 | //! Erlang/Elixir using [Ort](https://docs.rs/ort) and [Rustler](https://docs.rs/rustler). 4 | //! These are only meant to be accessed via the NIF interface provided by Rustler and not 5 | //! directly. 6 | 7 | mod constants; 8 | mod model; 9 | mod tensor; 10 | mod utils; 11 | 12 | use model::OrtexModel; 13 | use tensor::OrtexTensor; 14 | 15 | use rustler::resource::ResourceArc; 16 | use rustler::types::Binary; 17 | use rustler::{Atom, Env, NifResult, Term}; 18 | 19 | #[rustler::nif(schedule = "DirtyIo")] 20 | fn init( 21 | env: Env, 22 | model_path: String, 23 | eps: Vec, 24 | opt: i32, 25 | ) -> NifResult> { 26 | let eps = utils::map_eps(env, eps); 27 | let model = model::init(model_path, eps, opt) 28 | .map_err(|e| rustler::Error::Term(Box::new(e.to_string())))?; 29 | Ok(ResourceArc::new(model)) 30 | } 31 | 32 | #[rustler::nif] 33 | fn show_session( 34 | model: ResourceArc, 35 | ) -> NifResult<( 36 | Vec<(String, String, Option>)>, 37 | Vec<(String, String, Option>)>, 38 | )> { 39 | Ok(model::show(model)) 40 | } 41 | 42 | #[rustler::nif(schedule = "DirtyIo")] 43 | fn run( 44 | model: ResourceArc, 45 | inputs: Vec>, 46 | ) -> NifResult, Vec, Atom, usize)>> { 47 | model::run(model, inputs).map_err(|e| rustler::Error::Term(Box::new(e.to_string()))) 48 | } 49 | 50 | #[rustler::nif(schedule = "DirtyCpu")] 51 | fn from_binary(bin: Binary, shape: Term, dtype: Term) -> NifResult> { 52 | let shape: Vec = rustler::types::tuple::get_tuple(shape)? 53 | .iter() 54 | .map(|x| -> NifResult { Ok(x.decode::())? }) 55 | .collect::>>()?; 56 | let (dtype_t, dtype_bits): (Term, usize) = dtype.decode()?; 57 | let dtype_str = dtype_t.atom_to_string()?; 58 | 59 | utils::from_binary(bin, shape, dtype_str, dtype_bits) 60 | .map_err(|e| rustler::Error::Term(Box::new(e.to_string()))) 61 | } 62 | 63 | #[rustler::nif(schedule = "DirtyCpu")] 64 | fn to_binary<'a>( 65 | env: Env<'a>, 66 | reference: ResourceArc, 67 | bits: usize, 68 | limit: usize, 69 | ) -> NifResult> { 70 | utils::to_binary(env, reference, bits, limit) 71 | } 72 | 73 | #[rustler::nif] 74 | pub fn slice<'a>( 75 | tensor: ResourceArc, 76 | start_indicies: Vec, 77 | lengths: Vec, 78 | strides: Vec, 79 | ) -> NifResult> { 80 | Ok(ResourceArc::new(tensor.slice( 81 | start_indicies, 82 | lengths, 83 | strides, 84 | ))) 85 | } 86 | 87 | #[rustler::nif] 88 | pub fn reshape<'a>( 89 | tensor: ResourceArc, 90 | shape: Vec, 91 | ) -> NifResult> { 92 | Ok(ResourceArc::new(tensor.reshape(shape)?)) 93 | } 94 | 95 | #[rustler::nif] 96 | pub fn concatenate<'a>( 97 | tensors: Vec>, 98 | dtype: Term, 99 | axis: i32, 100 | ) -> NifResult> { 101 | let (dtype_t, dtype_bits): (Term, usize) = dtype.decode()?; 102 | let dtype_str = dtype_t.atom_to_string()?; 103 | let concatted = tensor::concatenate(tensors, (&dtype_str, dtype_bits), axis as usize); 104 | Ok(ResourceArc::new(concatted)) 105 | } 106 | 107 | rustler::init!( 108 | "Elixir.Ortex.Native", 109 | [ 110 | run, 111 | init, 112 | from_binary, 113 | to_binary, 114 | show_session, 115 | slice, 116 | reshape, 117 | concatenate 118 | ], 119 | load = |env: Env, _| { 120 | rustler::resource!(OrtexModel, env); 121 | rustler::resource!(OrtexTensor, env); 122 | true 123 | } 124 | ); 125 | -------------------------------------------------------------------------------- /native/ortex/src/model.rs: -------------------------------------------------------------------------------- 1 | //! Abstractions for creating an ONNX Runtime Session and Environment which can be safely 2 | //! passed to and from the BEAM. 3 | //! 4 | //! # Examples 5 | //! 6 | //! ``` 7 | //! let model = init("./models/resnet50.onnx", vec![])?; 8 | //! let (inputs, outputs) = show(model)?; 9 | //! ``` 10 | 11 | use crate::tensor::OrtexTensor; 12 | use crate::utils::{is_bool_input, map_opt_level}; 13 | use std::convert::TryInto; 14 | use std::iter::zip; 15 | 16 | use ort::{Error, ExecutionProviderDispatch, Session}; 17 | use rustler::resource::ResourceArc; 18 | use rustler::Atom; 19 | 20 | /// Holds the model state which include onnxruntime session and environment. All 21 | /// are threadsafe so this can be called concurrently from the beam. 22 | pub struct OrtexModel { 23 | pub session: ort::Session, 24 | } 25 | 26 | // Since we're only using the session for inference and 27 | // inference is threadsafe, this Sync is safe. Additionally, 28 | // Environment is global and also threadsafe 29 | // https://github.com/microsoft/onnxruntime/issues/114 30 | unsafe impl Sync for OrtexModel {} 31 | 32 | /// Creates a model given the path to the model and vector of execution providers. 33 | /// The execution providers are Atoms from Erlang/Elixir. 34 | pub fn init( 35 | model_path: String, 36 | eps: Vec, 37 | opt: i32, 38 | ) -> Result { 39 | // TODO: send tracing logs to erlang/elixir _somehow_ 40 | // tracing_subscriber::fmt::init(); 41 | 42 | let session = Session::builder()? 43 | .with_optimization_level(map_opt_level(opt))? 44 | .with_execution_providers(eps)? 45 | .commit_from_file(model_path)?; 46 | 47 | let state = OrtexModel { session }; 48 | Ok(state) 49 | } 50 | 51 | /// Returns input/output information about a model. The result is a Tuple of 52 | /// `inputs` and `outputs` with elements of `(Name, Type, Dimension)` where 53 | /// `Dimension` elements of -1 are dynamic. 54 | pub fn show( 55 | model: ResourceArc, 56 | ) -> ( 57 | Vec<(String, String, Option>)>, 58 | Vec<(String, String, Option>)>, 59 | ) { 60 | let model: &OrtexModel = &*model; 61 | 62 | let mut inputs = Vec::new(); 63 | for input in model.session.inputs.iter() { 64 | let name = input.name.to_string(); 65 | let repr = format!("{:#?}", input.input_type); 66 | let dims = Option::<&Vec>::cloned(input.input_type.tensor_dimensions()); 67 | inputs.push((name, repr, dims)); 68 | } 69 | 70 | let mut outputs = Vec::new(); 71 | for output in model.session.outputs.iter() { 72 | let name = output.name.to_string(); 73 | let repr = format!("{:#?}", output.output_type); 74 | let dims = Option::<&Vec>::cloned(output.output_type.tensor_dimensions()); 75 | outputs.push((name, repr, dims)); 76 | } 77 | 78 | (inputs, outputs) 79 | } 80 | 81 | /// Runs the model with the given inputs. Returns a vector of tensors. Use `model::show` 82 | /// to see what the model expects for input and output shapes. 83 | pub fn run( 84 | model: ResourceArc, 85 | inputs: Vec>, 86 | ) -> Result, Vec, Atom, usize)>, Error> { 87 | // Grab the session and run a forward pass with it 88 | let session: &ort::Session = &model.session; 89 | 90 | let mut ortified_inputs: Vec = Vec::new(); 91 | 92 | for (elixir_input, onnx_input) in zip(inputs, &session.inputs) { 93 | let derefed_input: &OrtexTensor = &elixir_input; 94 | if is_bool_input(&onnx_input.input_type) { 95 | // this assumes that the boolean input isn't huge -- we're cloning it twice; 96 | // once below, once in the try_into() 97 | let boolified_input: &OrtexTensor = &derefed_input.clone().to_bool(); 98 | let v: ort::SessionInputValue = boolified_input.try_into()?; 99 | ortified_inputs.push(v); 100 | } else { 101 | let v: ort::SessionInputValue = derefed_input.try_into()?; 102 | ortified_inputs.push(v); 103 | } 104 | } 105 | 106 | // Construct a Vec of ModelOutput enums based on the DynOrtTensor data type 107 | let outputs = session.run(&ortified_inputs[..])?; 108 | let mut collected_outputs = Vec::new(); 109 | 110 | for output_descriptor in &session.outputs { 111 | let output_name: &str = &output_descriptor.name; 112 | let val = outputs.get(output_name).expect( 113 | &format!( 114 | "Expected {} to be in the outputs, but didn't find it", 115 | output_name 116 | )[..], 117 | ); 118 | 119 | // NOTE: try_into impl here will implicitly map bool outputs to u8 outputs 120 | let ortextensor: OrtexTensor = val.try_into()?; 121 | let shape = ortextensor.shape(); 122 | let (dtype, bits) = ortextensor.dtype(); 123 | 124 | let collected_output = (ResourceArc::new(ortextensor), shape, dtype, bits); 125 | collected_outputs.push(collected_output) 126 | } 127 | 128 | Ok(collected_outputs) 129 | } 130 | -------------------------------------------------------------------------------- /native/ortex/src/tensor.rs: -------------------------------------------------------------------------------- 1 | //! Conversions for packing/unpacking `OrtexTensor`s into different types 2 | use core::convert::TryFrom; 3 | use ndarray::prelude::*; 4 | use ndarray::{ArrayBase, ArrayView, Data, IxDyn, IxDynImpl, ViewRepr}; 5 | use ort::{DynValue, Error, Value}; 6 | use rustler::resource::ResourceArc; 7 | use rustler::Atom; 8 | use std::convert::TryInto; 9 | 10 | use crate::constants::ortex_atoms; 11 | 12 | #[derive(Debug)] 13 | #[allow(non_camel_case_types)] 14 | /// Enum for wrapping different types to pass back to the BEAM since rustler can't 15 | /// pass type generics back and forth 16 | pub enum OrtexTensor { 17 | s8(Array), 18 | s16(Array), 19 | s32(Array), 20 | s64(Array), 21 | u8(Array), 22 | u16(Array), 23 | u32(Array), 24 | u64(Array), 25 | f16(Array), 26 | bf16(Array), 27 | f32(Array), 28 | f64(Array), 29 | // the bool input is for internal use only. 30 | // Any Nx facing ops should panic if called on a bool input 31 | bool(Array), 32 | } 33 | 34 | impl OrtexTensor { 35 | pub fn shape(&self) -> Vec { 36 | match self { 37 | OrtexTensor::s8(y) => y.shape().to_owned(), 38 | OrtexTensor::s16(y) => y.shape().to_owned(), 39 | OrtexTensor::s32(y) => y.shape().to_owned(), 40 | OrtexTensor::s64(y) => y.shape().to_owned(), 41 | OrtexTensor::u8(y) => y.shape().to_owned(), 42 | OrtexTensor::u16(y) => y.shape().to_owned(), 43 | OrtexTensor::u32(y) => y.shape().to_owned(), 44 | OrtexTensor::u64(y) => y.shape().to_owned(), 45 | OrtexTensor::f16(y) => y.shape().to_owned(), 46 | OrtexTensor::bf16(y) => y.shape().to_owned(), 47 | OrtexTensor::f32(y) => y.shape().to_owned(), 48 | OrtexTensor::f64(y) => y.shape().to_owned(), 49 | _ => panic!("Can't convert this type to Nx format"), 50 | } 51 | } 52 | 53 | pub fn reshape(&self, shape: Vec) -> rustler::NifResult { 54 | match self { 55 | OrtexTensor::s8(y) => Ok(OrtexTensor::s8( 56 | y.clone() 57 | .into_shape_with_order(shape) 58 | .map_err(|e| rustler::Error::Term(Box::new(e.to_string())))?, 59 | )), 60 | OrtexTensor::s16(y) => Ok(OrtexTensor::s16( 61 | y.clone() 62 | .into_shape_with_order(shape) 63 | .map_err(|e| rustler::Error::Term(Box::new(e.to_string())))?, 64 | )), 65 | OrtexTensor::s32(y) => Ok(OrtexTensor::s32( 66 | y.clone() 67 | .into_shape_with_order(shape) 68 | .map_err(|e| rustler::Error::Term(Box::new(e.to_string())))?, 69 | )), 70 | OrtexTensor::s64(y) => Ok(OrtexTensor::s64( 71 | y.clone() 72 | .into_shape_with_order(shape) 73 | .map_err(|e| rustler::Error::Term(Box::new(e.to_string())))?, 74 | )), 75 | OrtexTensor::u8(y) => Ok(OrtexTensor::u8( 76 | y.clone() 77 | .into_shape_with_order(shape) 78 | .map_err(|e| rustler::Error::Term(Box::new(e.to_string())))?, 79 | )), 80 | OrtexTensor::u16(y) => Ok(OrtexTensor::u16( 81 | y.clone() 82 | .into_shape_with_order(shape) 83 | .map_err(|e| rustler::Error::Term(Box::new(e.to_string())))?, 84 | )), 85 | OrtexTensor::u32(y) => Ok(OrtexTensor::u32( 86 | y.clone() 87 | .into_shape_with_order(shape) 88 | .map_err(|e| rustler::Error::Term(Box::new(e.to_string())))?, 89 | )), 90 | OrtexTensor::u64(y) => Ok(OrtexTensor::u64( 91 | y.clone() 92 | .into_shape_with_order(shape) 93 | .map_err(|e| rustler::Error::Term(Box::new(e.to_string())))?, 94 | )), 95 | OrtexTensor::f16(y) => Ok(OrtexTensor::f16( 96 | y.clone() 97 | .into_shape_with_order(shape) 98 | .map_err(|e| rustler::Error::Term(Box::new(e.to_string())))?, 99 | )), 100 | OrtexTensor::bf16(y) => Ok(OrtexTensor::bf16( 101 | y.clone() 102 | .into_shape_with_order(shape) 103 | .map_err(|e| rustler::Error::Term(Box::new(e.to_string())))?, 104 | )), 105 | OrtexTensor::f32(y) => Ok(OrtexTensor::f32( 106 | y.clone() 107 | .into_shape_with_order(shape) 108 | .map_err(|e| rustler::Error::Term(Box::new(e.to_string())))?, 109 | )), 110 | OrtexTensor::f64(y) => Ok(OrtexTensor::f64( 111 | y.clone() 112 | .into_shape_with_order(shape) 113 | .map_err(|e| rustler::Error::Term(Box::new(e.to_string())))?, 114 | )), 115 | _ => panic!("Can't convert this type to Nx format"), 116 | } 117 | } 118 | 119 | pub fn dtype(&self) -> (Atom, usize) { 120 | match self { 121 | OrtexTensor::s8(_) => (ortex_atoms::s(), 8), 122 | OrtexTensor::s16(_) => (ortex_atoms::s(), 16), 123 | OrtexTensor::s32(_) => (ortex_atoms::s(), 32), 124 | OrtexTensor::s64(_) => (ortex_atoms::s(), 64), 125 | OrtexTensor::u8(_) => (ortex_atoms::u(), 8), 126 | OrtexTensor::u16(_) => (ortex_atoms::u(), 16), 127 | OrtexTensor::u32(_) => (ortex_atoms::u(), 32), 128 | OrtexTensor::u64(_) => (ortex_atoms::u(), 64), 129 | OrtexTensor::f16(_) => (ortex_atoms::f(), 16), 130 | OrtexTensor::bf16(_) => (ortex_atoms::bf(), 16), 131 | OrtexTensor::f32(_) => (ortex_atoms::f(), 32), 132 | OrtexTensor::f64(_) => (ortex_atoms::f(), 64), 133 | _ => panic!("Can't convert this type to Nx format"), 134 | } 135 | } 136 | 137 | pub fn to_bytes<'a>(&'a self) -> &'a [u8] { 138 | let contents: &'a [u8] = match self { 139 | OrtexTensor::s8(y) => get_bytes(y), 140 | OrtexTensor::s16(y) => get_bytes(y), 141 | OrtexTensor::s32(y) => get_bytes(y), 142 | OrtexTensor::s64(y) => get_bytes(y), 143 | OrtexTensor::u8(y) => get_bytes(y), 144 | OrtexTensor::u16(y) => get_bytes(y), 145 | OrtexTensor::u32(y) => get_bytes(y), 146 | OrtexTensor::u64(y) => get_bytes(y), 147 | OrtexTensor::f16(y) => get_bytes(y), 148 | OrtexTensor::bf16(y) => get_bytes(y), 149 | OrtexTensor::f32(y) => get_bytes(y), 150 | OrtexTensor::f64(y) => get_bytes(y), 151 | _ => panic!("Can't convert this type to Nx format"), 152 | }; 153 | contents 154 | } 155 | 156 | pub fn slice<'a>( 157 | &'a self, 158 | start_indicies: Vec, 159 | lengths: Vec, 160 | strides: Vec, 161 | ) -> Self { 162 | let mut slice_specs: Vec<(isize, Option, isize)> = vec![]; 163 | for ((start_index, length), stride) in start_indicies 164 | .iter() 165 | .zip(lengths.iter()) 166 | .zip(strides.iter()) 167 | { 168 | slice_specs.push((*start_index, Some(*length + *start_index), *stride)); 169 | } 170 | match self { 171 | OrtexTensor::s8(y) => OrtexTensor::s8(slice_array(y, &slice_specs).to_owned()), 172 | OrtexTensor::s16(y) => OrtexTensor::s16(slice_array(y, &slice_specs).to_owned()), 173 | OrtexTensor::s32(y) => OrtexTensor::s32(slice_array(y, &slice_specs).to_owned()), 174 | OrtexTensor::s64(y) => OrtexTensor::s64(slice_array(y, &slice_specs).to_owned()), 175 | OrtexTensor::u8(y) => OrtexTensor::u8(slice_array(y, &slice_specs).to_owned()), 176 | OrtexTensor::u16(y) => OrtexTensor::u16(slice_array(y, &slice_specs).to_owned()), 177 | OrtexTensor::u32(y) => OrtexTensor::u32(slice_array(y, &slice_specs).to_owned()), 178 | OrtexTensor::u64(y) => OrtexTensor::u64(slice_array(y, &slice_specs).to_owned()), 179 | OrtexTensor::f16(y) => OrtexTensor::f16(slice_array(y, &slice_specs).to_owned()), 180 | OrtexTensor::bf16(y) => OrtexTensor::bf16(slice_array(y, &slice_specs).to_owned()), 181 | OrtexTensor::f32(y) => OrtexTensor::f32(slice_array(y, &slice_specs).to_owned()), 182 | OrtexTensor::f64(y) => OrtexTensor::f64(slice_array(y, &slice_specs).to_owned()), 183 | _ => panic!("Can't convert this type to Nx format"), 184 | } 185 | } 186 | 187 | pub fn to_bool(self) -> OrtexTensor { 188 | match self { 189 | OrtexTensor::u8(y) => { 190 | let bool_tensor = y.to_owned().mapv(|x| match x { 191 | 0 => false, 192 | 1 => true, 193 | _ => { 194 | panic!( 195 | "Tried to convert a u8 tensor to bool, but not every element is 0 or 1" 196 | ) 197 | } 198 | }); 199 | OrtexTensor::bool(bool_tensor) 200 | } 201 | t => panic!("Can't convert this type {:?} to bool", t.dtype()), 202 | } 203 | } 204 | } 205 | 206 | fn slice_array<'a, T, D>( 207 | array: &'a Array, 208 | slice_specs: &'a Vec<(isize, Option, isize)>, 209 | ) -> ArrayView<'a, T, D> 210 | where 211 | D: Dimension, 212 | { 213 | array.slice_each_axis(|ax: ndarray::AxisDescription| { 214 | let (start, end, step) = slice_specs[ax.axis.index()]; 215 | ndarray::Slice { start, end, step } 216 | }) 217 | } 218 | 219 | fn get_bytes<'a, T>(array: &'a ArrayBase) -> &'a [u8] 220 | where 221 | T: Data, 222 | { 223 | let len = array.len(); 224 | let binding = unsafe { std::mem::zeroed() }; 225 | let f = array.get(0).unwrap_or(&binding); 226 | let size: usize = std::mem::size_of_val(f); 227 | unsafe { std::slice::from_raw_parts(array.as_ptr() as *const u8, len * size) } 228 | } 229 | 230 | impl TryFrom<&Value> for OrtexTensor { 231 | type Error = Error; 232 | fn try_from(e: &Value) -> Result { 233 | let dtype: ort::ValueType = e.dtype(); 234 | let ty = match dtype { 235 | ort::ValueType::Tensor { 236 | ty: t, 237 | dimensions: _, 238 | } => t, 239 | _ => panic!("can't decode non tensor, got {}", dtype), 240 | }; 241 | 242 | let tensor = match ty { 243 | ort::TensorElementType::Bfloat16 => { 244 | OrtexTensor::bf16(e.try_extract_tensor::()?.into_owned()) 245 | } 246 | ort::TensorElementType::Float16 => { 247 | OrtexTensor::f16(e.try_extract_tensor::()?.into_owned()) 248 | } 249 | ort::TensorElementType::Float32 => { 250 | OrtexTensor::f32(e.try_extract_tensor::()?.into_owned()) 251 | } 252 | ort::TensorElementType::Float64 => { 253 | OrtexTensor::f64(e.try_extract_tensor::()?.into_owned()) 254 | } 255 | ort::TensorElementType::Uint8 => { 256 | OrtexTensor::u8(e.try_extract_tensor::()?.into_owned()) 257 | } 258 | ort::TensorElementType::Uint16 => { 259 | OrtexTensor::u16(e.try_extract_tensor::()?.into_owned()) 260 | } 261 | ort::TensorElementType::Uint32 => { 262 | OrtexTensor::u32(e.try_extract_tensor::()?.into_owned()) 263 | } 264 | ort::TensorElementType::Uint64 => { 265 | OrtexTensor::u64(e.try_extract_tensor::()?.into_owned()) 266 | } 267 | ort::TensorElementType::Int8 => { 268 | OrtexTensor::s8(e.try_extract_tensor::()?.into_owned()) 269 | } 270 | ort::TensorElementType::Int16 => { 271 | OrtexTensor::s16(e.try_extract_tensor::()?.into_owned()) 272 | } 273 | ort::TensorElementType::Int32 => { 274 | OrtexTensor::s32(e.try_extract_tensor::()?.into_owned()) 275 | } 276 | ort::TensorElementType::Int64 => { 277 | OrtexTensor::s64(e.try_extract_tensor::()?.into_owned()) 278 | } 279 | ort::TensorElementType::String => { 280 | todo!("Can't return string tensors") 281 | } 282 | // map the output into u8 space 283 | ort::TensorElementType::Bool => { 284 | let nd_array = e.try_extract_tensor::()?.into_owned(); 285 | OrtexTensor::u8(nd_array.mapv(|x| x as u8)) 286 | } 287 | }; 288 | 289 | Ok(tensor) 290 | } 291 | } 292 | 293 | impl TryFrom<&OrtexTensor> for ort::SessionInputValue<'_> { 294 | type Error = Error; 295 | fn try_from(ort_tensor: &OrtexTensor) -> Result { 296 | let r: DynValue = match ort_tensor { 297 | OrtexTensor::s8(arr) => arr.to_owned().try_into()?, 298 | OrtexTensor::s16(arr) => arr.clone().try_into()?, 299 | OrtexTensor::s32(arr) => arr.clone().try_into()?, 300 | OrtexTensor::s64(arr) => arr.clone().try_into()?, 301 | OrtexTensor::f16(arr) => arr.clone().try_into()?, 302 | OrtexTensor::f32(arr) => arr.clone().try_into()?, 303 | OrtexTensor::f64(arr) => arr.clone().try_into()?, 304 | OrtexTensor::bf16(arr) => arr.clone().try_into()?, 305 | OrtexTensor::u8(arr) => arr.clone().try_into()?, 306 | OrtexTensor::u16(arr) => arr.clone().try_into()?, 307 | OrtexTensor::u32(arr) => arr.clone().try_into()?, 308 | OrtexTensor::u64(arr) => arr.clone().try_into()?, 309 | OrtexTensor::bool(arr) => arr.clone().try_into()?, 310 | }; 311 | Ok(r.into()) 312 | } 313 | } 314 | 315 | impl Clone for OrtexTensor { 316 | fn clone(&self) -> Self { 317 | match self { 318 | OrtexTensor::s8(t) => OrtexTensor::s8(t.clone()), 319 | OrtexTensor::s16(t) => OrtexTensor::s16(t.clone()), 320 | OrtexTensor::s32(t) => OrtexTensor::s32(t.clone()), 321 | OrtexTensor::s64(t) => OrtexTensor::s64(t.clone()), 322 | OrtexTensor::bf16(t) => OrtexTensor::bf16(t.clone()), 323 | OrtexTensor::f16(t) => OrtexTensor::f16(t.clone()), 324 | OrtexTensor::f32(t) => OrtexTensor::f32(t.clone()), 325 | OrtexTensor::f64(t) => OrtexTensor::f64(t.clone()), 326 | OrtexTensor::u8(t) => OrtexTensor::u8(t.clone()), 327 | OrtexTensor::u16(t) => OrtexTensor::u16(t.clone()), 328 | OrtexTensor::u32(t) => OrtexTensor::u32(t.clone()), 329 | OrtexTensor::u64(t) => OrtexTensor::u64(t.clone()), 330 | OrtexTensor::bool(t) => OrtexTensor::bool(t.clone()), 331 | } 332 | } 333 | } 334 | 335 | // Currently only supports concatenating tenors of the same type. 336 | // 337 | // This is a similar structure to the above match clauses, except each function 338 | // in map is more complex and needs to be written out explicitly. To reduce 339 | // repetition, the concatenate! macro expands that code and makes the necessary 340 | // minor tweaks 341 | 342 | macro_rules! concatenate { 343 | // `typ` is the actual datatype, `ort_tensor_kind` is the OrtexTensor variant 344 | ($tensors:expr, $axis:expr, $typ:ty, $ort_tensor_kind:ident) => {{ 345 | type ArrayType<'a> = ArrayBase, Dim>; 346 | fn filter(tensor: &OrtexTensor) -> Option { 347 | match tensor { 348 | OrtexTensor::$ort_tensor_kind(x) => Some(x.view()), 349 | _ => None, 350 | } 351 | } 352 | // hack way to type coalesce. Filters out any ndarray's that don't 353 | // have the desired type 354 | let tensors: Vec = $tensors 355 | .iter() 356 | .filter_map(|tensor| filter(tensor)) 357 | .collect(); 358 | 359 | let tensors = ndarray::concatenate(Axis($axis), &tensors).unwrap(); 360 | // data is not contiguous after the concatenation above. To decode 361 | // properly, need to create a new contiguous vector 362 | let tensors = 363 | Array::from_shape_vec(tensors.raw_dim(), tensors.iter().cloned().collect()).unwrap(); 364 | OrtexTensor::$ort_tensor_kind(tensors) 365 | }}; 366 | } 367 | 368 | pub fn concatenate( 369 | tensors: Vec>, 370 | dtype: (&str, usize), 371 | axis: usize, 372 | ) -> OrtexTensor { 373 | match dtype { 374 | ("s", 8) => concatenate!(tensors, axis, i8, s8), 375 | ("s", 16) => concatenate!(tensors, axis, i16, s16), 376 | ("s", 32) => concatenate!(tensors, axis, i32, s32), 377 | ("s", 64) => concatenate!(tensors, axis, i64, s64), 378 | ("u", 8) => concatenate!(tensors, axis, u8, u8), 379 | ("u", 16) => concatenate!(tensors, axis, u16, u16), 380 | ("u", 32) => concatenate!(tensors, axis, u32, u32), 381 | ("u", 64) => concatenate!(tensors, axis, u64, u64), 382 | ("f", 16) => concatenate!(tensors, axis, half::f16, f16), 383 | ("bf", 16) => concatenate!(tensors, axis, half::bf16, bf16), 384 | ("f", 32) => concatenate!(tensors, axis, f32, f32), 385 | ("f", 64) => concatenate!(tensors, axis, f64, f64), 386 | _ => unimplemented!(), 387 | } 388 | } 389 | -------------------------------------------------------------------------------- /native/ortex/src/utils.rs: -------------------------------------------------------------------------------- 1 | //! Serialization and deserialization to transfer between Ortex and BinaryBackend 2 | //! [Nx](https://hexdocs.pm/nx) backend. 3 | 4 | use crate::constants::*; 5 | use crate::tensor::OrtexTensor; 6 | use ndarray::{ArrayViewMut, Ix, IxDyn}; 7 | 8 | use ndarray::ShapeError; 9 | 10 | use rustler::resource::ResourceArc; 11 | use rustler::types::Binary; 12 | use rustler::{Atom, Env, NifResult}; 13 | 14 | use ort::{ExecutionProviderDispatch, GraphOptimizationLevel}; 15 | 16 | /// A faster (unsafe) way of creating an Array from an Erlang binary 17 | fn initialize_from_raw_ptr(ptr: *const T, shape: &[Ix]) -> ArrayViewMut { 18 | let array = unsafe { ArrayViewMut::from_shape_ptr(shape, ptr as *mut T) }; 19 | array 20 | } 21 | 22 | /// Given a Binary term, shape, and dtype from the BEAM, constructs an OrtexTensor and 23 | /// returns the reference to be used as an Nx.Backend representation. 24 | /// 25 | /// # Example 26 | /// 27 | /// ```elixir 28 | /// bin = <<1, 0, 0, 0, 1, 0, 0, 0>> 29 | /// ``` 30 | /// 31 | /// Create a shape `[2]` u32 OrtexTensor from a binary of 8 bytes 32 | /// ```elixir 33 | /// {:ok, reference} = from_binary(bin, {2}, {:u, 32}) 34 | /// ``` 35 | pub fn from_binary( 36 | bin: Binary, 37 | shape: Vec, 38 | dtype_str: String, 39 | dtype_bits: usize, 40 | ) -> Result, ShapeError> { 41 | match (dtype_str.as_ref(), dtype_bits) { 42 | ("bf", 16) => Ok(ResourceArc::new(OrtexTensor::bf16( 43 | initialize_from_raw_ptr(bin.as_ptr() as *const half::bf16, &shape).to_owned(), 44 | ))), 45 | ("f", 16) => Ok(ResourceArc::new(OrtexTensor::f16( 46 | initialize_from_raw_ptr(bin.as_ptr() as *const half::f16, &shape).to_owned(), 47 | ))), 48 | ("f", 32) => Ok(ResourceArc::new(OrtexTensor::f32( 49 | initialize_from_raw_ptr(bin.as_ptr() as *const f32, &shape).to_owned(), 50 | ))), 51 | ("f", 64) => Ok(ResourceArc::new(OrtexTensor::f64( 52 | initialize_from_raw_ptr(bin.as_ptr() as *const f64, &shape).to_owned(), 53 | ))), 54 | ("s", 8) => Ok(ResourceArc::new(OrtexTensor::s8( 55 | initialize_from_raw_ptr(bin.as_ptr() as *const i8, &shape).to_owned(), 56 | ))), 57 | ("s", 16) => Ok(ResourceArc::new(OrtexTensor::s16( 58 | initialize_from_raw_ptr(bin.as_ptr() as *const i16, &shape).to_owned(), 59 | ))), 60 | ("s", 32) => Ok(ResourceArc::new(OrtexTensor::s32( 61 | initialize_from_raw_ptr(bin.as_ptr() as *const i32, &shape).to_owned(), 62 | ))), 63 | ("s", 64) => Ok(ResourceArc::new(OrtexTensor::s64( 64 | initialize_from_raw_ptr(bin.as_ptr() as *const i64, &shape).to_owned(), 65 | ))), 66 | ("u", 8) => Ok(ResourceArc::new(OrtexTensor::u8( 67 | initialize_from_raw_ptr(bin.as_ptr() as *const u8, &shape).to_owned(), 68 | ))), 69 | ("u", 16) => Ok(ResourceArc::new(OrtexTensor::u16( 70 | initialize_from_raw_ptr(bin.as_ptr() as *const u16, &shape).to_owned(), 71 | ))), 72 | ("u", 32) => Ok(ResourceArc::new(OrtexTensor::u32( 73 | initialize_from_raw_ptr(bin.as_ptr() as *const u32, &shape).to_owned(), 74 | ))), 75 | ("u", 64) => Ok(ResourceArc::new(OrtexTensor::u64( 76 | initialize_from_raw_ptr(bin.as_ptr() as *const u64, &shape).to_owned(), 77 | ))), 78 | (&_, _) => unimplemented!(), 79 | } 80 | } 81 | 82 | /// Given a reference to an OrtexTensor return the binary representation to be used 83 | /// by the BinaryBackend of Nx. 84 | pub fn to_binary<'a>( 85 | env: Env<'a>, 86 | reference: ResourceArc, 87 | _bits: usize, 88 | _limit: usize, 89 | ) -> NifResult> { 90 | Ok(reference.make_binary(env, |x| x.to_bytes())) 91 | } 92 | 93 | /// Takes a vec of Atoms and transforms them into a vec of ExecutionProvider Enums 94 | pub fn map_eps(env: rustler::env::Env, eps: Vec) -> Vec { 95 | eps.iter() 96 | .map(|e| match &e.to_term(env).atom_to_string().unwrap()[..] { 97 | CPU => ort::CPUExecutionProvider::default().build(), 98 | CUDA => ort::CUDAExecutionProvider::default().build(), 99 | TENSORRT => ort::TensorRTExecutionProvider::default().build(), 100 | ACL => ort::ACLExecutionProvider::default().build(), 101 | ONEDNN => ort::OneDNNExecutionProvider::default().build(), 102 | COREML => ort::CoreMLExecutionProvider::default().build(), 103 | DIRECTML => ort::DirectMLExecutionProvider::default().build(), 104 | ROCM => ort::ROCmExecutionProvider::default().build(), 105 | _ => ort::CPUExecutionProvider::default().build(), 106 | }) 107 | .collect() 108 | } 109 | 110 | /// Take an optimization level and returns the 111 | pub fn map_opt_level(opt: i32) -> GraphOptimizationLevel { 112 | match opt { 113 | 1 => GraphOptimizationLevel::Level1, 114 | 2 => GraphOptimizationLevel::Level2, 115 | 3 => GraphOptimizationLevel::Level3, 116 | _ => GraphOptimizationLevel::Disable, 117 | } 118 | } 119 | 120 | pub fn is_bool_input(inp: &ort::ValueType) -> bool { 121 | match inp { 122 | ort::ValueType::Tensor { ty, .. } => ty == &ort::TensorElementType::Bool, 123 | ort::ValueType::Map { value, .. } => value == &ort::TensorElementType::Bool, 124 | ort::ValueType::Sequence(boxed_input) => is_bool_input(boxed_input), 125 | ort::ValueType::Optional(boxed_input) => is_bool_input(boxed_input), 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /python/export_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models as models 3 | 4 | model = models.resnet50(pretrained=True) 5 | 6 | model.eval() 7 | onnx_input = torch.randn(1, 3, 224, 224) 8 | 9 | torch.onnx.export( 10 | model, 11 | onnx_input, 12 | "resnet50.onnx", 13 | verbose=False, 14 | input_names=["input"], 15 | output_names=["output"], 16 | dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, 17 | export_params=True, 18 | opset_version=19, 19 | ) 20 | -------------------------------------------------------------------------------- /python/multi_input.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class MultiInputModel(torch.nn.Module): 5 | """ 6 | A simple model for testing Ortex multi-input and multi-output 7 | with different dtypes 8 | """ 9 | 10 | def __init__(self): 11 | super(MultiInputModel, self).__init__() 12 | 13 | self.linear1 = torch.nn.Linear(100, 10) 14 | self.linear2 = torch.nn.Linear(100, 10) 15 | 16 | def forward(self, x, y): 17 | x = self.linear1(x.float()) 18 | y = self.linear2(y) 19 | return x + y, x, y 20 | 21 | 22 | tinymodel = MultiInputModel() 23 | print(tinymodel) 24 | 25 | x = torch.zeros([100], dtype=torch.int32).unsqueeze(0) 26 | y = torch.zeros([100], dtype=torch.float32).unsqueeze(0) 27 | 28 | tinymodel(x, y) 29 | 30 | torch.onnx.export( 31 | tinymodel, 32 | (x, y), 33 | "tinymodel.onnx", 34 | input_names=["x", "y"], 35 | output_names=["output1", "output2", "output3"], 36 | dynamic_axes={ 37 | "x": {0: "batch_size"}, 38 | "y": {0: "batch_size"}, 39 | }, 40 | opset_version=19, 41 | ) 42 | -------------------------------------------------------------------------------- /test/dtype/dtype_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Ortex.TestDtypes do 2 | use ExUnit.Case 3 | 4 | {tensor, _} = Nx.Random.uniform(Nx.Random.key(42), 0, 256, shape: {100, 100}) 5 | @tensor tensor 6 | 7 | defp bin_binary(dtype) do 8 | %{data: %{state: bin}} = @tensor |> Nx.as_type(dtype) 9 | bin 10 | end 11 | 12 | defp bin_ortex(dtype) do 13 | %{data: %{state: bin}} = 14 | @tensor 15 | |> Nx.as_type(dtype) 16 | |> Nx.backend_transfer(Ortex.Backend) 17 | |> Nx.backend_transfer(Nx.BinaryBackend) 18 | 19 | bin 20 | end 21 | 22 | test "size 0 tensor" do 23 | %{data: %{state: bin1}} = Nx.tensor(0) 24 | 25 | %{data: %{state: bin2}} = 26 | Nx.tensor(0) 27 | |> Nx.backend_transfer(Ortex.Backend) 28 | |> Nx.backend_transfer(Nx.BinaryBackend) 29 | 30 | assert bin1 == bin2 31 | end 32 | 33 | test "u8 conversion" do 34 | assert bin_binary(:u8) == bin_ortex(:u8) 35 | end 36 | 37 | test "u16 conversion" do 38 | assert bin_binary(:u16) == bin_ortex(:u16) 39 | end 40 | 41 | test "u32 conversion" do 42 | assert bin_binary(:u32) == bin_ortex(:u32) 43 | end 44 | 45 | test "u64 conversion" do 46 | assert bin_binary(:u64) == bin_ortex(:u64) 47 | end 48 | 49 | test "s8 conversion" do 50 | assert bin_binary(:s8) == bin_ortex(:s8) 51 | end 52 | 53 | test "s16 conversion" do 54 | assert bin_binary(:s16) == bin_ortex(:s16) 55 | end 56 | 57 | test "s32 conversion" do 58 | assert bin_binary(:s32) == bin_ortex(:s32) 59 | end 60 | 61 | test "s64 conversion" do 62 | assert bin_binary(:s64) == bin_ortex(:s64) 63 | end 64 | 65 | test "f16 conversion" do 66 | assert bin_binary(:f16) == bin_ortex(:f16) 67 | end 68 | 69 | test "bf16 conversion" do 70 | assert bin_binary(:bf16) == bin_ortex(:bf16) 71 | end 72 | 73 | test "f32 conversion" do 74 | assert bin_binary(:f32) == bin_ortex(:f32) 75 | end 76 | 77 | test "f64 conversion" do 78 | assert bin_binary(:f64) == bin_ortex(:f64) 79 | end 80 | end 81 | -------------------------------------------------------------------------------- /test/ortex_test.exs: -------------------------------------------------------------------------------- 1 | defmodule OrtexTest do 2 | use ExUnit.Case 3 | doctest Ortex 4 | 5 | @tag :resnet50 6 | test "resnet50" do 7 | model = Ortex.load("./models/resnet50.onnx") 8 | 9 | input = Nx.broadcast(0.0, {1, 3, 224, 224}) 10 | {output} = Ortex.run(model, {input}) 11 | argmax = output |> Nx.backend_transfer() |> Nx.argmax(axis: 1) 12 | 13 | assert argmax == Nx.tensor([499]) 14 | end 15 | 16 | @tag :resnet50 17 | test "Nx.Serving with resnet50" do 18 | model = Ortex.load("./models/resnet50.onnx") 19 | 20 | serving = Nx.Serving.new(Ortex.Serving, model) 21 | batch = Nx.Batch.stack([{Nx.broadcast(0.0, {3, 224, 224})}]) 22 | {result} = Nx.Serving.run(serving, batch) 23 | assert result |> Nx.backend_transfer() |> Nx.argmax(axis: 1) == Nx.tensor([499]) 24 | end 25 | 26 | test "Nx.Serving with tinymodel" do 27 | model = Ortex.load("./models/tinymodel.onnx") 28 | 29 | serving = Nx.Serving.new(Ortex.Serving, model) 30 | 31 | # Create a batch of size 3 with {int32, float32} inputs 32 | batch = 33 | Nx.Batch.stack([ 34 | {Nx.broadcast(0, {100}) |> Nx.as_type(:s32), 35 | Nx.broadcast(0.0, {100}) |> Nx.as_type(:f32)}, 36 | {Nx.broadcast(1, {100}) |> Nx.as_type(:s32), 37 | Nx.broadcast(1.0, {100}) |> Nx.as_type(:f32)}, 38 | {Nx.broadcast(2, {100}) |> Nx.as_type(:s32), Nx.broadcast(2.0, {100}) |> Nx.as_type(:f32)} 39 | ]) 40 | 41 | {%Nx.Tensor{shape: {3, 10}}, %Nx.Tensor{shape: {3, 10}}, %Nx.Tensor{shape: {3, 10}}} = 42 | Nx.Serving.run(serving, batch) 43 | end 44 | end 45 | -------------------------------------------------------------------------------- /test/shape/concat_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Ortex.TestConcat do 2 | use ExUnit.Case 3 | 4 | # Testing each type, since there's a bunch of boilerplate that we want to 5 | # check for errors on the Rust side 6 | %{ 7 | "s8" => {:s, 8}, 8 | "s16" => {:s, 16}, 9 | "s32" => {:s, 16}, 10 | "s64" => {:s, 16}, 11 | "u8" => {:s, 16}, 12 | "u16" => {:s, 16}, 13 | "u32" => {:s, 16}, 14 | "u64" => {:s, 16}, 15 | "f16" => {:s, 16}, 16 | "bf16" => {:s, 16}, 17 | "f32" => {:s, 16}, 18 | "f64" => {:s, 16} 19 | } 20 | |> Enum.each(fn {type_str, type_tuple} -> 21 | test "Concat 1d tensors #{type_str}" do 22 | t1 = 23 | Nx.tensor([1, 2, 3, 4], type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend) 24 | 25 | t2 = 26 | Nx.tensor([1, 2, 3, 4], type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend) 27 | 28 | concatted = Nx.concatenate([t1, t2]) |> Nx.backend_transfer() 29 | expected = Nx.tensor([1, 2, 3, 4, 1, 2, 3, 4], type: unquote(type_tuple)) 30 | assert concatted == expected 31 | end 32 | 33 | test "Concat 3d tensors #{type_str}" do 34 | o1 = Nx.iota({2, 3, 5}, type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend) 35 | o2 = Nx.iota({1, 3, 5}, type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend) 36 | concatted = Nx.concatenate([o1, o2]) |> Nx.backend_transfer() 37 | expected = Nx.concatenate([o1 |> Nx.backend_transfer(), o2 |> Nx.backend_transfer()]) 38 | assert concatted == expected 39 | end 40 | 41 | test "Concat 3 #{type_str} vectors" do 42 | t1 = 43 | Nx.tensor([1, 2, 3, 4], type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend) 44 | 45 | t2 = 46 | Nx.tensor([1, 2, 3, 4], type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend) 47 | 48 | t3 = 49 | Nx.tensor([1, 2, 3, 4], type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend) 50 | 51 | concatted = Nx.concatenate([t1, t2, t3]) |> Nx.backend_transfer() 52 | expected = Nx.tensor([1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4], type: unquote(type_tuple)) 53 | assert concatted == expected 54 | end 55 | 56 | test "Concat axis #{type_str} 1" do 57 | o1 = Nx.iota({3, 5}, type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend) 58 | o2 = Nx.iota({3, 5}, type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend) 59 | 60 | concatted = Nx.concatenate([o1, o2], axis: 1) |> Nx.backend_transfer() 61 | 62 | n1 = Nx.iota({3, 5}, type: unquote(type_tuple)) 63 | n2 = Nx.iota({3, 5}, type: unquote(type_tuple)) 64 | 65 | expected = Nx.concatenate([n1, n2], axis: 1) 66 | assert concatted == expected 67 | end 68 | 69 | test "Concat axis 1 of three 3-dimensional #{type_str} vector" do 70 | t1 = Nx.iota({3, 5, 7}, type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend) 71 | t2 = Nx.iota({3, 5, 7}, type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend) 72 | t3 = Nx.iota({3, 5, 7}, type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend) 73 | 74 | concatted = Nx.concatenate([t1, t2, t3], axis: 1) |> Nx.backend_transfer() 75 | 76 | expected = 77 | Nx.concatenate( 78 | [ 79 | Nx.iota({3, 5, 7}, type: unquote(type_tuple)), 80 | Nx.iota({3, 5, 7}, type: unquote(type_tuple)), 81 | Nx.iota({3, 5, 7}, type: unquote(type_tuple)) 82 | ], 83 | axis: 1 84 | ) 85 | 86 | assert concatted == expected 87 | end 88 | 89 | test "Concat axis 2 of three 3-dimensional #{type_str} vector" do 90 | t1 = Nx.iota({3, 5, 7}, type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend) 91 | t2 = Nx.iota({3, 5, 7}, type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend) 92 | t3 = Nx.iota({3, 5, 7}, type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend) 93 | 94 | concatted = Nx.concatenate([t1, t2, t3], axis: 2) |> Nx.backend_transfer() 95 | 96 | expected = 97 | Nx.concatenate( 98 | [ 99 | Nx.iota({3, 5, 7}, type: unquote(type_tuple)), 100 | Nx.iota({3, 5, 7}, type: unquote(type_tuple)), 101 | Nx.iota({3, 5, 7}, type: unquote(type_tuple)) 102 | ], 103 | axis: 2 104 | ) 105 | 106 | assert concatted == expected 107 | end 108 | 109 | test "Concat doesn't alter component #{type_str} vectors" do 110 | t1 = 111 | Nx.tensor([1, 2, 3, 4], type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend) 112 | 113 | t2 = 114 | Nx.tensor([1, 2, 3, 4], type: unquote(type_tuple)) |> Nx.backend_transfer(Ortex.Backend) 115 | 116 | concatted = Nx.concatenate([t1, t2]) |> Nx.backend_transfer() 117 | second_concatted = Nx.concatenate([t1, t2]) |> Nx.backend_transfer() 118 | 119 | assert concatted == second_concatted 120 | end 121 | end) 122 | 123 | test "Concat fails to concat vectors of differing types" do 124 | assert_raise RuntimeError, 125 | "Ortex does not currently support concatenation of vectors with differing types.", 126 | fn -> 127 | t1 = Nx.tensor([1, 2, 3], type: {:s, 16}) |> Nx.backend_transfer(Ortex.Backend) 128 | t2 = Nx.tensor([1, 2, 3], type: {:s, 32}) |> Nx.backend_transfer(Ortex.Backend) 129 | _err = Nx.concatenate([t1, t2]) 130 | end 131 | end 132 | end 133 | -------------------------------------------------------------------------------- /test/shape/reshape_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Ortex.TestReshape do 2 | # TODO: Fix this, it is not truly validating the reshaping on the ortex side 3 | use ExUnit.Case 4 | 5 | test "1d reshape" do 6 | t = Nx.tensor([1, 2, 3, 4]) 7 | bin = t |> Nx.reshape({2, 2}) 8 | 9 | ort = t |> Nx.backend_copy(Ortex.Backend) |> Nx.reshape({2, 2}) |> Nx.backend_transfer() 10 | 11 | assert bin == ort 12 | end 13 | 14 | test "2d reshape" do 15 | shape = Nx.tensor([[0], [0], [0], [0]]) 16 | t = Nx.tensor([1, 2, 3, 4]) 17 | bin = t |> Nx.reshape(shape) 18 | 19 | ort = 20 | t 21 | |> Nx.backend_copy(Ortex.Backend) 22 | |> Nx.reshape(shape |> Nx.backend_copy(Ortex.Backend)) 23 | |> Nx.backend_transfer() 24 | 25 | assert bin == ort 26 | end 27 | 28 | test "scalar reshape" do 29 | shape = {1, 1, 1} 30 | t = Nx.tensor(1) 31 | bin = t |> Nx.reshape(shape) 32 | 33 | ort = 34 | t 35 | |> Nx.backend_copy(Ortex.Backend) 36 | |> Nx.reshape(shape) 37 | |> Nx.backend_transfer() 38 | 39 | assert bin == ort 40 | end 41 | 42 | test "auto reshape" do 43 | shape = {:auto, 2} 44 | t = Nx.tensor([[1, 2, 3], [4, 5, 6]]) 45 | bin = t |> Nx.reshape(shape) 46 | 47 | ort = 48 | t 49 | |> Nx.backend_copy(Ortex.Backend) 50 | |> Nx.reshape(shape) 51 | |> Nx.backend_transfer() 52 | 53 | assert bin == ort 54 | end 55 | end 56 | -------------------------------------------------------------------------------- /test/shape/slice_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Ortex.TestSlice do 2 | use ExUnit.Case 3 | 4 | {tensor1d, _} = Nx.Random.uniform(Nx.Random.key(42), 0, 256, shape: {10}) 5 | {tensor2d, _} = Nx.Random.uniform(Nx.Random.key(42), 0, 256, shape: {10, 10}) 6 | 7 | @tensor1d tensor1d 8 | @tensor2d tensor2d 9 | 10 | defp tensor_binary(tensor, dtype) do 11 | tensor |> Nx.as_type(dtype) 12 | end 13 | 14 | defp tensor_ortex(tensor, dtype) do 15 | tensor 16 | |> Nx.as_type(dtype) 17 | |> Nx.backend_transfer(Ortex.Backend) 18 | end 19 | 20 | test "1d slice f32" do 21 | bin = tensor_binary(@tensor1d, :f32) |> Nx.slice([0], [4]) 22 | 23 | ort = tensor_ortex(@tensor1d, :f32) |> Nx.slice([0], [4]) |> Nx.backend_transfer() 24 | 25 | assert bin == ort 26 | end 27 | 28 | test "2d slice f32" do 29 | bin = tensor_binary(@tensor2d, :f32) |> Nx.slice([0, 2], [4, 6]) 30 | 31 | ort = 32 | tensor_ortex(@tensor2d, :f32) 33 | |> Nx.slice([0, 2], [4, 6]) 34 | |> Nx.backend_transfer() 35 | 36 | assert bin == ort 37 | end 38 | 39 | test "1d slice u8" do 40 | bin = tensor_binary(@tensor1d, :u8) |> Nx.slice([0], [4]) 41 | 42 | ort = tensor_ortex(@tensor1d, :u8) |> Nx.slice([0], [4]) |> Nx.backend_transfer() 43 | 44 | assert bin == ort 45 | end 46 | 47 | test "2d slice u8" do 48 | bin = tensor_binary(@tensor2d, :u8) |> Nx.slice([0, 2], [4, 6]) 49 | 50 | ort = 51 | tensor_ortex(@tensor2d, :u8) 52 | |> Nx.slice([0, 2], [4, 6]) 53 | |> Nx.backend_transfer() 54 | 55 | assert bin == ort 56 | end 57 | 58 | test "1d slice f32 strided" do 59 | bin = tensor_binary(@tensor1d, :f32) |> Nx.slice([0], [4], strides: [2]) 60 | 61 | ort = 62 | tensor_ortex(@tensor1d, :f32) |> Nx.slice([0], [4], strides: [2]) |> Nx.backend_transfer() 63 | 64 | assert bin == ort 65 | end 66 | 67 | test "2d slice f32 strided" do 68 | bin = tensor_binary(@tensor2d, :f32) |> Nx.slice([0, 2], [4, 6], strides: [2, 1]) 69 | 70 | ort = 71 | tensor_ortex(@tensor2d, :f32) 72 | |> Nx.slice([0, 2], [4, 6], strides: [2, 1]) 73 | |> Nx.backend_transfer() 74 | 75 | assert bin == ort 76 | end 77 | 78 | test "1d slice u8 strided" do 79 | bin = tensor_binary(@tensor1d, :u8) |> Nx.slice([0], [4], strides: [2]) 80 | 81 | ort = 82 | tensor_ortex(@tensor1d, :u8) |> Nx.slice([0], [4], strides: [2]) |> Nx.backend_transfer() 83 | 84 | assert bin == ort 85 | end 86 | 87 | test "2d slice u8 strided" do 88 | bin = tensor_binary(@tensor2d, :u8) |> Nx.slice([0, 2], [4, 6], strides: [2, 1]) 89 | 90 | ort = 91 | tensor_ortex(@tensor2d, :u8) 92 | |> Nx.slice([0, 2], [4, 6], strides: [2, 1]) 93 | |> Nx.backend_transfer() 94 | 95 | assert bin == ort 96 | end 97 | end 98 | -------------------------------------------------------------------------------- /test/shape/squeeze_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Ortex.TestSqueeze do 2 | use ExUnit.Case 3 | 4 | test "1d squeeze" do 5 | t = Nx.tensor([[[1, 2, 3, 4]]]) 6 | bin = t |> Nx.squeeze() 7 | 8 | ort = t |> Nx.backend_copy(Ortex.Backend) |> Nx.squeeze() |> Nx.backend_transfer() 9 | 10 | assert bin == ort 11 | end 12 | 13 | test "2d squeeze" do 14 | t = Nx.tensor([[[[1, 2]], [[3, 4]]]]) 15 | bin = t |> Nx.squeeze() 16 | 17 | ort = t |> Nx.backend_copy(Ortex.Backend) |> Nx.squeeze() |> Nx.backend_transfer() 18 | 19 | assert bin == ort 20 | end 21 | 22 | test "axis squeeze" do 23 | t = Nx.tensor([[[[1, 2]], [[3, 4]]]]) 24 | bin = t |> Nx.squeeze(axes: [0]) 25 | 26 | ort = t |> Nx.backend_copy(Ortex.Backend) |> Nx.squeeze(axes: [0]) |> Nx.backend_transfer() 27 | 28 | assert bin == ort 29 | end 30 | 31 | test "named squeeze" do 32 | t = Nx.tensor([[[[1, 2]], [[3, 4]]]], names: [:w, :x, :y, :z]) 33 | bin = t |> Nx.squeeze(axes: [:w]) 34 | 35 | ort = t |> Nx.backend_copy(Ortex.Backend) |> Nx.squeeze(axes: [:w]) |> Nx.backend_transfer() 36 | 37 | assert bin == ort 38 | end 39 | end 40 | -------------------------------------------------------------------------------- /test/test_helper.exs: -------------------------------------------------------------------------------- 1 | exclude = 2 | if File.exists?("models/resnet50.onnx") do 3 | [] 4 | else 5 | IO.warn( 6 | """ 7 | skipping resnet50 tests because model is not available. 8 | Run python/export_resnet.py before for a complete test suite\ 9 | """, 10 | [] 11 | ) 12 | 13 | [:resnet50] 14 | end 15 | 16 | ExUnit.start(exclude: exclude) 17 | --------------------------------------------------------------------------------