├── .tool-versions ├── native └── baml_elixir │ ├── Cross.toml │ ├── .cargo │ └── config.toml │ ├── Cargo.toml │ ├── README.md │ └── src │ ├── collector.rs │ ├── type_builder.rs │ └── lib.rs ├── .formatter.exs ├── lib ├── baml_elixir.ex └── baml_elixir │ ├── collector.ex │ ├── type_builder.ex │ ├── native.ex │ └── client.ex ├── test ├── test_helper.exs ├── support │ ├── openai_handler.ex │ └── fake_openai_server.ex ├── baml_src │ └── baml_elixir_test.baml └── baml_elixir_test.exs ├── .gitmodules ├── checksum-Elixir.BamlElixir.Native.exs ├── RELEASE.md ├── .gitignore ├── mix.exs ├── scripts └── update_baml.sh ├── .github └── workflows │ └── release.yml ├── mix.lock ├── README.md └── LICENSE /.tool-versions: -------------------------------------------------------------------------------- 1 | erlang 27.1 2 | elixir 1.18.3-otp-27 3 | -------------------------------------------------------------------------------- /native/baml_elixir/Cross.toml: -------------------------------------------------------------------------------- 1 | [build.env] 2 | passthrough = [ 3 | "RUSTLER_NIF_VERSION" 4 | ] 5 | -------------------------------------------------------------------------------- /.formatter.exs: -------------------------------------------------------------------------------- 1 | # Used by "mix format" 2 | [ 3 | inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"] 4 | ] 5 | -------------------------------------------------------------------------------- /lib/baml_elixir.ex: -------------------------------------------------------------------------------- 1 | defmodule BamlElixir do 2 | @baml_version "0.208.3" 3 | 4 | def baml_version, do: @baml_version 5 | end 6 | -------------------------------------------------------------------------------- /test/test_helper.exs: -------------------------------------------------------------------------------- 1 | ExUnit.start() 2 | 3 | Mox.defmock(BamlElixirTest.OpenAIHandlerMock, for: BamlElixirTest.OpenAIHandler) 4 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "native/baml_elixir/baml"] 2 | path = native/baml_elixir/baml 3 | url = https://github.com/BoundaryML/baml.git 4 | branch = canary 5 | -------------------------------------------------------------------------------- /native/baml_elixir/.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [target.aarch64-apple-darwin] 2 | rustflags = [ 3 | "-C", "link-arg=-undefined", 4 | "-C", "link-arg=dynamic_lookup", 5 | ] -------------------------------------------------------------------------------- /checksum-Elixir.BamlElixir.Native.exs: -------------------------------------------------------------------------------- 1 | %{ 2 | "libbaml_elixir-v1.0.0-pre.22-nif-2.15-aarch64-apple-darwin.so.tar.gz" => "sha256:8f378fce2e43b9e4217acd1520506f76ad5b1455f2bf6ae455abeaa0f9fad20d", 3 | "libbaml_elixir-v1.0.0-pre.22-nif-2.15-x86_64-unknown-linux-gnu.so.tar.gz" => "sha256:a7a38b869b41a8eee50ea93828ef62f1fab38ba9fba7f765dbaa8e84a8ff6653", 4 | } 5 | -------------------------------------------------------------------------------- /test/support/openai_handler.ex: -------------------------------------------------------------------------------- 1 | defmodule BamlElixirTest.OpenAIHandler do 2 | @moduledoc false 3 | 4 | @type header_map :: %{optional(String.t()) => String.t()} 5 | 6 | @callback handle_request(path :: String.t(), headers :: header_map, body :: binary()) :: %{ 7 | required(:status) => pos_integer(), 8 | optional(:headers) => [{String.t(), String.t()}], 9 | required(:body) => binary() 10 | } 11 | end 12 | -------------------------------------------------------------------------------- /native/baml_elixir/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "baml_elixir" 3 | version = "0.1.0" 4 | authors = [] 5 | edition = "2021" 6 | 7 | [lib] 8 | name = "baml_elixir" 9 | crate-type = ["cdylib"] 10 | 11 | [dependencies] 12 | rustler = { version = "0.36.1", default-features = false, features = ["derive", "nif_version_2_15"] } 13 | baml-runtime = { path = "baml/engine/baml-runtime", features = ["internal"] } 14 | baml-types = { path = "baml/engine/baml-lib/baml-types" } 15 | internal-baml-core = { path = "baml/engine/baml-lib/baml-core" } -------------------------------------------------------------------------------- /native/baml_elixir/README.md: -------------------------------------------------------------------------------- 1 | # NIF for BamlElixir.Native 2 | 3 | ## To build the NIF module: 4 | 5 | - Your NIF will now build along with your project. 6 | 7 | ## To load the NIF: 8 | 9 | ```elixir 10 | defmodule BamlElixir.Native do 11 | use Rustler, otp_app: :baml_elixir 12 | 13 | # When your NIF is loaded, it will override this function. 14 | def add(_a, _b), do: :erlang.nif_error(:nif_not_loaded) 15 | end 16 | ``` 17 | 18 | ## Examples 19 | 20 | [This](https://github.com/rusterlium/NifIo) is a complete example of a NIF written in Rust. 21 | -------------------------------------------------------------------------------- /lib/baml_elixir/collector.ex: -------------------------------------------------------------------------------- 1 | defmodule BamlElixir.Collector do 2 | defstruct reference: nil 3 | 4 | def new(name) when is_binary(name) do 5 | reference = BamlElixir.Native.collector_new(name) 6 | %__MODULE__{reference: reference} 7 | end 8 | 9 | def usage(%__MODULE__{reference: reference}) when is_reference(reference) do 10 | BamlElixir.Native.collector_usage(reference) 11 | end 12 | 13 | def last_function_log(%__MODULE__{reference: reference}) when is_reference(reference) do 14 | BamlElixir.Native.collector_last_function_log(reference) 15 | end 16 | end 17 | -------------------------------------------------------------------------------- /lib/baml_elixir/type_builder.ex: -------------------------------------------------------------------------------- 1 | defmodule BamlElixir.TypeBuilder do 2 | defmodule Class do 3 | defstruct [:name, :fields] 4 | end 5 | 6 | defmodule Enum do 7 | defstruct [:name, :values] 8 | end 9 | 10 | defmodule EnumValue do 11 | defstruct [:value, :description, :alias, :skip] 12 | end 13 | 14 | defmodule Field do 15 | defstruct [:name, :type, :description, :alias, :skip] 16 | end 17 | 18 | defmodule Union do 19 | defstruct [:name, :types] 20 | end 21 | 22 | defmodule Literal do 23 | defstruct [:name, :value] 24 | end 25 | 26 | defmodule Map do 27 | defstruct [:key_type, :value_type] 28 | end 29 | 30 | defmodule List do 31 | defstruct [:type] 32 | end 33 | end 34 | -------------------------------------------------------------------------------- /RELEASE.md: -------------------------------------------------------------------------------- 1 | # How to release 2 | 3 | Because we use 4 | [`RustlerPrecompiled`](https://hexdocs.pm/rustler_precompiled/RustlerPrecompiled.html), releasing 5 | is a bit more involved than it would be otherwise. 6 | 7 | 1. Ensure the version in `mix.exs` is updated. 8 | 2. Cut a GitHub release and tag the commit with the version number. 9 | 3. This will kick off the "Build precompiled NIFs" GitHub Action. Wait for this to complete. It 10 | usually takes around 5-10 minutes. 11 | 4. While the NIFs are compiling, ensure you have the latest version of `main` and don't have any 12 | intermediate builds by running `rm -rf native/baml_elixir/target`. 13 | 5. Once the NIFs are built, use `mix rustler_precompiled.download BamlElixir.Native --all --print` to download generate the checksum file. 14 | 6. Run `mix hex.publish`. 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # The directory Mix will write compiled artifacts to. 2 | /_build/ 3 | 4 | # If you run "mix test --cover", coverage assets end up here. 5 | /cover/ 6 | 7 | # The directory Mix downloads your dependencies sources to. 8 | /deps/ 9 | 10 | # Where third-party dependencies like ExDoc output generated docs. 11 | /doc/ 12 | 13 | # Ignore .fetch files in case you like to edit your project deps locally. 14 | /.fetch 15 | 16 | # If the VM crashes, it generates a dump, let's ignore it too. 17 | erl_crash.dump 18 | 19 | # Also ignore archive artifacts (built via "mix archive.build"). 20 | *.ez 21 | 22 | # Ignore package tarball (built via "mix hex.build"). 23 | baml_elixir-*.tar 24 | 25 | # Temporary files, for example, from tests. 26 | /tmp/ 27 | 28 | # Rust binary artifacts 29 | target/ 30 | 31 | /priv/native/ 32 | baml_client/ 33 | baml_client.tmp 34 | /priv 35 | -------------------------------------------------------------------------------- /lib/baml_elixir/native.ex: -------------------------------------------------------------------------------- 1 | defmodule BamlElixir.Native do 2 | version = Mix.Project.config()[:version] 3 | 4 | use RustlerPrecompiled, 5 | otp_app: :baml_elixir, 6 | base_url: "https://github.com/emilsoman/baml_elixir/releases/download/v#{version}/", 7 | force_build: System.get_env("BAML_ELIXIR_BUILD") in ["1", "true"], 8 | version: version, 9 | targets: [ 10 | "aarch64-apple-darwin", 11 | "x86_64-unknown-linux-gnu", 12 | "aarch64-unknown-linux-gnu" 13 | ] 14 | 15 | def call(_function_name, _args, _path, _collectors, _client_registry, _tb), 16 | do: :erlang.nif_error(:nif_not_loaded) 17 | 18 | def stream(_pid, _reference, _function_name, _args, _path, _collectors, _client_registry, _tb), 19 | do: :erlang.nif_error(:nif_not_loaded) 20 | 21 | def collector_new(_name), do: :erlang.nif_error(:nif_not_loaded) 22 | 23 | def collector_usage(_collector), do: :erlang.nif_error(:nif_not_loaded) 24 | 25 | def collector_last_function_log(_collector), do: :erlang.nif_error(:nif_not_loaded) 26 | 27 | def parse_baml(_path), do: :erlang.nif_error(:nif_not_loaded) 28 | end 29 | -------------------------------------------------------------------------------- /mix.exs: -------------------------------------------------------------------------------- 1 | defmodule BamlElixir.MixProject do 2 | use Mix.Project 3 | 4 | @version "1.0.0-pre.23" 5 | 6 | def project do 7 | [ 8 | app: :baml_elixir, 9 | description: "Call BAML functions from Elixir.", 10 | version: @version, 11 | elixir: "~> 1.17", 12 | start_permanent: Mix.env() == :prod, 13 | elixirc_paths: elixirc_paths(Mix.env()), 14 | deps: deps(), 15 | package: package() 16 | ] 17 | end 18 | 19 | # Run "mix help compile.app" to learn about applications. 20 | def application do 21 | [ 22 | extra_applications: [:logger] 23 | ] 24 | end 25 | 26 | # Run "mix help deps" to learn about dependencies. 27 | defp elixirc_paths(:test), do: ["lib", "test/support"] 28 | defp elixirc_paths(_), do: ["lib"] 29 | 30 | defp deps do 31 | [ 32 | {:rustler, "~> 0.36.1", optional: true}, 33 | {:rustler_precompiled, "~> 0.8"}, 34 | {:mox, "~> 1.1", only: :test}, 35 | {:ex_doc, ">= 0.0.0", only: :dev, runtime: false} 36 | ] 37 | end 38 | 39 | defp package do 40 | [ 41 | files: [ 42 | "lib", 43 | "checksum-*.exs", 44 | "mix.exs", 45 | "LICENSE" 46 | ], 47 | licenses: ["Apache-2.0"], 48 | links: %{"GitHub" => "https://github.com/emilsoman/baml_elixir"}, 49 | maintainers: ["Emil Soman"] 50 | ] 51 | end 52 | end 53 | -------------------------------------------------------------------------------- /scripts/update_baml.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if a tag argument was provided 4 | if [ -z "$1" ]; then 5 | echo "Error: Please provide a BAML tag to update to" 6 | echo "Usage: $0 " 7 | echo "Example: $0 0.87.2" 8 | exit 1 9 | fi 10 | 11 | TAG=$1 12 | 13 | # Ensure we're in the root directory of the project 14 | if [ ! -d "native/baml_elixir/baml" ]; then 15 | echo "Error: Could not find BAML submodule at native/baml_elixir/baml" 16 | exit 1 17 | fi 18 | 19 | # Update the submodule to the specified tag 20 | echo "Updating BAML submodule to tag: $TAG" 21 | cd native/baml_elixir/baml 22 | git fetch --tags 23 | git checkout "$TAG" 24 | if [ $? -ne 0 ]; then 25 | echo "Error: Failed to checkout tag $TAG" 26 | exit 1 27 | fi 28 | 29 | # Update the submodule in the main repository 30 | cd ../../.. 31 | git add native/baml_elixir/baml 32 | 33 | # Update the BAML version in the Elixir module 34 | echo "Updating BAML version in Elixir module to: $TAG" 35 | sed -i '' "s/@baml_version \".*\"/@baml_version \"$TAG\"/" lib/baml_elixir.ex 36 | git add lib/baml_elixir.ex 37 | 38 | # Compile the project to ensure everything works and update Cargo.lock 39 | echo "Compiling project..." 40 | mix compile 41 | if [ $? -ne 0 ]; then 42 | echo "Error: Compilation failed. Please fix any issues before committing." 43 | exit 1 44 | fi 45 | 46 | # Add the updated Cargo.lock 47 | git add native/baml_elixir/Cargo.lock 48 | 49 | # Show preview of changes 50 | echo -e "\nPreview of changes to be committed:" 51 | echo "----------------------------------------" 52 | git diff --cached 53 | echo "----------------------------------------" 54 | 55 | # Ask for confirmation 56 | read -p "Do you want to commit these changes? (y/N) " -n 1 -r 57 | echo 58 | if [[ ! $REPLY =~ ^[Yy]$ ]]; then 59 | echo "Aborting commit. Changes are staged but not committed." 60 | exit 1 61 | fi 62 | 63 | # Commit all changes 64 | git commit -m "Update BAML submodule and version to $TAG" 65 | 66 | echo "Successfully updated BAML submodule and version to $TAG" -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Build precompiled NIFs 2 | 3 | on: 4 | push: 5 | tags: 6 | - "*" 7 | 8 | jobs: 9 | build_release: 10 | name: NIF ${{ matrix.nif }} - ${{ matrix.job.target }} (${{ matrix.job.os }}) 11 | runs-on: ${{ matrix.job.os }} 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | nif: ["2.15"] 16 | job: 17 | - { target: aarch64-apple-darwin, os: macos-latest } 18 | - { target: x86_64-unknown-linux-gnu, os: ubuntu-latest } 19 | - { 20 | target: aarch64-unknown-linux-gnu, 21 | os: ubuntu-latest, 22 | use-cross: true, 23 | } 24 | 25 | steps: 26 | - name: Checkout source code 27 | uses: actions/checkout@v3 28 | with: 29 | submodules: recursive 30 | 31 | - name: Extract crate information 32 | shell: bash 33 | run: | 34 | # Get the project version from mix.exs 35 | echo "PROJECT_VERSION=$(sed -n 's/^ @version "\(.*\)"/\1/p' mix.exs | head -n1)" >> $GITHUB_ENV 36 | 37 | - name: Install Rust toolchain 38 | uses: dtolnay/rust-toolchain@stable 39 | with: 40 | target: ${{ matrix.job.target }} 41 | 42 | - name: Build the project 43 | id: build-crate 44 | uses: philss/rustler-precompiled-action@v1.0.1 45 | with: 46 | nif-version: ${{ matrix.nif }} 47 | project-dir: "native/baml_elixir" 48 | project-name: baml_elixir 49 | project-version: ${{ env.PROJECT_VERSION }} 50 | target: ${{ matrix.job.target }} 51 | use-cross: ${{ matrix.job.use-cross }} 52 | 53 | - name: Artifact upload 54 | uses: actions/upload-artifact@v4 55 | with: 56 | name: ${{ steps.build-crate.outputs.file-name }} 57 | path: ${{ steps.build-crate.outputs.file-path }} 58 | 59 | - name: Publish archives and packages 60 | uses: softprops/action-gh-release@v2 61 | if: github.ref_type == 'tag' 62 | with: 63 | files: | 64 | ${{ steps.build-crate.outputs.file-path }} 65 | -------------------------------------------------------------------------------- /test/baml_src/baml_elixir_test.baml: -------------------------------------------------------------------------------- 1 | 2 | client GPT4 { 3 | provider openai 4 | options { 5 | model gpt-4o-mini 6 | api_key env.OPENAI_API_KEY 7 | } 8 | } 9 | 10 | client DeepSeekR1 { 11 | provider openai-generic 12 | options { 13 | base_url "https://api.together.ai/v1" 14 | api_key env.TOGETHER_API_KEY 15 | model "deepseek-ai/DeepSeek-R1-Distill-Llama-70B-free" 16 | } 17 | } 18 | 19 | client Claude35Sonnet { 20 | provider anthropic 21 | options { 22 | model claude-3-5-sonnet-latest 23 | api_key env.ANTHROPIC_API_KEY 24 | } 25 | } 26 | 27 | client Gemini25Pro { 28 | provider google-ai 29 | options { 30 | model "gemini-2.5-pro" 31 | api_key env.GEMINI_API_KEY 32 | } 33 | } 34 | 35 | function ExtractPerson(info: string) -> Person { 36 | client GPT4 37 | prompt #" 38 | {{ ctx.output_format }} 39 | 40 | Extract the person's information from the following string: 41 | {{ info }} 42 | "# 43 | } 44 | 45 | class Person { 46 | name string 47 | age int 48 | } 49 | 50 | function DescribeImage(myImg: image) -> string { 51 | client GPT4 52 | prompt #" 53 | {{ _.role("user")}} 54 | Describe the image in four words: 55 | {{ myImg }} 56 | "# 57 | } 58 | 59 | enum Model { 60 | DeepSeekR1 61 | GPT4oMini 62 | } 63 | 64 | class MyClass { 65 | property1 string 66 | property2 int? 67 | @@dynamic // allows adding fields dynamically at runtime 68 | } 69 | 70 | class NewEmployeeFullyDynamic { 71 | employee_id string 72 | @@dynamic // allows adding fields dynamically at runtime 73 | } 74 | 75 | function CreateEmployee() -> NewEmployeeFullyDynamic { 76 | client GPT4 77 | prompt #" 78 | Create a fake employee data with the following information: 79 | {{ ctx.output_format }} 80 | "# 81 | } 82 | 83 | function WhichModel() -> Model { 84 | client GPT4 85 | prompt #" 86 | Which model are you? 87 | 88 | {{ ctx.output_format }} 89 | "# 90 | } 91 | 92 | function WhichModelUnion() -> "DeepSeek" | "GPT" { 93 | client GPT4 94 | prompt #" 95 | Which model are you? 96 | 97 | {{ ctx.output_format }} 98 | "# 99 | } 100 | 101 | class DummyOutput { 102 | nonce string 103 | nonce2 string 104 | } 105 | 106 | function DummyOutputFunction() -> DummyOutput { 107 | client GPT4 108 | prompt #" 109 | Say "hello there". 110 | "# 111 | } 112 | 113 | class Attendees { 114 | hosts Person[] 115 | guests Person[] 116 | } 117 | 118 | function ParseAttendees(attendees: string) -> Attendees { 119 | client GPT4 120 | prompt #" 121 | {{ ctx.output_format }} 122 | 123 | Parse the following string into an Attendees struct: 124 | {{ attendees }} 125 | "# 126 | } 127 | 128 | function FlipSwitch(switch: bool) -> bool { 129 | client GPT4 130 | prompt #" 131 | Flip the switch: 132 | {{ switch }} 133 | 134 | {{ ctx.output_format }} 135 | "# 136 | } -------------------------------------------------------------------------------- /mix.lock: -------------------------------------------------------------------------------- 1 | %{ 2 | "castore": {:hex, :castore, "1.0.12", "053f0e32700cbec356280c0e835df425a3be4bc1e0627b714330ad9d0f05497f", [:mix], [], "hexpm", "3dca286b2186055ba0c9449b4e95b97bf1b57b47c1f2644555879e659960c224"}, 3 | "earmark_parser": {:hex, :earmark_parser, "1.4.44", "f20830dd6b5c77afe2b063777ddbbff09f9759396500cdbe7523efd58d7a339c", [:mix], [], "hexpm", "4778ac752b4701a5599215f7030989c989ffdc4f6df457c5f36938cc2d2a2750"}, 4 | "ex_doc": {:hex, :ex_doc, "0.37.3", "f7816881a443cd77872b7d6118e8a55f547f49903aef8747dbcb345a75b462f9", [:mix], [{:earmark_parser, "~> 1.4.42", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "e6aebca7156e7c29b5da4daa17f6361205b2ae5f26e5c7d8ca0d3f7e18972233"}, 5 | "jason": {:hex, :jason, "1.4.4", "b9226785a9aa77b6857ca22832cffa5d5011a667207eb2a0ad56adb5db443b8a", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "c5eb0cab91f094599f94d55bc63409236a8ec69a21a67814529e8d5f6cc90b3b"}, 6 | "makeup": {:hex, :makeup, "1.2.1", "e90ac1c65589ef354378def3ba19d401e739ee7ee06fb47f94c687016e3713d1", [:mix], [{:nimble_parsec, "~> 1.4", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "d36484867b0bae0fea568d10131197a4c2e47056a6fbe84922bf6ba71c8d17ce"}, 7 | "makeup_elixir": {:hex, :makeup_elixir, "1.0.1", "e928a4f984e795e41e3abd27bfc09f51db16ab8ba1aebdba2b3a575437efafc2", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "7284900d412a3e5cfd97fdaed4f5ed389b8f2b4cb49efc0eb3bd10e2febf9507"}, 8 | "makeup_erlang": {:hex, :makeup_erlang, "1.0.2", "03e1804074b3aa64d5fad7aa64601ed0fb395337b982d9bcf04029d68d51b6a7", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "af33ff7ef368d5893e4a267933e7744e46ce3cf1f61e2dccf53a111ed3aa3727"}, 9 | "mox": {:hex, :mox, "1.2.0", "a2cd96b4b80a3883e3100a221e8adc1b98e4c3a332a8fc434c39526babafd5b3", [:mix], [{:nimble_ownership, "~> 1.0", [hex: :nimble_ownership, repo: "hexpm", optional: false]}], "hexpm", "c7b92b3cc69ee24a7eeeaf944cd7be22013c52fcb580c1f33f50845ec821089a"}, 10 | "nimble_ownership": {:hex, :nimble_ownership, "1.0.2", "fa8a6f2d8c592ad4d79b2ca617473c6aefd5869abfa02563a77682038bf916cf", [:mix], [], "hexpm", "098af64e1f6f8609c6672127cfe9e9590a5d3fcdd82bc17a377b8692fd81a879"}, 11 | "nimble_parsec": {:hex, :nimble_parsec, "1.4.2", "8efba0122db06df95bfaa78f791344a89352ba04baedd3849593bfce4d0dc1c6", [:mix], [], "hexpm", "4b21398942dda052b403bbe1da991ccd03a053668d147d53fb8c4e0efe09c973"}, 12 | "rustler": {:hex, :rustler, "0.36.1", "2d4b1ff57ea2789a44756a40dbb5fbb73c6ee0a13d031dcba96d0a5542598a6a", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.7", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "f3fba4ad272970e0d1bc62972fc4a99809651e54a125c5242de9bad4574b2d02"}, 13 | "rustler_precompiled": {:hex, :rustler_precompiled, "0.8.2", "5f25cbe220a8fac3e7ad62e6f950fcdca5a5a5f8501835d2823e8c74bf4268d5", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "63d1bd5f8e23096d1ff851839923162096364bac8656a4a3c00d1fff8e83ee0a"}, 14 | "toml": {:hex, :toml, "0.7.0", "fbcd773caa937d0c7a02c301a1feea25612720ac3fa1ccb8bfd9d30d822911de", [:mix], [], "hexpm", "0690246a2478c1defd100b0c9b89b4ea280a22be9a7b313a8a058a2408a2fa70"}, 15 | } 16 | -------------------------------------------------------------------------------- /test/support/fake_openai_server.ex: -------------------------------------------------------------------------------- 1 | defmodule BamlElixirTest.FakeOpenAIServer do 2 | @moduledoc false 3 | 4 | # A tiny OpenAI-compatible HTTP server for tests. 5 | # Request handling is delegated to a handler module (usually a Mox mock). 6 | 7 | @mock BamlElixirTest.OpenAIHandlerMock 8 | 9 | @doc """ 10 | Convenience helper for tests: expect a single chat completion request and respond 11 | with an OpenAI-compatible JSON body whose assistant message content is `response_content`. 12 | 13 | The expectation is set on #{@mock} so tests don't need to reference the mock directly. 14 | """ 15 | @type expected_header_value :: String.t() | :present | {:contains, String.t()} 16 | @type expected_headers :: %{optional(String.t()) => expected_header_value} 17 | 18 | @spec expect_chat_completion(String.t(), expected_headers()) :: :ok 19 | def expect_chat_completion(response_content, expected_headers \\ %{}) 20 | when is_binary(response_content) and is_map(expected_headers) do 21 | body = 22 | Jason.encode!(%{ 23 | "id" => "chatcmpl-test", 24 | "object" => "chat.completion", 25 | "created" => 1_700_000_000, 26 | "model" => "gpt-4o-mini", 27 | "choices" => [ 28 | %{ 29 | "index" => 0, 30 | "message" => %{"role" => "assistant", "content" => response_content}, 31 | "finish_reason" => "stop" 32 | } 33 | ], 34 | "usage" => %{"prompt_tokens" => 1, "completion_tokens" => 1, "total_tokens" => 2} 35 | }) 36 | 37 | normalized_expected_headers = 38 | Map.new(expected_headers, fn {k, v} -> {String.downcase(to_string(k)), v} end) 39 | 40 | Mox.expect(@mock, :handle_request, fn path, headers, _body -> 41 | # Basic sanity check that the runtime hit the expected endpoint 42 | if !String.contains?(path, "chat/completions") do 43 | raise "Unexpected path: #{inspect(path)}" 44 | end 45 | 46 | assert_expected_headers!(headers, normalized_expected_headers) 47 | 48 | %{status: 200, headers: [{"content-type", "application/json"}], body: body} 49 | end) 50 | 51 | :ok 52 | end 53 | 54 | defp assert_expected_headers!(_headers, expected) when expected == %{}, do: :ok 55 | 56 | defp assert_expected_headers!(headers, expected) when is_map(headers) and is_map(expected) do 57 | Enum.each(expected, fn {key, expectation} -> 58 | case expectation do 59 | :present -> 60 | if !Map.has_key?(headers, key) do 61 | raise "Expected header #{inspect(key)} to be present, got: #{inspect(Map.keys(headers))}" 62 | end 63 | 64 | {:contains, substring} when is_binary(substring) -> 65 | actual = Map.get(headers, key) 66 | 67 | if is_nil(actual) or !String.contains?(actual, substring) do 68 | raise "Expected header #{inspect(key)} to contain #{inspect(substring)}, got: #{inspect(actual)}" 69 | end 70 | 71 | value when is_binary(value) -> 72 | actual = Map.get(headers, key) 73 | 74 | if actual != value do 75 | raise "Expected header #{inspect(key)} to equal #{inspect(value)}, got: #{inspect(actual)}" 76 | end 77 | 78 | other -> 79 | raise "Unsupported header expectation for #{inspect(key)}: #{inspect(other)}" 80 | end 81 | end) 82 | end 83 | 84 | @doc """ 85 | Convenience helper for tests: starts the server wired to the Mox mock and returns 86 | a base_url suitable for `openai-generic` (`.../v1`). 87 | """ 88 | @spec start_base_url() :: String.t() 89 | def start_base_url() do 90 | {:ok, server_pid, port} = start_link(@mock) 91 | Mox.allow(@mock, self(), server_pid) 92 | "http://127.0.0.1:#{port}/v1" 93 | end 94 | 95 | @spec start_link(module()) :: {:ok, pid(), non_neg_integer()} 96 | def start_link(handler_module) when is_atom(handler_module) do 97 | {:ok, listen_socket} = 98 | :gen_tcp.listen(0, [:binary, packet: :raw, active: false, reuseaddr: true]) 99 | 100 | {:ok, port} = :inet.port(listen_socket) 101 | 102 | pid = 103 | spawn_link(fn -> 104 | {:ok, socket} = :gen_tcp.accept(listen_socket) 105 | 106 | {:ok, header_blob} = recv_until_headers(socket, <<>>) 107 | {path, headers, content_length} = parse_request(header_blob) 108 | 109 | body = 110 | if content_length > 0 do 111 | case :gen_tcp.recv(socket, content_length, 5_000) do 112 | {:ok, b} -> b 113 | _ -> <<>> 114 | end 115 | else 116 | <<>> 117 | end 118 | 119 | response = handler_module.handle_request(path, headers, body) 120 | %{status: status, body: resp_body} = response 121 | resp_headers = Map.get(response, :headers, []) 122 | 123 | # Ensure minimal headers exist 124 | resp_headers = 125 | resp_headers 126 | |> ensure_header("content-length", Integer.to_string(byte_size(resp_body))) 127 | |> ensure_header("connection", "close") 128 | 129 | resp = 130 | "HTTP/1.1 #{status} OK\r\n" <> 131 | Enum.map_join(resp_headers, "", fn {k, v} -> "#{k}: #{v}\r\n" end) <> 132 | "\r\n" <> 133 | resp_body 134 | 135 | :gen_tcp.send(socket, resp) 136 | :gen_tcp.close(socket) 137 | :gen_tcp.close(listen_socket) 138 | end) 139 | 140 | {:ok, pid, port} 141 | end 142 | 143 | defp ensure_header(headers, key, value) do 144 | key_down = String.downcase(key) 145 | 146 | if Enum.any?(headers, fn {k, _} -> String.downcase(k) == key_down end) do 147 | headers 148 | else 149 | [{key, value} | headers] 150 | end 151 | end 152 | 153 | defp recv_until_headers(socket, acc) do 154 | case :binary.match(acc, "\r\n\r\n") do 155 | {_, _} -> 156 | {:ok, acc} 157 | 158 | :nomatch -> 159 | case :gen_tcp.recv(socket, 0, 5_000) do 160 | {:ok, chunk} -> recv_until_headers(socket, acc <> chunk) 161 | other -> other 162 | end 163 | end 164 | end 165 | 166 | defp parse_request(header_blob) do 167 | [request_line | header_lines] = 168 | header_blob 169 | |> String.split("\r\n", trim: true) 170 | 171 | path = 172 | case String.split(request_line, " ", parts: 3) do 173 | [_method, p, _http] -> p 174 | _ -> "" 175 | end 176 | 177 | headers = 178 | header_lines 179 | |> Enum.reduce(%{}, fn line, acc -> 180 | case String.split(line, ":", parts: 2) do 181 | [k, v] -> Map.put(acc, String.downcase(String.trim(k)), String.trim(v)) 182 | _ -> acc 183 | end 184 | end) 185 | 186 | content_length = 187 | headers 188 | |> Map.get("content-length", "0") 189 | |> String.to_integer() 190 | 191 | {path, headers, content_length} 192 | end 193 | end 194 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BamlElixir 2 | 3 | Call BAML functions from Elixir, using a Rust NIF. 4 | 5 | ## First of all, can this be used in production? 6 | 7 | Well, I use it in production. But it's way too early for you if you expect stable APIs 8 | and things to not break at all. If you're okay with debugging issues with me when things go wrong, 9 | please go ahead! 10 | 11 | What this library does: 12 | 13 | - Generates Elixir structs, types and functions from BAML files. 14 | - Gives you autocomplete and dialyzer type checking. 15 | - Parses BAML results into Elixir structs. 16 | - Switch between different LLM clients. 17 | - Get usage data using collectors. 18 | 19 | What this library does not do: 20 | 21 | - Generate Elixir `baml_client` files from BAML files. Codegen happens at compile time. 22 | - Automatically parse BAML results into Elixir structs. 23 | 24 | ## Usage 25 | 26 | Create a baml_src directory in priv and add a BAML file in there: 27 | 28 | ```baml 29 | client GPT4 { 30 | provider openai 31 | options { 32 | model gpt-4o-mini 33 | api_key env.OPENAI_API_KEY 34 | } 35 | } 36 | 37 | class Resume { 38 | name string 39 | job_title string 40 | company string 41 | } 42 | 43 | function ExtractResume(resume: string) -> Resume { 44 | client GPT4 45 | prompt #" 46 | {{ _.role('system') }} 47 | 48 | Extract the following information from the resume: 49 | 50 | Resume: 51 | <<<< 52 | {{ resume }} 53 | <<<< 54 | 55 | Output JSON schema: 56 | {{ ctx.output_format }} 57 | 58 | JSON: 59 | "# 60 | } 61 | ``` 62 | 63 | Now create a BAML client module: 64 | 65 | ```elixir 66 | defmodule MyApp.BamlClient do 67 | # {:my_app, "priv/baml_src"} Will be expanded to Application.app_dir(:my_app, "priv/baml_src") 68 | use BamlElixir.Client, path: {:my_app, "priv/baml_src"} 69 | end 70 | ``` 71 | 72 | Now call the BAML function: 73 | 74 | ```elixir 75 | MyApp.BamlClient.ExtractResume.call(%{resume: "John Doe is the CTO of Acme Inc."}) 76 | ``` 77 | 78 | ### Stream results 79 | 80 | ```elixir 81 | MyApp.BamlClient.ExtractResume.stream(%{resume: "John Doe is the CTO of Acme Inc."}, fn 82 | {:partial, result} -> 83 | IO.inspect(result) 84 | 85 | {:done, result} -> 86 | IO.inspect(result) 87 | 88 | {:error, error} -> 89 | IO.inspect(error) 90 | end) 91 | ``` 92 | 93 | You can also use `sync_stream` to get partial results and block until the function is done. 94 | 95 | ```elixir 96 | case MyApp.BamlClient.ExtractResume.sync_stream( 97 | %{resume: "John Doe is the CTO of Acme Inc."}, 98 | fn result -> 99 | IO.inspect(result) 100 | end 101 | ) do 102 | {:ok, result} -> 103 | IO.inspect(result) 104 | 105 | {:error, error} -> 106 | IO.inspect(error) 107 | end 108 | ``` 109 | 110 | ### Images 111 | 112 | Send an image URL: 113 | 114 | ```elixir 115 | MyApp.BamlClient.DescribeImage.call(%{ 116 | myImg: %{ 117 | url: "https://upload.wikimedia.org/wikipedia/en/4/4d/Shrek_%28character%29.png" 118 | } 119 | }) 120 | |> IO.inspect() 121 | ``` 122 | 123 | Or send base64 encoded image data: 124 | 125 | ```elixir 126 | MyApp.BamlClient.DescribeImage.stream(%{ 127 | myImg: %{ 128 | base64: "data:image/png;base64,..." 129 | } 130 | }, fn result -> 131 | IO.inspect(result) 132 | end) 133 | ``` 134 | 135 | ### Collect usage data 136 | 137 | ```elixir 138 | collector = BamlElixir.Collector.new("my_collector") 139 | 140 | MyApp.BamlClient.ExtractResume.call(%{resume: "John Doe is the CTO of Acme Inc."}, %{ 141 | collectors: [collector] 142 | }) 143 | 144 | BamlElixir.Collector.usage(collector) 145 | ``` 146 | 147 | When streaming, you can get the usage after :done message is received. 148 | 149 | ### Switch LLM clients 150 | 151 | From the existing list of LLM clients, you can switch to a different one by calling `Client.use_llm_client/2`. 152 | 153 | ```elixir 154 | MyApp.BamlClient.WhichModel.call(%{}, %{ 155 | llm_client: "GPT4oMini" 156 | }) 157 | |> IO.inspect() 158 | # => "gpt-4o-mini" 159 | 160 | MyApp.BamlClient.WhichModel.call(%{}, %{ 161 | llm_client: "DeepSeekR1" 162 | }) 163 | |> IO.inspect() 164 | # => "deepseek-r1" 165 | ``` 166 | 167 | ### Type Builder 168 | 169 | You can provide a type builder to dynamically define types at runtime. This is useful for classes with `@@dynamic` attributes or when you need to create types that aren't defined in your BAML files. 170 | 171 | The type builder is a list of TypeBuilder structs: 172 | 173 | #### Example 174 | 175 | Given this BAML file: 176 | 177 | ```baml 178 | class DynamicEmployee { 179 | employee_id string 180 | @@dynamic // allows adding fields dynamically at runtime 181 | } 182 | ``` 183 | 184 | ```elixir 185 | {:ok, 186 | %{ 187 | __baml_class__: "DynamicEmployee", 188 | employee_id: _, 189 | person: %{ 190 | name: "Foobar123", 191 | age: _, 192 | children_count: _, 193 | favorite_day: _, 194 | favorite_color: :RED, 195 | __baml_class__: "TestPerson" 196 | } 197 | }} = 198 | BamlElixirTest.CreateEmployee.call(%{}, %{ 199 | tb: [ 200 | %TypeBuilder.Class{ 201 | name: "TestPerson", 202 | fields: [ 203 | %TypeBuilder.Field{ 204 | name: "name", 205 | type: :string, 206 | description: "The name of the person - this should always be Foobar123" 207 | }, 208 | %TypeBuilder.Field{name: "age", type: :int}, 209 | %TypeBuilder.Field{name: "children_count", type: 1}, 210 | %TypeBuilder.Field{name: "favorite_day", type: %TypeBuilder.Union{types: ["sunday", "monday"]}}, 211 | %TypeBuilder.Field{name: "favorite_color", type: %TypeBuilder.Enum{name: "FavoriteColor"}} 212 | ] 213 | }, 214 | %TypeBuilder.Class{ 215 | name: "DynamicEmployee", 216 | fields: [ 217 | %TypeBuilder.Field{name: "person", type: %TypeBuilder.Class{name: "TestPerson"}} 218 | ] 219 | }, 220 | %TypeBuilder.Enum{ 221 | name: "FavoriteColor", 222 | values: [ 223 | %TypeBuilder.EnumValue{value: "RED", description: "Pick this always"}, 224 | %TypeBuilder.EnumValue{value: "GREEN"}, 225 | %TypeBuilder.EnumValue{value: "BLUE"} 226 | ] 227 | } 228 | ] 229 | }) 230 | ``` 231 | 232 | **Note**: Classes with dynamic fields are not parsed into structs. They return a map with a `__baml_class__` key which can be used for pattern matching. 233 | 234 | ## Installation 235 | 236 | Add baml_elixir to your mix.exs: 237 | 238 | ```elixir 239 | def deps do 240 | [ 241 | {:baml_elixir, "~> 1.0.0-pre.23"} 242 | ] 243 | end 244 | ``` 245 | 246 | This also downloads the pre built NIFs for these targets: 247 | 248 | - aarch64-apple-darwin (Apple Silicon) 249 | - x86_64-unknown-linux-gnu 250 | - aarch64-unknown-linux-gnu 251 | 252 | If you need to build the NIFs for other targets, you need to clone the repo and build it locally as documented below. 253 | 254 | ### TODO 255 | 256 | - Type aliases 257 | - Dynamic types (WIP, works partially) 258 | - Stream cancellation 259 | - Add support for audio, PDF, and video output types 260 | - Stream metadata exposure (`@stream.done`, `@stream.not_null`, `@stream.with_state`) 261 | - OnTick callbacks 262 | - Structured error types (replace `{:error, String.t()}` with structured error types like `BamlValidationError`) 263 | - Runtime strategy configuration via Elixir 264 | 265 | ### Development 266 | 267 | This project includes Git submodules. To clone the repository with all its submodules, use: 268 | 269 | ```bash 270 | git clone --recurse-submodules 271 | ``` 272 | 273 | If you've already cloned the repository without submodules, initialize them with: 274 | 275 | ```bash 276 | git submodule init 277 | git submodule update 278 | ``` 279 | 280 | The project includes Rust code in the `native/` directory: 281 | 282 | - `native/baml_elixir/` - Main Rust NIF code 283 | - `native/baml_elixir/baml/` - Submodule containing baml which is a dependency of the NIF 284 | 285 | ### Building 286 | 287 | 1. Ensure you have Rust installed (https://rustup.rs/). Can use asdf to install it. 288 | 2. Build the project: 289 | 290 | ```bash 291 | mix deps.get 292 | mix compile 293 | ``` 294 | -------------------------------------------------------------------------------- /native/baml_elixir/src/collector.rs: -------------------------------------------------------------------------------- 1 | use baml_runtime::tracingv2::storage::storage::Collector as BamlCollector; 2 | use rustler::{Encoder, Env, Resource, ResourceArc, Term}; 3 | use std::sync::{Arc, Mutex}; 4 | 5 | #[rustler::resource_impl()] 6 | impl Resource for CollectorResource {} 7 | 8 | pub struct CollectorResource { 9 | pub inner: Arc, 10 | } 11 | 12 | impl CollectorResource { 13 | pub fn new(name: Option) -> ResourceArc { 14 | let collector = BamlCollector::new(name); 15 | ResourceArc::new(CollectorResource { 16 | inner: Arc::new(collector), 17 | }) 18 | } 19 | 20 | pub fn usage(&self) -> Usage { 21 | Usage { 22 | inner: self.inner.clone().usage(), 23 | } 24 | } 25 | 26 | pub fn last_function_log(&self) -> Option { 27 | self.inner.last_function_log().map(|log| FunctionLog { 28 | inner: Arc::new(Mutex::new(log)), 29 | }) 30 | } 31 | } 32 | 33 | pub struct FunctionLog { 34 | pub inner: Arc>, 35 | } 36 | 37 | pub struct Usage { 38 | pub inner: baml_runtime::tracingv2::storage::storage::Usage, 39 | } 40 | 41 | pub struct Timing { 42 | pub inner: baml_runtime::tracingv2::storage::storage::Timing, 43 | } 44 | 45 | pub struct StreamTiming { 46 | pub inner: baml_runtime::tracingv2::storage::storage::StreamTiming, 47 | } 48 | 49 | pub struct LLMCallKind { 50 | pub inner: baml_runtime::tracingv2::storage::storage::LLMCallKind, 51 | } 52 | 53 | pub struct LLMCall { 54 | pub inner: baml_runtime::tracingv2::storage::storage::LLMCall, 55 | } 56 | 57 | pub struct LLMStreamCall { 58 | pub inner: baml_runtime::tracingv2::storage::storage::LLMStreamCall, 59 | } 60 | 61 | impl Encoder for Usage { 62 | fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { 63 | let map = Term::map_new(env); 64 | map.map_put("input_tokens", self.inner.input_tokens) 65 | .unwrap() 66 | .map_put("output_tokens", self.inner.output_tokens) 67 | .unwrap() 68 | } 69 | } 70 | 71 | impl Encoder for Timing { 72 | fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { 73 | let map = Term::map_new(env); 74 | map.map_put("start_time_utc_ms", self.inner.start_time_utc_ms) 75 | .unwrap() 76 | .map_put("duration_ms", self.inner.duration_ms) 77 | .unwrap() 78 | // TODO: BAML doesn't track this yet 79 | // .map_put( 80 | // "time_to_first_parsed_ms", 81 | // self.inner.time_to_first_parsed_ms, 82 | // ) 83 | // .unwrap() 84 | } 85 | } 86 | 87 | impl Encoder for StreamTiming { 88 | fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { 89 | let map = Term::map_new(env); 90 | map.map_put("start_time_utc_ms", self.inner.start_time_utc_ms) 91 | .unwrap() 92 | .map_put("duration_ms", self.inner.duration_ms) 93 | .unwrap() 94 | // TODO: BAML doesn't track this yet 95 | // .map_put( 96 | // "time_to_first_parsed_ms", 97 | // self.inner.time_to_first_parsed_ms, 98 | // ) 99 | // .unwrap() 100 | // TODO: BAML doesn't track this yet 101 | // .map_put("time_to_first_token_ms", self.inner.time_to_first_token_ms) 102 | // .unwrap() 103 | } 104 | } 105 | 106 | impl Encoder for LLMCall { 107 | fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { 108 | let map = Term::map_new(env); 109 | map.map_put("client_name", self.inner.client_name.clone()) 110 | .unwrap() 111 | .map_put("provider", self.inner.provider.clone()) 112 | .unwrap() 113 | .map_put( 114 | "timing", 115 | Timing { 116 | inner: self.inner.timing.clone(), 117 | }, 118 | ) 119 | .unwrap() 120 | .map_put( 121 | "request", 122 | self.inner.request.as_deref().map(|r| { 123 | let map = Term::map_new(env); 124 | map.map_put("method", r.method.clone()) 125 | .unwrap() 126 | .map_put("url", r.url.clone()) 127 | .unwrap() 128 | .map_put("headers", r.headers().clone()) 129 | .unwrap() 130 | .map_put("body", r.body.text().unwrap_or_default().encode(env)) 131 | .unwrap() 132 | }), 133 | ) 134 | .unwrap() 135 | .map_put( 136 | "response", 137 | self.inner.response.as_deref().map(|r| { 138 | let map = Term::map_new(env); 139 | map.map_put("status", r.status.clone()) 140 | .unwrap() 141 | .map_put("headers", r.headers()) 142 | .unwrap() 143 | .map_put("body", r.body.text().unwrap_or_default().encode(env)) 144 | .unwrap() 145 | }), 146 | ) 147 | .unwrap() 148 | .map_put( 149 | "usage", 150 | Usage { 151 | inner: self.inner.usage.clone().unwrap_or_default(), 152 | }, 153 | ) 154 | .unwrap() 155 | } 156 | } 157 | 158 | impl Encoder for LLMStreamCall { 159 | fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { 160 | let map = Term::map_new(env); 161 | map.map_put("client_name", self.inner.llm_call.client_name.clone()) 162 | .unwrap() 163 | .map_put("provider", self.inner.llm_call.provider.clone()) 164 | .unwrap() 165 | .map_put( 166 | "timing", 167 | StreamTiming { 168 | inner: self.inner.timing.clone(), 169 | }, 170 | ) 171 | .unwrap() 172 | .map_put( 173 | "request", 174 | self.inner.llm_call.request.as_deref().map(|r| { 175 | let map = Term::map_new(env); 176 | map.map_put("method", r.method.clone()) 177 | .unwrap() 178 | .map_put("url", r.url.clone()) 179 | .unwrap() 180 | .map_put("headers", r.headers()) 181 | .unwrap() 182 | .map_put("body", r.body.text().unwrap_or_default().encode(env)) 183 | .unwrap() 184 | }), 185 | ) 186 | .unwrap() 187 | .map_put( 188 | "response", 189 | self.inner.llm_call.response.as_deref().map(|r| { 190 | let map = Term::map_new(env); 191 | map.map_put("status", r.status.clone()) 192 | .unwrap() 193 | .map_put("headers", r.headers()) 194 | .unwrap() 195 | .map_put("body", r.body.text().unwrap_or_default().encode(env)) 196 | .unwrap() 197 | }), 198 | ) 199 | .unwrap() 200 | .map_put( 201 | "usage", 202 | Usage { 203 | inner: self.inner.llm_call.usage.clone().unwrap_or_default(), 204 | }, 205 | ) 206 | .unwrap() 207 | } 208 | } 209 | 210 | impl Encoder for LLMCallKind { 211 | fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { 212 | match &self.inner { 213 | baml_runtime::tracingv2::storage::storage::LLMCallKind::Basic(call) => LLMCall { 214 | inner: call.clone(), 215 | } 216 | .encode(env), 217 | baml_runtime::tracingv2::storage::storage::LLMCallKind::Stream(stream) => { 218 | LLMStreamCall { 219 | inner: stream.clone(), 220 | } 221 | .encode(env) 222 | } 223 | } 224 | } 225 | } 226 | 227 | impl Encoder for FunctionLog { 228 | fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { 229 | let map = Term::map_new(env); 230 | let mut inner = self.inner.lock().unwrap(); 231 | map.map_put("id", inner.id().to_string()) 232 | .unwrap() 233 | .map_put("function_name", inner.function_name()) 234 | .unwrap() 235 | .map_put("log_type", inner.log_type().clone()) 236 | .unwrap() 237 | .map_put( 238 | "timing", 239 | Timing { 240 | inner: inner.timing().clone(), 241 | }, 242 | ) 243 | .unwrap() 244 | .map_put( 245 | "usage", 246 | Usage { 247 | inner: inner.usage().clone(), 248 | }, 249 | ) 250 | .unwrap() 251 | .map_put( 252 | "calls", 253 | inner 254 | .calls() 255 | .iter() 256 | .map(|c| LLMCallKind { inner: c.clone() }.encode(env)) 257 | .collect::>(), 258 | ) 259 | .unwrap() 260 | .map_put( 261 | "raw_llm_response", 262 | inner.raw_llm_response().unwrap_or_default().encode(env), 263 | ) 264 | .unwrap() 265 | } 266 | } 267 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /test/baml_elixir_test.exs: -------------------------------------------------------------------------------- 1 | defmodule BamlElixirTest do 2 | use ExUnit.Case 3 | use BamlElixir.Client, path: "test/baml_src" 4 | 5 | import Mox 6 | 7 | alias BamlElixir.TypeBuilder 8 | 9 | doctest BamlElixir 10 | 11 | setup :set_mox_from_context 12 | setup :verify_on_exit! 13 | 14 | @tag :client_registry 15 | test "client_registry supports clients key (list form)" do 16 | client_registry = %{ 17 | primary: "InjectedClient", 18 | clients: [ 19 | %{ 20 | name: "InjectedClient", 21 | provider: "definitely-not-a-provider", 22 | retry_policy: nil, 23 | options: %{model: "gpt-4o-mini"} 24 | } 25 | ] 26 | } 27 | 28 | # parse: false to avoid any parsing work; we want to exercise registry decoding/validation 29 | assert {:error, msg} = 30 | BamlElixirTest.WhichModel.call(%{}, %{client_registry: client_registry, parse: false}) 31 | 32 | assert msg =~ "Invalid client provider" 33 | end 34 | 35 | @tag :client_registry 36 | test "client_registry supports clients key (map form)" do 37 | client_registry = %{ 38 | primary: "InjectedClient", 39 | clients: %{ 40 | "InjectedClient" => %{ 41 | provider: "definitely-not-a-provider", 42 | retry_policy: nil, 43 | options: %{model: "gpt-4o-mini"} 44 | } 45 | } 46 | } 47 | 48 | assert {:error, msg} = 49 | BamlElixirTest.WhichModel.call(%{}, %{client_registry: client_registry, parse: false}) 50 | 51 | assert msg =~ "Invalid client provider" 52 | end 53 | 54 | @tag :client_registry 55 | test "client_registry can inject and select a client not present in the BAML files (success path)" do 56 | BamlElixirTest.FakeOpenAIServer.expect_chat_completion("GPT") 57 | base_url = BamlElixirTest.FakeOpenAIServer.start_base_url() 58 | 59 | client_registry = %{ 60 | primary: "InjectedClient", 61 | clients: [ 62 | %{ 63 | name: "InjectedClient", 64 | provider: "openai-generic", 65 | retry_policy: nil, 66 | options: %{ 67 | base_url: base_url, 68 | api_key: "test-key", 69 | model: "gpt-4o-mini" 70 | } 71 | } 72 | ] 73 | } 74 | 75 | # This function declares `client GPT4` in the .baml file, so success here proves 76 | # `client_registry.primary` overrides the static client selection. 77 | assert {:ok, "GPT"} = 78 | BamlElixirTest.WhichModelUnion.call(%{}, %{client_registry: client_registry}) 79 | end 80 | 81 | @tag :client_registry 82 | test "client_registry passes clients[].options.headers into the HTTP request" do 83 | BamlElixirTest.FakeOpenAIServer.expect_chat_completion("GPT", %{ 84 | "x-test-header" => "hello-from-elixir" 85 | }) 86 | 87 | base_url = BamlElixirTest.FakeOpenAIServer.start_base_url() 88 | 89 | client_registry = %{ 90 | primary: "InjectedClient", 91 | clients: [ 92 | %{ 93 | name: "InjectedClient", 94 | provider: "openai-generic", 95 | retry_policy: nil, 96 | options: %{ 97 | base_url: base_url, 98 | api_key: "test-key", 99 | model: "gpt-4o-mini", 100 | headers: %{ 101 | "x-test-header" => "hello-from-elixir" 102 | } 103 | } 104 | } 105 | ] 106 | } 107 | 108 | assert {:ok, "GPT"} = 109 | BamlElixirTest.WhichModelUnion.call(%{}, %{client_registry: client_registry}) 110 | end 111 | 112 | test "parses into a struct" do 113 | assert {:ok, %BamlElixirTest.Person{name: "John Doe", age: 28}} = 114 | BamlElixirTest.ExtractPerson.call(%{info: "John Doe, 28, Engineer"}) 115 | end 116 | 117 | test "parsing into a struct with streaming" do 118 | pid = self() 119 | 120 | BamlElixirTest.ExtractPerson.stream(%{info: "John Doe, 28, Engineer"}, fn result -> 121 | send(pid, result) 122 | end) 123 | 124 | messages = wait_for_all_messages() 125 | 126 | # assert more than 1 partial message 127 | assert Enum.filter(messages, fn {type, _} -> type == :partial end) |> length() > 1 128 | 129 | assert Enum.filter(messages, fn {type, _} -> type == :done end) == [ 130 | {:done, %BamlElixirTest.Person{name: "John Doe", age: 28}} 131 | ] 132 | end 133 | 134 | test "parsing into a struct with sync_stream" do 135 | {:ok, agent_pid} = Agent.start_link(fn -> 0 end, name: :counter) 136 | 137 | assert {:ok, %BamlElixirTest.Person{name: "John Doe", age: 28}} = 138 | BamlElixirTest.ExtractPerson.sync_stream( 139 | %{info: "John Doe, 28, Engineer"}, 140 | fn _result -> 141 | Agent.update(agent_pid, fn count -> count + 1 end) 142 | end 143 | ) 144 | 145 | assert Agent.get(agent_pid, fn count -> count end) > 1 146 | end 147 | 148 | test "bool input and output" do 149 | assert {:ok, true} = BamlElixirTest.FlipSwitch.call(%{switch: false}) 150 | end 151 | 152 | test "parses into a struct with a type builder" do 153 | assert {:ok, 154 | %{ 155 | __baml_class__: "NewEmployeeFullyDynamic", 156 | employee_id: _, 157 | person: %{ 158 | name: "Foobar123", 159 | age: _, 160 | owned_houses_count: _, 161 | favorite_day: _, 162 | favorite_color: :RED, 163 | __baml_class__: "TestPerson" 164 | } 165 | }} = 166 | BamlElixirTest.CreateEmployee.call(%{}, %{ 167 | tb: [ 168 | %TypeBuilder.Class{ 169 | name: "TestPerson", 170 | fields: [ 171 | %TypeBuilder.Field{ 172 | name: "name", 173 | type: :string, 174 | description: "The name of the person - this should always be Foobar123" 175 | }, 176 | %TypeBuilder.Field{name: "age", type: :int}, 177 | %TypeBuilder.Field{name: "owned_houses_count", type: 1}, 178 | %TypeBuilder.Field{ 179 | name: "favorite_day", 180 | type: %TypeBuilder.Union{types: ["sunday", "monday"]} 181 | }, 182 | %TypeBuilder.Field{ 183 | name: "favorite_color", 184 | type: %TypeBuilder.Enum{name: "FavoriteColor"} 185 | } 186 | ] 187 | }, 188 | %TypeBuilder.Enum{ 189 | name: "FavoriteColor", 190 | values: [ 191 | %TypeBuilder.EnumValue{value: "RED", description: "Pick this always"}, 192 | %TypeBuilder.EnumValue{value: "GREEN"}, 193 | %TypeBuilder.EnumValue{value: "BLUE"} 194 | ] 195 | }, 196 | %TypeBuilder.Class{ 197 | name: "NewEmployeeFullyDynamic", 198 | fields: [ 199 | %TypeBuilder.Field{ 200 | name: "person", 201 | type: %TypeBuilder.Class{name: "TestPerson"} 202 | } 203 | ] 204 | } 205 | ] 206 | }) 207 | end 208 | 209 | test "parses type builder with nested types" do 210 | assert {:ok, 211 | %{ 212 | __baml_class__: "NewEmployeeFullyDynamic", 213 | employee_id: _, 214 | person: %{ 215 | __baml_class__: "ThisClassIsNotDefinedInTheBAMLFile", 216 | name: _, 217 | age: _, 218 | departments: list_of_deps, 219 | managers: list_of_managers, 220 | work_experience: work_exp_map 221 | } 222 | } = employee} = 223 | BamlElixirTest.CreateEmployee.call(%{}, %{ 224 | tb: [ 225 | %TypeBuilder.Class{ 226 | name: "NewEmployeeFullyDynamic", 227 | fields: [ 228 | %TypeBuilder.Field{ 229 | name: "person", 230 | type: %TypeBuilder.Class{ 231 | name: "ThisClassIsNotDefinedInTheBAMLFile", 232 | fields: [ 233 | %TypeBuilder.Field{name: "name", type: :string}, 234 | %TypeBuilder.Field{name: "age", type: :int}, 235 | %TypeBuilder.Field{ 236 | name: "departments", 237 | type: %TypeBuilder.List{ 238 | type: %TypeBuilder.Class{ 239 | name: "Department", 240 | fields: [ 241 | %TypeBuilder.Field{name: "name", type: :string}, 242 | %TypeBuilder.Field{name: "location", type: :string} 243 | ] 244 | } 245 | } 246 | }, 247 | %TypeBuilder.Field{ 248 | name: "managers", 249 | type: %TypeBuilder.List{type: :string} 250 | }, 251 | %TypeBuilder.Field{ 252 | name: "work_experience", 253 | type: %TypeBuilder.Map{ 254 | key_type: :string, 255 | value_type: :string 256 | } 257 | } 258 | ] 259 | } 260 | } 261 | ] 262 | } 263 | ] 264 | }) 265 | 266 | assert Enum.sort(Map.keys(employee)) == 267 | Enum.sort([:__baml_class__, :employee_id, :person]) 268 | 269 | assert Enum.sort(Map.keys(employee.person)) == 270 | Enum.sort([:__baml_class__, :name, :age, :departments, :managers, :work_experience]) 271 | 272 | assert is_list(list_of_deps) 273 | assert is_list(list_of_managers) 274 | assert is_map(work_exp_map) 275 | assert Enum.all?(work_exp_map, fn {key, value} -> is_binary(key) and is_binary(value) end) 276 | end 277 | 278 | test "change default model" do 279 | assert BamlElixirTest.WhichModel.call(%{}, %{llm_client: "GPT4"}) == {:ok, :GPT4oMini} 280 | assert BamlElixirTest.WhichModel.call(%{}, %{llm_client: "DeepSeekR1"}) == {:ok, :DeepSeekR1} 281 | end 282 | 283 | test "get union type" do 284 | assert BamlElixirTest.WhichModelUnion.call(%{}, %{llm_client: "GPT4"}) == {:ok, "GPT"} 285 | 286 | assert BamlElixirTest.WhichModelUnion.call(%{}, %{llm_client: "DeepSeekR1"}) == 287 | {:ok, "DeepSeek"} 288 | end 289 | 290 | test "Error when parsing the output of a function" do 291 | assert {:error, "Failed to coerce value" <> _} = BamlElixirTest.DummyOutputFunction.call(%{}) 292 | end 293 | 294 | test "get usage from collector" do 295 | collector = BamlElixir.Collector.new("test-collector") 296 | 297 | assert BamlElixirTest.WhichModel.call(%{}, %{llm_client: "GPT4", collectors: [collector]}) == 298 | {:ok, :GPT4oMini} 299 | 300 | usage = BamlElixir.Collector.usage(collector) 301 | assert usage["input_tokens"] == 33 302 | assert usage["output_tokens"] > 0 303 | end 304 | 305 | test "get usage from collector with streaming using GPT4" do 306 | collector = BamlElixir.Collector.new("test-collector") 307 | pid = self() 308 | 309 | BamlElixirTest.CreateEmployee.stream( 310 | %{}, 311 | fn result -> send(pid, result) end, 312 | %{llm_client: "GPT4", collectors: [collector]} 313 | ) 314 | 315 | _messages = wait_for_all_messages() 316 | 317 | usage = BamlElixir.Collector.usage(collector) 318 | assert usage["input_tokens"] == 32 319 | end 320 | 321 | test "get last function log from collector" do 322 | collector = BamlElixir.Collector.new("test-collector") 323 | 324 | assert BamlElixirTest.WhichModel.call(%{}, %{llm_client: "GPT4", collectors: [collector]}) == 325 | {:ok, :GPT4oMini} 326 | 327 | last_function_log = BamlElixir.Collector.last_function_log(collector) 328 | assert last_function_log["function_name"] == "WhichModel" 329 | 330 | response_body = 331 | last_function_log["calls"] 332 | |> Enum.at(0) 333 | |> Map.get("response") 334 | |> Map.get("body") 335 | |> Jason.decode!() 336 | 337 | assert response_body["usage"]["prompt_tokens_details"] == %{ 338 | "audio_tokens" => 0, 339 | "cached_tokens" => 0 340 | } 341 | 342 | assert Map.keys(last_function_log) == [ 343 | "calls", 344 | "function_name", 345 | "id", 346 | "log_type", 347 | "raw_llm_response", 348 | "timing", 349 | "usage" 350 | ] 351 | end 352 | 353 | test "get last function log from collector with streaming" do 354 | collector = BamlElixir.Collector.new("test-collector") 355 | pid = self() 356 | 357 | BamlElixirTest.CreateEmployee.stream( 358 | %{}, 359 | fn result -> send(pid, result) end, 360 | %{llm_client: "GPT4", collectors: [collector]} 361 | ) 362 | 363 | _messages = wait_for_all_messages() 364 | 365 | last_function_log = BamlElixir.Collector.last_function_log(collector) 366 | 367 | %{"messages" => messages} = 368 | last_function_log["calls"] 369 | |> Enum.at(0) 370 | |> Map.get("request") 371 | |> Map.get("body") 372 | |> Jason.decode!() 373 | 374 | assert messages == [ 375 | %{ 376 | "content" => [ 377 | %{ 378 | "text" => 379 | "Create a fake employee data with the following information:\nAnswer in JSON using this schema:\n{\n employee_id: string,\n}", 380 | "type" => "text" 381 | } 382 | ], 383 | "role" => "system" 384 | } 385 | ] 386 | end 387 | 388 | test "parsing of nested structs" do 389 | attendees = %BamlElixirTest.Attendees{ 390 | hosts: [ 391 | %BamlElixirTest.Person{name: "John Doe", age: 28}, 392 | %BamlElixirTest.Person{name: "Bob Johnson", age: 35} 393 | ], 394 | guests: [ 395 | %BamlElixirTest.Person{name: "Alice Smith", age: 25}, 396 | %BamlElixirTest.Person{name: "Carol Brown", age: 30}, 397 | %BamlElixirTest.Person{name: "Jane Doe", age: 28} 398 | ] 399 | } 400 | 401 | assert {:ok, attendees} == 402 | BamlElixirTest.ParseAttendees.call(%{ 403 | attendees: """ 404 | John Doe 28 - Host 405 | Alice Smith 25 - Guest 406 | Bob Johnson 35 - Host 407 | Carol Brown 30 - Guest 408 | Jane Doe 28 - Guest 409 | """ 410 | }) 411 | end 412 | 413 | defp wait_for_all_messages(messages \\ []) do 414 | receive do 415 | {:partial, _} = message -> 416 | wait_for_all_messages([message | messages]) 417 | 418 | {:done, _} = message -> 419 | [message | messages] |> Enum.reverse() 420 | 421 | {:error, message} -> 422 | raise "Error: #{inspect(message)}" 423 | end 424 | end 425 | end 426 | -------------------------------------------------------------------------------- /lib/baml_elixir/client.ex: -------------------------------------------------------------------------------- 1 | defmodule BamlElixir.Client do 2 | @moduledoc """ 3 | A client for interacting with BAML functions. 4 | Data structures and functions are generated from BAML source files. 5 | 6 | > #### `use BamlElixir.Client, path: "priv/baml_src"` {: .info} 7 | > 8 | > When you `use BamlElixir.Client`, it will define: 9 | > - A module for each function in the BAML source files with `call/2` and `stream/3` functions along with the types. 10 | > - A module with `defstruct/1` and `@type t/0` for each class in the BAML source file. 11 | > - A module with `@type t/0` for each enum in the BAML source file. 12 | > 13 | > The `path` option is optional and defaults to `"baml_src"`, you may want to set it to `"priv/baml_src"`. 14 | 15 | This module also provides functionality to call BAML functions either sync/async. 16 | """ 17 | 18 | defmacro __using__(opts) do 19 | path = Keyword.get(opts, :path, "baml_src") 20 | {baml_src_path, _} = Code.eval_quoted(app_path(path), [], __CALLER__) 21 | 22 | # Get all .baml files in the directory 23 | baml_files = get_baml_files(baml_src_path) 24 | 25 | # Add @external_resource for each BAML file to establish compile-time dependencies 26 | for baml_file <- baml_files do 27 | quote do 28 | @external_resource unquote(baml_file) 29 | end 30 | end 31 | 32 | # Get BAML types 33 | baml_types = BamlElixir.Native.parse_baml(baml_src_path) 34 | baml_class_types = baml_types[:classes] 35 | baml_enum_types = baml_types[:enums] 36 | baml_functions = baml_types[:functions] 37 | 38 | baml_class_types_quoted = generate_class_types(baml_class_types, __CALLER__) 39 | baml_enum_types_quoted = generate_enum_types(baml_enum_types, __CALLER__) 40 | baml_functions_quoted = generate_function_modules(baml_functions, path, __CALLER__) 41 | recompile_function = generate_recompile_function(baml_src_path, baml_files) 42 | 43 | quote do 44 | import BamlElixir.Client 45 | 46 | unquote(baml_class_types_quoted) 47 | unquote(baml_enum_types_quoted) 48 | unquote(baml_functions_quoted) 49 | unquote(recompile_function) 50 | end 51 | end 52 | 53 | @doc """ 54 | Calls a BAML function synchronously. 55 | 56 | ## Parameters 57 | - `function_name`: The name of the BAML function to call 58 | - `args`: A map of arguments to pass to the function 59 | - `opts`: A map of options 60 | - `path`: The path to the BAML source file 61 | - `collectors`: A list of collectors to use 62 | - `llm_client`: The name of the LLM client to use 63 | 64 | ## Returns 65 | - `{:ok, term()}` on success, where the term is the function's return value 66 | - `{:error, String.t()}` on failure, with an error message 67 | 68 | ## Examples 69 | {:ok, result} = BamlElixir.Client.call(client, "MyFunction", %{arg1: "value"}) 70 | """ 71 | @spec call(String.t(), map(), map()) :: 72 | {:ok, term()} | {:error, String.t()} 73 | def call(function_name, args, opts \\ %{}) do 74 | {path, collectors, client_registry, tb} = prepare_opts(opts) 75 | args = to_map(args) 76 | 77 | with {:ok, result} <- 78 | BamlElixir.Native.call(function_name, args, path, collectors, client_registry, tb) do 79 | result = 80 | if opts[:parse] != false do 81 | parse_result(result, opts[:prefix], tb) 82 | else 83 | result 84 | end 85 | 86 | {:ok, result} 87 | end 88 | end 89 | 90 | @doc """ 91 | Streams a BAML function asynchronously. 92 | 93 | ## Parameters 94 | - `function_name`: The name of the BAML function to stream 95 | - `args`: A map of arguments to pass to the function 96 | - `callback`: A function that will be called with the result of the function 97 | - `opts`: A map of options 98 | - `path`: The path to the BAML source file 99 | - `collectors`: A list of collectors to use 100 | - `llm_client`: The name of the LLM client to use 101 | 102 | """ 103 | def stream(function_name, args, callback, opts \\ %{}) do 104 | ref = make_ref() 105 | args = to_map(args) 106 | 107 | spawn_link(fn -> 108 | start_sync_stream(self(), ref, function_name, args, opts) 109 | handle_stream_result(ref, callback, opts) 110 | end) 111 | end 112 | 113 | @doc """ 114 | Streams partial output and also blocks until the function is done. 115 | Finally returns {:ok, result} or {:error, error} 116 | """ 117 | def sync_stream(function_name, args, callback, opts \\ %{}) do 118 | pid = self() 119 | 120 | stream( 121 | function_name, 122 | args, 123 | fn 124 | {:partial, result} -> 125 | callback.(result) 126 | 127 | result -> 128 | send(pid, {:done, result}) 129 | end, 130 | opts 131 | ) 132 | 133 | receive do 134 | {:done, {:error, error}} -> 135 | {:error, error} 136 | 137 | {:done, {:done, result}} -> 138 | {:ok, result} 139 | end 140 | end 141 | 142 | def app_path(path) do 143 | case path do 144 | {app, path} -> 145 | Application.app_dir(app, path) 146 | 147 | _ -> 148 | path 149 | end 150 | end 151 | 152 | # Get all .baml files in the specified directory 153 | def get_baml_files(baml_src_path) do 154 | if File.exists?(baml_src_path) and File.dir?(baml_src_path) do 155 | Path.wildcard(Path.join(baml_src_path, "**/*.baml")) 156 | else 157 | [] 158 | end 159 | end 160 | 161 | # Create a hash from a list of file paths 162 | def create_files_hash(file_paths) do 163 | file_paths 164 | |> Enum.map(&File.stat!/1) 165 | |> Enum.map(fn stat -> {stat.mtime, stat.size} end) 166 | |> inspect() 167 | |> :erlang.md5() 168 | end 169 | 170 | # Generate the __mix_recompile__?/0 function that checks if any .baml files have changed 171 | defp generate_recompile_function(baml_src_path, baml_files) do 172 | # Create a hash of all BAML files at compile time 173 | files_hash = create_files_hash(baml_files) 174 | 175 | quote do 176 | def __mix_recompile__?() do 177 | baml_src_path = unquote(baml_src_path) 178 | 179 | # Check if the directory still exists 180 | if not File.exists?(baml_src_path) or not File.dir?(baml_src_path) do 181 | true 182 | else 183 | # Get current BAML files and compare hashes 184 | current_baml_files = BamlElixir.Client.get_baml_files(baml_src_path) 185 | current_files_hash = BamlElixir.Client.create_files_hash(current_baml_files) 186 | current_files_hash != unquote(files_hash) 187 | end 188 | end 189 | end 190 | end 191 | 192 | defp start_sync_stream(pid, ref, function_name, args, opts) do 193 | {path, collectors, client_registry, tb} = prepare_opts(opts) 194 | 195 | spawn_link(fn -> 196 | result = 197 | BamlElixir.Native.stream( 198 | pid, 199 | ref, 200 | function_name, 201 | args, 202 | path, 203 | collectors, 204 | client_registry, 205 | tb 206 | ) 207 | 208 | send(pid, {ref, result}) 209 | end) 210 | end 211 | 212 | defp handle_stream_result(ref, callback, opts) do 213 | receive do 214 | {^ref, {:partial, result}} -> 215 | result = 216 | if opts[:parse] != false do 217 | parse_result(result, opts[:prefix], opts[:tb]) 218 | else 219 | result 220 | end 221 | 222 | callback.({:partial, result}) 223 | handle_stream_result(ref, callback, opts) 224 | 225 | {^ref, {:error, _} = msg} -> 226 | callback.(msg) 227 | 228 | {^ref, {:done, result}} -> 229 | result = 230 | if opts[:parse] != false do 231 | parse_result(result, opts[:prefix], opts[:tb]) 232 | else 233 | result 234 | end 235 | 236 | callback.({:done, result}) 237 | end 238 | end 239 | 240 | # Every class in the BAML source file is converted to an Elixir module 241 | # with a `defstruct/1` and a `@type t/0` type. 242 | defp generate_class_types(class_types, caller) do 243 | module = caller.module 244 | 245 | for {type_name, %{"fields" => fields, "dynamic" => dynamic}} <- class_types do 246 | field_names = get_field_names(fields) 247 | field_types = get_field_types(fields, caller) 248 | module_name = Module.concat([module, type_name]) 249 | 250 | quote do 251 | defmodule unquote(module_name) do 252 | defstruct unquote(field_names) 253 | @type t :: %__MODULE__{unquote_splicing(field_types)} 254 | 255 | def name, do: unquote(type_name) 256 | def type, do: :class 257 | def dynamic?, do: unquote(dynamic) 258 | end 259 | end 260 | end 261 | end 262 | 263 | # Every enum in the BAML source file is converted to an Elixir module 264 | # with a `@type t/0` type. 265 | defp generate_enum_types(enum_types, caller) do 266 | module = caller.module 267 | 268 | for {enum_name, variants} <- enum_types do 269 | variant_atoms = Enum.map(variants, &String.to_atom/1) 270 | module_name = Module.concat([module, enum_name]) 271 | 272 | union_type = 273 | Enum.reduce(variant_atoms, fn atom, acc -> 274 | {:|, [], [atom, acc]} 275 | end) 276 | 277 | quote do 278 | defmodule unquote(module_name) do 279 | @type t :: unquote(union_type) 280 | 281 | def name, do: unquote(enum_name) 282 | def values, do: unquote(variant_atoms) 283 | def type, do: :enum 284 | end 285 | end 286 | end 287 | end 288 | 289 | # Every function in the BAML source file is converted to an Elixir module 290 | # which has a `call/2` function and a `stream/3` function. 291 | defp generate_function_modules(functions, path, caller) do 292 | module = caller.module 293 | 294 | for {function_name, function_info} <- functions do 295 | module_name = Module.concat(module, function_name) 296 | 297 | param_types = 298 | for {param_name, param_type} <- function_info["params"] do 299 | {String.to_atom(param_name), to_elixir_type(param_type, caller)} 300 | end 301 | 302 | return_type = to_elixir_type(function_info["return_type"], caller) 303 | 304 | quote do 305 | defmodule unquote(module_name) do 306 | @spec call(%{unquote_splicing(param_types)}, map()) :: 307 | {:ok, unquote(return_type)} | {:error, String.t()} 308 | def call(args, opts \\ %{}) do 309 | opts = 310 | opts 311 | |> Map.put(:path, BamlElixir.Client.app_path(unquote(path))) 312 | |> Map.put(:prefix, unquote(module)) 313 | 314 | BamlElixir.Client.call(unquote(function_name), args, opts) 315 | end 316 | 317 | @spec stream( 318 | %{unquote_splicing(param_types)}, 319 | ({:ok, unquote(return_type) | {:error, String.t()} | :done} -> any()), 320 | map() 321 | ) :: 322 | Enumerable.t() 323 | def stream(args, callback, opts \\ %{}) do 324 | opts = 325 | opts 326 | |> Map.put(:path, BamlElixir.Client.app_path(unquote(path))) 327 | |> Map.put(:prefix, unquote(module)) 328 | 329 | BamlElixir.Client.stream(unquote(function_name), args, callback, opts) 330 | end 331 | 332 | @spec sync_stream( 333 | %{unquote_splicing(param_types)}, 334 | (unquote(return_type) -> any()), 335 | map() 336 | ) :: {:ok, unquote(return_type)} | {:error, String.t()} 337 | def sync_stream(args, callback, opts \\ %{}) do 338 | opts = 339 | opts 340 | |> Map.put(:path, BamlElixir.Client.app_path(unquote(path))) 341 | |> Map.put(:prefix, unquote(module)) 342 | 343 | BamlElixir.Client.sync_stream(unquote(function_name), args, callback, opts) 344 | end 345 | end 346 | end 347 | end 348 | end 349 | 350 | defp to_elixir_type(type, caller) do 351 | case type do 352 | {:primitive, primitive} -> 353 | case primitive do 354 | :string -> 355 | quote(do: String.t()) 356 | 357 | :integer -> 358 | quote(do: integer()) 359 | 360 | :float -> 361 | quote(do: float()) 362 | 363 | :boolean -> 364 | quote(do: boolean()) 365 | 366 | nil -> 367 | quote(do: nil) 368 | 369 | :media -> 370 | quote( 371 | do: 372 | %{url: String.t()} 373 | | %{url: String.t(), media_type: String.t()} 374 | | %{base64: String.t()} 375 | | %{base64: String.t(), media_type: String.t()} 376 | ) 377 | end 378 | 379 | {:enum, name} -> 380 | # Convert enum name to module reference with .t() 381 | module = Module.concat([caller.module, name]) 382 | quote(do: unquote(module).t()) 383 | 384 | {:class, name} -> 385 | # Convert class name to module reference with .t() 386 | module = Module.concat([caller.module, name]) 387 | quote(do: unquote(module).t()) 388 | 389 | {:list, inner_type} -> 390 | # Convert to list type 391 | quote(do: [unquote(to_elixir_type(inner_type, caller))]) 392 | 393 | {:map, key_type, value_type} -> 394 | # Convert to map type 395 | quote( 396 | do: %{ 397 | unquote(to_elixir_type(key_type, caller)) => 398 | unquote(to_elixir_type(value_type, caller)) 399 | } 400 | ) 401 | 402 | {:literal, value} -> 403 | # For literals, use the value directly 404 | case value do 405 | v when is_atom(v) -> v 406 | v when is_integer(v) -> v 407 | v when is_boolean(v) -> v 408 | end 409 | 410 | {:union, types} -> 411 | # Convert union to pipe operator 412 | [first_type | rest_types] = types 413 | first_ast = to_elixir_type(first_type, caller) 414 | 415 | Enum.reduce(rest_types, first_ast, fn type, acc -> 416 | {:|, [], [to_elixir_type(type, caller), acc]} 417 | end) 418 | 419 | {:tuple, types} -> 420 | # Convert to tuple type 421 | types_ast = Enum.map(types, &to_elixir_type(&1, caller)) 422 | {:{}, [], types_ast} 423 | 424 | {:optional, inner_type} -> 425 | # Convert optional to union with nil 426 | {:|, [], [to_elixir_type(inner_type, caller), nil]} 427 | 428 | {:alias, name} -> 429 | # For recursive type aliases, use the name with .t() 430 | module = String.to_atom(name) 431 | quote(do: unquote(module).t()) 432 | 433 | _ -> 434 | # Fallback to any 435 | quote(do: any()) 436 | end 437 | end 438 | 439 | defp get_field_names(fields) do 440 | for {field_name, _} <- fields do 441 | String.to_atom(field_name) 442 | end 443 | end 444 | 445 | defp get_field_types(fields, caller) do 446 | for {field_name, field_type} <- fields do 447 | elixir_type = to_elixir_type(field_type, caller) 448 | {String.to_atom(field_name), elixir_type} 449 | end 450 | end 451 | 452 | defp prepare_opts(opts) do 453 | path = opts[:path] || "baml_src" 454 | collectors = (opts[:collectors] || []) |> Enum.map(fn collector -> collector.reference end) 455 | 456 | client_registry = 457 | if opts[:client_registry] do 458 | opts[:client_registry] 459 | else 460 | if opts[:llm_client] do 461 | %{primary: opts[:llm_client]} 462 | else 463 | nil 464 | end 465 | end 466 | 467 | {path, collectors, client_registry, opts[:tb]} 468 | end 469 | 470 | # If type builder is provided, return as map instead of struct 471 | defp parse_result(%{:__baml_class__ => _class_name} = result, prefix, tb) 472 | when not is_nil(tb) do 473 | Map.new(result, fn {key, value} -> {key, parse_result(value, prefix, tb)} end) 474 | end 475 | 476 | defp parse_result(%{:__baml_class__ => class_name} = result, prefix, tb) do 477 | module = Module.concat(prefix, class_name) 478 | values = Enum.map(result, fn {key, value} -> {key, parse_result(value, prefix, tb)} end) 479 | struct(module, values) 480 | end 481 | 482 | defp parse_result(%{:__baml_enum__ => _, :value => value}, _prefix, _tb) do 483 | String.to_atom(value) 484 | end 485 | 486 | defp parse_result(list, prefix, tb) when is_list(list) do 487 | Enum.map(list, fn item -> parse_result(item, prefix, tb) end) 488 | end 489 | 490 | defp parse_result(result, _prefix, _tb) do 491 | result 492 | end 493 | 494 | defp to_map(args) when is_struct(args) do 495 | args 496 | |> Map.from_struct() 497 | |> to_map() 498 | end 499 | 500 | defp to_map(args) when is_map(args) do 501 | Map.new(args, fn {key, value} -> {key, to_map(value)} end) 502 | end 503 | 504 | defp to_map(args) when is_list(args) do 505 | Enum.map(args, &to_map/1) 506 | end 507 | 508 | defp to_map(args) do 509 | args 510 | end 511 | end 512 | -------------------------------------------------------------------------------- /native/baml_elixir/src/type_builder.rs: -------------------------------------------------------------------------------- 1 | use crate::Error; 2 | use baml_runtime::type_builder::{TypeBuilder, WithMeta}; 3 | use baml_types::{ir_type::UnionConstructor, LiteralValue, TypeIR}; 4 | use rustler::{Env, MapIterator, Term}; 5 | 6 | pub fn parse_type_builder_spec<'a>( 7 | env: Env<'a>, 8 | term: Term<'a>, 9 | builder: &TypeBuilder, 10 | ) -> Result<(), Error> { 11 | if !term.is_list() { 12 | return Err(Error::Term(Box::new( 13 | "TypeBuilder specification must be a list", 14 | ))); 15 | } 16 | 17 | // New format: list of TypeBuilder structs 18 | let list: Vec = term.decode()?; 19 | for item in list { 20 | parse_type_builder_item(env, item, builder)?; 21 | } 22 | Ok(()) 23 | } 24 | 25 | fn parse_type_builder_item<'a>( 26 | env: Env<'a>, 27 | term: Term<'a>, 28 | builder: &TypeBuilder, 29 | ) -> Result<(), Error> { 30 | if !term.is_map() { 31 | return Err(Error::Term(Box::new("TypeBuilder item must be a map"))); 32 | } 33 | 34 | let iter = MapIterator::new(term).ok_or(Error::Term(Box::new("Invalid map")))?; 35 | let mut item_type = None; 36 | 37 | for (key_term, value_term) in iter { 38 | let key = term_to_string(key_term)?; 39 | match key.as_str() { 40 | "__struct__" => match term_to_string(value_term) { 41 | Ok(struct_name) => { 42 | item_type = Some(struct_name); 43 | } 44 | Err(e) => { 45 | return Err(e); 46 | } 47 | }, 48 | _ => { 49 | // Ignore other fields for now 50 | } 51 | } 52 | } 53 | 54 | match item_type.as_deref() { 55 | Some("Elixir.BamlElixir.TypeBuilder.Class") => { 56 | parse_class_item(env, term, builder)?; 57 | } 58 | Some("Elixir.BamlElixir.TypeBuilder.Enum") => { 59 | parse_enum_item(term, builder)?; 60 | } 61 | Some(other) => { 62 | return Err(Error::Term(Box::new(format!( 63 | "Unsupported TypeBuilder struct: {}", 64 | other 65 | )))); 66 | } 67 | None => { 68 | return Err(Error::Term(Box::new("Missing __struct__ field"))); 69 | } 70 | } 71 | 72 | Ok(()) 73 | } 74 | 75 | fn parse_class_item<'a>( 76 | env: Env<'a>, 77 | class_term: Term<'a>, 78 | builder: &TypeBuilder, 79 | ) -> Result<(), Error> { 80 | if !class_term.is_map() { 81 | return Err(Error::Term(Box::new("Class data must be a map"))); 82 | } 83 | 84 | let iter = MapIterator::new(class_term).ok_or(Error::Term(Box::new("Invalid class map")))?; 85 | let mut class_name = None; 86 | let mut fields = None; 87 | 88 | for (key_term, value_term) in iter { 89 | let key = term_to_string(key_term)?; 90 | match key.as_str() { 91 | "name" => { 92 | class_name = Some(term_to_string(value_term)?); 93 | } 94 | "fields" => { 95 | fields = Some(value_term); 96 | } 97 | _ => {} 98 | } 99 | } 100 | 101 | let class_name = class_name.ok_or(Error::Term(Box::new("Class missing name field")))?; 102 | let fields = fields.ok_or(Error::Term(Box::new("Class missing fields")))?; 103 | 104 | // Create the class in the type builder 105 | let cls = builder.upsert_class(&class_name); 106 | let cls = cls.lock().unwrap(); 107 | 108 | if fields.is_list() { 109 | let field_list: Vec = fields.decode()?; 110 | for field_term in field_list { 111 | parse_field_item(env, field_term, builder, &class_name, &cls)?; 112 | } 113 | } else { 114 | return Err(Error::Term(Box::new("Class fields must be a list"))); 115 | } 116 | 117 | Ok(()) 118 | } 119 | 120 | fn parse_enum_item<'a>(enum_term: Term<'a>, builder: &TypeBuilder) -> Result<(), Error> { 121 | if !enum_term.is_map() { 122 | return Err(Error::Term(Box::new("Enum data must be a map"))); 123 | } 124 | 125 | let iter = MapIterator::new(enum_term).ok_or(Error::Term(Box::new("Invalid enum map")))?; 126 | let mut enum_name = None; 127 | let mut values = None; 128 | 129 | for (key_term, value_term) in iter { 130 | let key = term_to_string(key_term)?; 131 | match key.as_str() { 132 | "name" => { 133 | enum_name = Some(term_to_string(value_term)?); 134 | } 135 | "values" => { 136 | values = Some(value_term); 137 | } 138 | _ => {} 139 | } 140 | } 141 | 142 | let enum_name = enum_name.ok_or(Error::Term(Box::new("Enum missing name field")))?; 143 | let values = values.ok_or(Error::Term(Box::new("Enum missing values")))?; 144 | 145 | // Create the enum in the type builder 146 | let enum_builder = builder.upsert_enum(&enum_name); 147 | let enum_builder = enum_builder.lock().unwrap(); 148 | 149 | if values.is_list() { 150 | let value_list: Vec = values.decode()?; 151 | for value_term in value_list { 152 | parse_enum_value_item(value_term, &enum_builder)?; 153 | } 154 | } else { 155 | return Err(Error::Term(Box::new("Enum values must be a list"))); 156 | } 157 | 158 | Ok(()) 159 | } 160 | 161 | fn parse_enum_value_item<'a>( 162 | value_term: Term<'a>, 163 | enum_builder: &std::sync::MutexGuard, 164 | ) -> Result<(), Error> { 165 | let iter = 166 | MapIterator::new(value_term).ok_or(Error::Term(Box::new("Invalid enum value map")))?; 167 | let mut value_name = None; 168 | let mut description = None; 169 | 170 | for (key_term, value_term) in iter { 171 | let key = term_to_string(key_term)?; 172 | match key.as_str() { 173 | "__struct__" => { 174 | // Check if this is an EnumValue struct 175 | let struct_name = term_to_string(value_term)?; 176 | if struct_name != "Elixir.BamlElixir.TypeBuilder.EnumValue" { 177 | return Err(Error::Term(Box::new(format!( 178 | "Expected EnumValue struct, got: {}", 179 | struct_name 180 | )))); 181 | } 182 | } 183 | "value" => { 184 | value_name = Some(term_to_string(value_term)?); 185 | } 186 | "description" => { 187 | description = Some(term_to_string(value_term)?); 188 | } 189 | _ => {} 190 | } 191 | } 192 | 193 | let value_name = value_name.ok_or(Error::Term(Box::new("Enum value missing value field")))?; 194 | 195 | // Add the enum value 196 | let value_builder = enum_builder.upsert_value(&value_name); 197 | let value_builder = value_builder.lock().unwrap(); 198 | 199 | // Add description if provided 200 | if let Some(desc) = description { 201 | value_builder.with_meta("description", baml_types::BamlValue::String(desc)); 202 | } 203 | 204 | Ok(()) 205 | } 206 | 207 | fn parse_field_item<'a>( 208 | env: Env<'a>, 209 | field_term: Term<'a>, 210 | builder: &TypeBuilder, 211 | parent_class: &str, 212 | cls: &std::sync::MutexGuard, 213 | ) -> Result<(), Error> { 214 | if !field_term.is_map() { 215 | return Err(Error::Term(Box::new("Field must be a map"))); 216 | } 217 | 218 | let iter = MapIterator::new(field_term).ok_or(Error::Term(Box::new("Invalid field map")))?; 219 | let mut field_name = None; 220 | let mut field_type = None; 221 | let mut description = None; 222 | 223 | for (key_term, value_term) in iter { 224 | let key = term_to_string(key_term)?; 225 | match key.as_str() { 226 | "name" => { 227 | field_name = Some(term_to_string(value_term)?); 228 | } 229 | "type" => { 230 | field_type = Some(value_term); 231 | } 232 | "description" => { 233 | description = Some(term_to_string(value_term)?); 234 | } 235 | _ => {} 236 | } 237 | } 238 | 239 | let field_name = field_name.ok_or(Error::Term(Box::new("Missing field name")))?; 240 | let field_type_term = field_type.ok_or(Error::Term(Box::new("Missing field type")))?; 241 | 242 | let type_ir = parse_field_type( 243 | env, 244 | field_type_term, 245 | builder, 246 | Some(parent_class), 247 | Some(&field_name), 248 | )?; 249 | 250 | // Add the field to the class 251 | let property = cls.upsert_property(&field_name); 252 | let property = property.lock().unwrap(); 253 | property.set_type(type_ir); 254 | 255 | // Add description if provided 256 | if let Some(desc) = description { 257 | property.with_meta("description", baml_types::BamlValue::String(desc)); 258 | } 259 | 260 | Ok(()) 261 | } 262 | 263 | fn parse_field_type<'a>( 264 | env: Env<'a>, 265 | term: Term<'a>, 266 | builder: &TypeBuilder, 267 | parent_class: Option<&str>, 268 | field_name: Option<&str>, 269 | ) -> Result { 270 | if term.is_atom() { 271 | let atom_str = term 272 | .atom_to_string() 273 | .map_err(|_| Error::Term(Box::new("Invalid atom")))?; 274 | 275 | match atom_str.as_str() { 276 | "string" => Ok(TypeIR::string()), 277 | "int" => Ok(TypeIR::int()), 278 | "float" => Ok(TypeIR::float()), 279 | "bool" => Ok(TypeIR::bool()), 280 | _ => Ok(TypeIR::class(&atom_str)), 281 | } 282 | } else if let Ok(string_value) = term.decode::() { 283 | // Handle string literals like "1", "hello", etc. 284 | Ok(TypeIR::literal(LiteralValue::String(string_value))) 285 | } else if let Ok(int_value) = term.decode::() { 286 | // Handle integer literals like 1, 42, etc. 287 | Ok(TypeIR::literal(LiteralValue::Int(int_value))) 288 | } else if let Ok(bool_value) = term.decode::() { 289 | // Handle boolean literals like true, false 290 | Ok(TypeIR::literal(LiteralValue::Bool(bool_value))) 291 | } else if term.is_map() { 292 | // Check if this is a TypeBuilder struct 293 | let iter = 294 | MapIterator::new(term).ok_or(Error::Term(Box::new("Invalid map for object type")))?; 295 | let mut struct_type = None; 296 | 297 | for (key_term, value_term) in iter { 298 | let key = term_to_string(key_term)?; 299 | match key.as_str() { 300 | "__struct__" => { 301 | let struct_name = term_to_string(value_term)?; 302 | struct_type = Some(struct_name); 303 | } 304 | _ => {} 305 | } 306 | } 307 | 308 | match struct_type.as_deref() { 309 | Some("Elixir.BamlElixir.TypeBuilder.Class") => { 310 | // Check if this is a class definition (has fields) or class reference (only has name) 311 | let iter = 312 | MapIterator::new(term).ok_or(Error::Term(Box::new("Invalid class map")))?; 313 | let mut class_name = None; 314 | let mut has_fields = false; 315 | 316 | for (key_term, value_term) in iter { 317 | let key = term_to_string(key_term)?; 318 | match key.as_str() { 319 | "name" => { 320 | class_name = Some(term_to_string(value_term)?); 321 | } 322 | "fields" => { 323 | if value_term.is_list() { 324 | has_fields = true; 325 | } 326 | } 327 | _ => {} 328 | } 329 | } 330 | 331 | if let Some(name) = class_name { 332 | if has_fields { 333 | // This is a class definition, parse it 334 | parse_class_item(env, term, builder)?; 335 | } 336 | // Return the class type (whether it was just defined or already existed) 337 | return Ok(TypeIR::class(&name)); 338 | } 339 | Err(Error::Term(Box::new("Could not extract class name"))) 340 | } 341 | Some("Elixir.BamlElixir.TypeBuilder.List") => { 342 | // Extract the inner type from the list 343 | let iter = 344 | MapIterator::new(term).ok_or(Error::Term(Box::new("Invalid list map")))?; 345 | for (key_term, value_term) in iter { 346 | let key = term_to_string(key_term)?; 347 | if key == "type" { 348 | let inner_type = 349 | parse_field_type(env, value_term, builder, parent_class, field_name)?; 350 | return Ok(TypeIR::list(inner_type)); 351 | } 352 | } 353 | Err(Error::Term(Box::new("Could not extract list inner type"))) 354 | } 355 | Some("Elixir.BamlElixir.TypeBuilder.Map") => { 356 | // Extract key and value types from the map 357 | let mut key_type = None; 358 | let mut value_type = None; 359 | 360 | let iter = 361 | MapIterator::new(term).ok_or(Error::Term(Box::new("Invalid map map")))?; 362 | for (key_term, value_term) in iter { 363 | let key = term_to_string(key_term)?; 364 | match key.as_str() { 365 | "key_type" => { 366 | key_type = Some(parse_field_type( 367 | env, 368 | value_term, 369 | builder, 370 | parent_class, 371 | field_name, 372 | )?); 373 | } 374 | "value_type" => { 375 | value_type = Some(parse_field_type( 376 | env, 377 | value_term, 378 | builder, 379 | parent_class, 380 | field_name, 381 | )?); 382 | } 383 | _ => {} 384 | } 385 | } 386 | 387 | if let (Some(key), Some(value)) = (key_type, value_type) { 388 | return Ok(TypeIR::map(key, value)); 389 | } 390 | Err(Error::Term(Box::new( 391 | "Could not extract map key and value types", 392 | ))) 393 | } 394 | Some("Elixir.BamlElixir.TypeBuilder.Union") => { 395 | // Extract types from the union 396 | let iter = 397 | MapIterator::new(term).ok_or(Error::Term(Box::new("Invalid union map")))?; 398 | for (key_term, value_term) in iter { 399 | let key = term_to_string(key_term)?; 400 | if key == "types" { 401 | if value_term.is_list() { 402 | let types_list: Vec = value_term.decode()?; 403 | let mut union_types = Vec::new(); 404 | for type_term in types_list { 405 | // Recursively parse each type in the union 406 | let parsed_type = parse_field_type( 407 | env, 408 | type_term, 409 | builder, 410 | parent_class, 411 | field_name, 412 | )?; 413 | union_types.push(parsed_type); 414 | } 415 | return Ok(TypeIR::union(union_types)); 416 | } else { 417 | return Err(Error::Term(Box::new("Union types must be a list"))); 418 | } 419 | } 420 | } 421 | Err(Error::Term(Box::new("Could not extract union types"))) 422 | } 423 | Some("Elixir.BamlElixir.TypeBuilder.Enum") => { 424 | // Check if this is an enum definition (has values) or enum reference (only has name) 425 | let iter = 426 | MapIterator::new(term).ok_or(Error::Term(Box::new("Invalid enum map")))?; 427 | let mut enum_name = None; 428 | let mut has_values = false; 429 | 430 | for (key_term, value_term) in iter { 431 | let key = term_to_string(key_term)?; 432 | match key.as_str() { 433 | "name" => { 434 | enum_name = Some(term_to_string(value_term)?); 435 | } 436 | "values" => { 437 | if value_term.is_list() { 438 | has_values = true; 439 | } 440 | } 441 | _ => {} 442 | } 443 | } 444 | 445 | if let Some(name) = enum_name { 446 | if has_values { 447 | // This is an enum definition, parse it 448 | parse_enum_item(term, builder)?; 449 | } 450 | // Return the enum type (whether it was just defined or already existed) 451 | return Ok(TypeIR::r#enum(&name)); 452 | } 453 | Err(Error::Term(Box::new("Could not extract enum name"))) 454 | } 455 | _ => Err(Error::Term(Box::new(format!( 456 | "Unsupported TypeBuilder struct: {:?}", 457 | struct_type 458 | )))), 459 | } 460 | } else { 461 | Err(Error::Term(Box::new("Unsupported field type"))) 462 | } 463 | } 464 | 465 | // Helper function to convert a Term to a String 466 | fn term_to_string(term: Term) -> Result { 467 | if term.is_atom() { 468 | term.atom_to_string() 469 | .map_err(|_| Error::Term(Box::new("Invalid atom"))) 470 | } else if let Ok(string_value) = term.decode::() { 471 | Ok(string_value) 472 | } else { 473 | Err(Error::Term(Box::new("Term is not a string or atom"))) 474 | } 475 | } 476 | -------------------------------------------------------------------------------- /native/baml_elixir/src/lib.rs: -------------------------------------------------------------------------------- 1 | use baml_runtime::client_registry::{ClientProperty, ClientProvider, ClientRegistry}; 2 | use baml_runtime::tracingv2::storage::storage::Collector; 3 | use baml_runtime::type_builder::TypeBuilder; 4 | use baml_runtime::{BamlRuntime, FunctionResult, RuntimeContextManager, TripWire}; 5 | use baml_types::ir_type::UnionTypeViewGeneric; 6 | use baml_types::{BamlMap, BamlValue, LiteralValue, TypeIR}; 7 | use rustler::types::atom; 8 | 9 | use collector::{FunctionLog, Usage}; 10 | use rustler::{ 11 | Encoder, Env, Error, LocalPid, MapIterator, NifResult, NifStruct, ResourceArc, Term, 12 | }; 13 | use std::collections::HashMap; 14 | use std::path::Path; 15 | use std::str::FromStr; 16 | use std::sync::Arc; 17 | mod atoms { 18 | rustler::atoms! { 19 | partial, 20 | done, 21 | } 22 | } 23 | 24 | mod collector; 25 | mod type_builder; 26 | 27 | fn term_to_string(term: Term) -> Result { 28 | if term.is_atom() { 29 | term.atom_to_string().map(|s| s.to_owned()) 30 | } else { 31 | term.decode() 32 | } 33 | } 34 | 35 | fn term_to_baml_value<'a>(term: Term<'a>) -> Result { 36 | if term.is_number() { 37 | if let Ok(int) = term.decode::() { 38 | return Ok(BamlValue::Int(int)); 39 | } 40 | if let Ok(float) = term.decode::() { 41 | return Ok(BamlValue::Float(float)); 42 | } 43 | } 44 | 45 | if let Ok(string) = term.decode::() { 46 | return Ok(BamlValue::String(string)); 47 | } 48 | 49 | if let Ok(list) = term.decode::>() { 50 | let mut baml_list = Vec::new(); 51 | for item in list { 52 | baml_list.push(term_to_baml_value(item)?); 53 | } 54 | return Ok(BamlValue::List(baml_list)); 55 | } 56 | 57 | if term.is_map() { 58 | let mut map = BamlMap::new(); 59 | for (key_term, value_term) in 60 | MapIterator::new(term).ok_or(Error::Term(Box::new("Invalid map")))? 61 | { 62 | let key = term_to_string(key_term)?; 63 | let value = term_to_baml_value(value_term)?; 64 | map.insert(key, value); 65 | } 66 | return Ok(BamlValue::Map(map)); 67 | } 68 | 69 | if term.is_atom() && term.decode::()? == atom::nil() { 70 | return Ok(BamlValue::Null); 71 | } 72 | 73 | if term.is_atom() && term.decode::()? == atom::true_() { 74 | return Ok(BamlValue::Bool(true)); 75 | } 76 | 77 | if term.is_atom() && term.decode::()? == atom::false_() { 78 | return Ok(BamlValue::Bool(false)); 79 | } 80 | 81 | Err(Error::Term(Box::new(format!( 82 | "Unsupported type: {:?}", 83 | term 84 | )))) 85 | } 86 | 87 | fn term_to_optional_string(term: Term) -> Result, Error> { 88 | if term.is_atom() && term.decode::()? == atom::nil() { 89 | Ok(None) 90 | } else { 91 | Ok(Some(term_to_string(term)?)) 92 | } 93 | } 94 | 95 | fn term_to_baml_map(term: Term) -> Result, Error> { 96 | if term.is_atom() && term.decode::()? == atom::nil() { 97 | return Ok(BamlMap::new()); 98 | } 99 | if !term.is_map() { 100 | return Err(Error::Term(Box::new("Expected a map"))); 101 | } 102 | let mut map = BamlMap::new(); 103 | for (key_term, value_term) in 104 | MapIterator::new(term).ok_or(Error::Term(Box::new("Invalid map")))? 105 | { 106 | let key = term_to_string(key_term)?; 107 | let value = term_to_baml_value(value_term)?; 108 | map.insert(key, value); 109 | } 110 | Ok(map) 111 | } 112 | 113 | fn term_to_client_property(term: Term, name_override: Option) -> Result { 114 | if !term.is_map() { 115 | return Err(Error::Term(Box::new("Client must be a map"))); 116 | } 117 | 118 | let mut name: Option = name_override; 119 | let mut provider: Option = None; 120 | let mut retry_policy: Option = None; 121 | let mut options: BamlMap = BamlMap::new(); 122 | 123 | let iter = MapIterator::new(term).ok_or(Error::Term(Box::new("Invalid client map")))?; 124 | for (key_term, value_term) in iter { 125 | let key = term_to_string(key_term)?; 126 | match key.as_str() { 127 | "name" => { 128 | name = Some(term_to_string(value_term)?); 129 | } 130 | "provider" => { 131 | let provider_str = term_to_string(value_term)?; 132 | provider = Some( 133 | ClientProvider::from_str(&provider_str).map_err(|e| { 134 | Error::Term(Box::new(format!("Invalid client provider: {e}"))) 135 | })?, 136 | ); 137 | } 138 | "retry_policy" => { 139 | retry_policy = term_to_optional_string(value_term)?; 140 | } 141 | "options" => { 142 | options = term_to_baml_map(value_term)?; 143 | } 144 | _ => {} 145 | } 146 | } 147 | 148 | let name = name.ok_or(Error::Term(Box::new("Client missing required key: name")))?; 149 | let provider = provider.ok_or(Error::Term(Box::new("Client missing required key: provider")))?; 150 | 151 | Ok(ClientProperty::new(name, provider, retry_policy, options)) 152 | } 153 | 154 | fn baml_value_to_term<'a>(env: Env<'a>, value: &BamlValue) -> NifResult> { 155 | match value { 156 | BamlValue::String(s) => Ok(s.encode(env)), 157 | BamlValue::Int(i) => Ok(i.encode(env)), 158 | BamlValue::Float(f) => Ok(f.encode(env)), 159 | BamlValue::Bool(b) => Ok(b.encode(env)), 160 | BamlValue::Null => Ok(atom::nil().encode(env)), 161 | BamlValue::List(items) => { 162 | let terms: Result, Error> = items 163 | .iter() 164 | .map(|item| baml_value_to_term(env, item)) 165 | .collect(); 166 | Ok(terms?.encode(env)) 167 | } 168 | BamlValue::Map(map) => { 169 | let mut result_map = Term::map_new(env); 170 | for (key, value) in map.iter() { 171 | let value_term = baml_value_to_term(env, value)?; 172 | result_map = result_map 173 | .map_put(key.encode(env), value_term) 174 | .map_err(|_| Error::Term(Box::new("Failed to add key to map")))?; 175 | } 176 | Ok(result_map) 177 | } 178 | BamlValue::Class(class_name, map) => { 179 | let mut result_map = Term::map_new(env); 180 | let class_atom = rustler::Atom::from_str(env, "__baml_class__") 181 | .map_err(|_| Error::Term(Box::new("Failed to create atom")))?; 182 | result_map = result_map 183 | .map_put(class_atom.encode(env), class_name.encode(env)) 184 | .map_err(|_| Error::Term(Box::new("Failed to add class name")))?; 185 | for (key, value) in map.iter() { 186 | let key_atom = rustler::Atom::from_str(env, key) 187 | .map_err(|_| Error::Term(Box::new("Failed to create key atom")))?; 188 | let value_term = baml_value_to_term(env, value)?; 189 | result_map = result_map 190 | .map_put(key_atom.encode(env), value_term) 191 | .map_err(|_| Error::Term(Box::new("Failed to add key to map")))?; 192 | } 193 | Ok(result_map) 194 | } 195 | BamlValue::Media(_media) => { 196 | // For now, return an error since we need to check the actual BamlMedia structure 197 | Err(Error::Term(Box::new("Media type not yet supported"))) 198 | } 199 | BamlValue::Enum(enum_type, variant) => { 200 | // Convert enum to a map with __baml_enum__ and value 201 | let mut result_map = Term::map_new(env); 202 | let enum_atom = rustler::Atom::from_str(env, "__baml_enum__") 203 | .map_err(|_| Error::Term(Box::new("Failed to create enum atom")))?; 204 | let value_atom = rustler::Atom::from_str(env, "value") 205 | .map_err(|_| Error::Term(Box::new("Failed to create value atom")))?; 206 | result_map = result_map 207 | .map_put(enum_atom.encode(env), enum_type.encode(env)) 208 | .map_err(|_| Error::Term(Box::new("Failed to add enum type")))?; 209 | result_map = result_map 210 | .map_put(value_atom.encode(env), variant.encode(env)) 211 | .map_err(|_| Error::Term(Box::new("Failed to add enum variant")))?; 212 | Ok(result_map) 213 | } 214 | } 215 | } 216 | 217 | #[derive(NifStruct)] 218 | #[module = "BamlElixir.Client"] 219 | struct Client<'a> { 220 | from: String, 221 | client_registry: Term<'a>, 222 | collectors: Vec>, 223 | } 224 | 225 | fn prepare_request<'a>( 226 | env: Env<'a>, 227 | args: Term<'a>, 228 | path: String, 229 | collectors: Vec>, 230 | client_registry: Term<'a>, 231 | tb_elixir: Term<'a>, 232 | ) -> Result< 233 | ( 234 | BamlRuntime, 235 | BamlMap, 236 | RuntimeContextManager, 237 | Option>>, 238 | Option, 239 | Option, 240 | ), 241 | Error, 242 | > { 243 | let runtime = match BamlRuntime::from_directory( 244 | &Path::new(&path), 245 | std::env::vars().collect(), 246 | internal_baml_core::feature_flags::FeatureFlags::new(), 247 | ) { 248 | Ok(r) => r, 249 | Err(e) => return Err(Error::Term(Box::new(e.to_string()))), 250 | }; 251 | 252 | // Convert args to BamlMap 253 | let mut params = BamlMap::new(); 254 | if args.is_map() { 255 | let iter = MapIterator::new(args).ok_or(Error::Term(Box::new("Invalid map")))?; 256 | for (key_term, value_term) in iter { 257 | let key = term_to_string(key_term)?; 258 | let value = term_to_baml_value(value_term)?; 259 | params.insert(key.clone(), value); 260 | } 261 | } else { 262 | return Err(Error::Term(Box::new("Arguments must be a map"))); 263 | } 264 | 265 | // Create context 266 | let ctx = runtime.create_ctx_manager( 267 | BamlValue::String("elixir".to_string()), 268 | None, // baml source reader 269 | ); 270 | 271 | let collectors = if collectors.is_empty() { 272 | None 273 | } else { 274 | Some(collectors.iter().map(|c| c.inner.clone()).collect()) 275 | }; 276 | 277 | let client_registry = 278 | if client_registry.is_atom() && client_registry.decode::()? == atom::nil() { 279 | None 280 | } else if client_registry.is_map() { 281 | let mut registry = ClientRegistry::new(); 282 | let iter = MapIterator::new(client_registry) 283 | .ok_or(Error::Term(Box::new("Invalid registry map")))?; 284 | for (key_term, value_term) in iter { 285 | let key = term_to_string(key_term)?; 286 | if key == "primary" { 287 | let primary = term_to_string(value_term)?; 288 | registry.set_primary(primary); 289 | } else if key == "clients" { 290 | // Accept either: 291 | // - a list of client maps: [%{name: ..., provider: ..., ...}, ...] 292 | // - a map of name => client map: %{ "name" => %{provider: ..., ...}, ... } 293 | if let Ok(list) = value_term.decode::>() { 294 | for client_term in list { 295 | let client = term_to_client_property(client_term, None)?; 296 | registry.add_client(client); 297 | } 298 | } else if value_term.is_map() { 299 | let client_iter = MapIterator::new(value_term) 300 | .ok_or(Error::Term(Box::new("Invalid clients map")))?; 301 | for (name_term, client_term) in client_iter { 302 | let name = term_to_string(name_term)?; 303 | let client = term_to_client_property(client_term, Some(name))?; 304 | registry.add_client(client); 305 | } 306 | } else if value_term.is_atom() 307 | && value_term.decode::()? == atom::nil() 308 | { 309 | // allow nil clients 310 | } else { 311 | return Err(Error::Term(Box::new( 312 | "Client registry clients must be a list, a map, or nil", 313 | ))); 314 | } 315 | } 316 | } 317 | Some(registry) 318 | } else { 319 | return Err(Error::Term(Box::new( 320 | "Client registry must be nil or a map", 321 | ))); 322 | }; 323 | 324 | let tb = if tb_elixir.is_list() { 325 | let builder = TypeBuilder::new(); 326 | 327 | // Use the parse_type_builder_spec function from type_builder module 328 | if let Err(e) = type_builder::parse_type_builder_spec(env, tb_elixir, &builder) { 329 | return Err(e); 330 | } 331 | 332 | Some(builder) 333 | } else { 334 | None 335 | }; 336 | 337 | Ok((runtime, params, ctx, collectors, client_registry, tb)) 338 | } 339 | 340 | fn parse_function_result_call<'a>(env: Env<'a>, result: FunctionResult) -> NifResult> { 341 | let parsed_value = result.parsed(); 342 | match parsed_value { 343 | Some(Ok(response_baml_value)) => { 344 | let baml_value = response_baml_value.0.clone().value(); 345 | let result_term = baml_value_to_term(env, &baml_value)?; 346 | Ok((atom::ok(), result_term).encode(env)) 347 | } 348 | Some(Err(e)) => Ok((atom::error(), format!("{:?}", e)).encode(env)), 349 | None => Ok((atom::error(), "No parsed value available").encode(env)), 350 | } 351 | } 352 | 353 | fn parse_function_result_stream<'a>( 354 | env: Env<'a>, 355 | result: FunctionResult, 356 | ) -> Result, String> { 357 | let parsed_value = result.parsed(); 358 | match parsed_value { 359 | Some(Ok(response_baml_value)) => { 360 | let baml_value = response_baml_value.0.clone().value(); 361 | let result_term = baml_value_to_term(env, &baml_value) 362 | .map_err(|e| format!("Failed to convert BAML value to term: {:?}", e))?; 363 | Ok(result_term) 364 | } 365 | Some(Err(e)) => Err(e.to_string()), 366 | None => Err("No parsed value available".to_string()), 367 | } 368 | } 369 | 370 | #[rustler::nif(schedule = "DirtyIo")] 371 | fn call<'a>( 372 | env: Env<'a>, 373 | function_name: String, 374 | arguments: Term<'a>, 375 | path: String, 376 | collectors: Vec>, 377 | client_registry: Term<'a>, 378 | tb: Term<'a>, 379 | ) -> NifResult> { 380 | let (runtime, params, ctx, collectors, client_registry, tb) = 381 | prepare_request(env, arguments, path, collectors, client_registry, tb)?; 382 | 383 | // Call function synchronously 384 | let (result, _trace_id) = runtime.call_function_sync( 385 | function_name, 386 | ¶ms, 387 | &ctx, 388 | tb.as_ref(), // type builder (optional) 389 | client_registry.as_ref(), // client registry (optional) 390 | collectors, 391 | std::env::vars().collect(), 392 | TripWire::new(None), // TODO: Add tripwire 393 | ); 394 | 395 | // Handle result 396 | match result { 397 | Ok(function_result) => parse_function_result_call(env, function_result), 398 | Err(e) => Ok((atom::error(), format!("{:?}", e)).encode(env)), 399 | } 400 | } 401 | 402 | #[rustler::nif(schedule = "DirtyIo")] 403 | fn stream<'a>( 404 | env: Env<'a>, 405 | pid: Term<'a>, 406 | reference: Term<'a>, 407 | function_name: String, 408 | arguments: Term<'a>, 409 | path: String, 410 | collectors: Vec>, 411 | client_registry: Term<'a>, 412 | tb: Term<'a>, 413 | ) -> NifResult> { 414 | let pid = pid.decode::()?; 415 | let (runtime, params, ctx, collectors, client_registry, tb) = 416 | prepare_request(env, arguments, path, collectors, client_registry, tb)?; 417 | 418 | let on_event = |r: FunctionResult| { 419 | match parse_function_result_stream(env, r) { 420 | Ok(result_term) => { 421 | let wrapped_result = (reference, (atoms::partial(), result_term)).encode(env); 422 | let _ = env.send(&pid, wrapped_result); 423 | } 424 | Err(_) => { 425 | // Do nothing on error because this can happen when 426 | // the result cannot be coerced to a BAML value. 427 | // This can happen when the result is incomplete. 428 | // We'll get the final result and check for a real error then. 429 | return; 430 | } 431 | } 432 | }; 433 | 434 | let result = runtime.stream_function( 435 | function_name, 436 | ¶ms, 437 | &ctx, 438 | tb.as_ref(), 439 | client_registry.as_ref(), 440 | collectors, 441 | std::env::vars().collect(), 442 | TripWire::new(None), // TODO: Add tripwire 443 | ); 444 | 445 | match result { 446 | Ok(mut stream) => { 447 | let (result, _trace_id) = stream.run_sync( 448 | None::, 449 | Some(on_event), 450 | &ctx, 451 | None, 452 | None, 453 | std::env::vars().collect(), 454 | ); 455 | match result { 456 | Ok(r) => match r.parsed() { 457 | Some(Ok(result)) => { 458 | let baml_value = result.0.clone().value(); 459 | let result_term = baml_value_to_term(env, &baml_value)?; 460 | Ok((atoms::done(), result_term).encode(env)) 461 | } 462 | Some(Err(e)) => Ok((atom::error(), format!("{:?}", e)).encode(env)), 463 | None => Ok((atom::error(), "No parsed value available").encode(env)), 464 | }, 465 | Err(e) => Ok((atom::error(), format!("{:?}", e)).encode(env)), 466 | } 467 | } 468 | Err(e) => Ok((atom::error(), format!("{:?}", e)).encode(env)), 469 | } 470 | } 471 | 472 | #[rustler::nif] 473 | fn collector_new(name: Option) -> ResourceArc { 474 | collector::CollectorResource::new(name) 475 | } 476 | 477 | #[rustler::nif] 478 | fn collector_usage(collector: ResourceArc) -> Usage { 479 | collector.usage() 480 | } 481 | 482 | #[rustler::nif] 483 | fn collector_last_function_log( 484 | collector: ResourceArc, 485 | ) -> Option { 486 | collector.last_function_log() 487 | } 488 | 489 | #[rustler::nif] 490 | fn parse_baml(env: Env, path: Option) -> NifResult { 491 | let path = path.unwrap_or_else(|| "baml_src".to_string()); 492 | 493 | // Create runtime 494 | let runtime = match BamlRuntime::from_directory( 495 | &Path::new(&path), 496 | std::env::vars().collect(), 497 | internal_baml_core::feature_flags::FeatureFlags::new(), 498 | ) { 499 | Ok(r) => r, 500 | Err(e) => return Err(Error::Term(Box::new(e.to_string()))), 501 | }; 502 | 503 | let ir = runtime.inner.ir.clone(); 504 | 505 | // Create a map of the classes and their fields along with their types 506 | let mut class_fields = HashMap::new(); 507 | let mut class_attributes = HashMap::new(); 508 | for class in ir.walk_classes() { 509 | let mut fields = HashMap::new(); 510 | for field in class.walk_fields() { 511 | let field_type = to_elixir_type(env, &field.r#type()); 512 | fields.insert(field.name().to_string(), field_type); 513 | } 514 | class_fields.insert(class.name().to_string(), fields); 515 | 516 | // Check if class has @@dynamic attribute 517 | let is_dynamic = class.item.attributes.get("dynamic_type").is_some(); 518 | class_attributes.insert(class.name().to_string(), is_dynamic); 519 | } 520 | 521 | // Create a map of the enums and their variants 522 | let mut enum_variants = HashMap::new(); 523 | for r#enum in ir.walk_enums() { 524 | let mut variants = Vec::new(); 525 | for variant in r#enum.walk_values() { 526 | variants.push(variant.name().to_string()); 527 | } 528 | enum_variants.insert(r#enum.name().to_string(), variants); 529 | } 530 | 531 | // Create a map of the functions and their parameters 532 | let mut function_params = HashMap::new(); 533 | for function in ir.walk_functions() { 534 | let mut params = HashMap::new(); 535 | 536 | // Get input parameters 537 | for (name, field_type) in function.inputs() { 538 | let param_type = to_elixir_type(env, field_type); 539 | params.insert(name.to_string(), param_type); 540 | } 541 | 542 | // Get return type 543 | let return_type = to_elixir_type(env, &function.output()); 544 | 545 | function_params.insert(function.name().to_string(), (params, return_type)); 546 | } 547 | 548 | // convert to elixir map term 549 | let mut map = Term::map_new(env); 550 | 551 | // Add classes 552 | let mut classes_map = Term::map_new(env); 553 | for (class_name, fields) in class_fields { 554 | let mut class_map = Term::map_new(env); 555 | 556 | // Add fields 557 | let mut field_map = Term::map_new(env); 558 | for (field_name, field_type) in fields { 559 | field_map = field_map.map_put(field_name.encode(env), field_type)?; 560 | } 561 | class_map = class_map.map_put("fields".encode(env), field_map)?; 562 | 563 | // Add dynamic attribute 564 | let is_dynamic = class_attributes.get(&class_name).unwrap_or(&false); 565 | class_map = class_map.map_put("dynamic".encode(env), is_dynamic.encode(env))?; 566 | 567 | classes_map = classes_map.map_put(class_name.encode(env), class_map)?; 568 | } 569 | map = map.map_put( 570 | rustler::Atom::from_str(env, "classes").unwrap().encode(env), 571 | classes_map, 572 | )?; 573 | 574 | // Add enums 575 | let mut enums_map = Term::map_new(env); 576 | for (enum_name, variants) in enum_variants { 577 | let variants_list = variants.encode(env); 578 | enums_map = enums_map.map_put(enum_name.encode(env), variants_list)?; 579 | } 580 | map = map.map_put( 581 | rustler::Atom::from_str(env, "enums").unwrap().encode(env), 582 | enums_map, 583 | )?; 584 | 585 | // Add functions 586 | let mut functions_map = Term::map_new(env); 587 | for (function_name, (params, return_type)) in function_params { 588 | let mut function_map = Term::map_new(env); 589 | 590 | // Add parameters 591 | let mut params_map = Term::map_new(env); 592 | for (param_name, param_type) in params { 593 | params_map = params_map.map_put(param_name.encode(env), param_type)?; 594 | } 595 | function_map = function_map.map_put("params".encode(env), params_map)?; 596 | 597 | // Add return type 598 | function_map = function_map.map_put("return_type".encode(env), return_type)?; 599 | 600 | functions_map = functions_map.map_put(function_name.encode(env), function_map)?; 601 | } 602 | map = map.map_put( 603 | rustler::Atom::from_str(env, "functions") 604 | .unwrap() 605 | .encode(env), 606 | functions_map, 607 | )?; 608 | 609 | Ok(map) 610 | } 611 | 612 | fn to_elixir_type<'a>(env: Env<'a>, field_type: &TypeIR) -> Term<'a> { 613 | match field_type { 614 | TypeIR::Top(_) => panic!( 615 | "TypeIR::Top should have been resolved by the compiler before code generation. \ 616 | This indicates a bug in the type resolution phase." 617 | ), 618 | TypeIR::Enum { name, .. } => { 619 | // Return {:enum, name} 620 | (rustler::Atom::from_str(env, "enum").unwrap(), name).encode(env) 621 | } 622 | TypeIR::Class { name, .. } => { 623 | // Return {:class, name} 624 | (rustler::Atom::from_str(env, "class").unwrap(), name).encode(env) 625 | } 626 | TypeIR::List(inner, _) => { 627 | // Return {:list, inner_type} 628 | let inner_type = to_elixir_type(env, inner); 629 | (rustler::Atom::from_str(env, "list").unwrap(), inner_type).encode(env) 630 | } 631 | TypeIR::Map(key, value, _) => { 632 | // Return {:map, key_type, value_type} 633 | let key_type = to_elixir_type(env, key); 634 | let value_type = to_elixir_type(env, value); 635 | ( 636 | rustler::Atom::from_str(env, "map").unwrap(), 637 | key_type, 638 | value_type, 639 | ) 640 | .encode(env) 641 | } 642 | TypeIR::Primitive(r#type, _) => { 643 | // Return {:primitive, primitive_value} 644 | let primitive_value = match r#type { 645 | baml_types::TypeValue::String => rustler::Atom::from_str(env, "string").unwrap(), 646 | baml_types::TypeValue::Int => rustler::Atom::from_str(env, "integer").unwrap(), 647 | baml_types::TypeValue::Float => rustler::Atom::from_str(env, "float").unwrap(), 648 | baml_types::TypeValue::Bool => rustler::Atom::from_str(env, "boolean").unwrap(), 649 | baml_types::TypeValue::Null => atom::nil(), 650 | baml_types::TypeValue::Media(_) => rustler::Atom::from_str(env, "media").unwrap(), 651 | }; 652 | ( 653 | rustler::Atom::from_str(env, "primitive").unwrap(), 654 | primitive_value, 655 | ) 656 | .encode(env) 657 | } 658 | TypeIR::Literal(value, _) => { 659 | // Return {:literal, value} 660 | let literal_value = match value { 661 | LiteralValue::String(s) => rustler::Atom::from_str(env, &s).unwrap().encode(env), 662 | LiteralValue::Int(i) => i.encode(env), 663 | LiteralValue::Bool(b) => b.encode(env), 664 | }; 665 | ( 666 | rustler::Atom::from_str(env, "literal").unwrap(), 667 | literal_value, 668 | ) 669 | .encode(env) 670 | } 671 | TypeIR::Union(inner, _) => match inner.view() { 672 | UnionTypeViewGeneric::Null => (atom::nil()).encode(env), 673 | UnionTypeViewGeneric::Optional(inner) => { 674 | // Return {:optional, type} 675 | let inner_type = to_elixir_type(env, inner); 676 | ( 677 | rustler::Atom::from_str(env, "optional").unwrap(), 678 | inner_type, 679 | ) 680 | .encode(env) 681 | } 682 | UnionTypeViewGeneric::OneOf(inner) => { 683 | // Return {:union, list_of_types} 684 | let types: Vec = inner.iter().map(|t| to_elixir_type(env, t)).collect(); 685 | (rustler::Atom::from_str(env, "union").unwrap(), types).encode(env) 686 | } 687 | UnionTypeViewGeneric::OneOfOptional(inner) => { 688 | // Return {:optional, {:union, list_of_types}} 689 | let types: Vec = inner.iter().map(|t| to_elixir_type(env, t)).collect(); 690 | ( 691 | rustler::Atom::from_str(env, "optional").unwrap(), 692 | (rustler::Atom::from_str(env, "union").unwrap(), types), 693 | ) 694 | .encode(env) 695 | } 696 | }, 697 | TypeIR::Tuple(inner, _) => { 698 | // Return {:tuple, list_of_types} 699 | let types: Vec = inner.iter().map(|t| to_elixir_type(env, t)).collect(); 700 | (rustler::Atom::from_str(env, "tuple").unwrap(), types).encode(env) 701 | } 702 | TypeIR::RecursiveTypeAlias { name, .. } => { 703 | // Return {:alias, name} 704 | (rustler::Atom::from_str(env, "alias").unwrap(), name).encode(env) 705 | } 706 | TypeIR::Arrow(..) => { 707 | // Arrow types are not supported in Elixir type specs 708 | panic!("Arrow types are not supported in Elixir") 709 | } 710 | } 711 | } 712 | 713 | rustler::init!("Elixir.BamlElixir.Native"); 714 | --------------------------------------------------------------------------------