├── .formatter.exs
├── .github
├── images
│ ├── background.jpg
│ ├── kino_bumblebee_token_classification.png
│ └── phx_image_classification.png
└── workflows
│ ├── nightly.yml
│ └── test.yaml
├── .gitignore
├── CHANGELOG.md
├── LICENSE
├── README.md
├── config
└── config.exs
├── examples
└── phoenix
│ ├── README.md
│ ├── image_classification.exs
│ ├── speech_to_text.exs
│ └── text_classification.exs
├── lib
├── bumblebee.ex
└── bumblebee
│ ├── application.ex
│ ├── audio.ex
│ ├── audio
│ ├── speech_to_text_whisper.ex
│ ├── whisper.ex
│ └── whisper_featurizer.ex
│ ├── configurable.ex
│ ├── conversion
│ ├── pytorch_loader.ex
│ ├── pytorch_loader
│ │ └── file_tensor.ex
│ └── pytorch_params.ex
│ ├── diffusion
│ ├── controlnet.ex
│ ├── ddim_scheduler.ex
│ ├── layers.ex
│ ├── layers
│ │ └── unet.ex
│ ├── lcm_scheduler.ex
│ ├── pndm_scheduler.ex
│ ├── scheduler_utils.ex
│ ├── stable_diffusion.ex
│ ├── stable_diffusion
│ │ └── safety_checker.ex
│ ├── stable_diffusion_controlnet.ex
│ ├── unet_2d_conditional.ex
│ └── vae_kl.ex
│ ├── featurizer.ex
│ ├── huggingface
│ ├── hub.ex
│ └── transformers
│ │ ├── config.ex
│ │ ├── model.ex
│ │ └── utils.ex
│ ├── layers.ex
│ ├── layers
│ ├── decoder.ex
│ └── transformer.ex
│ ├── model_spec.ex
│ ├── multimodal
│ ├── blip.ex
│ ├── clip.ex
│ └── layout_lm.ex
│ ├── scheduler.ex
│ ├── shared.ex
│ ├── shared
│ └── converters.ex
│ ├── text.ex
│ ├── text
│ ├── albert.ex
│ ├── bart.ex
│ ├── bert.ex
│ ├── blenderbot.ex
│ ├── blip_text.ex
│ ├── clip_text.ex
│ ├── distilbert.ex
│ ├── fill_mask.ex
│ ├── gemma.ex
│ ├── generation.ex
│ ├── generation
│ │ └── logits_processing.ex
│ ├── generation_config.ex
│ ├── gpt2.ex
│ ├── gpt_big_code.ex
│ ├── gpt_neo_x.ex
│ ├── llama.ex
│ ├── m2m100.ex
│ ├── mbart.ex
│ ├── mistral.ex
│ ├── phi.ex
│ ├── phi3.ex
│ ├── pre_trained_tokenizer.ex
│ ├── question_answering.ex
│ ├── roberta.ex
│ ├── t5.ex
│ ├── text_classification.ex
│ ├── text_embedding.ex
│ ├── text_generation.ex
│ ├── token_classification.ex
│ ├── translation.ex
│ ├── whisper_generation_config.ex
│ └── zero_shot_classification.ex
│ ├── tokenizer.ex
│ ├── utils.ex
│ ├── utils
│ ├── axon.ex
│ ├── http.ex
│ ├── image.ex
│ ├── model.ex
│ └── nx.ex
│ ├── vision.ex
│ └── vision
│ ├── bit_featurizer.ex
│ ├── blip_featurizer.ex
│ ├── blip_vision.ex
│ ├── clip_featurizer.ex
│ ├── clip_vision.ex
│ ├── convnext.ex
│ ├── convnext_featurizer.ex
│ ├── deit.ex
│ ├── deit_featurizer.ex
│ ├── dino_v2.ex
│ ├── image_classification.ex
│ ├── image_embedding.ex
│ ├── image_to_text.ex
│ ├── resnet.ex
│ ├── swin.ex
│ ├── vit.ex
│ └── vit_featurizer.ex
├── mix.exs
├── mix.lock
├── notebooks
├── examples.livemd
├── fine_tuning.livemd
├── llms.livemd
├── llms_rag.livemd
└── stable_diffusion.livemd
└── test
├── bumblebee
├── audio
│ ├── speech_to_text_whisper_test.exs
│ ├── whisper_featurizer_test.exs
│ └── whisper_test.exs
├── conversion
│ ├── pytorch_loader_test.exs
│ └── pytorch_params_test.exs
├── diffusion
│ ├── controlnet_test.exs
│ ├── ddim_scheduler_test.exs
│ ├── lcm_scheduler_test.exs
│ ├── pndm_scheduler_test.exs
│ ├── stable_diffusion
│ │ └── safety_checker_test.exs
│ ├── stable_diffusion_controlnet_test.exs
│ ├── stable_diffusion_test.exs
│ ├── unet_2d_conditional_test.exs
│ └── vae_kl_test.exs
├── huggingface
│ └── hub_test.exs
├── multimodal
│ ├── blip_test.exs
│ ├── clip_test.exs
│ └── layout_lm_test.exs
├── shared_test.exs
├── text
│ ├── albert_test.exs
│ ├── bart_test.exs
│ ├── bert_test.exs
│ ├── blenderbot_test.exs
│ ├── blip_text_test.exs
│ ├── camembert_test.exs
│ ├── clip_text_test.exs
│ ├── distilbert_test.exs
│ ├── fill_mask_test.exs
│ ├── gemma_test.exs
│ ├── generation
│ │ └── logits_processing_test.exs
│ ├── generation_config_test.exs
│ ├── generation_test.exs
│ ├── gpt2_test.exs
│ ├── gpt_big_code_test.exs
│ ├── gpt_neo_x_test.exs
│ ├── llama_test.exs
│ ├── m2m100_test.exs
│ ├── mbart_test.exs
│ ├── mistral_test.exs
│ ├── nllb_test.exs
│ ├── phi3_test.exs
│ ├── phi_test.exs
│ ├── pre_trained_tokenizer_test.exs
│ ├── question_answering_test.exs
│ ├── roberta_test.exs
│ ├── t5_test.exs
│ ├── text_classification_test.exs
│ ├── text_embedding_test.exs
│ ├── text_generation_test.exs
│ ├── token_classification_test.exs
│ ├── translation_test.exs
│ ├── xlm_roberta_test.exs
│ └── zero_shot_classification_test.exs
├── utils
│ ├── image_test.exs
│ └── nx_test.exs
└── vision
│ ├── bit_featurizer_test.exs
│ ├── blip_featurizer_test.exs
│ ├── blip_vision_test.exs
│ ├── clip_featurizer_test.exs
│ ├── clip_vision_test.exs
│ ├── convnext_featurizer_test.exs
│ ├── convnext_test.exs
│ ├── deit_featurizer_test.exs
│ ├── deit_test.exs
│ ├── dino_v2_test.exs
│ ├── image_classification_test.exs
│ ├── image_embedding_test.exs
│ ├── image_to_text_test.exs
│ ├── resnet_test.exs
│ ├── swin_test.exs
│ ├── vit_featurizer_test.exs
│ └── vit_test.exs
├── bumblebee_test.exs
├── fixtures
├── audio
│ ├── common_voice
│ │ ├── a6c7706a220eeea7ee3687c1122fe7ac17962d2449d25b6db37cc41cdaace442683e11945b6f581e73941c3083cd4eecfafc938840459cd8c571dae7774ee687.wav
│ │ ├── a6c7706a220eeea7ee3687c1122fe7ac17962d2449d25b6db37cc41cdaace442683e11945b6f581e73941c3083cd4eecfafc938840459cd8c571dae7774ee687_pcm_f32le_16000.bin
│ │ └── info.md
│ ├── generate.sh
│ └── librivox
│ │ ├── 46s.mp3
│ │ ├── 46s_pcm_f32le_16000.bin
│ │ └── info.md
├── images
│ └── coco
│ │ ├── 39769.jpeg
│ │ └── info.md
└── pytorch
│ ├── generate.py
│ ├── noncontiguous_numpy_array.legacy.pt
│ ├── noncontiguous_numpy_array.zip.pt
│ ├── noncontiguous_tensor.legacy.pt
│ ├── noncontiguous_tensor.zip.pt
│ ├── numpy_arrays.legacy.pt
│ ├── numpy_arrays.zip.pt
│ ├── ordered_dict.legacy.pt
│ ├── ordered_dict.zip.pt
│ ├── state_dict_base.zip.pt
│ ├── state_dict_full.zip.pt
│ ├── storage_view.legacy.pt
│ ├── tensors.legacy.pt
│ └── tensors.zip.pt
├── support
└── test_helpers.ex
└── test_helper.exs
/.formatter.exs:
--------------------------------------------------------------------------------
1 | [
2 | import_deps: [:nx],
3 | inputs: ["{mix,.formatter}.exs", "{config,lib,test,examples}/**/*.{ex,exs}"]
4 | ]
5 |
--------------------------------------------------------------------------------
/.github/images/background.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/.github/images/background.jpg
--------------------------------------------------------------------------------
/.github/images/kino_bumblebee_token_classification.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/.github/images/kino_bumblebee_token_classification.png
--------------------------------------------------------------------------------
/.github/images/phx_image_classification.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/.github/images/phx_image_classification.png
--------------------------------------------------------------------------------
/.github/workflows/nightly.yml:
--------------------------------------------------------------------------------
1 | name: Nightly
2 | on:
3 | schedule:
4 | - cron: "0 0 * * *"
5 | workflow_dispatch:
6 | env:
7 | elixir: 1.14.0
8 | otp: 24.0
9 | jobs:
10 | test:
11 | name: Test all
12 | runs-on: ubuntu-latest
13 | env:
14 | MIX_ENV: test
15 | XLA_CACHE_DIR: ${{ github.workspace }}/cache/xla
16 | LIBTORCH_DIR: ${{ github.workspace }}/cache/libtorch
17 | steps:
18 | - uses: actions/checkout@v3
19 | - uses: erlef/setup-beam@v1
20 | with:
21 | otp-version: ${{env.otp}}
22 | elixir-version: ${{env.elixir}}
23 | - uses: actions/cache@v3
24 | with:
25 | path: |
26 | deps
27 | _build
28 | cache
29 | key: ${{ runner.os }}-mix-${{env.elixir}}-${{env.otp}}-${{ hashFiles('**/mix.lock') }}
30 | restore-keys: |
31 | ${{ runner.os }}-mix-
32 | - run: mix deps.get
33 | - run: mix test --include slow
34 |
--------------------------------------------------------------------------------
/.github/workflows/test.yaml:
--------------------------------------------------------------------------------
1 | name: Test
2 | on:
3 | pull_request:
4 | push:
5 | branches:
6 | - main
7 |
8 | jobs:
9 | main:
10 | name: "main (${{ matrix.pair.elixir }}, ${{ matrix.pair.otp }})"
11 | runs-on: ubuntu-latest
12 | strategy:
13 | fail-fast: false
14 | matrix:
15 | include:
16 | - pair:
17 | elixir: "1.15.4"
18 | otp: "26.0.2"
19 | lint: true
20 | slow: true
21 | - pair:
22 | elixir: "1.14.5"
23 | otp: "25.3.2.2"
24 | env:
25 | MIX_ENV: test
26 | XLA_CACHE_DIR: ${{ github.workspace }}/cache/xla
27 | LIBTORCH_DIR: ${{ github.workspace }}/cache/torch
28 | steps:
29 | - uses: actions/checkout@v3
30 | with:
31 | # We need the previous commit for git diff later
32 | fetch-depth: 2
33 | - uses: erlef/setup-beam@v1
34 | with:
35 | otp-version: ${{ matrix.pair.otp }}
36 | elixir-version: ${{ matrix.pair.elixir }}
37 | - uses: actions/cache@v3
38 | with:
39 | path: |
40 | deps
41 | _build
42 | cache
43 | key: ${{ runner.os }}-mix-${{ matrix.pair.elixir }}-${{ matrix.pair.otp }}-${{ hashFiles('**/mix.lock') }}
44 | - run: mix deps.get
45 | - run: mix format --check-formatted
46 | if: ${{ matrix.lint }}
47 | - run: mix deps.unlock --check-unused
48 | if: ${{ matrix.lint }}
49 | - run: mix deps.compile
50 | - run: mix compile --warnings-as-errors
51 | if: ${{ matrix.lint }}
52 | - name: Restore bumblebee cache
53 | id: cache-bumblebee-restore
54 | uses: actions/cache/restore@v3
55 | with:
56 | path: bumblebee_cache
57 | key: ${{ runner.os }}-bumblebee-cache-${{ matrix.pair.elixir }}-${{ matrix.pair.otp }}
58 | - run: mix test
59 | env:
60 | BUMBLEBEE_CACHE_DIR: ${{ github.workspace }}/bumblebee_cache
61 | - name: Save bumblebee cache
62 | id: cache-bumblebee-save
63 | uses: actions/cache/save@v3
64 | with:
65 | path: bumblebee_cache
66 | key: ${{ steps.cache-bumblebee-restore.outputs.cache-primary-key }}
67 | - name: Diff tests
68 | run: |
69 | changed_tests="$(git diff --name-only --diff-filter=AMRC HEAD^1 'test/**/*_test.exs' | tr '\n' ' ')"
70 | echo "Changed test files: $changed_tests"
71 | echo "CHANGED_TESTS=$changed_tests" >> $GITHUB_ENV
72 | - name: Changed slow tests
73 | # mix test exits with a non-zero code if there are no matching tests,
74 | # so we make sure we fail only when the test suite fails
75 | run: mix test test/bumblebee_test.exs --only slow --exit-status 100 ${{ env.CHANGED_TESTS }} || [ $? -ne 100 ]
76 | if: ${{ matrix.slow && env.CHANGED_TESTS != '' }}
77 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # The directory Mix will write compiled artifacts to.
2 | /_build/
3 |
4 | # If you run "mix test --cover", coverage assets end up here.
5 | /cover/
6 |
7 | # The directory Mix downloads your dependencies sources to.
8 | /deps/
9 |
10 | # Where third-party dependencies like ExDoc output generated docs.
11 | /doc/
12 |
13 | # Ignore .fetch files in case you like to edit your project deps locally.
14 | /.fetch
15 |
16 | # If the VM crashes, it generates a dump, let's ignore it too.
17 | erl_crash.dump
18 |
19 | # Also ignore archive artifacts (built via "mix archive.build").
20 | *.ez
21 |
22 | # Ignore package tarball (built via "mix hex.build").
23 | bumblebee-*.tar
24 |
25 | # Temporary files, for example, from tests.
26 | /tmp/
27 |
--------------------------------------------------------------------------------
/config/config.exs:
--------------------------------------------------------------------------------
1 | import Config
2 |
3 | config :exla, :add_backend_on_inspect, config_env() != :test
4 |
--------------------------------------------------------------------------------
/examples/phoenix/text_classification.exs:
--------------------------------------------------------------------------------
1 | Mix.install([
2 | {:phoenix_playground, "~> 0.1.7"},
3 | {:bumblebee, "~> 0.6.0"},
4 | {:nx, "~> 0.9.0"},
5 | {:exla, "~> 0.9.0"}
6 | ])
7 |
8 | Application.put_env(:nx, :default_backend, EXLA.Backend)
9 |
10 | defmodule DemoLive do
11 | use Phoenix.LiveView
12 |
13 | @impl true
14 | def mount(_params, _session, socket) do
15 | {:ok, assign(socket, text: "", label: nil)}
16 | end
17 |
18 | @impl true
19 | def render(assigns) do
20 | ~H"""
21 |
23 |
24 |
25 |
26 |
41 |
42 | Emotion:
43 | <.async_result :let={label} :if={@label} assign={@label}>
44 | <:loading>
45 | <.spinner />
46 |
47 | <:failed :let={_reason}>
48 | Oops, something went wrong!
49 |
50 | <%= label %>
51 |
52 |
53 |
54 |
55 | """
56 | end
57 |
58 | defp spinner(assigns) do
59 | ~H"""
60 |
75 | """
76 | end
77 |
78 | @impl true
79 | def handle_event("predict", %{"text" => text}, socket) do
80 | socket =
81 | socket
82 | |> assign(:text, text)
83 | # Discard previous label so we show the loading state once more
84 | |> assign(:label, nil)
85 | |> assign_async(:label, fn ->
86 | output = Nx.Serving.batched_run(Demo.Serving, text)
87 | %{predictions: [%{label: label}]} = output
88 | {:ok, %{label: label}}
89 | end)
90 |
91 | {:noreply, socket}
92 | end
93 | end
94 |
95 | # Application startup
96 |
97 | {:ok, model_info} = Bumblebee.load_model({:hf, "finiteautomata/bertweet-base-emotion-analysis"})
98 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "vinai/bertweet-base"})
99 |
100 | serving =
101 | Bumblebee.Text.text_classification(model_info, tokenizer,
102 | top_k: 1,
103 | compile: [batch_size: 4, sequence_length: 100],
104 | defn_options: [
105 | compiler: EXLA,
106 | cache: Path.join(System.tmp_dir!(), "bumblebee_examples/text_classification")
107 | ]
108 | )
109 |
110 | Nx.Serving.start_link(serving: serving, name: Demo.Serving, batch_timeout: 100)
111 |
112 | PhoenixPlayground.start(live: DemoLive, port: 8080)
113 |
--------------------------------------------------------------------------------
/lib/bumblebee/application.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Application do
2 | @moduledoc false
3 |
4 | use Application
5 |
6 | @impl true
7 | def start(_type, _args) do
8 | Bumblebee.Utils.HTTP.start_inets_profile()
9 |
10 | children = []
11 | opts = [strategy: :one_for_one, name: Bumblebee.Supervisor]
12 | Supervisor.start_link(children, opts)
13 | end
14 |
15 | @impl true
16 | def stop(_state) do
17 | Bumblebee.Utils.HTTP.stop_inets_profile()
18 | end
19 | end
20 |
--------------------------------------------------------------------------------
/lib/bumblebee/audio/whisper_featurizer.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Audio.WhisperFeaturizer do
2 | alias Bumblebee.Shared
3 |
4 | import Nx.Defn
5 |
6 | options = [
7 | feature_size: [
8 | default: 80,
9 | doc: "the dimension of the extracted features. This corresponds to the number of Mel bins"
10 | ],
11 | sampling_rate: [
12 | default: 16_000,
13 | doc: "the sampling rate at which the audio files should be digitally expressed in Hertz"
14 | ],
15 | num_seconds: [
16 | default: 30,
17 | doc: """
18 | the maximum duration of the audio sequence. This implies that the maximum length of the
19 | input sequence is `:num_seconds` * `:sampling_rate`
20 | """
21 | ],
22 | hop_length: [
23 | default: 160,
24 | doc:
25 | "the hop between consecutive overlapping windows for the STFT used to obtain Mel Frequency coefficients"
26 | ],
27 | fft_length: [
28 | default: 400,
29 | doc: "the size of the fourier transform"
30 | ],
31 | padding_value: [
32 | default: 0.0,
33 | doc: "the value used to pad the audio. Should correspond to silence"
34 | ]
35 | ]
36 |
37 | @moduledoc """
38 | Whisper featurizer for audio data.
39 |
40 | ## Configuration
41 |
42 | #{Shared.options_doc(options)}
43 | """
44 |
45 | defstruct Shared.option_defaults(options)
46 |
47 | @behaviour Bumblebee.Featurizer
48 | @behaviour Bumblebee.Configurable
49 |
50 | @impl true
51 | def config(featurizer, opts) do
52 | Shared.put_config_attrs(featurizer, opts)
53 | end
54 |
55 | @impl true
56 | def process_input(featurizer, raw_samples) do
57 | max_length = featurizer.num_seconds * featurizer.sampling_rate
58 |
59 | samples =
60 | for sample <- List.wrap(raw_samples) do
61 | unless Nx.rank(sample) == 1 do
62 | raise ArgumentError,
63 | "expected sample to be a 1-rank tensor, got: #{Nx.rank(sample)}-rank"
64 | end
65 |
66 | pad_size = max_length - Nx.axis_size(sample, 0)
67 | Nx.pad(sample, featurizer.padding_value, [{0, pad_size, 0}])
68 | end
69 |
70 | Nx.stack(samples)
71 | end
72 |
73 | @impl true
74 | def batch_template(featurizer, batch_size) do
75 | max_length = featurizer.num_seconds * featurizer.sampling_rate
76 | Nx.template({batch_size, max_length}, :f32)
77 | end
78 |
79 | @impl true
80 | def process_batch(featurizer, samples) do
81 | samples =
82 | samples
83 | |> Nx.vectorize(:batch)
84 | |> extract_fbank_features(
85 | fft_length: featurizer.fft_length,
86 | sampling_rate: featurizer.sampling_rate,
87 | mel_bins: featurizer.feature_size,
88 | hop_length: featurizer.hop_length
89 | )
90 | |> Nx.devectorize()
91 |
92 | %{"input_features" => samples}
93 | end
94 |
95 | defnp extract_fbank_features(waveform, opts \\ []) do
96 | opts = keyword!(opts, [:fft_length, :sampling_rate, :mel_bins, :hop_length])
97 |
98 | window = NxSignal.Windows.hann(n: opts[:fft_length], is_periodic: true)
99 |
100 | {stft, _, _} =
101 | NxSignal.stft(waveform, window,
102 | sampling_rate: opts[:sampling_rate],
103 | fft_length: opts[:fft_length],
104 | overlap_length: opts[:fft_length] - opts[:hop_length],
105 | window_padding: :reflect
106 | )
107 |
108 | stft = stft[0..-2//1]
109 |
110 | # Magic numbers taken from the reference implementation. This yields
111 | # max_mel ~ 3016
112 | frequency_spacing = 200.0 / 3
113 | max_mel = frequency_spacing * 45.245640471924965
114 |
115 | NxSignal.stft_to_mel(stft, opts[:sampling_rate],
116 | fft_length: opts[:fft_length],
117 | mel_bins: opts[:mel_bins],
118 | max_mel: max_mel,
119 | mel_frequency_spacing: frequency_spacing
120 | )
121 | end
122 |
123 | defimpl Bumblebee.HuggingFace.Transformers.Config do
124 | def load(featurizer, data) do
125 | import Shared.Converters
126 |
127 | opts =
128 | convert!(data,
129 | feature_size: {"feature_size", number()},
130 | sampling_rate: {"sampling_rate", number()},
131 | hop_length: {"hop_length", number()},
132 | num_seconds: {"chunk_length", number()},
133 | fft_length: {"n_fft", number()},
134 | padding_value: {"padding_value", number()}
135 | )
136 |
137 | @for.config(featurizer, opts)
138 | end
139 | end
140 | end
141 |
--------------------------------------------------------------------------------
/lib/bumblebee/configurable.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Configurable do
2 | @moduledoc """
3 | An interface for configurable entities.
4 |
5 | A module implementing this behaviour is expected to define a struct
6 | with configuration.
7 | """
8 |
9 | @type t :: struct()
10 |
11 | @doc """
12 | Configures the struct.
13 | """
14 | @callback config(t(), keyword()) :: t()
15 | end
16 |
--------------------------------------------------------------------------------
/lib/bumblebee/conversion/pytorch_loader/file_tensor.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Conversion.PyTorchLoader.FileTensor do
2 | @moduledoc false
3 |
4 | defstruct [:shape, :type, :offset, :strides, :storage]
5 | end
6 |
7 | defimpl Nx.LazyContainer, for: Bumblebee.Conversion.PyTorchLoader.FileTensor do
8 | alias Bumblebee.Conversion.PyTorchLoader
9 |
10 | def traverse(lazy_tensor, acc, fun) do
11 | template = Nx.template(lazy_tensor.shape, lazy_tensor.type)
12 |
13 | load = fn ->
14 | binary =
15 | case lazy_tensor.storage do
16 | {:zip, path, file_name} ->
17 | PyTorchLoader.open_zip!(path, fn unzip ->
18 | PyTorchLoader.read_zip_file(unzip, file_name)
19 | end)
20 |
21 | {:file, path, offset, size} ->
22 | File.open!(path, [:read, :raw], fn file ->
23 | {:ok, binary} = :file.pread(file, offset, size)
24 | binary
25 | end)
26 | end
27 |
28 | %{offset: offset, shape: shape, type: type, strides: strides} = lazy_tensor
29 |
30 | {_, bit_unit} = type
31 | byte_unit = div(bit_unit, 8)
32 | size = Tuple.product(shape)
33 | binary = binary_part(binary, offset * byte_unit, size * byte_unit)
34 | binary |> Nx.from_binary(type) |> to_contiguous(shape, strides)
35 | end
36 |
37 | fun.(template, load, acc)
38 | end
39 |
40 | defp to_contiguous(tensor, shape, strides) do
41 | # PyTorch tensors may not be contiguous in memory, so strides are
42 | # used to indicate jumps necessary when traversing each axis.
43 | # Since Nx doesn't have the notion of strides, we transpose the
44 | # tensor, in a way that makes it contiguous, which is equivalent
45 | # to strides being decreasing
46 |
47 | memory_axes_order =
48 | strides
49 | |> Tuple.to_list()
50 | |> Enum.with_index()
51 | |> Enum.sort_by(&elem(&1, 0), :desc)
52 | |> Enum.map(&elem(&1, 1))
53 |
54 | if memory_axes_order == Nx.axes(shape) do
55 | Nx.reshape(tensor, shape)
56 | else
57 | memory_shape =
58 | memory_axes_order
59 | |> Enum.map(fn axis -> elem(shape, axis) end)
60 | |> List.to_tuple()
61 |
62 | tensor
63 | |> Nx.reshape(memory_shape)
64 | |> Nx.transpose(axes: inverse_permutation(memory_axes_order))
65 | end
66 | end
67 |
68 | defp inverse_permutation(list) do
69 | list
70 | |> Enum.with_index()
71 | |> Enum.reduce(List.to_tuple(list), fn {src_idx, dest_idx}, inverse ->
72 | put_elem(inverse, src_idx, dest_idx)
73 | end)
74 | |> Tuple.to_list()
75 | end
76 | end
77 |
--------------------------------------------------------------------------------
/lib/bumblebee/diffusion/scheduler_utils.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Diffusion.SchedulerUtils do
2 | @moduledoc false
3 |
4 | import Nx.Defn
5 |
6 | @pi :math.pi()
7 |
8 | @doc """
9 | Returns a beta schedule of the given type.
10 |
11 | The supported types are:
12 |
13 | * `:linear` - a linear schedule from Ho et al. (https://arxiv.org/pdf/2006.11239.pdf)
14 |
15 | * `:quadratic` - a quadratic schedule specific to the latent diffusion models
16 |
17 | * `:squared_cosine` - a cosine schedule from Nichol et al. (https://arxiv.org/pdf/2102.09672.pdf),
18 | used in OpenAI GLIDE
19 |
20 | ## Options
21 |
22 | * `:start` - start for the linear and quadratic schedules. Defaults to `0.0001`
23 |
24 | * `:end` - end for the linear and quadratic schedules. Defaults to `0.02`
25 |
26 | """
27 | deftransform beta_schedule(type, num_timesteps, opts \\ []) do
28 | opts = Keyword.validate!(opts, start: 0.0001, end: 0.02)
29 | beta_start = opts[:start]
30 | beta_end = opts[:end]
31 |
32 | case type do
33 | :linear ->
34 | Nx.linspace(beta_start, beta_end, n: num_timesteps)
35 |
36 | :quadratic ->
37 | Nx.linspace(Nx.sqrt(beta_start), Nx.sqrt(beta_end), n: num_timesteps) |> Nx.pow(2)
38 |
39 | :squared_cosine ->
40 | betas_for_alpha_bar(&squared_cosine_alpha_bar/1, num_timesteps: num_timesteps)
41 | end
42 | end
43 |
44 | defnp squared_cosine_alpha_bar(t) do
45 | s = 0.008
46 | Nx.cos((t + s) / (1 + s) * @pi / 2) ** 2
47 | end
48 |
49 | # Creates a beta schedule that discretizes the given alpha_t_bar function,
50 | # which defines the cumulative product of (1 - beta) over time t in [0, 1].
51 | defnp betas_for_alpha_bar(alpha_t_bar_fun, opts \\ []) do
52 | opts = keyword!(opts, [:num_timesteps, max_beta: 0.999])
53 | num_timesteps = opts[:num_timesteps]
54 | max_beta = opts[:max_beta]
55 |
56 | i = Nx.iota({num_timesteps})
57 | t1 = i / num_timesteps
58 | t2 = (i + 1) / num_timesteps
59 | beta = 1 - alpha_t_bar_fun.(t2) / alpha_t_bar_fun.(t1)
60 | min(beta, max_beta)
61 | end
62 |
63 | @doc """
64 | Returns evenly spaced timesteps as used in the DDIM schedule.
65 | """
66 | deftransform ddim_timesteps(num_train_steps, num_steps, offset) do
67 | timestep_gap = div(num_train_steps, num_steps)
68 |
69 | Nx.iota({num_steps})
70 | |> Nx.multiply(timestep_gap)
71 | |> Nx.add(offset)
72 | |> Nx.reverse()
73 | end
74 | end
75 |
--------------------------------------------------------------------------------
/lib/bumblebee/featurizer.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Featurizer do
2 | @moduledoc """
3 | An interface for configuring and applying featurizers.
4 |
5 | A featurizer is used to convert raw data into model input.
6 |
7 | Every module implementing this behaviour is expected to also define
8 | a configuration struct.
9 | """
10 |
11 | @type t :: Bumblebee.Configurable.t()
12 |
13 | @doc """
14 | Converts the given input to a batched tensor (or a tensor container).
15 |
16 | Numerical batch processing should be moved to `c:process_batch/2`
17 | whenever possible.
18 | """
19 | @callback process_input(t(), input :: any()) :: Nx.t() | Nx.Container.t()
20 |
21 | @doc """
22 | Returns an input template for `c:process_batch/2`.
23 |
24 | The shape is effectively the same as the result of `c:process_input/2`,
25 | except for the batch size.
26 | """
27 | @callback batch_template(t(), batch_size :: pos_integer()) :: Nx.t() | Nx.Container.t()
28 |
29 | @doc """
30 | Optional batch processing stage.
31 |
32 | This is a numerical function. It receives the result of `c:process_input/2`,
33 | except the batch size may differ.
34 |
35 | When using featurizer as part of `Nx.Serving`, the batch stage can
36 | be merged with the model computation and compiled together.
37 | """
38 | @callback process_batch(t(), input :: Nx.t() | Nx.Container.t()) :: Nx.t() | Nx.Container.t()
39 |
40 | @optional_callbacks batch_template: 2, process_batch: 2
41 |
42 | @doc """
43 | Converts the given input to a batched tensor (or a tensor container).
44 | """
45 | @spec process_input(t(), any()) :: Nx.t() | Nx.Container.t()
46 | def process_input(%module{} = featurizer, input) do
47 | module.process_input(featurizer, input)
48 | end
49 |
50 | @doc """
51 | Returns an input template for `process_batch/2`.
52 |
53 | If the featurizer does not define batch processing, `nil` is returned.
54 | """
55 | @spec batch_template(t(), pos_integer()) :: Nx.t() | Nx.Container.t() | nil
56 | def batch_template(%module{} = featurizer, batch_size) do
57 | if Code.ensure_loaded?(module) and function_exported?(module, :batch_template, 2) do
58 | module.batch_template(featurizer, batch_size)
59 | end
60 | end
61 |
62 | @doc """
63 | Optional batch processing stage.
64 |
65 | This is a numerical function. It receives the result of `c:process_input/2`,
66 | except the batch size may differ.
67 |
68 | If the featurizer does not define batch processing, the input is
69 | returned as is.
70 | """
71 | @spec process_batch(t(), Nx.t() | Nx.Container.t()) :: Nx.t() | Nx.Container.t()
72 | def process_batch(%module{} = featurizer, batch) do
73 | if Code.ensure_loaded?(module) and function_exported?(module, :process_batch, 2) do
74 | module.process_batch(featurizer, batch)
75 | else
76 | batch
77 | end
78 | end
79 | end
80 |
--------------------------------------------------------------------------------
/lib/bumblebee/huggingface/transformers/config.ex:
--------------------------------------------------------------------------------
1 | defprotocol Bumblebee.HuggingFace.Transformers.Config do
2 | @moduledoc false
3 |
4 | # This protocol defines a bridge between Bumblebee and huggingface/transformers
5 | # configuration.
6 |
7 | @doc """
8 | Updates configuration based on a parsed JSON data.
9 | """
10 | @spec load(t(), map()) :: Bumblebee.ModelSpec.t()
11 | def load(config, data)
12 | end
13 |
--------------------------------------------------------------------------------
/lib/bumblebee/huggingface/transformers/model.ex:
--------------------------------------------------------------------------------
1 | defprotocol Bumblebee.HuggingFace.Transformers.Model do
2 | @moduledoc false
3 |
4 | # This protocol defines details related to loading Bumblebee model
5 | # from huggingface/transformers model.
6 |
7 | @type params_mapping :: %{layer_name() => params_source()}
8 |
9 | @type params_source :: layer_name() | list(layer_name()) | param_builders()
10 |
11 | @type param_builders :: %{param_name() => param_builder()}
12 |
13 | @type param_builder ::
14 | {list(param_source()), (list(Nx.tensor()) -> Nx.Tensor.t() | Nx.Container.t())}
15 |
16 | @type param_source :: param_ref() | list(param_ref())
17 | @type param_ref :: {layer_name(), param_name()}
18 |
19 | @type layer_name :: String.t()
20 | @type param_name :: String.t()
21 |
22 | @doc """
23 | Returns a map describing layers/parameters relationship between an
24 | Axon model and a corresponding huggingface/transformers model.
25 |
26 | ## Mapping format
27 |
28 | The basic mapping format is a map with Axon layer names (target) as
29 | keys and PyTorch layer names (source) as values. For example:
30 |
31 | %{
32 | "embedder.token_embedding" => "bert.embeddings.word_embeddings",
33 | ...
34 | }
35 |
36 | The mapping should always use the longest names, that is, depending on
37 | the architecture, the PyTorch layer name could be either
38 | `"bert.embeddings.word_embeddings"` or `"embeddings.word_embeddings"`.
39 | The longer version should generally be used. Prefixes are removed/added
40 | as necessary, so loading partial models is supported automatically.
41 |
42 | The layer names may include simple substitutions, useful for lists
43 | of layers:
44 |
45 | %{
46 | "encoder.blocks.{n}.self_attention.query" => "bert.encoder.layer.{n}.attention.self.query",
47 | ...
48 | }
49 |
50 | Both param names and values for corresponding layers may not match
51 | exactly, so they require further transformations. For example, the
52 | convolution `"kernel"` in Axon corresponds to a transposed `"weight"`
53 | from PyTorch. For most common layers such conversions are handled
54 | automatically.
55 |
56 | In some cases, particularly with model-specific layers/parameters,
57 | we may need more control over the parameter mapping. In such cases,
58 | instead of source layer name, a map with parameter-level transformations
59 | may be specified:
60 |
61 | %{
62 | "embedder.class_embedding" => %{
63 | "embedding" => {
64 | [{"vit.embeddings", "cls_token"}],
65 | fn [value] -> Nx.squeeze(value, axes: [0, 1]) end
66 | }
67 | },
68 | ...
69 | }
70 |
71 | For each parameter, we specify a list of source parameters in the
72 | form of `{source_layer_name, source_param_name}`, then a function
73 | to build our parameter value. Instead of a single tuple, we can
74 | specify a list of those to try one by one. With the explicit
75 | transformation we can handle arbitrary parameter name and value
76 | transformations.
77 | """
78 | @spec params_mapping(t()) :: params_mapping()
79 | def params_mapping(spec)
80 | end
81 |
--------------------------------------------------------------------------------
/lib/bumblebee/huggingface/transformers/utils.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.HuggingFace.Transformers.Utils do
2 | @moduledoc false
3 |
4 | import Bumblebee.Utils.Model, only: [join: 2]
5 |
6 | alias Bumblebee.HuggingFace.Transformers
7 |
8 | @doc """
9 | Prefixes target and source layer names in the given params mapping.
10 | """
11 | @spec prefix_params_mapping(
12 | Transformers.Model.params_mapping(),
13 | String.t() | nil,
14 | String.t() | nil
15 | ) :: Transformers.Model.params_mapping()
16 | def prefix_params_mapping(params_mapping, target_prefix, source_prefix) do
17 | Map.new(params_mapping, fn {target_layer_name, params_source} ->
18 | {
19 | join(target_prefix, target_layer_name),
20 | map_params_source_layer_names(params_source, &join(source_prefix, &1))
21 | }
22 | end)
23 | end
24 |
25 | @doc """
26 | Maps layer names in a params mapping value.
27 | """
28 | @spec map_params_source_layer_names(
29 | Transformers.Model.params_source(),
30 | (String.t() -> String.t())
31 | ) :: Transformers.Model.params_source()
32 | def map_params_source_layer_names(%{} = param_builders, fun) do
33 | Map.new(param_builders, fn {param_name, {sources, builder_fun}} ->
34 | sources =
35 | for ref_or_refs <- sources do
36 | case ref_or_refs do
37 | {layer_name, param_name} ->
38 | {fun.(layer_name), param_name}
39 |
40 | refs ->
41 | for {layer_name, param_name} <- refs, do: {fun.(layer_name), param_name}
42 | end
43 | end
44 |
45 | {param_name, {sources, builder_fun}}
46 | end)
47 | end
48 |
49 | def map_params_source_layer_names(layer_names, fun) when is_list(layer_names) do
50 | Enum.map(layer_names, fun)
51 | end
52 |
53 | def map_params_source_layer_names(layer_name, fun) when is_binary(layer_name) do
54 | fun.(layer_name)
55 | end
56 | end
57 |
--------------------------------------------------------------------------------
/lib/bumblebee/model_spec.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.ModelSpec do
2 | @moduledoc """
3 | An interface for configuring and building models based on the same
4 | architecture.
5 |
6 | Every module implementing this behaviour is expected to also define
7 | a configuration struct.
8 | """
9 |
10 | @type t :: Bumblebee.Configurable.t()
11 |
12 | @doc """
13 | Returns the list of supported model architectures.
14 | """
15 | @callback architectures :: list(atom())
16 |
17 | @doc """
18 | Builds a template input for the model.
19 |
20 | The template is used to compile the model when initializing parameters.
21 | """
22 | @callback input_template(t()) :: map()
23 |
24 | @doc """
25 | Builds an `Axon` model according to the given configuration.
26 | """
27 | @callback model(t()) :: Axon.t()
28 | end
29 |
--------------------------------------------------------------------------------
/lib/bumblebee/scheduler.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Scheduler do
2 | @moduledoc """
3 | An interface for configuring and using schedulers.
4 |
5 | A scheduler defines a sampling method, usually used for multi-step
6 | denoising process, as in stable diffusion.
7 |
8 | Every module implementing this behaviour is expected to also define
9 | a configuration struct.
10 |
11 | ## Context
12 |
13 | Imagine a denoising model trained in 1000 steps. During training,
14 | we take some original data and add random noise 1000 times, this
15 | way we obtain 1000 steps with increasing level of noise. Then, the
16 | model learns to predict noise at each timestep, given data at that
17 | step (sample) and the timestep.
18 |
19 | Once such model is trained, we can obtain brand new data (such as
20 | image) by generating random data and denoising it with our model in
21 | 1000 steps.
22 |
23 | Doing 1000 forward passes of the model for a single generation can
24 | be expensive, hence multiple methods have been developed to reduce
25 | the number of steps during denoising, with no changes to the model.
26 |
27 | Each method specifies a subset of the original timesteps, at each
28 | timestep we need to do a forward pass of the model (or possibly a
29 | few), then the method extrapolates the sample to the next selected
30 | timestep, possibly skipping a lot of timesteps in between.
31 |
32 | ## Note on wording
33 |
34 | Throughout the docs and APIs the word "steps" refers to diffusion
35 | steps, whereas "timesteps" is more specific and refers to the exact
36 | values $t$ (points in time).
37 | """
38 |
39 | @type t :: Bumblebee.Configurable.t()
40 |
41 | @type state :: Nx.Container.t()
42 |
43 | @doc """
44 | Initializes state for a new scheduler loop.
45 |
46 | Returns a pair of `{state, timesteps}`, where `state` is an opaque
47 | `Nx.Container` and `timesteps` is a tensor with the subsequent
48 | timesteps for model forward pass.
49 | """
50 | @callback init(
51 | t(),
52 | num_steps :: pos_integer(),
53 | sample_template :: Nx.Tensor.t(),
54 | prng_key :: Nx.Tensor.t()
55 | ) :: {state :: map(), timesteps :: Nx.Tensor.t()}
56 |
57 | @doc """
58 | Predicts sample at the previous timestep.
59 |
60 | Takes the current `sample` and `prediction` (usually noise) returned
61 | by the model at the current timestep. Returns `{state, prev_sample}`,
62 | where `state` is the updated state and `prev_sample` is the predicted
63 | sample at the previous timestep.
64 | """
65 | @callback step(
66 | t(),
67 | state(),
68 | sample :: Nx.Tensor.t(),
69 | prediction :: Nx.Tensor.t()
70 | ) :: {state :: map(), prev_sample :: Nx.Tensor.t()}
71 | end
72 |
--------------------------------------------------------------------------------
/lib/bumblebee/text/text_classification.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.TextClassification do
2 | @moduledoc false
3 |
4 | alias Bumblebee.Shared
5 |
6 | def text_classification(model_info, tokenizer, opts \\ []) do
7 | %{model: model, params: params, spec: spec} = model_info
8 | Shared.validate_architecture!(spec, :for_sequence_classification)
9 |
10 | opts =
11 | Keyword.validate!(opts, [
12 | :compile,
13 | top_k: 5,
14 | scores_function: :softmax,
15 | defn_options: [],
16 | preallocate_params: false
17 | ])
18 |
19 | top_k = opts[:top_k]
20 | scores_function = opts[:scores_function]
21 | preallocate_params = opts[:preallocate_params]
22 | defn_options = opts[:defn_options]
23 |
24 | compile =
25 | if compile = opts[:compile] do
26 | compile
27 | |> Keyword.validate!([:batch_size, :sequence_length])
28 | |> Shared.require_options!([:batch_size, :sequence_length])
29 | end
30 |
31 | batch_size = compile[:batch_size]
32 | sequence_length = compile[:sequence_length]
33 |
34 | tokenizer =
35 | Bumblebee.configure(tokenizer, length: sequence_length, return_token_type_ids: false)
36 |
37 | {_init_fun, predict_fun} = Axon.build(model)
38 |
39 | scores_fun = fn params, input ->
40 | outputs = predict_fun.(params, input)
41 | scores = Shared.logits_to_scores(outputs.logits, scores_function)
42 | k = min(top_k, Nx.axis_size(scores, 1))
43 | {top_scores, top_indices} = Nx.top_k(scores, k: k)
44 | {top_scores, top_indices}
45 | end
46 |
47 | batch_keys = Shared.sequence_batch_keys(sequence_length)
48 |
49 | Nx.Serving.new(
50 | fn batch_key, defn_options ->
51 | params = Shared.maybe_preallocate(params, preallocate_params, defn_options)
52 |
53 | scope = {:scores, batch_key}
54 |
55 | scores_fun =
56 | Shared.compile_or_jit(scores_fun, scope, defn_options, compile != nil, fn ->
57 | {:sequence_length, sequence_length} = batch_key
58 |
59 | inputs = %{
60 | "input_ids" => Nx.template({batch_size, sequence_length}, :u32),
61 | "attention_mask" => Nx.template({batch_size, sequence_length}, :u32)
62 | }
63 |
64 | [params, inputs]
65 | end)
66 |
67 | fn inputs ->
68 | inputs = Shared.maybe_pad(inputs, batch_size)
69 | scores_fun.(params, inputs) |> Shared.serving_post_computation()
70 | end
71 | end,
72 | defn_options
73 | )
74 | |> Nx.Serving.batch_size(batch_size)
75 | |> Nx.Serving.process_options(batch_keys: batch_keys)
76 | |> Nx.Serving.client_preprocessing(fn input ->
77 | {texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string_or_pairs/1)
78 |
79 | inputs =
80 | Nx.with_default_backend(Nx.BinaryBackend, fn ->
81 | Bumblebee.apply_tokenizer(tokenizer, texts)
82 | end)
83 |
84 | batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length)
85 | batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key)
86 |
87 | {batch, multi?}
88 | end)
89 | |> Nx.Serving.client_postprocessing(fn {{top_scores, top_indices}, _metadata}, multi? ->
90 | Enum.zip_with(
91 | Nx.to_list(top_scores),
92 | Nx.to_list(top_indices),
93 | fn top_scores, top_indices ->
94 | predictions =
95 | Enum.zip_with(top_scores, top_indices, fn score, idx ->
96 | label = spec.id_to_label[idx] || "LABEL_#{idx}"
97 | %{score: score, label: label}
98 | end)
99 |
100 | %{predictions: predictions}
101 | end
102 | )
103 | |> Shared.normalize_output(multi?)
104 | end)
105 | end
106 | end
107 |
--------------------------------------------------------------------------------
/lib/bumblebee/text/whisper_generation_config.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.WhisperGenerationConfig do
2 | alias Bumblebee.Shared
3 |
4 | options = [
5 | no_timestamps_token_id: [
6 | default: nil,
7 | doc: "the id of the no-timestamps token"
8 | ],
9 | language_to_token_id: [
10 | default: %{},
11 | doc: "a map from language code to token id corresponding to that language"
12 | ],
13 | task_to_token_id: [
14 | default: %{},
15 | doc: "a map from task to token id corresponding to that task"
16 | ]
17 | ]
18 |
19 | @moduledoc """
20 | A set of Whisper-specific configuration options controlling text
21 | generation.
22 |
23 | This struct is used in the `Bumblebee.Text.GenerationConfig` struct
24 | under the `:extra_config` attribute.
25 |
26 | ## Configuration
27 |
28 | #{Shared.options_doc(options)}
29 | """
30 |
31 | defstruct Shared.option_defaults(options)
32 |
33 | @behaviour Bumblebee.Configurable
34 |
35 | @type t :: %__MODULE__{}
36 |
37 | @impl true
38 | def config(config, opts \\ []) do
39 | Shared.put_config_attrs(config, opts)
40 | end
41 |
42 | defimpl Bumblebee.HuggingFace.Transformers.Config do
43 | def load(config, data) do
44 | import Shared.Converters
45 |
46 | language_converter = fn name, value ->
47 | with {:ok, value} <- string().(name, value) do
48 | {:ok,
49 | value
50 | |> String.replace_prefix("<|", "")
51 | |> String.replace_suffix("|>", "")}
52 | end
53 | end
54 |
55 | opts =
56 | convert!(data,
57 | no_timestamps_token_id: {"no_timestamps_token_id", number()},
58 | language_to_token_id: {"lang_to_id", map(language_converter, number())},
59 | task_to_token_id: {"task_to_id", map(atom(), number())}
60 | )
61 |
62 | @for.config(config, opts)
63 | end
64 | end
65 | end
66 |
--------------------------------------------------------------------------------
/lib/bumblebee/tokenizer.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Tokenizer do
2 | @moduledoc """
3 | An interface for configuring and applying tokenizers.
4 |
5 | A tokenizer is used to convert raw text data into model input.
6 |
7 | Every module implementing this behaviour is expected to also define
8 | a configuration struct.
9 | """
10 |
11 | @type t :: struct()
12 |
13 | @type input :: String.t() | {String.t(), String.t()}
14 | @type token :: String.t()
15 | @type token_id :: non_neg_integer()
16 |
17 | @typedoc """
18 | A type corresponding to a special token in the vocabulary.
19 |
20 | ## Common types
21 |
22 | * `:bos` - a token representing the beginning of a sentence
23 |
24 | * `:eos` - a token representing the end of a sentence
25 |
26 | * `:unk` - a token representing an out-of-vocabulary token
27 |
28 | * `:sep` - a token separating two different sentences in the same
29 | input
30 |
31 | * `:pad` - a token added when processing a batch of sequences with
32 | different length
33 |
34 | * `:cls` - a token representing the class of the input
35 |
36 | * `:mask` - a token representing a masked token, used for masked
37 | language modeling tasks
38 |
39 | """
40 | @type special_token_type :: atom()
41 |
42 | @doc """
43 | Performs tokenization and encoding on the given input.
44 | """
45 | @callback apply(t(), input() | list(input())) :: any()
46 |
47 | @doc """
48 | Decodes a list of token ids into a sentence.
49 | """
50 | @callback decode(t(), list(token_id()) | list(list(token_id()))) :: String.t()
51 |
52 | @doc """
53 | Converts the given token into the corresponding numeric id.
54 | """
55 | @callback token_to_id(t(), token()) :: token_id() | nil
56 |
57 | @doc """
58 | Converts the given token id the corresponding token.
59 | """
60 | @callback id_to_token(t(), token_id()) :: token() | nil
61 |
62 | @doc """
63 | Returns a map with special tokens.
64 | """
65 | @callback special_tokens(t()) :: %{special_token_type() => token()}
66 |
67 | @doc """
68 | Returns a list with extra special tokens, in addition to the named
69 | `special_tokens/1`.
70 | """
71 | @callback additional_special_tokens(t()) :: MapSet.t(token())
72 |
73 | @doc """
74 | Decodes a list of token ids into a sentence.
75 | """
76 | @spec decode(
77 | t(),
78 | token() | list(token_id()) | list(list(token_id())) | Nx.Tensor.t()
79 | ) :: String.t()
80 | def decode(%module{} = tokenizer, ids) do
81 | ids = with %Nx.Tensor{} <- ids, do: Nx.to_list(ids)
82 | ids = List.wrap(ids)
83 | module.decode(tokenizer, ids)
84 | end
85 |
86 | @doc """
87 | Converts the given token into the corresponding numeric id.
88 | """
89 | @spec token_to_id(t(), token()) :: token_id() | nil
90 | def token_to_id(%module{} = tokenizer, token) do
91 | module.token_to_id(tokenizer, token)
92 | end
93 |
94 | @doc """
95 | Converts the given token id to the corresponding token.
96 | """
97 | @spec id_to_token(t(), token_id()) :: token() | nil
98 | def id_to_token(%module{} = tokenizer, id) do
99 | module.id_to_token(tokenizer, id)
100 | end
101 |
102 | @doc """
103 | Returns a special token by name.
104 | """
105 | @spec special_token(t(), special_token_type()) :: token() | nil
106 | def special_token(%module{} = tokenizer, type) do
107 | special_tokens = module.special_tokens(tokenizer)
108 | special_tokens[type]
109 | end
110 |
111 | @doc """
112 | Returns id of a special token by name.
113 | """
114 | @spec special_token_id(t(), special_token_type()) :: token_id() | nil
115 | def special_token_id(tokenizer, type) do
116 | if token = special_token(tokenizer, type) do
117 | token_to_id(tokenizer, token)
118 | end
119 | end
120 |
121 | @doc """
122 | Returns all special tokens, including any extra tokens.
123 | """
124 | @spec all_special_tokens(t()) :: list(token_id())
125 | def all_special_tokens(%module{} = tokenizer) do
126 | special_tokens = module.special_tokens(tokenizer)
127 | additional_special_tokens = module.additional_special_tokens(tokenizer)
128 | for {_type, token} <- special_tokens, do: token, into: additional_special_tokens
129 | end
130 | end
131 |
--------------------------------------------------------------------------------
/lib/bumblebee/utils.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Utils do
2 | @moduledoc false
3 |
4 | @doc """
5 | Checks if the progress bar is enabled globally.
6 | """
7 | @spec progress_bar_enabled? :: boolean()
8 | def progress_bar_enabled?() do
9 | Application.get_env(:bumblebee, :progress_bar_enabled, true)
10 | end
11 | end
12 |
--------------------------------------------------------------------------------
/lib/bumblebee/utils/image.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Utils.Image do
2 | @moduledoc false
3 |
4 | import Nx.Defn
5 |
6 | @doc """
7 | Converts the given term to a batch of image.
8 | """
9 | defn to_batched_tensor(image) do
10 | case Nx.rank(image) do
11 | 3 ->
12 | Nx.new_axis(image, 0, :batch)
13 |
14 | 4 ->
15 | image
16 |
17 | rank ->
18 | raise ArgumentError,
19 | "expected image to be a rank-3 image or a rank-4 batch, got rank: #{rank}"
20 | end
21 | end
22 |
23 | @doc """
24 | Normalizes an image size to a `{height, width}` tuple.
25 |
26 | Accepts either an existing tuple or a single number used for both
27 | dimensions.
28 | """
29 | def normalize_size(size)
30 |
31 | def normalize_size({height, width}), do: {height, width}
32 | def normalize_size(size) when is_integer(size), do: {size, size}
33 |
34 | @doc """
35 | Matches image against the desired number of channels and applies
36 | automatic conversions if applicable.
37 | """
38 | def normalize_channels(input, channels) do
39 | channel_axis = Nx.axis_index(input, -1)
40 |
41 | case {Nx.axis_size(input, channel_axis), channels} do
42 | {channels, channels} ->
43 | input
44 |
45 | {4, 3} ->
46 | Nx.slice_along_axis(input, 0, 3, axis: channel_axis)
47 |
48 | {1, 3} ->
49 | shape = input |> Nx.shape() |> put_elem(channel_axis, 3)
50 | Nx.broadcast(input, shape)
51 |
52 | {actual, expected} ->
53 | raise ArgumentError,
54 | "expected image with #{expected} channels, but got #{actual} and no automatic conversion applies"
55 | end
56 | end
57 | end
58 |
--------------------------------------------------------------------------------
/lib/bumblebee/utils/model.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Utils.Model do
2 | @moduledoc false
3 |
4 | @doc """
5 | Adds another word to a hierarchical name.
6 | """
7 | @spec join(String.t() | nil, String.Chars.t()) :: String.t()
8 | def join(name, suffix)
9 |
10 | def join(nil, suffix), do: to_string(suffix)
11 | def join(name, suffix), do: name <> "." <> to_string(suffix)
12 |
13 | @doc """
14 | Converts a list of inputs to a map with input names as keys.
15 | """
16 | @spec inputs_to_map(list(Axon.t())) :: %{String.t() => Axon.t()}
17 | def inputs_to_map(inputs) when is_list(inputs) do
18 | for %Axon{output: id, nodes: nodes} = axon <- inputs, into: %{} do
19 | input = nodes[id]
20 | {input.name.(:input, %{}), axon}
21 | end
22 | end
23 | end
24 |
--------------------------------------------------------------------------------
/lib/bumblebee/vision/blip_featurizer.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.BlipFeaturizer do
2 | alias Bumblebee.Shared
3 |
4 | options = [
5 | resize: [
6 | default: true,
7 | doc: "whether to resize the input to the given `:size`"
8 | ],
9 | size: [
10 | default: %{height: 384, width: 384},
11 | doc: """
12 | the size to resize the input to, given as `%{height: ..., width: ...}`. Only has
13 | an effect if `:resize` is `true`
14 | """
15 | ],
16 | resize_method: [
17 | default: :bicubic,
18 | doc:
19 | "the resizing method, either of `:nearest`, `:bilinear`, `:bicubic`, `:lanczos3`, `:lanczos5`"
20 | ],
21 | normalize: [
22 | default: true,
23 | doc: "whether or not to normalize the input with mean and standard deviation"
24 | ],
25 | image_mean: [
26 | default: [0.48145466, 0.4578275, 0.40821073],
27 | doc: "the sequence of mean values for each channel, to be used when normalizing images"
28 | ],
29 | image_std: [
30 | default: [0.26862954, 0.26130258, 0.27577711],
31 | doc:
32 | "the sequence of standard deviations for each channel, to be used when normalizing images"
33 | ]
34 | ]
35 |
36 | @moduledoc """
37 | BLIP featurizer for image data.
38 |
39 | ## Configuration
40 |
41 | #{Shared.options_doc(options)}
42 | """
43 |
44 | defstruct Shared.option_defaults(options)
45 |
46 | @behaviour Bumblebee.Featurizer
47 | @behaviour Bumblebee.Configurable
48 |
49 | alias Bumblebee.Utils.Image
50 |
51 | @impl true
52 | def config(featurizer, opts) do
53 | Shared.put_config_attrs(featurizer, opts)
54 | end
55 |
56 | @impl true
57 | def process_input(featurizer, images) do
58 | images = List.wrap(images)
59 |
60 | for image <- images do
61 | image =
62 | image
63 | |> Image.to_batched_tensor()
64 | |> Nx.as_type(:f32)
65 | |> Image.normalize_channels(length(featurizer.image_mean))
66 |
67 | if featurizer.resize do
68 | %{height: height, width: width} = featurizer.size
69 | NxImage.resize(image, {height, width}, method: featurizer.resize_method)
70 | else
71 | image
72 | end
73 | end
74 | |> Nx.concatenate()
75 | end
76 |
77 | @impl true
78 | def batch_template(featurizer, batch_size) do
79 | %{height: height, width: width} = featurizer.size
80 | num_channels = length(featurizer.image_mean)
81 | Nx.template({batch_size, height, width, num_channels}, :f32)
82 | end
83 |
84 | @impl true
85 | def process_batch(featurizer, images) do
86 | images = NxImage.to_continuous(images, 0, 1)
87 |
88 | images =
89 | if featurizer.normalize do
90 | NxImage.normalize(
91 | images,
92 | Nx.tensor(featurizer.image_mean),
93 | Nx.tensor(featurizer.image_std)
94 | )
95 | else
96 | images
97 | end
98 |
99 | %{"pixel_values" => images}
100 | end
101 |
102 | defimpl Bumblebee.HuggingFace.Transformers.Config do
103 | def load(featurizer, data) do
104 | import Shared.Converters
105 |
106 | opts =
107 | convert!(data,
108 | resize: {"do_resize", boolean()},
109 | size: {"size", image_size()},
110 | resize_method: {"resample", resize_method()},
111 | normalize: {"do_normalize", boolean()},
112 | image_mean: {"image_mean", list(number())},
113 | image_std: {"image_std", list(number())}
114 | )
115 |
116 | @for.config(featurizer, opts)
117 | end
118 | end
119 | end
120 |
--------------------------------------------------------------------------------
/lib/bumblebee/vision/deit_featurizer.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.DeitFeaturizer do
2 | alias Bumblebee.Shared
3 |
4 | options = [
5 | resize: [
6 | default: true,
7 | doc: "whether to resize the input to the given `:size`"
8 | ],
9 | size: [
10 | default: %{height: 256, width: 256},
11 | doc: """
12 | the size to resize the input to, given as `%{height: ..., width: ...}`. Only has
13 | an effect if `:resize` is `true`
14 | """
15 | ],
16 | resize_method: [
17 | default: :bicubic,
18 | doc:
19 | "the resizing method, either of `:nearest`, `:bilinear`, `:bicubic`, `:lanczos3`, `:lanczos5`"
20 | ],
21 | center_crop: [
22 | default: true,
23 | doc: """
24 | whether to crop the input at the center. If the input size is smaller than `:crop_size` along
25 | any edge, the image is padded with zeros and then center cropped
26 | """
27 | ],
28 | crop_size: [
29 | default: %{height: 224, width: 224},
30 | doc: """
31 | the size to center crop the image to, given as `%{height: ..., width: ...}`. Only has an effect
32 | if `:center_crop` is `true`
33 | """
34 | ],
35 | normalize: [
36 | default: true,
37 | doc: "whether or not to normalize the input with mean and standard deviation"
38 | ],
39 | image_mean: [
40 | default: [0.485, 0.456, 0.406],
41 | doc: "the sequence of mean values for each channel, to be used when normalizing images"
42 | ],
43 | image_std: [
44 | default: [0.229, 0.224, 0.225],
45 | doc:
46 | "the sequence of standard deviations for each channel, to be used when normalizing images"
47 | ]
48 | ]
49 |
50 | @moduledoc """
51 | DeiT featurizer for image data.
52 |
53 | ## Configuration
54 |
55 | #{Shared.options_doc(options)}
56 | """
57 |
58 | defstruct Shared.option_defaults(options)
59 |
60 | @behaviour Bumblebee.Featurizer
61 | @behaviour Bumblebee.Configurable
62 |
63 | alias Bumblebee.Utils.Image
64 |
65 | @impl true
66 | def config(featurizer, opts) do
67 | Shared.put_config_attrs(featurizer, opts)
68 | end
69 |
70 | @impl true
71 | def process_input(featurizer, images) do
72 | images = List.wrap(images)
73 |
74 | for image <- images do
75 | images =
76 | image
77 | |> Image.to_batched_tensor()
78 | |> Nx.as_type(:f32)
79 | |> Image.normalize_channels(length(featurizer.image_mean))
80 |
81 | if featurizer.resize do
82 | %{height: height, width: width} = featurizer.size
83 | NxImage.resize(images, {height, width}, method: featurizer.resize_method)
84 | else
85 | images
86 | end
87 | end
88 | |> Nx.concatenate()
89 | end
90 |
91 | @impl true
92 | def batch_template(featurizer, batch_size) do
93 | %{height: height, width: width} = featurizer.size
94 | num_channels = length(featurizer.image_mean)
95 | Nx.template({batch_size, height, width, num_channels}, :f32)
96 | end
97 |
98 | @impl true
99 | def process_batch(featurizer, images) do
100 | images =
101 | if featurizer.center_crop do
102 | %{height: height, width: width} = featurizer.crop_size
103 | NxImage.center_crop(images, {height, width})
104 | else
105 | images
106 | end
107 |
108 | images = NxImage.to_continuous(images, 0, 1)
109 |
110 | images =
111 | if featurizer.normalize do
112 | NxImage.normalize(
113 | images,
114 | Nx.tensor(featurizer.image_mean),
115 | Nx.tensor(featurizer.image_std)
116 | )
117 | else
118 | images
119 | end
120 |
121 | %{"pixel_values" => images}
122 | end
123 |
124 | defimpl Bumblebee.HuggingFace.Transformers.Config do
125 | def load(featurizer, data) do
126 | import Shared.Converters
127 |
128 | opts =
129 | convert!(data,
130 | resize: {"do_resize", boolean()},
131 | size: {"size", image_size()},
132 | resize_method: {"resample", resize_method()},
133 | center_crop: {"do_center_crop", boolean()},
134 | crop_size: {"crop_size", image_size()},
135 | normalize: {"do_normalize", boolean()},
136 | image_mean: {"image_mean", list(number())},
137 | image_std: {"image_std", list(number())}
138 | )
139 |
140 | @for.config(featurizer, opts)
141 | end
142 | end
143 | end
144 |
--------------------------------------------------------------------------------
/lib/bumblebee/vision/image_classification.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.ImageClassification do
2 | @moduledoc false
3 |
4 | alias Bumblebee.Shared
5 |
6 | def image_classification(model_info, featurizer, opts \\ []) do
7 | %{model: model, params: params, spec: spec} = model_info
8 |
9 | Shared.validate_architecture!(spec, [
10 | :for_image_classification,
11 | :for_image_classification_with_teacher
12 | ])
13 |
14 | opts =
15 | Keyword.validate!(opts, [
16 | :compile,
17 | top_k: 5,
18 | scores_function: :softmax,
19 | defn_options: [],
20 | preallocate_params: false
21 | ])
22 |
23 | top_k = opts[:top_k]
24 | scores_function = opts[:scores_function]
25 | preallocate_params = opts[:preallocate_params]
26 | defn_options = opts[:defn_options]
27 |
28 | compile =
29 | if compile = opts[:compile] do
30 | compile
31 | |> Keyword.validate!([:batch_size])
32 | |> Shared.require_options!([:batch_size])
33 | end
34 |
35 | batch_size = compile[:batch_size]
36 |
37 | {_init_fun, predict_fun} = Axon.build(model)
38 |
39 | scores_fun = fn params, input ->
40 | input = Bumblebee.Featurizer.process_batch(featurizer, input)
41 | outputs = predict_fun.(params, input)
42 | scores = Shared.logits_to_scores(outputs.logits, scores_function)
43 | k = min(top_k, Nx.axis_size(scores, 1))
44 | {top_scores, top_indices} = Nx.top_k(scores, k: k)
45 | {top_scores, top_indices}
46 | end
47 |
48 | Nx.Serving.new(
49 | fn defn_options ->
50 | params = Shared.maybe_preallocate(params, preallocate_params, defn_options)
51 |
52 | scores_fun =
53 | Shared.compile_or_jit(scores_fun, :scores, defn_options, compile != nil, fn ->
54 | inputs = Bumblebee.Featurizer.batch_template(featurizer, batch_size)
55 | [params, inputs]
56 | end)
57 |
58 | fn inputs ->
59 | inputs = Shared.maybe_pad(inputs, batch_size)
60 | scores_fun.(params, inputs) |> Shared.serving_post_computation()
61 | end
62 | end,
63 | defn_options
64 | )
65 | |> Nx.Serving.batch_size(batch_size)
66 | |> Nx.Serving.client_preprocessing(fn input ->
67 | {images, multi?} = Shared.validate_serving_input!(input, &Shared.validate_image/1)
68 | inputs = Bumblebee.Featurizer.process_input(featurizer, images)
69 | {Nx.Batch.concatenate([inputs]), multi?}
70 | end)
71 | |> Nx.Serving.client_postprocessing(fn {{top_scores, top_indices}, _metadata}, multi? ->
72 | Enum.zip_with(
73 | Nx.to_list(top_scores),
74 | Nx.to_list(top_indices),
75 | fn top_scores, top_indices ->
76 | predictions =
77 | Enum.zip_with(top_scores, top_indices, fn score, idx ->
78 | label = spec.id_to_label[idx] || "LABEL_#{idx}"
79 | %{score: score, label: label}
80 | end)
81 |
82 | %{predictions: predictions}
83 | end
84 | )
85 | |> Shared.normalize_output(multi?)
86 | end)
87 | end
88 | end
89 |
--------------------------------------------------------------------------------
/lib/bumblebee/vision/image_embedding.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.ImageEmbedding do
2 | @moduledoc false
3 |
4 | alias Bumblebee.Shared
5 |
6 | def image_embedding(model_info, featurizer, opts \\ []) do
7 | %{model: model, params: params} = model_info
8 |
9 | opts =
10 | Keyword.validate!(opts, [
11 | :compile,
12 | output_attribute: :pooled_state,
13 | embedding_processor: nil,
14 | defn_options: [],
15 | preallocate_params: false
16 | ])
17 |
18 | output_attribute = opts[:output_attribute]
19 | embedding_processor = opts[:embedding_processor]
20 | preallocate_params = opts[:preallocate_params]
21 | defn_options = opts[:defn_options]
22 |
23 | compile =
24 | if compile = opts[:compile] do
25 | compile
26 | |> Keyword.validate!([:batch_size])
27 | |> Shared.require_options!([:batch_size])
28 | end
29 |
30 | batch_size = compile[:batch_size]
31 |
32 | {_init_fun, encoder} = Axon.build(model)
33 |
34 | embedding_fun = fn params, inputs ->
35 | inputs = Bumblebee.Featurizer.process_batch(featurizer, inputs)
36 |
37 | output = encoder.(params, inputs)
38 |
39 | output =
40 | case output do
41 | %{^output_attribute => output} ->
42 | output
43 |
44 | %{} ->
45 | keys = output |> Map.keys() |> Enum.sort()
46 |
47 | raise ArgumentError,
48 | "key #{inspect(output_attribute)} not found in the output map," <>
49 | " you may want to set :output_attribute to one of the map keys: #{inspect(keys)}"
50 |
51 | _ ->
52 | output
53 | end
54 |
55 | output =
56 | case embedding_processor do
57 | nil ->
58 | output
59 |
60 | :l2_norm ->
61 | Bumblebee.Utils.Nx.normalize(output)
62 |
63 | other ->
64 | raise ArgumentError,
65 | "expected :embedding_processor to be one of nil or :l2_norm, got: #{inspect(other)}"
66 | end
67 |
68 | output
69 | end
70 |
71 | Nx.Serving.new(
72 | fn defn_options ->
73 | params = Shared.maybe_preallocate(params, preallocate_params, defn_options)
74 |
75 | embedding_fun =
76 | Shared.compile_or_jit(embedding_fun, :embedding, defn_options, compile != nil, fn ->
77 | inputs = Bumblebee.Featurizer.batch_template(featurizer, batch_size)
78 | [params, inputs]
79 | end)
80 |
81 | fn inputs ->
82 | inputs = Shared.maybe_pad(inputs, batch_size)
83 | embedding_fun.(params, inputs) |> Shared.serving_post_computation()
84 | end
85 | end,
86 | defn_options
87 | )
88 | |> Nx.Serving.batch_size(batch_size)
89 | |> Nx.Serving.client_preprocessing(fn input ->
90 | {images, multi?} = Shared.validate_serving_input!(input, &Shared.validate_image/1)
91 |
92 | inputs = Bumblebee.Featurizer.process_input(featurizer, images)
93 |
94 | {Nx.Batch.concatenate([inputs]), multi?}
95 | end)
96 | |> Nx.Serving.client_postprocessing(fn {embeddings, _metadata}, multi? ->
97 | for embedding <- Bumblebee.Utils.Nx.batch_to_list(embeddings) do
98 | %{embedding: embedding}
99 | end
100 | |> Shared.normalize_output(multi?)
101 | end)
102 | end
103 | end
104 |
--------------------------------------------------------------------------------
/lib/bumblebee/vision/image_to_text.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.ImageToText do
2 | @moduledoc false
3 |
4 | alias Bumblebee.Shared
5 | alias Bumblebee.Text
6 |
7 | def image_to_text(
8 | model_info,
9 | featurizer,
10 | tokenizer,
11 | %Text.GenerationConfig{} = generation_config,
12 | opts \\ []
13 | ) do
14 | opts = Keyword.validate!(opts, [:compile, defn_options: [], preallocate_params: false])
15 |
16 | %{model: model, params: params, spec: spec} = model_info
17 |
18 | Shared.validate_architecture!(spec, [:for_conditional_generation])
19 |
20 | preallocate_params = opts[:preallocate_params]
21 | defn_options = opts[:defn_options]
22 |
23 | compile =
24 | if compile = opts[:compile] do
25 | compile
26 | |> Keyword.validate!([:batch_size])
27 | |> Shared.require_options!([:batch_size])
28 | end
29 |
30 | batch_size = compile[:batch_size]
31 |
32 | generate_fun = Text.Generation.build_generate(model, spec, generation_config)
33 |
34 | generate_fun = fn params, {inputs, seed} ->
35 | inputs = Bumblebee.Featurizer.process_batch(featurizer, inputs)
36 | inputs = Map.put(inputs, "seed", seed)
37 | %{token_ids: token_ids} = generate_fun.(params, inputs)
38 | token_ids
39 | end
40 |
41 | Nx.Serving.new(
42 | fn defn_options ->
43 | params = Shared.maybe_preallocate(params, preallocate_params, defn_options)
44 |
45 | generate_fun =
46 | Shared.compile_or_jit(generate_fun, :generate, defn_options, compile != nil, fn ->
47 | inputs = Bumblebee.Featurizer.batch_template(featurizer, batch_size)
48 | seed = Nx.template({batch_size}, :s64)
49 | [params, {inputs, seed}]
50 | end)
51 |
52 | fn inputs ->
53 | inputs = Shared.maybe_pad(inputs, batch_size)
54 | generate_fun.(params, inputs) |> Shared.serving_post_computation()
55 | end
56 | end,
57 | defn_options
58 | )
59 | |> Nx.Serving.batch_size(batch_size)
60 | |> Nx.Serving.client_preprocessing(fn input ->
61 | {inputs, multi?} = Shared.validate_serving_input!(input, &validate_input/1)
62 |
63 | images = Enum.map(inputs, & &1.image)
64 | seed = Enum.map(inputs, & &1.seed) |> Nx.tensor(type: :s64, backend: Nx.BinaryBackend)
65 |
66 | inputs = Bumblebee.Featurizer.process_input(featurizer, images)
67 | {Nx.Batch.concatenate([{inputs, seed}]), multi?}
68 | end)
69 | |> Nx.Serving.client_postprocessing(fn {token_ids, _metadata}, multi? ->
70 | decoded = Bumblebee.Tokenizer.decode(tokenizer, token_ids)
71 |
72 | decoded
73 | |> Enum.map(&%{results: [%{text: &1}]})
74 | |> Shared.normalize_output(multi?)
75 | end)
76 | end
77 |
78 | defp validate_input(%{image: image} = input) do
79 | if Shared.image?(image) do
80 | {:ok, %{image: image, seed: input[:seed] || :erlang.system_time()}}
81 | else
82 | {:error, "expected an image, got: #{inspect(image)}"}
83 | end
84 | end
85 |
86 | defp validate_input(input), do: validate_input(%{image: input})
87 | end
88 |
--------------------------------------------------------------------------------
/lib/bumblebee/vision/vit_featurizer.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.VitFeaturizer do
2 | alias Bumblebee.Shared
3 |
4 | options = [
5 | resize: [
6 | default: true,
7 | doc: "whether to resize the input to the given `:size`"
8 | ],
9 | size: [
10 | default: %{height: 224, width: 224},
11 | doc: """
12 | the size to resize the input to, given as `%{height: ..., width: ...}`. Only has
13 | an effect if `:resize` is `true`
14 | """
15 | ],
16 | resize_method: [
17 | default: :bilinear,
18 | doc:
19 | "the resizing method, either of `:nearest`, `:bilinear`, `:bicubic`, `:lanczos3`, `:lanczos5`"
20 | ],
21 | normalize: [
22 | default: true,
23 | doc: "whether or not to normalize the input with mean and standard deviation"
24 | ],
25 | image_mean: [
26 | default: [0.5, 0.5, 0.5],
27 | doc: "the sequence of mean values for each channel, to be used when normalizing images"
28 | ],
29 | image_std: [
30 | default: [0.5, 0.5, 0.5],
31 | doc:
32 | "the sequence of standard deviations for each channel, to be used when normalizing images"
33 | ]
34 | ]
35 |
36 | @moduledoc """
37 | ViT featurizer for image data.
38 |
39 | ## Configuration
40 |
41 | #{Shared.options_doc(options)}
42 | """
43 |
44 | defstruct Shared.option_defaults(options)
45 |
46 | @behaviour Bumblebee.Featurizer
47 | @behaviour Bumblebee.Configurable
48 |
49 | alias Bumblebee.Utils.Image
50 |
51 | @impl true
52 | def config(featurizer, opts) do
53 | Shared.put_config_attrs(featurizer, opts)
54 | end
55 |
56 | @impl true
57 | def process_input(featurizer, images) do
58 | images = List.wrap(images)
59 |
60 | for image <- images do
61 | image =
62 | image
63 | |> Image.to_batched_tensor()
64 | |> Nx.as_type(:f32)
65 | |> Image.normalize_channels(length(featurizer.image_mean))
66 |
67 | if featurizer.resize do
68 | %{height: height, width: width} = featurizer.size
69 | NxImage.resize(image, {height, width}, method: featurizer.resize_method)
70 | else
71 | image
72 | end
73 | end
74 | |> Nx.concatenate()
75 | end
76 |
77 | @impl true
78 | def batch_template(featurizer, batch_size) do
79 | %{height: height, width: width} = featurizer.size
80 | num_channels = length(featurizer.image_mean)
81 | Nx.template({batch_size, height, width, num_channels}, :f32)
82 | end
83 |
84 | @impl true
85 | def process_batch(featurizer, images) do
86 | images = NxImage.to_continuous(images, 0, 1)
87 |
88 | images =
89 | if featurizer.normalize do
90 | NxImage.normalize(
91 | images,
92 | Nx.tensor(featurizer.image_mean),
93 | Nx.tensor(featurizer.image_std)
94 | )
95 | else
96 | images
97 | end
98 |
99 | %{"pixel_values" => images}
100 | end
101 |
102 | defimpl Bumblebee.HuggingFace.Transformers.Config do
103 | def load(featurizer, data) do
104 | import Shared.Converters
105 |
106 | opts =
107 | convert!(data,
108 | resize: {"do_resize", boolean()},
109 | size: {"size", image_size()},
110 | resize_method: {"resample", resize_method()},
111 | normalize: {"do_normalize", boolean()},
112 | image_mean: {"image_mean", list(number())},
113 | image_std: {"image_std", list(number())}
114 | )
115 |
116 | @for.config(featurizer, opts)
117 | end
118 | end
119 | end
120 |
--------------------------------------------------------------------------------
/notebooks/stable_diffusion.livemd:
--------------------------------------------------------------------------------
1 | # Stable Diffusion
2 |
3 | ```elixir
4 | Mix.install([
5 | {:bumblebee, "~> 0.6.0"},
6 | {:nx, "~> 0.9.0"},
7 | {:exla, "~> 0.9.0"},
8 | {:kino, "~> 0.14.0"}
9 | ])
10 |
11 | Nx.global_default_backend({EXLA.Backend, client: :host})
12 | ```
13 |
14 | ## Introduction
15 |
16 | Stable Diffusion is a latent text-to-image diffusion model, primarily used to generate images based on a text prompt. Ever since it [became open-source](https://stability.ai/blog/stable-diffusion-public-release), the research, applications and tooling around it exploded. You can find a ton of resources and examples online, meanwhile let's see how to run Stable Diffusion using Bumblebee!
17 |
18 |
19 |
20 | > **Note:** Stable Diffusion is a very involved model, so the generation can take a long time if you run it on a CPU. Also, running on the GPU currently requires at least 5GiB of VRAM (or 3GiB with lower speed, see below).
21 |
22 |
23 |
24 | ## Text to image
25 |
26 | Stable Diffusion is composed of several separate models and preprocessors, so we will load all of them.
27 |
28 | ```elixir
29 | repo_id = "CompVis/stable-diffusion-v1-4"
30 | opts = [params_variant: "fp16", type: :bf16, backend: EXLA.Backend]
31 |
32 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"})
33 | {:ok, clip} = Bumblebee.load_model({:hf, repo_id, subdir: "text_encoder"}, opts)
34 | {:ok, unet} = Bumblebee.load_model({:hf, repo_id, subdir: "unet"}, opts)
35 | {:ok, vae} = Bumblebee.load_model({:hf, repo_id, subdir: "vae"}, [architecture: :decoder] ++ opts)
36 | {:ok, scheduler} = Bumblebee.load_scheduler({:hf, repo_id, subdir: "scheduler"})
37 | {:ok, featurizer} = Bumblebee.load_featurizer({:hf, repo_id, subdir: "feature_extractor"})
38 | {:ok, safety_checker} = Bumblebee.load_model({:hf, repo_id, subdir: "safety_checker"}, opts)
39 |
40 | :ok
41 | ```
42 |
43 | > **Note:** some checkpoints, such as [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5), require a license agreement. In those cases, sign up on Hugging Face, accept the license on the repository page, generate an access token in [the settings](https://huggingface.co/settings/tokens) and add it to the repository specification via `:auth_token`. You can use Livebook secrets to pass the token securely.
44 |
45 |
46 |
47 | With all the models loaded, we can now configure a serving implementation of the text-to-image task.
48 |
49 | ```elixir
50 | serving =
51 | Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler,
52 | num_steps: 20,
53 | num_images_per_prompt: 1,
54 | safety_checker: safety_checker,
55 | safety_checker_featurizer: featurizer,
56 | compile: [batch_size: 1, sequence_length: 60],
57 | # Option 1
58 | defn_options: [compiler: EXLA]
59 | # Option 2 (reduces GPU usage, but runs noticeably slower)
60 | # Also remove `backend: EXLA.Backend` from the loading options above
61 | # defn_options: [compiler: EXLA, lazy_transfers: :always]
62 | )
63 |
64 | Kino.start_child({Nx.Serving, name: StableDiffusion, serving: serving})
65 | ```
66 |
67 | ```elixir
68 | prompt_input =
69 | Kino.Input.text("Prompt", default: "numbat, forest, high quality, detailed, digital art")
70 |
71 | negative_prompt_input = Kino.Input.text("Negative Prompt", default: "darkness, rainy, foggy")
72 |
73 | Kino.Layout.grid([prompt_input, negative_prompt_input])
74 | ```
75 |
76 | We are ready to generate images!
77 |
78 | ```elixir
79 | prompt = Kino.Input.read(prompt_input)
80 | negative_prompt = Kino.Input.read(negative_prompt_input)
81 |
82 | output =
83 | Nx.Serving.batched_run(StableDiffusion, %{prompt: prompt, negative_prompt: negative_prompt})
84 |
85 | for result <- output.results do
86 | Kino.Image.new(result.image)
87 | end
88 | |> Kino.Layout.grid(columns: 2)
89 | ```
90 |
91 | To achieve a better quality you can increase the number of steps and images.
92 |
--------------------------------------------------------------------------------
/test/bumblebee/audio/whisper_featurizer_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Audio.WhisperFeaturizerTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | describe "integration" do
7 | test "encodes text" do
8 | assert {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "openai/whisper-tiny"})
9 |
10 | assert %Bumblebee.Audio.WhisperFeaturizer{} = featurizer
11 |
12 | audio = Nx.sin(Nx.iota({100}, type: :f32))
13 |
14 | inputs = Bumblebee.apply_featurizer(featurizer, audio, defn_options: [compiler: EXLA])
15 |
16 | assert_all_close(
17 | inputs["input_features"][[0, 0..3, 0..3]],
18 | Nx.tensor([
19 | [
20 | [0.7313, 0.7820, 0.7391, 0.6787],
21 | [0.4332, 0.4861, 0.4412, 0.3497],
22 | [-0.5938, -0.5938, -0.5938, -0.5938],
23 | [-0.5938, -0.5938, -0.5938, -0.5938]
24 | ]
25 | ])
26 | )
27 | end
28 | end
29 | end
30 |
--------------------------------------------------------------------------------
/test/bumblebee/audio/whisper_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Audio.WhisperTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-WhisperModel"})
11 |
12 | assert %Bumblebee.Audio.Whisper{architecture: :base} = spec
13 |
14 | inputs = %{
15 | "input_features" => Nx.sin(Nx.iota({1, 60, 80}, type: :f32)),
16 | "decoder_input_ids" => Nx.tensor([[15, 25, 35, 45, 55, 65, 0, 0]]),
17 | "decoder_attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 0, 0]])
18 | }
19 |
20 | outputs = Axon.predict(model, params, inputs)
21 |
22 | assert Nx.shape(outputs.hidden_state) == {1, 8, 16}
23 |
24 | assert_all_close(
25 | outputs.hidden_state[[.., 1..3, 1..3]],
26 | Nx.tensor([
27 | [[-0.3791, -1.6131, -0.6913], [0.1247, -1.3631, 0.0034], [-0.0097, 0.2039, 1.9897]]
28 | ])
29 | )
30 | end
31 |
32 | test ":for_conditional_generation" do
33 | assert {:ok, %{model: model, params: params, spec: spec}} =
34 | Bumblebee.load_model(
35 | {:hf, "hf-internal-testing/tiny-random-WhisperForConditionalGeneration"}
36 | )
37 |
38 | assert %Bumblebee.Audio.Whisper{architecture: :for_conditional_generation} = spec
39 |
40 | inputs = %{
41 | "input_features" => Nx.sin(Nx.iota({1, 60, 80}, type: :f32)),
42 | "decoder_input_ids" => Nx.tensor([[15, 25, 35, 45, 55, 65, 0, 0]]),
43 | "decoder_attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 0, 0]])
44 | }
45 |
46 | outputs = Axon.predict(model, params, inputs)
47 |
48 | assert Nx.shape(outputs.logits) == {1, 8, 50257}
49 |
50 | assert_all_close(
51 | outputs.logits[[.., 1..3, 1..3]],
52 | Nx.tensor([
53 | [[0.0942, 0.1288, 0.0243], [-0.1667, -0.1401, 0.1191], [0.0398, -0.0449, -0.0574]]
54 | ])
55 | )
56 | end
57 | end
58 |
--------------------------------------------------------------------------------
/test/bumblebee/conversion/pytorch_loader_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Conversion.PyTorchLoaderTest do
2 | use ExUnit.Case, async: true
3 |
4 | alias Bumblebee.Conversion.PyTorchLoader
5 |
6 | setup do
7 | Nx.default_backend(Nx.BinaryBackend)
8 | :ok
9 | end
10 |
11 | @dir Path.expand("../../fixtures/pytorch", __DIR__)
12 |
13 | for format <- ["zip", "legacy"] do
14 | @format format
15 |
16 | describe "#{format} format" do
17 | test "tensors" do
18 | path = Path.join(@dir, "tensors.#{@format}.pt")
19 |
20 | assert path |> PyTorchLoader.load!() |> Enum.map(&Nx.to_tensor/1) == [
21 | Nx.tensor([-1.0, 1.0], type: :f64),
22 | Nx.tensor([-1.0, 1.0], type: :f32),
23 | Nx.tensor([-1.0, 1.0], type: :f16),
24 | Nx.tensor([-1, 1], type: :s64),
25 | Nx.tensor([-1, 1], type: :s32),
26 | Nx.tensor([-1, 1], type: :s16),
27 | Nx.tensor([-1, 1], type: :s8),
28 | Nx.tensor([0, 1], type: :u8),
29 | Nx.tensor([0, 1, 0, 1], type: :u8),
30 | Nx.tensor([-1.0, 1.0], type: :bf16),
31 | Nx.tensor([Complex.new(1, -1), Complex.new(1, 1)], type: :c128),
32 | Nx.tensor([Complex.new(1, -1), Complex.new(1, 1)], type: :c64)
33 | ]
34 | end
35 |
36 | test "numpy arrays" do
37 | path = Path.join(@dir, "numpy_arrays.#{@format}.pt")
38 |
39 | assert PyTorchLoader.load!(path) == [
40 | Nx.tensor([-1.0, 1.0], type: :f64),
41 | Nx.tensor([-1.0, 1.0], type: :f32),
42 | Nx.tensor([-1.0, 1.0], type: :f16),
43 | Nx.tensor([-1, 1], type: :s64),
44 | Nx.tensor([-1, 1], type: :s32),
45 | Nx.tensor([-1, 1], type: :s16),
46 | Nx.tensor([-1, 1], type: :s8),
47 | Nx.tensor([0, 1], type: :u64),
48 | Nx.tensor([0, 1], type: :u32),
49 | Nx.tensor([0, 1], type: :u16),
50 | Nx.tensor([0, 1], type: :u8),
51 | Nx.tensor([0, 1], type: :u8),
52 | Nx.tensor([Complex.new(1, -1), Complex.new(1, 1)], type: :c128),
53 | Nx.tensor([Complex.new(1, -1), Complex.new(1, 1)], type: :c64)
54 | ]
55 | end
56 |
57 | test "ordered dict" do
58 | path = Path.join(@dir, "ordered_dict.#{@format}.pt")
59 |
60 | assert PyTorchLoader.load!(path) == %{"x" => 1, "y" => 2}
61 | end
62 |
63 | test "noncontiguous tensor" do
64 | path = Path.join(@dir, "noncontiguous_tensor.#{@format}.pt")
65 |
66 | assert path |> PyTorchLoader.load!() |> Nx.to_tensor() ==
67 | Nx.tensor([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], type: :s64)
68 | end
69 |
70 | test "numpy array in Fortran order" do
71 | path = Path.join(@dir, "noncontiguous_numpy_array.#{@format}.pt")
72 |
73 | assert PyTorchLoader.load!(path) ==
74 | Nx.tensor([[1, 4], [2, 5], [3, 6]], type: :s64)
75 | end
76 | end
77 | end
78 |
79 | test "legacy format storage view" do
80 | # Note that storage views have been removed in PyTorch v0.4.0,
81 | # this test is based on https://github.com/pytorch/pytorch/blob/v1.11.0/test/test_serialization.py#L554-L575
82 | path = Path.join(@dir, "storage_view.legacy.pt")
83 |
84 | assert {
85 | {:storage, %Unpickler.Global{scope: "torch", name: "FloatStorage"}, storage1},
86 | {:storage, %Unpickler.Global{scope: "torch", name: "FloatStorage"}, storage2}
87 | } = PyTorchLoader.load!(path)
88 |
89 | assert {:file, path, offset, size} = storage1
90 | assert path |> File.read!() |> binary_part(offset, size) == <<0, 0, 0, 0>>
91 |
92 | assert {:file, path, offset, size} = storage2
93 | assert path |> File.read!() |> binary_part(offset, size) == <<0, 0, 0, 0>>
94 | end
95 |
96 | test "raises if the files does not exist" do
97 | assert_raise File.Error, ~r/no such file or directory/, fn ->
98 | PyTorchLoader.load!("nonexistent")
99 | end
100 | end
101 | end
102 |
--------------------------------------------------------------------------------
/test/bumblebee/conversion/pytorch_params_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Conversion.PyTorchParamsTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | alias Bumblebee.Conversion.PyTorchParams
7 |
8 | @dir Path.expand("../../fixtures/pytorch", __DIR__)
9 |
10 | describe "load_params!/3" do
11 | defp base_model() do
12 | Axon.input("input", shape: {nil, 4, 4, 3})
13 | |> Axon.conv(2, kernel_size: 2, name: "conv")
14 | end
15 |
16 | defp full_model() do
17 | Axon.input("input", shape: {nil, 4, 4, 3})
18 | |> Axon.conv(2, kernel_size: 2, name: "base.conv")
19 | |> Axon.flatten()
20 | |> Axon.dense(2, name: "classifier.intermediate")
21 | |> Axon.dense(1, name: "classifier.output")
22 | end
23 |
24 | defp input_template() do
25 | Nx.broadcast(1, {1, 4, 4, 3})
26 | end
27 |
28 | defp params_mapping() do
29 | %{
30 | "base.conv" => "base.conv",
31 | "classifier.intermediate" => "classifier.layers.0",
32 | "classifier.output" => "classifier.layers.1"
33 | }
34 | end
35 |
36 | test "silently loads parameters if all match" do
37 | model = base_model()
38 | path = Path.join(@dir, "state_dict_base.zip.pt")
39 |
40 | log =
41 | ExUnit.CaptureLog.capture_log(fn ->
42 | %Axon.ModelState{data: params} =
43 | PyTorchParams.load_params!(model, input_template(), path,
44 | params_mapping: params_mapping()
45 | )
46 |
47 | assert_equal(params["conv"]["kernel"], Nx.broadcast(1.0, {2, 2, 3, 2}))
48 | assert_equal(params["conv"]["bias"], Nx.broadcast(0.0, {2}))
49 | end)
50 |
51 | refute log =~ "parameters"
52 | end
53 |
54 | test "logs parameters diff" do
55 | model = full_model()
56 | path = Path.join(@dir, "state_dict_full.zip.pt")
57 |
58 | log =
59 | ExUnit.CaptureLog.capture_log(fn ->
60 | PyTorchParams.load_params!(model, input_template(), path,
61 | params_mapping: params_mapping()
62 | )
63 | end)
64 |
65 | assert log =~ """
66 | the following parameters were missing:
67 |
68 | * classifier.output.kernel
69 | * classifier.output.bias
70 | """
71 |
72 | assert log =~ """
73 | the following PyTorch parameters were unused:
74 |
75 | * extra.weight
76 | """
77 |
78 | assert log =~ """
79 | the following parameters were ignored, because of non-matching shape:
80 |
81 | * classifier.intermediate.kernel (expected {18, 2}, got: {1, 1})
82 | * classifier.intermediate.bias (expected {2}, got: {1})
83 | """
84 | end
85 |
86 | test "loads parameters without prefix into a specialised model" do
87 | model = base_model()
88 | path = Path.join(@dir, "state_dict_full.zip.pt")
89 |
90 | log =
91 | ExUnit.CaptureLog.capture_log(fn ->
92 | %Axon.ModelState{data: params} =
93 | PyTorchParams.load_params!(model, input_template(), path,
94 | params_mapping: params_mapping()
95 | )
96 |
97 | assert_equal(params["conv"]["kernel"], Nx.broadcast(1.0, {2, 2, 3, 2}))
98 | assert_equal(params["conv"]["bias"], Nx.broadcast(0.0, {2}))
99 | end)
100 |
101 | refute log =~ "conv"
102 | end
103 |
104 | test "loads parameters with prefix into a base model" do
105 | model = full_model()
106 | path = Path.join(@dir, "state_dict_base.zip.pt")
107 |
108 | log =
109 | ExUnit.CaptureLog.capture_log(fn ->
110 | %Axon.ModelState{data: params} =
111 | PyTorchParams.load_params!(model, input_template(), path,
112 | params_mapping: params_mapping()
113 | )
114 |
115 | assert_equal(params["base.conv"]["kernel"], Nx.broadcast(1.0, {2, 2, 3, 2}))
116 | assert_equal(params["base.conv"]["bias"], Nx.broadcast(0.0, {2}))
117 | end)
118 |
119 | refute log =~ "conv"
120 | end
121 | end
122 | end
123 |
--------------------------------------------------------------------------------
/test/bumblebee/diffusion/controlnet_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Diffusion.ControlNetTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-controlnet"})
11 |
12 | assert %Bumblebee.Diffusion.ControlNet{architecture: :base} = spec
13 |
14 | inputs = %{
15 | "sample" => Nx.broadcast(0.5, {1, 32, 32, 4}),
16 | "timestep" => Nx.tensor(1),
17 | "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, 32}),
18 | "conditioning" => Nx.broadcast(0.5, {1, 64, 64, 3})
19 | }
20 |
21 | outputs = Axon.predict(model, params, inputs)
22 |
23 | assert Nx.shape(outputs.mid_block_state) == {1, 16, 16, 64}
24 |
25 | assert_all_close(
26 | outputs.mid_block_state[[.., 1..3, 1..3, 1..3]],
27 | Nx.tensor([
28 | [
29 | [[-0.2818, 1.6207, -0.7002], [0.2391, 1.1387, 0.9682], [-0.6386, 0.7026, -0.4218]],
30 | [[1.0681, 1.8418, -1.0586], [0.9387, 0.5971, 1.2284], [1.2914, 0.4060, -0.9559]],
31 | [[0.5841, 1.2935, 0.0081], [0.7306, 0.2915, 0.7736], [0.0875, 0.9619, 0.4108]]
32 | ]
33 | ])
34 | )
35 |
36 | assert tuple_size(outputs.down_block_states) == 6
37 |
38 | first_down_block_state = elem(outputs.down_block_states, 0)
39 | assert Nx.shape(first_down_block_state) == {1, 32, 32, 32}
40 |
41 | assert_all_close(
42 | first_down_block_state[[.., 1..3, 1..3, 1..3]],
43 | Nx.tensor([
44 | [
45 | [[-0.1423, 0.2804, -0.0497], [-0.1425, 0.2798, -0.0485], [-0.1426, 0.2794, -0.0488]],
46 | [[-0.1419, 0.2810, -0.0493], [-0.1427, 0.2803, -0.0479], [-0.1427, 0.2800, -0.0486]],
47 | [[-0.1417, 0.2812, -0.0494], [-0.1427, 0.2807, -0.0480], [-0.1426, 0.2804, -0.0486]]
48 | ]
49 | ])
50 | )
51 |
52 | last_down_block_state = elem(outputs.down_block_states, 5)
53 | assert Nx.shape(last_down_block_state) == {1, 16, 16, 64}
54 |
55 | assert_all_close(
56 | last_down_block_state[[.., 1..3, 1..3, 1..3]],
57 | Nx.tensor([
58 | [
59 | [[-1.1169, 0.8087, 0.1024], [0.4832, 0.0686, 1.0149], [-0.3314, 0.1486, 0.4445]],
60 | [[0.5770, 0.3195, -0.2008], [1.5692, -0.1771, 0.7669], [0.4908, 0.1258, 0.0694]],
61 | [[0.4694, -0.3723, 0.1505], [1.7356, -0.4214, 0.8929], [0.4702, 0.2400, 0.1213]]
62 | ]
63 | ])
64 | )
65 | end
66 | end
67 |
--------------------------------------------------------------------------------
/test/bumblebee/diffusion/ddim_scheduler_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Diffusion.DdimSchedulerTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | test "timesteps" do
7 | scheduler = Bumblebee.configure(Bumblebee.Diffusion.DdimScheduler)
8 |
9 | timesteps = scheduler_timesteps(scheduler, 10)
10 |
11 | assert_equal(timesteps, Nx.tensor([900, 800, 700, 600, 500, 400, 300, 200, 100, 0]))
12 | end
13 |
14 | test "default configuration" do
15 | scheduler = Bumblebee.configure(Bumblebee.Diffusion.DdimScheduler)
16 |
17 | sample = scheduler_loop(scheduler, 10)
18 |
19 | assert_all_close(
20 | sample[[1..3, 1..3, 1..2]],
21 | Nx.tensor([
22 | [[0.0648, 0.0666], [0.0718, 0.0736], [0.0788, 0.0806]],
23 | [[0.1209, 0.1226], [0.1279, 0.1296], [0.1349, 0.1367]],
24 | [[0.1770, 0.1787], [0.1840, 0.1857], [0.1910, 0.1927]]
25 | ])
26 | )
27 |
28 | assert_all_close(Nx.sum(sample), Nx.tensor(57.1861))
29 | end
30 |
31 | test ":quadratic beta schedule" do
32 | scheduler = Bumblebee.configure(Bumblebee.Diffusion.DdimScheduler, beta_schedule: :quadratic)
33 |
34 | sample = scheduler_loop(scheduler, 10)
35 |
36 | assert_all_close(
37 | sample[[1..3, 1..3, 1..2]],
38 | Nx.tensor([
39 | [[0.0675, 0.0693], [0.0747, 0.0766], [0.0820, 0.0839]],
40 | [[0.1258, 0.1276], [0.1331, 0.1349], [0.1404, 0.1422]],
41 | [[0.1841, 0.1860], [0.1914, 0.1932], [0.1987, 0.2005]]
42 | ])
43 | )
44 |
45 | assert_all_close(Nx.sum(sample), Nx.tensor(59.5049))
46 | end
47 |
48 | test ":squared_cosine beta schedule" do
49 | scheduler =
50 | Bumblebee.configure(Bumblebee.Diffusion.DdimScheduler, beta_schedule: :squared_cosine)
51 |
52 | sample = scheduler_loop(scheduler, 10)
53 |
54 | assert_all_close(
55 | sample[[1..3, 1..3, 1..2]],
56 | Nx.tensor([
57 | [[0.0684, 0.0702], [0.0757, 0.0776], [0.0831, 0.0850]],
58 | [[0.1275, 0.1293], [0.1349, 0.1367], [0.1422, 0.1441]],
59 | [[0.1866, 0.1884], [0.1940, 0.1958], [0.2014, 0.2032]]
60 | ])
61 | )
62 |
63 | assert_all_close(Nx.sum(sample), Nx.tensor(60.2983))
64 | end
65 |
66 | test "beta schedule range" do
67 | scheduler =
68 | Bumblebee.configure(Bumblebee.Diffusion.DdimScheduler, beta_start: 0.001, beta_end: 0.02)
69 |
70 | sample = scheduler_loop(scheduler, 10)
71 |
72 | assert_all_close(
73 | sample[[1..3, 1..3, 1..2]],
74 | Nx.tensor([
75 | [[0.0653, 0.0670], [0.0723, 0.0741], [0.0794, 0.0812]],
76 | [[0.1217, 0.1235], [0.1288, 0.1306], [0.1359, 0.1376]],
77 | [[0.1782, 0.1800], [0.1853, 0.1870], [0.1923, 0.1941]]
78 | ])
79 | )
80 |
81 | assert_all_close(Nx.sum(sample), Nx.tensor(57.5869))
82 | end
83 |
84 | test ":angular_velocity prediction type" do
85 | scheduler =
86 | Bumblebee.configure(Bumblebee.Diffusion.DdimScheduler, prediction_type: :angular_velocity)
87 |
88 | sample = scheduler_loop(scheduler, 10)
89 |
90 | assert_all_close(
91 | sample[[1..3, 1..3, 1..2]],
92 | Nx.tensor([
93 | [[0.0198, 0.0203], [0.0219, 0.0225], [0.0241, 0.0246]],
94 | [[0.0369, 0.0375], [0.0391, 0.0396], [0.0412, 0.0417]],
95 | [[0.0540, 0.0546], [0.0562, 0.0567], [0.0583, 0.0589]]
96 | ])
97 | )
98 |
99 | assert_all_close(Nx.sum(sample), Nx.tensor(17.4644))
100 | end
101 | end
102 |
--------------------------------------------------------------------------------
/test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do
2 | use ExUnit.Case, async: false
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag serving_test_tags()
7 |
8 | describe "text_to_image/6" do
9 | test "generates image for a text prompt with controlnet" do
10 | # Since we don't assert on the result in this case, we use
11 | # a tiny random checkpoint. This test is basically to verify
12 | # the whole generation computation end-to-end
13 |
14 | repository_id = "bumblebee-testing/tiny-stable-diffusion"
15 |
16 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"})
17 | {:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"})
18 |
19 | {:ok, unet} = Bumblebee.load_model({:hf, repository_id, subdir: "unet"})
20 |
21 | {:ok, controlnet} = Bumblebee.load_model({:hf, "bumblebee-testing/tiny-controlnet"})
22 |
23 | {:ok, vae} =
24 | Bumblebee.load_model({:hf, repository_id, subdir: "vae"}, architecture: :decoder)
25 |
26 | {:ok, scheduler} = Bumblebee.load_scheduler({:hf, repository_id, subdir: "scheduler"})
27 |
28 | {:ok, featurizer} =
29 | Bumblebee.load_featurizer({:hf, repository_id, subdir: "feature_extractor"})
30 |
31 | {:ok, safety_checker} = Bumblebee.load_model({:hf, repository_id, subdir: "safety_checker"})
32 |
33 | conditioning_size =
34 | controlnet.spec.sample_size *
35 | 2 ** (length(controlnet.spec.conditioning_embedding_hidden_sizes) - 1)
36 |
37 | serving =
38 | Bumblebee.Diffusion.StableDiffusionControlNet.text_to_image(
39 | clip,
40 | unet,
41 | vae,
42 | controlnet,
43 | tokenizer,
44 | scheduler,
45 | num_steps: 3,
46 | safety_checker: safety_checker,
47 | safety_checker_featurizer: featurizer
48 | )
49 |
50 | prompt = "numbat in forest, detailed, digital art"
51 |
52 | conditioning =
53 | Nx.broadcast(Nx.tensor(50, type: :u8), {conditioning_size, conditioning_size, 3})
54 |
55 | assert %{
56 | results: [%{image: %Nx.Tensor{}, is_safe: _boolean}]
57 | } =
58 | Nx.Serving.run(serving, %{
59 | prompt: prompt,
60 | conditioning: conditioning
61 | })
62 |
63 | # Without safety checker
64 |
65 | serving =
66 | Bumblebee.Diffusion.StableDiffusionControlNet.text_to_image(
67 | clip,
68 | unet,
69 | vae,
70 | controlnet,
71 | tokenizer,
72 | scheduler,
73 | num_steps: 3
74 | )
75 |
76 | prompt = "numbat in forest, detailed, digital art"
77 |
78 | assert %{results: [%{image: %Nx.Tensor{}}]} =
79 | Nx.Serving.run(serving, %{
80 | prompt: prompt,
81 | conditioning: conditioning
82 | })
83 |
84 | # With compilation
85 |
86 | serving =
87 | Bumblebee.Diffusion.StableDiffusionControlNet.text_to_image(
88 | clip,
89 | unet,
90 | vae,
91 | controlnet,
92 | tokenizer,
93 | scheduler,
94 | num_steps: 3,
95 | safety_checker: safety_checker,
96 | safety_checker_featurizer: featurizer,
97 | compile: [batch_size: 1, sequence_length: 60],
98 | defn_options: [compiler: EXLA]
99 | )
100 |
101 | prompt = "numbat in forest, detailed, digital art"
102 |
103 | assert %{
104 | results: [%{image: %Nx.Tensor{}, is_safe: _boolean}]
105 | } =
106 | Nx.Serving.run(serving, %{
107 | prompt: prompt,
108 | conditioning: conditioning
109 | })
110 | end
111 | end
112 | end
113 |
--------------------------------------------------------------------------------
/test/bumblebee/diffusion/stable_diffusion_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Diffusion.StableDiffusionTest do
2 | use ExUnit.Case, async: false
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag serving_test_tags()
7 |
8 | describe "text_to_image/6" do
9 | test "generates image for a text prompt" do
10 | # Since we don't assert on the result in this case, we use
11 | # a tiny random checkpoint. This test is basically to verify
12 | # the whole generation computation end-to-end
13 |
14 | # repository_id = "CompVis/stable-diffusion-v1-4"
15 | repository_id = "bumblebee-testing/tiny-stable-diffusion"
16 |
17 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"})
18 | {:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"})
19 | {:ok, unet} = Bumblebee.load_model({:hf, repository_id, subdir: "unet"})
20 |
21 | {:ok, vae} =
22 | Bumblebee.load_model({:hf, repository_id, subdir: "vae"}, architecture: :decoder)
23 |
24 | {:ok, scheduler} = Bumblebee.load_scheduler({:hf, repository_id, subdir: "scheduler"})
25 |
26 | {:ok, featurizer} =
27 | Bumblebee.load_featurizer({:hf, repository_id, subdir: "feature_extractor"})
28 |
29 | {:ok, safety_checker} = Bumblebee.load_model({:hf, repository_id, subdir: "safety_checker"})
30 |
31 | serving =
32 | Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler,
33 | num_steps: 3,
34 | safety_checker: safety_checker,
35 | safety_checker_featurizer: featurizer
36 | )
37 |
38 | prompt = "numbat in forest, detailed, digital art"
39 |
40 | assert %{
41 | results: [%{image: %Nx.Tensor{}, is_safe: _boolean}]
42 | } = Nx.Serving.run(serving, prompt)
43 |
44 | # Without safety checker
45 |
46 | serving =
47 | Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler,
48 | num_steps: 3
49 | )
50 |
51 | prompt = "numbat in forest, detailed, digital art"
52 |
53 | assert %{results: [%{image: %Nx.Tensor{}}]} = Nx.Serving.run(serving, prompt)
54 |
55 | # With compilation
56 |
57 | serving =
58 | Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler,
59 | num_steps: 3,
60 | safety_checker: safety_checker,
61 | safety_checker_featurizer: featurizer,
62 | compile: [batch_size: 1, sequence_length: 50],
63 | defn_options: [compiler: EXLA]
64 | )
65 |
66 | prompt = "numbat in forest, detailed, digital art"
67 |
68 | assert %{
69 | results: [%{image: %Nx.Tensor{}, is_safe: _boolean}]
70 | } = Nx.Serving.run(serving, prompt)
71 | end
72 | end
73 | end
74 |
--------------------------------------------------------------------------------
/test/bumblebee/diffusion/unet_2d_conditional_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Diffusion.UNet2DConditionalTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model(
11 | {:hf, "hf-internal-testing/tiny-stable-diffusion-torch", subdir: "unet"}
12 | )
13 |
14 | assert %Bumblebee.Diffusion.UNet2DConditional{architecture: :base} = spec
15 |
16 | inputs = %{
17 | "sample" => Nx.broadcast(0.5, {1, 32, 32, 4}),
18 | "timestep" => Nx.tensor(1),
19 | "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, 32})
20 | }
21 |
22 | outputs = Axon.predict(model, params, inputs)
23 |
24 | assert Nx.shape(outputs.sample) == {1, 32, 32, 4}
25 |
26 | assert_all_close(
27 | outputs.sample[[.., 1..3, 1..3, 1..3]],
28 | Nx.tensor([
29 | [
30 | [[-1.0813, -0.2179, 1.3359], [-0.5109, -0.2799, 0.8373], [-0.1545, -1.0922, -0.2392]],
31 | [[-0.8094, -0.9485, 0.9448], [-1.2588, -0.8376, -0.0478], [-0.8355, 0.0843, 0.6881]],
32 | [[-0.9218, -0.9650, -0.0154], [-1.2142, -0.7105, -0.5304], [-0.6982, -0.3920, 0.2081]]
33 | ]
34 | ])
35 | )
36 | end
37 |
38 | test ":base with additional states for skip connection" do
39 | tiny = "bumblebee-testing/tiny-stable-diffusion"
40 |
41 | assert {:ok, %{model: model, params: params, spec: spec}} =
42 | Bumblebee.load_model({:hf, tiny, subdir: "unet"})
43 |
44 | assert %Bumblebee.Diffusion.UNet2DConditional{architecture: :base} = spec
45 |
46 | down_block_states =
47 | [
48 | {1, 32, 32, 32},
49 | {1, 32, 32, 32},
50 | {1, 32, 32, 32},
51 | {1, 16, 16, 32},
52 | {1, 16, 16, 64},
53 | {1, 16, 16, 64}
54 | ]
55 | |> Enum.map(&Nx.broadcast(0.5, &1))
56 | |> List.to_tuple()
57 |
58 | mid_block_state = Nx.broadcast(0.5, {1, 16, 16, 64})
59 |
60 | inputs =
61 | %{
62 | "sample" => Nx.broadcast(0.5, {1, 32, 32, 4}),
63 | "timestep" => Nx.tensor(1),
64 | "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, 32}),
65 | "additional_down_block_states" => down_block_states,
66 | "additional_mid_block_state" => mid_block_state
67 | }
68 |
69 | outputs = Axon.predict(model, params, inputs)
70 |
71 | assert Nx.shape(outputs.sample) == {1, 32, 32, 4}
72 |
73 | assert_all_close(
74 | outputs.sample[[.., 1..3, 1..3, 1..3]],
75 | Nx.tensor([
76 | [
77 | [[-0.9457, -0.2378, 1.4223], [-0.5736, -0.2456, 0.7603], [-0.4346, -1.1370, -0.1988]],
78 | [[-0.5274, -1.0902, 0.5937], [-1.2290, -0.7996, 0.0264], [-0.3006, -0.1181, 0.7059]],
79 | [[-0.8336, -1.1615, -0.1906], [-1.0489, -0.3815, -0.5497], [-0.6255, 0.0863, 0.3285]]
80 | ]
81 | ])
82 | )
83 | end
84 | end
85 |
--------------------------------------------------------------------------------
/test/bumblebee/diffusion/vae_kl_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Diffusion.VaeKlTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model(
11 | {:hf, "hf-internal-testing/tiny-stable-diffusion-torch", subdir: "vae"}
12 | )
13 |
14 | assert %Bumblebee.Diffusion.VaeKl{architecture: :base} = spec
15 |
16 | inputs = %{
17 | "sample" => Nx.broadcast(0.5, {1, 32, 32, 3})
18 | }
19 |
20 | outputs = Axon.predict(model, params, inputs)
21 |
22 | assert Nx.shape(outputs.sample) == {1, 32, 32, 3}
23 |
24 | assert_all_close(
25 | outputs.sample[[.., 1..3, 1..3, ..]],
26 | Nx.tensor([
27 | [
28 | [[0.0164, 0.3587, -0.2398], [-0.1439, 0.4220, 0.2247], [0.4768, 0.1088, -0.2082]],
29 | [[0.3165, 0.4741, -0.1440], [0.0599, 0.4139, 0.0256], [0.1729, 0.6284, -0.0120]],
30 | [[0.1148, 0.4739, -0.0982], [0.5428, 0.1454, -0.3666], [0.6126, 0.3089, -0.1221]]
31 | ]
32 | ])
33 | )
34 | end
35 |
36 | test ":decoder" do
37 | assert {:ok, %{model: model, params: params, spec: spec}} =
38 | Bumblebee.load_model(
39 | {:hf, "hf-internal-testing/tiny-stable-diffusion-torch", subdir: "vae"},
40 | architecture: :decoder
41 | )
42 |
43 | assert %Bumblebee.Diffusion.VaeKl{architecture: :decoder} = spec
44 |
45 | inputs = %{
46 | "sample" => Nx.broadcast(0.5, {1, 16, 16, 4})
47 | }
48 |
49 | outputs = Axon.predict(model, params, inputs)
50 |
51 | assert Nx.shape(outputs.sample) == {1, 32, 32, 3}
52 |
53 | assert_all_close(
54 | outputs.sample[[.., 1..3, 1..3, ..]],
55 | Nx.tensor([
56 | [
57 | [[-0.1682, -0.6000, 0.2776], [-0.1015, -0.0538, 0.6985], [-0.4158, -0.6703, 0.2960]],
58 | [[-0.4621, 0.4113, 0.4759], [0.5176, 0.3203, -0.3528], [-0.0999, -0.5005, -0.7306]],
59 | [[-0.0685, 0.2073, 0.5656], [0.7141, -0.1205, -0.5857], [0.3287, 0.2487, -0.2490]]
60 | ]
61 | ])
62 | )
63 | end
64 |
65 | test ":encoder" do
66 | assert {:ok, %{model: model, params: params, spec: spec}} =
67 | Bumblebee.load_model(
68 | {:hf, "hf-internal-testing/tiny-stable-diffusion-torch", subdir: "vae"},
69 | architecture: :encoder
70 | )
71 |
72 | assert %Bumblebee.Diffusion.VaeKl{architecture: :encoder} = spec
73 |
74 | inputs = %{
75 | "sample" => Nx.broadcast(0.5, {1, 32, 32, 3})
76 | }
77 |
78 | outputs = Axon.predict(model, params, inputs)
79 |
80 | assert Nx.shape(outputs.latent_dist.mean) == {1, 16, 16, 4}
81 | assert Nx.shape(outputs.latent_dist.var) == {1, 16, 16, 4}
82 | assert Nx.shape(outputs.latent_dist.logvar) == {1, 16, 16, 4}
83 | assert Nx.shape(outputs.latent_dist.std) == {1, 16, 16, 4}
84 |
85 | assert_all_close(
86 | outputs.latent_dist.mean[[.., 1..3, 1..3, 1..3]],
87 | Nx.tensor([
88 | [
89 | [[0.1788, -0.6560, 0.2527], [0.2526, -0.1389, 0.6616], [0.3464, -0.1010, -0.1320]],
90 | [[0.1255, -0.0494, 0.2834], [0.4318, -0.5862, -0.1787], [0.0935, -0.2144, -0.1887]],
91 | [[-0.3859, 0.1139, 0.2339], [-0.1090, -0.5287, 0.6370], [-0.1257, -0.3207, -0.1075]]
92 | ]
93 | ])
94 | )
95 |
96 | assert_all_close(
97 | outputs.latent_dist.var[[.., 1..3, 1..3, 1..3]],
98 | Nx.tensor([
99 | [
100 | [[0.7926, 0.4830, 1.7315], [0.6405, 1.2762, 1.8338], [0.8108, 1.1277, 1.7099]],
101 | [[0.4721, 0.9835, 1.6843], [0.9543, 1.1715, 1.5880], [0.8660, 1.5034, 1.2972]],
102 | [[0.5069, 1.2341, 1.2979], [0.7749, 1.0105, 1.3841], [0.9574, 1.2950, 1.1591]]
103 | ]
104 | ])
105 | )
106 | end
107 | end
108 |
--------------------------------------------------------------------------------
/test/bumblebee/multimodal/blip_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Multimodal.BlipTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":for_conditional_generation" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model(
11 | {:hf, "hf-internal-testing/tiny-random-BlipForConditionalGeneration"}
12 | )
13 |
14 | assert %Bumblebee.Multimodal.Blip{architecture: :for_conditional_generation} = spec
15 |
16 | inputs = %{
17 | "decoder_input_ids" => Nx.tensor([[15, 25, 35, 45, 55, 65, 0, 0]]),
18 | "decoder_attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 0, 0]]),
19 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3})
20 | }
21 |
22 | outputs = Axon.predict(model, params, inputs)
23 |
24 | assert Nx.shape(outputs.logits) == {1, 8, 1124}
25 |
26 | assert_all_close(
27 | outputs.logits[[.., 1..3, 1..3]],
28 | Nx.tensor([[[0.1215, 0.0226, -0.1134], [0.1472, 0.1118, 0.1031], [-0.0687, 0.0104, 0.1781]]])
29 | )
30 | end
31 | end
32 |
--------------------------------------------------------------------------------
/test/bumblebee/multimodal/clip_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Multimodal.ClipTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-CLIPModel"})
11 |
12 | assert %Bumblebee.Multimodal.Clip{architecture: :base} = spec
13 |
14 | inputs = %{
15 | "input_ids" =>
16 | Nx.tensor([
17 | [10, 20, 30, 40, 50, 60, 70, 80, 0, 0],
18 | [15, 25, 35, 45, 55, 65, 75, 85, 0, 0]
19 | ]),
20 | "attention_mask" =>
21 | Nx.tensor([
22 | [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
23 | [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]
24 | ]),
25 | "pixel_values" =>
26 | Nx.concatenate([
27 | Nx.broadcast(0.25, {1, 30, 30, 3}),
28 | Nx.broadcast(0.75, {1, 30, 30, 3})
29 | ])
30 | }
31 |
32 | outputs = Axon.predict(model, params, inputs)
33 |
34 | assert Nx.shape(outputs.logits_per_text) == {2, 2}
35 | assert Nx.shape(outputs.logits_per_image) == {2, 2}
36 |
37 | assert_all_close(
38 | outputs.logits_per_text,
39 | Nx.tensor([[0.5381, 0.1981], [0.5212, 0.3291]])
40 | )
41 |
42 | assert_all_close(
43 | outputs.logits_per_image,
44 | Nx.tensor([[0.5381, 0.5212], [0.1981, 0.3291]])
45 | )
46 | end
47 | end
48 |
--------------------------------------------------------------------------------
/test/bumblebee/shared_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.SharedTest do
2 | use ExUnit.Case, async: true
3 |
4 | alias Bumblebee.Shared
5 |
6 | describe "validate_label_options/1" do
7 | test "passes when :id_to_label is empty" do
8 | spec = %{__struct__: TestConfig, num_labels: 3, id_to_label: %{}}
9 |
10 | assert Shared.validate_label_options(spec) == spec
11 | end
12 |
13 | test "passes when :id_to_label is matches :num_labels" do
14 | id_to_label = %{0 => "cat", 1 => "dog", 2 => "squirrel"}
15 | spec = %{__struct__: TestConfig, num_labels: 3, id_to_label: id_to_label}
16 |
17 | assert Shared.validate_label_options(spec) == spec
18 | end
19 |
20 | test "raises an error if mismatched :num_labels and :id_to_label are given" do
21 | id_to_label = %{0 => "cat", 1 => "dog"}
22 | spec = %{__struct__: TestConfig, num_labels: 3, id_to_label: id_to_label}
23 |
24 | assert_raise ArgumentError,
25 | ~s/size mismatch between :num_labels (3) and :id_to_label (%{0 => "cat", 1 => "dog"})/,
26 | fn ->
27 | Shared.validate_label_options(spec)
28 | end
29 | end
30 | end
31 | end
32 |
--------------------------------------------------------------------------------
/test/bumblebee/text/bart_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.BartTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-BartModel"})
11 |
12 | assert %Bumblebee.Text.Bart{architecture: :base} = spec
13 |
14 | inputs = %{
15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
17 | }
18 |
19 | outputs = Axon.predict(model, params, inputs)
20 |
21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 16}
22 |
23 | assert_all_close(
24 | outputs.hidden_state[[.., 1..3, 1..3]],
25 | Nx.tensor([
26 | [[0.9984, -0.0751, 0.4176], [0.0095, -0.3245, -0.4237], [-0.8061, -0.3498, 0.9201]]
27 | ])
28 | )
29 | end
30 |
31 | test ":for_conditional_generation" do
32 | assert {:ok, %{model: model, params: params, spec: spec}} =
33 | Bumblebee.load_model(
34 | {:hf, "hf-internal-testing/tiny-random-BartForConditionalGeneration"}
35 | )
36 |
37 | assert %Bumblebee.Text.Bart{architecture: :for_conditional_generation} = spec
38 |
39 | inputs = %{
40 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
41 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
42 | }
43 |
44 | outputs = Axon.predict(model, params, inputs)
45 |
46 | assert Nx.shape(outputs.logits) == {1, 10, 1024}
47 |
48 | assert_all_close(
49 | outputs.logits[[.., 1..3, 1..3]],
50 | Nx.tensor([
51 | [[0.0000, -0.0601, -0.0501], [0.0000, 0.0443, 0.0813], [0.0000, -0.1303, 0.0968]]
52 | ])
53 | )
54 | end
55 |
56 | test ":for_sequence_classification" do
57 | assert {:ok, %{model: model, params: params, spec: spec}} =
58 | Bumblebee.load_model(
59 | {:hf, "hf-internal-testing/tiny-random-BartForSequenceClassification"}
60 | )
61 |
62 | assert %Bumblebee.Text.Bart{architecture: :for_sequence_classification} = spec
63 |
64 | inputs = %{
65 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 2, 0]]),
66 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]])
67 | }
68 |
69 | outputs = Axon.predict(model, params, inputs)
70 |
71 | assert Nx.shape(outputs.logits) == {1, 3}
72 |
73 | assert_all_close(
74 | outputs.logits,
75 | Nx.tensor([[-0.0075, -0.0078, -0.0073]])
76 | )
77 | end
78 |
79 | test ":for_question_answering" do
80 | assert {:ok, %{model: model, params: params, spec: spec}} =
81 | Bumblebee.load_model(
82 | {:hf, "hf-internal-testing/tiny-random-BartForQuestionAnswering"}
83 | )
84 |
85 | assert %Bumblebee.Text.Bart{architecture: :for_question_answering} = spec
86 |
87 | inputs = %{
88 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
89 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
90 | }
91 |
92 | outputs = Axon.predict(model, params, inputs)
93 |
94 | assert Nx.shape(outputs.start_logits) == {1, 10}
95 | assert Nx.shape(outputs.end_logits) == {1, 10}
96 |
97 | assert_all_close(
98 | outputs.start_logits[[.., 1..3]],
99 | Nx.tensor([[0.0474, -0.0767, 0.0278]])
100 | )
101 |
102 | assert_all_close(
103 | outputs.end_logits[[.., 1..3]],
104 | Nx.tensor([[0.1557, -0.1034, -0.1271]])
105 | )
106 | end
107 |
108 | test ":for_causal_language_modeling" do
109 | assert {:ok, %{model: model, params: params, spec: spec}} =
110 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-BartForCausalLM"})
111 |
112 | assert %Bumblebee.Text.Bart{architecture: :for_causal_language_modeling} = spec
113 |
114 | inputs = %{
115 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
116 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
117 | }
118 |
119 | outputs = Axon.predict(model, params, inputs)
120 |
121 | assert Nx.shape(outputs.logits) == {1, 10, 1024}
122 |
123 | assert_all_close(
124 | outputs.logits[[.., 1..3, 1..3]],
125 | Nx.tensor([
126 | [[0.0000, -0.2084, -0.0013], [0.0000, -0.0502, 0.0656], [0.0000, -0.1301, -0.1234]]
127 | ])
128 | )
129 | end
130 | end
131 |
--------------------------------------------------------------------------------
/test/bumblebee/text/blenderbot_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.BlenderbotTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-BlenderbotModel"})
11 |
12 | assert %Bumblebee.Text.Blenderbot{architecture: :base} = spec
13 |
14 | inputs = %{
15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
17 | "decoder_input_ids" => Nx.tensor([[15, 25, 35, 45, 55, 65, 0, 0]]),
18 | "decoder_attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 0, 0]])
19 | }
20 |
21 | outputs = Axon.predict(model, params, inputs)
22 |
23 | assert Nx.shape(outputs.hidden_state) == {1, 8, 16}
24 |
25 | assert_all_close(
26 | outputs.hidden_state[[.., 1..3, 1..3]],
27 | Nx.tensor([
28 | [[0.6578, 1.9730, 0.6908], [-1.8067, 0.0553, -0.7491], [0.1820, -0.4390, -0.8273]]
29 | ])
30 | )
31 | end
32 |
33 | test ":for_conditional_generation" do
34 | assert {:ok, %{model: model, params: params, spec: spec}} =
35 | Bumblebee.load_model(
36 | {:hf, "hf-internal-testing/tiny-random-BlenderbotForConditionalGeneration"}
37 | )
38 |
39 | assert %Bumblebee.Text.Blenderbot{architecture: :for_conditional_generation} = spec
40 |
41 | inputs = %{
42 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
43 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
44 | "decoder_input_ids" => Nx.tensor([[15, 25, 35, 45, 55, 65, 0, 0]]),
45 | "decoder_attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 0, 0]])
46 | }
47 |
48 | outputs = Axon.predict(model, params, inputs)
49 |
50 | assert Nx.shape(outputs.logits) == {1, 8, 1024}
51 |
52 | assert_all_close(
53 | outputs.logits[[.., 1..3, 1..3]],
54 | Nx.tensor([
55 | [[0.0440, -0.0115, -0.0004], [0.0772, 0.0327, -0.0667], [-0.0419, 0.1483, 0.0140]]
56 | ])
57 | )
58 | end
59 | end
60 |
--------------------------------------------------------------------------------
/test/bumblebee/text/blip_text_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.BlipTextTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model(
11 | {:hf, "hf-internal-testing/tiny-random-BlipModel"},
12 | module: Bumblebee.Text.BlipText,
13 | architecture: :base
14 | )
15 |
16 | assert %Bumblebee.Text.BlipText{architecture: :base} = spec
17 |
18 | inputs = %{
19 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
20 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
21 | }
22 |
23 | outputs = Axon.predict(model, params, inputs)
24 |
25 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32}
26 |
27 | assert_all_close(
28 | outputs.hidden_state[[.., 1..3, 1..3]],
29 | Nx.tensor([
30 | [[-0.9281, 1.2373, 0.4223], [-1.1549, 2.1187, -0.9194], [0.0237, -0.7517, 0.5720]]
31 | ])
32 | )
33 | end
34 |
35 | test ":for_causal_language_modeling" do
36 | assert {:ok, %{model: model, params: params, spec: spec}} =
37 | Bumblebee.load_model(
38 | {:hf, "hf-internal-testing/tiny-random-BlipForConditionalGeneration"},
39 | module: Bumblebee.Text.BlipText,
40 | architecture: :for_causal_language_modeling
41 | )
42 |
43 | assert %Bumblebee.Text.BlipText{architecture: :for_causal_language_modeling} = spec
44 |
45 | inputs = %{
46 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
47 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
48 | }
49 |
50 | outputs = Axon.predict(model, params, inputs)
51 |
52 | assert Nx.shape(outputs.logits) == {1, 10, 1124}
53 |
54 | assert_all_close(
55 | outputs.logits[[.., 1..3, 1..3]],
56 | Nx.tensor([
57 | [[0.0736, -0.0142, 0.2178], [0.0744, 0.0990, 0.1510], [-0.1186, -0.1449, -0.0643]]
58 | ])
59 | )
60 | end
61 | end
62 |
--------------------------------------------------------------------------------
/test/bumblebee/text/camembert_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.CamembertTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-camembert"})
11 |
12 | assert %Bumblebee.Text.Roberta{architecture: :base} = spec
13 |
14 | inputs = %{
15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
17 | }
18 |
19 | outputs = Axon.predict(model, params, inputs)
20 |
21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32}
22 |
23 | assert_all_close(
24 | outputs.hidden_state[[.., 1..3, 1..3]],
25 | Nx.tensor([
26 | [[-0.1734, -0.5058, 0.6278], [-0.2506, -0.3877, -0.0394], [-0.4477, 1.9433, -0.7990]]
27 | ])
28 | )
29 | end
30 | end
31 |
--------------------------------------------------------------------------------
/test/bumblebee/text/clip_text_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.ClipTextTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-CLIPModel"},
11 | module: Bumblebee.Text.ClipText,
12 | architecture: :base
13 | )
14 |
15 | assert %Bumblebee.Text.ClipText{architecture: :base} = spec
16 |
17 | inputs = %{
18 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
19 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
20 | }
21 |
22 | outputs = Axon.predict(model, params, inputs)
23 |
24 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32}
25 | assert Nx.shape(outputs.pooled_state) == {1, 32}
26 |
27 | assert_all_close(
28 | outputs.hidden_state[[.., 1..3, 1..3]],
29 | Nx.tensor([
30 | [[0.1696, -0.2324, -0.1659], [-0.0525, -0.3103, 0.1557], [-0.2566, -0.4519, 0.6398]]
31 | ])
32 | )
33 |
34 | assert_all_close(
35 | outputs.pooled_state[[.., 1..3]],
36 | Nx.tensor([[-0.6903, -1.2524, 1.5328]])
37 | )
38 | end
39 |
40 | test ":for_embedding" do
41 | assert {:ok, %{model: model, params: params, spec: spec}} =
42 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-CLIPModel"},
43 | module: Bumblebee.Text.ClipText,
44 | architecture: :for_embedding
45 | )
46 |
47 | assert %Bumblebee.Text.ClipText{architecture: :for_embedding} = spec
48 |
49 | inputs = %{
50 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
51 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
52 | }
53 |
54 | outputs = Axon.predict(model, params, inputs)
55 |
56 | assert Nx.shape(outputs.embedding) == {1, 64}
57 |
58 | assert_all_close(
59 | outputs.embedding[[.., 1..3]],
60 | Nx.tensor([[1.1069, -0.0839, -1.6185]])
61 | )
62 | end
63 | end
64 |
--------------------------------------------------------------------------------
/test/bumblebee/text/fill_mask_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.FillMaskTest do
2 | use ExUnit.Case, async: false
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag serving_test_tags()
7 |
8 | test "returns top scored tokens" do
9 | {:ok, model_info} = Bumblebee.load_model({:hf, "google-bert/bert-base-uncased"})
10 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-uncased"})
11 |
12 | serving = Bumblebee.Text.FillMask.fill_mask(model_info, tokenizer)
13 |
14 | text = "The capital of [MASK] is Paris."
15 |
16 | assert %{
17 | predictions: [
18 | %{score: _, token: "france"},
19 | %{score: _, token: "brittany"},
20 | %{score: _, token: "algeria"},
21 | %{score: _, token: "department"},
22 | %{score: _, token: "reunion"}
23 | ]
24 | } = Nx.Serving.run(serving, text)
25 | end
26 |
27 | test "raises when there isn't exactly one mask token" do
28 | {:ok, model_info} = Bumblebee.load_model({:hf, "google-bert/bert-base-uncased"})
29 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-uncased"})
30 |
31 | serving = Bumblebee.Text.FillMask.fill_mask(model_info, tokenizer)
32 |
33 | assert_raise ArgumentError,
34 | ~s/expected exactly one occurrence of [MASK], got: 0 in "The capital of France is Paris."/,
35 | fn ->
36 | Nx.Serving.run(serving, "The capital of France is Paris.")
37 | end
38 |
39 | assert_raise ArgumentError,
40 | ~s/expected exactly one occurrence of [MASK], got: 2 in "The [MASK] of [MASK] is Paris."/,
41 | fn ->
42 | Nx.Serving.run(serving, "The [MASK] of [MASK] is Paris.")
43 | end
44 | end
45 | end
46 |
--------------------------------------------------------------------------------
/test/bumblebee/text/gemma_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.GemmaTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-GemmaModel"})
11 |
12 | assert %Bumblebee.Text.Gemma{architecture: :base} = spec
13 |
14 | inputs = %{
15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
17 | }
18 |
19 | outputs = Axon.predict(model, params, inputs)
20 |
21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32}
22 |
23 | assert_all_close(
24 | outputs.hidden_state[[.., 1..3, 1..3]],
25 | Nx.tensor([
26 | [[-2.0724, -1.1056, -0.4123], [0.7144, -1.5150, -1.8728], [0.7222, 0.7930, 0.3218]]
27 | ])
28 | )
29 | end
30 |
31 | test ":for_sequence_classification" do
32 | assert {:ok, %{model: model, params: params, spec: spec}} =
33 | Bumblebee.load_model(
34 | {:hf, "bumblebee-testing/tiny-random-GemmaForSequenceClassification"}
35 | )
36 |
37 | assert %Bumblebee.Text.Gemma{architecture: :for_sequence_classification} = spec
38 |
39 | inputs = %{
40 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
41 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
42 | }
43 |
44 | outputs = Axon.predict(model, params, inputs)
45 |
46 | assert Nx.shape(outputs.logits) == {1, 2}
47 |
48 | assert_all_close(
49 | outputs.logits,
50 | Nx.tensor([[-0.1422, 0.0613]])
51 | )
52 | end
53 |
54 | test ":for_causal_language_modeling" do
55 | assert {:ok, %{model: model, params: params, spec: spec}} =
56 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-GemmaForCausalLM"})
57 |
58 | assert %Bumblebee.Text.Gemma{architecture: :for_causal_language_modeling} = spec
59 |
60 | inputs = %{
61 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
62 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
63 | }
64 |
65 | outputs = Axon.predict(model, params, inputs)
66 |
67 | assert Nx.shape(outputs.logits) == {1, 10, 1024}
68 |
69 | assert_all_close(
70 | outputs.logits[[.., 1..3, 1..3]],
71 | Nx.tensor([
72 | [[0.0924, 0.1602, -0.0448], [-0.0934, -0.0543, -0.1045], [-0.1467, 0.0339, -0.1926]]
73 | ])
74 | )
75 | end
76 | end
77 |
--------------------------------------------------------------------------------
/test/bumblebee/text/generation_config_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.GenerationConfigTest do
2 | use ExUnit.Case, async: true
3 |
4 | alias Bumblebee.Text.GenerationConfig
5 |
6 | describe "config/2" do
7 | test "ensures either of length and new token options are set" do
8 | assert %GenerationConfig{max_length: 10, max_new_tokens: nil} =
9 | GenerationConfig.config(%GenerationConfig{max_new_tokens: 10}, max_length: 10)
10 |
11 | assert %GenerationConfig{max_length: nil, max_new_tokens: 10} =
12 | GenerationConfig.config(%GenerationConfig{max_length: 10}, max_new_tokens: 10)
13 |
14 | assert %GenerationConfig{min_length: 10, min_new_tokens: nil} =
15 | GenerationConfig.config(%GenerationConfig{min_new_tokens: 10}, min_length: 10)
16 |
17 | assert %GenerationConfig{min_length: nil, min_new_tokens: 10} =
18 | GenerationConfig.config(%GenerationConfig{min_length: 10}, min_new_tokens: 10)
19 | end
20 |
21 | test "raises if both length and new token options are set" do
22 | assert_raise ArgumentError,
23 | "only one of :max_new_tokens or :max_length options must be given, but got both",
24 | fn ->
25 | GenerationConfig.config(%GenerationConfig{},
26 | max_length: 10,
27 | max_new_tokens: 10
28 | )
29 | end
30 |
31 | assert_raise ArgumentError,
32 | "only one of :min_new_tokens or :min_length options must be given, but got both",
33 | fn ->
34 | GenerationConfig.config(%GenerationConfig{},
35 | min_length: 10,
36 | min_new_tokens: 10
37 | )
38 | end
39 | end
40 |
41 | test "raises on invalid strategy" do
42 | assert_raise ArgumentError,
43 | "expected strategy type to be either :greedy_search or :contrastive_search, got: :invalid",
44 | fn ->
45 | GenerationConfig.config(%GenerationConfig{}, strategy: %{type: :invalid})
46 | end
47 |
48 | assert_raise ArgumentError,
49 | "expected strategy to have :type, but was not present in %{}",
50 | fn ->
51 | GenerationConfig.config(%GenerationConfig{}, strategy: %{})
52 | end
53 |
54 | assert_raise ArgumentError,
55 | "expected strategy to be a map, but got: :greedy_search",
56 | fn ->
57 | GenerationConfig.config(%GenerationConfig{}, strategy: :greedy_search)
58 | end
59 |
60 | assert_raise ArgumentError,
61 | "missing keys [:alpha, :top_k] for strategy :contrastive_search",
62 | fn ->
63 | GenerationConfig.config(%GenerationConfig{},
64 | strategy: %{type: :contrastive_search}
65 | )
66 | end
67 |
68 | assert_raise ArgumentError,
69 | "unexpected keys [:unexpected] for strategy :contrastive_search",
70 | fn ->
71 | GenerationConfig.config(%GenerationConfig{},
72 | strategy: %{
73 | type: :contrastive_search,
74 | top_k: 4,
75 | alpha: 0.6,
76 | unexpected: true
77 | }
78 | )
79 | end
80 | end
81 | end
82 | end
83 |
--------------------------------------------------------------------------------
/test/bumblebee/text/generation_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.GenerationTest do
2 | use ExUnit.Case, async: false
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test "decoder model" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"})
11 |
12 | {:ok, generation_config} =
13 | Bumblebee.load_generation_config({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"})
14 |
15 | assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec
16 |
17 | inputs = %{
18 | "input_ids" => Nx.tensor([[0, 0, 10, 20, 30, 40, 50, 60, 70, 80]]),
19 | "attention_mask" => Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]),
20 | "seed" => Nx.tensor([0])
21 | }
22 |
23 | generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3)
24 |
25 | generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config)
26 | %{token_ids: token_ids} = generate.(params, inputs)
27 |
28 | assert_equal(token_ids, Nx.tensor([[80, 80, 80]]))
29 | end
30 |
31 | test "encoder-decoder model" do
32 | assert {:ok, %{model: model, params: params, spec: spec}} =
33 | Bumblebee.load_model(
34 | {:hf, "hf-internal-testing/tiny-random-BartForConditionalGeneration"}
35 | )
36 |
37 | {:ok, generation_config} =
38 | Bumblebee.load_generation_config(
39 | {:hf, "hf-internal-testing/tiny-random-BartForConditionalGeneration"}
40 | )
41 |
42 | assert %Bumblebee.Text.Bart{architecture: :for_conditional_generation} = spec
43 |
44 | inputs = %{
45 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
46 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
47 | "seed" => Nx.tensor([0])
48 | }
49 |
50 | generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3)
51 |
52 | generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config)
53 | %{token_ids: token_ids} = generate.(params, inputs)
54 |
55 | assert_equal(token_ids, Nx.tensor([[988, 988, 988]]))
56 | end
57 |
58 | test "encoder-decoder model and lower precision" do
59 | assert {:ok, %{model: model, params: params, spec: spec}} =
60 | Bumblebee.load_model(
61 | {:hf, "hf-internal-testing/tiny-random-BartForConditionalGeneration"},
62 | type: :f16
63 | )
64 |
65 | {:ok, generation_config} =
66 | Bumblebee.load_generation_config(
67 | {:hf, "hf-internal-testing/tiny-random-BartForConditionalGeneration"}
68 | )
69 |
70 | assert %Bumblebee.Text.Bart{architecture: :for_conditional_generation} = spec
71 |
72 | inputs = %{
73 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
74 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
75 | "seed" => Nx.tensor([0])
76 | }
77 |
78 | generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3)
79 |
80 | generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config)
81 | %{token_ids: token_ids} = generate.(params, inputs)
82 |
83 | assert_equal(token_ids, Nx.tensor([[988, 988, 988]]))
84 | end
85 |
86 | test "multiple end-of-sequence token ids" do
87 | assert {:ok, %{model: model, params: params, spec: spec}} =
88 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"})
89 |
90 | {:ok, generation_config} =
91 | Bumblebee.load_generation_config({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"})
92 |
93 | assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec
94 |
95 | inputs = %{
96 | "input_ids" => Nx.tensor([[0, 0, 10, 20, 30, 40, 50, 60, 70, 80]]),
97 | "attention_mask" => Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]),
98 | "seed" => Nx.tensor([0])
99 | }
100 |
101 | generation_config =
102 | Bumblebee.configure(generation_config, max_new_tokens: 3, eos_token_id: [0, 80])
103 |
104 | generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config)
105 | %{token_ids: token_ids} = generate.(params, inputs)
106 |
107 | assert_equal(token_ids, Nx.tensor([[80, 1023, 1023]]))
108 | end
109 | end
110 |
--------------------------------------------------------------------------------
/test/bumblebee/text/gpt2_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.Gpt2Test do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"})
11 |
12 | assert %Bumblebee.Text.Gpt2{architecture: :base} = spec
13 |
14 | inputs = %{
15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
17 | }
18 |
19 | outputs = Axon.predict(model, params, inputs)
20 |
21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32}
22 |
23 | assert_all_close(
24 | outputs.hidden_state[[.., 1..3, 1..3]],
25 | Nx.tensor([
26 | [[-0.8136, -0.2392, 0.2378], [0.9714, -0.4651, 0.8788], [-0.0980, 0.2294, -1.1416]]
27 | ])
28 | )
29 | end
30 |
31 | test ":for_causal_language_modeling" do
32 | assert {:ok, %{model: model, params: params, spec: spec}} =
33 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"})
34 |
35 | assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec
36 |
37 | inputs = %{
38 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
39 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
40 | }
41 |
42 | outputs = Axon.predict(model, params, inputs)
43 |
44 | assert Nx.shape(outputs.logits) == {1, 10, 1024}
45 |
46 | assert_all_close(
47 | outputs.logits[[.., 1..3, 1..3]],
48 | Nx.tensor([[[0.1184, -0.0259, 0.1688], [0.1064, 0.1412, 0.1120], [0.1421, -0.2010, 0.3757]]])
49 | )
50 | end
51 |
52 | test ":for_token_classification" do
53 | assert {:ok, %{model: model, params: params, spec: spec}} =
54 | Bumblebee.load_model(
55 | {:hf, "hf-internal-testing/tiny-random-GPT2ForTokenClassification"}
56 | )
57 |
58 | assert %Bumblebee.Text.Gpt2{architecture: :for_token_classification} = spec
59 |
60 | inputs = %{
61 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
62 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
63 | }
64 |
65 | outputs = Axon.predict(model, params, inputs)
66 |
67 | assert Nx.shape(outputs.logits) == {1, 10, 2}
68 |
69 | assert_all_close(
70 | outputs.logits[[.., 1..3//1, ..]],
71 | Nx.tensor([[[0.0207, 0.1338], [-0.1582, -0.0384], [-0.2225, -0.0400]]])
72 | )
73 | end
74 |
75 | test ":for_sequence_classification" do
76 | assert {:ok, %{model: model, params: params, spec: spec}} =
77 | Bumblebee.load_model(
78 | {:hf, "hf-internal-testing/tiny-random-GPT2ForSequenceClassification"}
79 | )
80 |
81 | assert %Bumblebee.Text.Gpt2{architecture: :for_sequence_classification} = spec
82 |
83 | inputs = %{
84 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 1023, 1023]]),
85 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
86 | }
87 |
88 | outputs = Axon.predict(model, params, inputs)
89 |
90 | assert Nx.shape(outputs.logits) == {1, 2}
91 |
92 | assert_all_close(
93 | outputs.logits,
94 | Nx.tensor([[-0.0098, -0.0456]])
95 | )
96 | end
97 | end
98 |
--------------------------------------------------------------------------------
/test/bumblebee/text/gpt_big_code_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.GptBigCodeTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPTBigCodeModel"})
11 |
12 | assert %Bumblebee.Text.GptBigCode{architecture: :base} = spec
13 |
14 | inputs = %{
15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
17 | }
18 |
19 | outputs = Axon.predict(model, params, inputs)
20 |
21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32}
22 |
23 | assert_all_close(
24 | outputs.hidden_state[[.., 1..3, 1..3]],
25 | Nx.tensor([
26 | [[-0.8193, 0.5945, -0.2915], [0.0150, 0.4736, 0.5148], [-0.4247, -1.8000, -1.6479]]
27 | ])
28 | )
29 | end
30 |
31 | test ":base without multi-query attention" do
32 | # We have a separate test to test parameter loading without
33 | # multi-query attention, because the parameters layout differs
34 |
35 | assert {:ok, %{model: model, params: params, spec: spec}} =
36 | Bumblebee.load_model(
37 | {:hf, "bumblebee-testing/tiny-random-GPTBigCodeModel-multi_query-False"}
38 | )
39 |
40 | assert %Bumblebee.Text.GptBigCode{architecture: :base} = spec
41 |
42 | inputs = %{
43 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
44 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
45 | }
46 |
47 | outputs = Axon.predict(model, params, inputs)
48 |
49 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32}
50 |
51 | assert_all_close(
52 | outputs.hidden_state[[.., 1..3, 1..3]],
53 | Nx.tensor([
54 | [[-1.3966, 0.6641, -1.3937], [-0.5489, 0.3397, 0.4567], [-0.6488, -1.6745, -1.1570]]
55 | ])
56 | )
57 | end
58 |
59 | test ":for_causal_language_modeling" do
60 | assert {:ok, %{model: model, params: params, spec: spec}} =
61 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPTBigCodeForCausalLM"})
62 |
63 | assert %Bumblebee.Text.GptBigCode{architecture: :for_causal_language_modeling} = spec
64 |
65 | inputs = %{
66 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
67 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
68 | }
69 |
70 | outputs = Axon.predict(model, params, inputs)
71 |
72 | assert Nx.shape(outputs.logits) == {1, 10, 1024}
73 |
74 | assert_all_close(
75 | outputs.logits[[.., 1..3, 1..3]],
76 | Nx.tensor([
77 | [[-0.1509, -0.1751, 0.1848], [-0.0860, -0.2476, 0.3373], [-0.2671, -0.2028, -0.0896]]
78 | ])
79 | )
80 | end
81 |
82 | test ":for_token_classification" do
83 | assert {:ok, %{model: model, params: params, spec: spec}} =
84 | Bumblebee.load_model(
85 | {:hf, "hf-internal-testing/tiny-random-GPTBigCodeForTokenClassification"}
86 | )
87 |
88 | assert %Bumblebee.Text.GptBigCode{architecture: :for_token_classification} = spec
89 |
90 | inputs = %{
91 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
92 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
93 | }
94 |
95 | outputs = Axon.predict(model, params, inputs)
96 |
97 | assert Nx.shape(outputs.logits) == {1, 10, 2}
98 |
99 | assert_all_close(
100 | outputs.logits[[.., 1..3//1, ..]],
101 | Nx.tensor([[[-0.0775, -0.0276], [0.0634, 0.0396], [-0.0695, 0.1575]]])
102 | )
103 | end
104 |
105 | test ":for_sequence_classification" do
106 | assert {:ok, %{model: model, params: params, spec: spec}} =
107 | Bumblebee.load_model(
108 | {:hf, "hf-internal-testing/tiny-random-GPTBigCodeForSequenceClassification"}
109 | )
110 |
111 | assert %Bumblebee.Text.GptBigCode{architecture: :for_sequence_classification} = spec
112 |
113 | inputs = %{
114 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 1021, 1021]]),
115 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
116 | }
117 |
118 | outputs = Axon.predict(model, params, inputs)
119 |
120 | assert Nx.shape(outputs.logits) == {1, 2}
121 |
122 | assert_all_close(
123 | outputs.logits,
124 | Nx.tensor([[0.1722, 0.1999]])
125 | )
126 | end
127 | end
128 |
--------------------------------------------------------------------------------
/test/bumblebee/text/gpt_neo_x_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.GptNeoXTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPTNeoXModel"})
11 |
12 | assert %Bumblebee.Text.GptNeoX{architecture: :base} = spec
13 |
14 | inputs = %{
15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
17 | }
18 |
19 | outputs = Axon.predict(model, params, inputs)
20 |
21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32}
22 |
23 | assert_all_close(
24 | outputs.hidden_state[[.., 1..3, 1..3]],
25 | Nx.tensor([
26 | [[0.4428, 0.3349, -1.1917], [-0.1550, -0.4439, -0.5855], [0.3737, 3.4893, -0.6499]]
27 | ])
28 | )
29 | end
30 |
31 | test ":for_sequence_classification" do
32 | assert {:ok, %{model: model, params: params, spec: spec}} =
33 | Bumblebee.load_model(
34 | {:hf, "hf-internal-testing/tiny-random-GPTNeoXForSequenceClassification"}
35 | )
36 |
37 | assert %Bumblebee.Text.GptNeoX{architecture: :for_sequence_classification} = spec
38 |
39 | inputs = %{
40 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
41 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
42 | }
43 |
44 | outputs = Axon.predict(model, params, inputs)
45 |
46 | assert Nx.shape(outputs.logits) == {1, 2}
47 |
48 | assert_all_close(
49 | outputs.logits,
50 | Nx.tensor([[0.1089, -0.3733]])
51 | )
52 | end
53 |
54 | test ":for_token_classification" do
55 | assert {:ok, %{model: model, params: params, spec: spec}} =
56 | Bumblebee.load_model(
57 | {:hf, "hf-internal-testing/tiny-random-GPTNeoXForTokenClassification"}
58 | )
59 |
60 | assert %Bumblebee.Text.GptNeoX{architecture: :for_token_classification} = spec
61 |
62 | inputs = %{
63 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
64 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
65 | }
66 |
67 | outputs = Axon.predict(model, params, inputs)
68 |
69 | assert Nx.shape(outputs.logits) == {1, 10, 2}
70 |
71 | assert_all_close(
72 | outputs.logits[[.., 1..3//1, ..]],
73 | Nx.tensor([[[-0.0900, -0.1853], [0.0567, -0.0443], [-0.0104, -0.1112]]])
74 | )
75 | end
76 |
77 | test ":for_causal_language_modeling" do
78 | assert {:ok, %{model: model, params: params, spec: spec}} =
79 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPTNeoXForCausalLM"})
80 |
81 | assert %Bumblebee.Text.GptNeoX{architecture: :for_causal_language_modeling} = spec
82 |
83 | inputs = %{
84 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
85 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
86 | }
87 |
88 | outputs = Axon.predict(model, params, inputs)
89 |
90 | assert Nx.shape(outputs.logits) == {1, 10, 1024}
91 |
92 | assert_all_close(
93 | outputs.logits[[.., 1..3, 1..3]],
94 | Nx.tensor([
95 | [[0.1134, 0.0507, -0.0534], [-0.1113, 0.0035, -0.0319], [0.0019, -0.0273, -0.0151]]
96 | ])
97 | )
98 | end
99 | end
100 |
--------------------------------------------------------------------------------
/test/bumblebee/text/llama_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.LlamaTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-LlamaModel"})
11 |
12 | assert %Bumblebee.Text.Llama{architecture: :base} = spec
13 |
14 | inputs = %{
15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
17 | }
18 |
19 | outputs = Axon.predict(model, params, inputs)
20 |
21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32}
22 |
23 | assert_all_close(
24 | outputs.hidden_state[[.., 1..3, 1..3]],
25 | Nx.tensor([
26 | [[1.4799, -2.0333, 0.4759], [2.3749, -0.8369, -0.0206], [0.5767, -0.0515, -1.1795]]
27 | ])
28 | )
29 | end
30 |
31 | test ":base rotary embedding scaling strategy :llama3" do
32 | assert {:ok, %{model: model, params: params, spec: spec}} =
33 | Bumblebee.load_model(
34 | {:hf,
35 | "bumblebee-testing/tiny-random-LlamaModel-rope_scaling-llama3-original_max_position_embeddings-64"}
36 | )
37 |
38 | assert %Bumblebee.Text.Llama{architecture: :base} = spec
39 |
40 | inputs = %{
41 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
42 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
43 | }
44 |
45 | outputs = Axon.predict(model, params, inputs)
46 |
47 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32}
48 |
49 | assert_all_close(
50 | outputs.hidden_state[[.., 1..3, 1..3]],
51 | Nx.tensor([
52 | [[1.4802, -2.0331, 0.4759], [2.3749, -0.8367, -0.0205], [0.5762, -0.0517, -1.1795]]
53 | ])
54 | )
55 | end
56 |
57 | test ":for_sequence_classification" do
58 | assert {:ok, %{model: model, params: params, spec: spec}} =
59 | Bumblebee.load_model(
60 | {:hf, "bumblebee-testing/tiny-random-LlamaForSequenceClassification"}
61 | )
62 |
63 | assert %Bumblebee.Text.Llama{architecture: :for_sequence_classification} = spec
64 |
65 | inputs = %{
66 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
67 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
68 | }
69 |
70 | outputs = Axon.predict(model, params, inputs)
71 |
72 | assert Nx.shape(outputs.logits) == {1, 2}
73 |
74 | assert_all_close(
75 | outputs.logits,
76 | Nx.tensor([[-0.1964, -0.1069]])
77 | )
78 | end
79 |
80 | test ":for_causal_language_modeling" do
81 | assert {:ok, %{model: model, params: params, spec: spec}} =
82 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-LlamaForCausalLM"})
83 |
84 | assert %Bumblebee.Text.Llama{architecture: :for_causal_language_modeling} = spec
85 |
86 | inputs = %{
87 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
88 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
89 | }
90 |
91 | outputs = Axon.predict(model, params, inputs)
92 |
93 | assert Nx.shape(outputs.logits) == {1, 10, 1024}
94 |
95 | assert_all_close(
96 | outputs.logits[[.., 1..3, 1..3]],
97 | Nx.tensor([
98 | [[0.0469, -0.0751, 0.0349], [0.0617, -0.1357, -0.0204], [-0.1495, 0.0557, -0.0737]]
99 | ])
100 | )
101 | end
102 | end
103 |
--------------------------------------------------------------------------------
/test/bumblebee/text/m2m100_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.M2m100Test do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-M2M100Model"},
11 | architecture: :base
12 | )
13 |
14 | assert %Bumblebee.Text.M2m100{architecture: :base} = spec
15 |
16 | input = %{
17 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
18 | "decoder_input_ids" => Nx.tensor([[15, 25, 35, 45, 55, 65, 0, 0]]),
19 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
20 | }
21 |
22 | output = Axon.predict(model, params, input)
23 |
24 | assert Nx.shape(output.hidden_state) == {1, 8, 16}
25 |
26 | assert_all_close(
27 | output.hidden_state[[.., 1..3, 1..3]],
28 | Nx.tensor([
29 | [
30 | [0.7856, -0.3174, -0.4792],
31 | [0.7265, -0.2752, -0.4823],
32 | [1.0580, -0.3263, -0.7994]
33 | ]
34 | ])
35 | )
36 | end
37 |
38 | test ":for_conditional_generation" do
39 | assert {:ok, %{model: model, params: params, spec: spec}} =
40 | Bumblebee.load_model(
41 | {:hf, "hf-internal-testing/tiny-random-M2M100ForConditionalGeneration"},
42 | architecture: :for_conditional_generation
43 | )
44 |
45 | assert %Bumblebee.Text.M2m100{architecture: :for_conditional_generation} = spec
46 |
47 | input = %{
48 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
49 | "decoder_input_ids" => Nx.tensor([[15, 25, 35, 45, 55, 65, 0, 0]]),
50 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
51 | }
52 |
53 | output = Axon.predict(model, params, input)
54 |
55 | assert Nx.shape(output.logits) == {1, 8, 128_112}
56 |
57 | assert_all_close(
58 | output.logits[[.., 1..3, 1..3]],
59 | Nx.tensor([
60 | [
61 | [0.0000, -0.0323, 0.0527],
62 | [0.0000, -0.0404, 0.0713],
63 | [0.0000, -0.0660, 0.0758]
64 | ]
65 | ])
66 | )
67 | end
68 | end
69 |
--------------------------------------------------------------------------------
/test/bumblebee/text/mbart_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.MbartTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MBartModel"})
11 |
12 | assert %Bumblebee.Text.Mbart{architecture: :base} = spec
13 |
14 | inputs = %{
15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
17 | }
18 |
19 | outputs = Axon.predict(model, params, inputs)
20 |
21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 16}
22 |
23 | assert_all_close(
24 | outputs.hidden_state[[.., 1..3, 1..3]],
25 | Nx.tensor([
26 | [[0.8300, -0.4815, 0.4641], [-1.6583, 0.9162, -0.3562], [-0.6983, -0.7699, 1.0282]]
27 | ])
28 | )
29 | end
30 |
31 | test ":for_conditional_generation" do
32 | assert {:ok, %{model: model, params: params, spec: spec}} =
33 | Bumblebee.load_model(
34 | {:hf, "hf-internal-testing/tiny-random-MBartForConditionalGeneration"}
35 | )
36 |
37 | assert %Bumblebee.Text.Mbart{architecture: :for_conditional_generation} = spec
38 |
39 | inputs = %{
40 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
41 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
42 | }
43 |
44 | outputs = Axon.predict(model, params, inputs)
45 |
46 | assert Nx.shape(outputs.logits) == {1, 10, 250_027}
47 |
48 | assert_all_close(
49 | outputs.logits[[.., 1..3, 1..3]],
50 | Nx.tensor([[[0.0000, 0.0923, 0.0841], [0.0000, 0.1023, -0.0938], [0.0000, 0.0703, 0.1231]]])
51 | )
52 | end
53 |
54 | test ":for_sequence_classification" do
55 | assert {:ok, %{model: model, params: params, spec: spec}} =
56 | Bumblebee.load_model(
57 | {:hf, "hf-internal-testing/tiny-random-MBartForSequenceClassification"}
58 | )
59 |
60 | assert %Bumblebee.Text.Mbart{architecture: :for_sequence_classification} = spec
61 |
62 | inputs = %{
63 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 2, 0]]),
64 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0]])
65 | }
66 |
67 | outputs = Axon.predict(model, params, inputs)
68 |
69 | assert Nx.shape(outputs.logits) == {1, 2}
70 |
71 | assert_all_close(
72 | outputs.logits,
73 | Nx.tensor([[0.0085, 0.0054]])
74 | )
75 | end
76 |
77 | test ":for_question_answering" do
78 | assert {:ok, %{model: model, params: params, spec: spec}} =
79 | Bumblebee.load_model(
80 | {:hf, "hf-internal-testing/tiny-random-MBartForQuestionAnswering"}
81 | )
82 |
83 | assert %Bumblebee.Text.Mbart{architecture: :for_question_answering} = spec
84 |
85 | inputs = %{
86 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
87 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
88 | }
89 |
90 | outputs = Axon.predict(model, params, inputs)
91 |
92 | assert Nx.shape(outputs.start_logits) == {1, 10}
93 | assert Nx.shape(outputs.end_logits) == {1, 10}
94 |
95 | assert_all_close(
96 | outputs.start_logits[[.., 1..3]],
97 | Nx.tensor([[0.1063, -0.1271, -0.1534]])
98 | )
99 |
100 | assert_all_close(
101 | outputs.end_logits[[.., 1..3]],
102 | Nx.tensor([[0.0268, 0.0238, 0.0857]])
103 | )
104 | end
105 |
106 | test ":for_causal_language_modeling" do
107 | assert {:ok, %{model: model, params: params, spec: spec}} =
108 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MBartForCausalLM"})
109 |
110 | assert %Bumblebee.Text.Mbart{architecture: :for_causal_language_modeling} = spec
111 |
112 | inputs = %{
113 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
114 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
115 | }
116 |
117 | outputs = Axon.predict(model, params, inputs)
118 |
119 | assert Nx.shape(outputs.logits) == {1, 10, 250_027}
120 |
121 | assert_all_close(
122 | outputs.logits[[.., 1..3, 1..3]],
123 | Nx.tensor([
124 | [[0.0000, -0.0236, -0.0043], [0.0000, -0.0101, 0.0510], [0.0000, 0.0404, 0.0327]]
125 | ])
126 | )
127 | end
128 | end
129 |
--------------------------------------------------------------------------------
/test/bumblebee/text/mistral_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.MistralTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MistralModel"})
11 |
12 | assert %Bumblebee.Text.Mistral{architecture: :base} = spec
13 |
14 | inputs = %{
15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
17 | }
18 |
19 | outputs = Axon.predict(model, params, inputs)
20 |
21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32}
22 |
23 | assert_all_close(
24 | outputs.hidden_state[[.., 1..3, 1..3]],
25 | Nx.tensor([
26 | [[0.9450, -1.3945, 0.7331], [-2.1118, -1.3091, -0.7834], [-1.7609, -1.3034, 1.0634]]
27 | ])
28 | )
29 | end
30 |
31 | test ":base with attention sliding window" do
32 | assert {:ok, spec} =
33 | Bumblebee.load_spec({:hf, "hf-internal-testing/tiny-random-MistralModel"})
34 |
35 | spec = Bumblebee.configure(spec, attention_window_size: 2)
36 |
37 | assert {:ok, %{model: model, params: params, spec: spec}} =
38 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MistralModel"},
39 | spec: spec
40 | )
41 |
42 | assert %Bumblebee.Text.Mistral{architecture: :base} = spec
43 |
44 | inputs = %{
45 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
46 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
47 | }
48 |
49 | outputs = Axon.predict(model, params, inputs)
50 |
51 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32}
52 |
53 | assert_all_close(
54 | outputs.hidden_state[[.., 1..3, 1..3]],
55 | Nx.tensor([
56 | [[0.9450, -1.3945, 0.7331], [-2.1118, -1.3091, -0.7834], [-1.3033, -1.3374, 0.8919]]
57 | ])
58 | )
59 | end
60 |
61 | test ":for_sequence_classification" do
62 | assert {:ok, %{model: model, params: params, spec: spec}} =
63 | Bumblebee.load_model(
64 | {:hf, "hf-internal-testing/tiny-random-MistralForSequenceClassification"}
65 | )
66 |
67 | assert %Bumblebee.Text.Mistral{architecture: :for_sequence_classification} = spec
68 |
69 | inputs = %{
70 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
71 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
72 | }
73 |
74 | outputs = Axon.predict(model, params, inputs)
75 |
76 | assert Nx.shape(outputs.logits) == {1, 2}
77 |
78 | assert_all_close(
79 | outputs.logits,
80 | Nx.tensor([[0.0035, -0.0357]])
81 | )
82 | end
83 |
84 | test ":for_causal_language_modeling" do
85 | assert {:ok, %{model: model, params: params, spec: spec}} =
86 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MistralForCausalLM"})
87 |
88 | assert %Bumblebee.Text.Mistral{architecture: :for_causal_language_modeling} = spec
89 |
90 | inputs = %{
91 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
92 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
93 | }
94 |
95 | outputs = Axon.predict(model, params, inputs)
96 |
97 | assert Nx.shape(outputs.logits) == {1, 10, 32000}
98 |
99 | assert_all_close(
100 | outputs.logits[[.., 1..3, 1..3]],
101 | Nx.tensor([
102 | [[-0.1054, 0.0026, 0.0450], [0.1400, 0.1388, 0.0265], [0.0060, -0.1150, -0.1463]]
103 | ])
104 | )
105 | end
106 | end
107 |
--------------------------------------------------------------------------------
/test/bumblebee/text/nllb_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.NllbTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":for_conditional_generation" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-nllb"},
11 | module: Bumblebee.Text.M2m100,
12 | architecture: :for_conditional_generation
13 | )
14 |
15 | assert %Bumblebee.Text.M2m100{architecture: :for_conditional_generation} = spec
16 |
17 | input = %{
18 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
19 | "decoder_input_ids" => Nx.tensor([[15, 25, 35, 45, 55, 65, 0, 0]]),
20 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
21 | }
22 |
23 | output = Axon.predict(model, params, input)
24 |
25 | assert Nx.shape(output.logits) == {1, 8, 128_112}
26 |
27 | assert_all_close(
28 | output.logits[[.., 1..3, 1..3]],
29 | Nx.tensor([
30 | [
31 | [0.0000, 0.0169, -0.0698],
32 | [0.0000, 0.0525, -0.1042],
33 | [0.0000, 0.0667, -0.1078]
34 | ]
35 | ])
36 | )
37 | end
38 | end
39 |
--------------------------------------------------------------------------------
/test/bumblebee/text/phi3_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.Phi3Test do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-Phi3Model"})
11 |
12 | assert %Bumblebee.Text.Phi3{architecture: :base} = spec
13 |
14 | inputs = %{
15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
17 | }
18 |
19 | outputs = Axon.predict(model, params, inputs)
20 |
21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32}
22 |
23 | assert_all_close(
24 | outputs.hidden_state[[.., 1..3, 1..3]],
25 | Nx.tensor([
26 | [[-1.4514, 0.6000, 0.1565], [-0.2677, 1.9352, 0.5334], [1.1021, -0.1642, 0.5992]]
27 | ])
28 | )
29 | end
30 |
31 | test ":base rotary embedding scaling strategy :longrope" do
32 | assert {:ok, %{model: model, params: params, spec: spec}} =
33 | Bumblebee.load_model(
34 | {:hf,
35 | "bumblebee-testing/tiny-random-Phi3Model-rope_scaling-longrope-original_max_position_embeddings-256"}
36 | )
37 |
38 | assert %Bumblebee.Text.Phi3{architecture: :base} = spec
39 |
40 | inputs = %{
41 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
42 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
43 | }
44 |
45 | outputs = Axon.predict(model, params, inputs)
46 |
47 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32}
48 |
49 | assert_all_close(
50 | outputs.hidden_state[[.., 1..3, 1..3]],
51 | Nx.tensor([
52 | [[-1.4528, 0.5995, 0.1573], [-0.2664, 1.9339, 0.5336], [1.1053, -0.1643, 0.5989]]
53 | ])
54 | )
55 | end
56 |
57 | test ":for_sequence_classification" do
58 | assert {:ok, %{model: model, params: params, spec: spec}} =
59 | Bumblebee.load_model(
60 | {:hf, "bumblebee-testing/tiny-random-Phi3ForSequenceClassification"}
61 | )
62 |
63 | assert %Bumblebee.Text.Phi3{architecture: :for_sequence_classification} = spec
64 |
65 | inputs = %{
66 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
67 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
68 | }
69 |
70 | outputs = Axon.predict(model, params, inputs)
71 |
72 | assert Nx.shape(outputs.logits) == {1, 2}
73 |
74 | assert_all_close(
75 | outputs.logits,
76 | Nx.tensor([[0.1249, 0.1090]])
77 | )
78 | end
79 |
80 | test ":for_token_classification" do
81 | assert {:ok, %{model: model, params: params, spec: spec}} =
82 | Bumblebee.load_model(
83 | {:hf, "bumblebee-testing/tiny-random-Phi3ForTokenClassification"}
84 | )
85 |
86 | assert %Bumblebee.Text.Phi3{architecture: :for_token_classification} = spec
87 |
88 | inputs = %{
89 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
90 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
91 | }
92 |
93 | outputs = Axon.predict(model, params, inputs)
94 |
95 | assert Nx.shape(outputs.logits) == {1, 10, 2}
96 |
97 | assert_all_close(
98 | outputs.logits[[.., 1..3//1, ..]],
99 | Nx.tensor([[[0.0588, -0.0997], [0.0494, -0.1636], [0.0402, 0.0486]]])
100 | )
101 | end
102 |
103 | test ":for_causal_language_modeling" do
104 | assert {:ok, %{model: model, params: params, spec: spec}} =
105 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-Phi3ForCausalLM"})
106 |
107 | assert %Bumblebee.Text.Phi3{architecture: :for_causal_language_modeling} = spec
108 |
109 | inputs = %{
110 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
111 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
112 | }
113 |
114 | outputs = Axon.predict(model, params, inputs)
115 |
116 | assert Nx.shape(outputs.logits) == {1, 10, 1024}
117 |
118 | assert_all_close(
119 | outputs.logits[[.., 1..3, 1..3]],
120 | Nx.tensor([
121 | [[-0.0893, 0.0890, -0.1252], [0.0574, 0.0197, -0.0580], [-0.0302, -0.0644, -0.1228]]
122 | ])
123 | )
124 | end
125 | end
126 |
--------------------------------------------------------------------------------
/test/bumblebee/text/phi_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.PhiTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-PhiModel"})
11 |
12 | assert %Bumblebee.Text.Phi{architecture: :base} = spec
13 |
14 | inputs = %{
15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
17 | }
18 |
19 | outputs = Axon.predict(model, params, inputs)
20 |
21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32}
22 |
23 | assert_all_close(
24 | outputs.hidden_state[[.., 1..3, 1..3]],
25 | Nx.tensor([[[-0.3275, 0.5231, 0.5690], [0.2239, 0.5028, 0.4599], [-0.0979, 1.0183, 0.3350]]])
26 | )
27 | end
28 |
29 | test ":for_sequence_classification" do
30 | assert {:ok, %{model: model, params: params, spec: spec}} =
31 | Bumblebee.load_model(
32 | {:hf, "bumblebee-testing/tiny-random-PhiForSequenceClassification"}
33 | )
34 |
35 | assert %Bumblebee.Text.Phi{architecture: :for_sequence_classification} = spec
36 |
37 | inputs = %{
38 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
39 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
40 | }
41 |
42 | outputs = Axon.predict(model, params, inputs)
43 |
44 | assert Nx.shape(outputs.logits) == {1, 2}
45 |
46 | assert_all_close(
47 | outputs.logits,
48 | Nx.tensor([[0.1403, -0.1382]])
49 | )
50 | end
51 |
52 | test ":for_token_classification" do
53 | assert {:ok, %{model: model, params: params, spec: spec}} =
54 | Bumblebee.load_model(
55 | {:hf, "bumblebee-testing/tiny-random-PhiForTokenClassification"}
56 | )
57 |
58 | assert %Bumblebee.Text.Phi{architecture: :for_token_classification} = spec
59 |
60 | inputs = %{
61 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
62 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
63 | }
64 |
65 | outputs = Axon.predict(model, params, inputs)
66 |
67 | assert Nx.shape(outputs.logits) == {1, 10, 2}
68 |
69 | assert_all_close(
70 | outputs.logits[[.., 1..3//1, ..]],
71 | Nx.tensor([[[-0.0364, -0.1207], [0.2520, 0.0755], [0.0243, 0.0269]]])
72 | )
73 | end
74 |
75 | test ":for_causal_language_modeling" do
76 | assert {:ok, %{model: model, params: params, spec: spec}} =
77 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-PhiForCausalLM"})
78 |
79 | assert %Bumblebee.Text.Phi{architecture: :for_causal_language_modeling} = spec
80 |
81 | inputs = %{
82 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
83 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
84 | }
85 |
86 | outputs = Axon.predict(model, params, inputs)
87 |
88 | assert Nx.shape(outputs.logits) == {1, 10, 1024}
89 |
90 | assert_all_close(
91 | outputs.logits[[.., 1..3, 1..3]],
92 | Nx.tensor([[[0.2541, 0.0827, 0.0526], [0.1901, 0.1289, 0.0758], [0.1051, 0.0658, -0.1167]]])
93 | )
94 | end
95 | end
96 |
--------------------------------------------------------------------------------
/test/bumblebee/text/question_answering_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.QuestionAnsweringTest do
2 | use ExUnit.Case, async: false
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag serving_test_tags()
7 |
8 | test "returns the most probable answer" do
9 | {:ok, roberta} = Bumblebee.load_model({:hf, "deepset/roberta-base-squad2"})
10 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "FacebookAI/roberta-base"})
11 |
12 | serving = Bumblebee.Text.question_answering(roberta, tokenizer)
13 |
14 | input = %{question: "What's my name?", context: "My name is Sarah and I live in London."}
15 |
16 | assert %{
17 | results: [
18 | %{
19 | text: "Sarah",
20 | start: 11,
21 | end: 16,
22 | score: score
23 | }
24 | ]
25 | } = Nx.Serving.run(serving, input)
26 |
27 | assert_all_close(score, 0.8105)
28 | end
29 |
30 | test "supports multiple inputs" do
31 | {:ok, roberta} = Bumblebee.load_model({:hf, "deepset/roberta-base-squad2"})
32 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "FacebookAI/roberta-base"})
33 |
34 | serving = Bumblebee.Text.question_answering(roberta, tokenizer)
35 |
36 | inputs = [
37 | %{question: "What's my name?", context: "My name is Sarah and I live in London."},
38 | %{question: "Where do I live?", context: "My name is Clara and I live in Berkeley."}
39 | ]
40 |
41 | assert [
42 | %{results: [%{text: "Sarah", start: 11, end: 16, score: _}]},
43 | %{results: [%{text: "Berkeley", start: 31, end: 39, score: _}]}
44 | ] = Nx.Serving.run(serving, inputs)
45 | end
46 | end
47 |
--------------------------------------------------------------------------------
/test/bumblebee/text/text_classification_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.TextClassificationTest do
2 | use ExUnit.Case, async: false
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag serving_test_tags()
7 |
8 | test "returns top scored labels" do
9 | {:ok, model_info} = Bumblebee.load_model({:hf, "cardiffnlp/twitter-roberta-base-emotion"})
10 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "FacebookAI/roberta-base"})
11 |
12 | serving = Bumblebee.Text.TextClassification.text_classification(model_info, tokenizer)
13 |
14 | text = "Cats are cute."
15 |
16 | assert %{
17 | predictions: [
18 | %{label: "optimism", score: _},
19 | %{label: "sadness", score: _},
20 | %{label: "anger", score: _},
21 | %{label: "joy", score: _}
22 | ]
23 | } = Nx.Serving.run(serving, text)
24 | end
25 | end
26 |
--------------------------------------------------------------------------------
/test/bumblebee/text/text_embedding_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.TextEmbeddingTest do
2 | use ExUnit.Case, async: false
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag serving_test_tags()
7 |
8 | test "returns embedding for a piece of text" do
9 | {:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-small-v2"})
10 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "intfloat/e5-small-v2"})
11 |
12 | serving = Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer)
13 |
14 | text = "query: Cats are cute."
15 |
16 | assert %{embedding: %Nx.Tensor{} = embedding} = Nx.Serving.run(serving, text)
17 |
18 | assert Nx.shape(embedding) == {384}
19 |
20 | assert_all_close(
21 | embedding[1..3],
22 | Nx.tensor([0.0420, -0.0188, 0.1115])
23 | )
24 | end
25 |
26 | test "returns normalized embedding for a piece of text" do
27 | {:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-small-v2"})
28 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "intfloat/e5-small-v2"})
29 |
30 | options = [embedding_processor: :l2_norm]
31 |
32 | serving = Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer, options)
33 |
34 | text = "query: Cats are cute."
35 |
36 | assert %{embedding: %Nx.Tensor{} = embedding} = Nx.Serving.run(serving, text)
37 |
38 | assert Nx.shape(embedding) == {384}
39 |
40 | assert_all_close(
41 | embedding[1..3],
42 | Nx.tensor([0.0433, -0.0194, 0.1151])
43 | )
44 |
45 | assert_all_close(Nx.sum(Nx.pow(embedding, 2)), Nx.tensor(1.0), atol: 1.0e-6)
46 | end
47 |
48 | test "supports compilation for single or multiple sequence lengths" do
49 | {:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-small-v2"})
50 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "intfloat/e5-small-v2"})
51 |
52 | serving_short =
53 | Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer,
54 | compile: [batch_size: 1, sequence_length: 8]
55 | )
56 |
57 | serving_long =
58 | Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer,
59 | compile: [batch_size: 1, sequence_length: 16]
60 | )
61 |
62 | serving_both =
63 | Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer,
64 | compile: [batch_size: 1, sequence_length: [8, 16]]
65 | )
66 |
67 | short_text = "short text"
68 | long_text = "definitely much longer text that should exceed 16 tokens"
69 |
70 | assert %{embedding: embedding_short} = Nx.Serving.run(serving_short, short_text)
71 | assert %{embedding: embedding_long} = Nx.Serving.run(serving_long, long_text)
72 |
73 | assert %{embedding: embedding_short2} = Nx.Serving.run(serving_both, short_text)
74 | assert %{embedding: embedding_long2} = Nx.Serving.run(serving_both, long_text)
75 |
76 | assert_equal(embedding_short, embedding_short2)
77 | assert_equal(embedding_long, embedding_long2)
78 | end
79 |
80 | @tag :multi_device
81 | test "works with partitioned serving", %{test: test} do
82 | {:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-small-v2"})
83 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "intfloat/e5-small-v2"})
84 |
85 | serving =
86 | Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer,
87 | compile: [batch_size: 1, sequence_length: 16],
88 | defn_options: [compiler: EXLA, client: :other_host],
89 | preallocate_params: true
90 | )
91 |
92 | start_supervised!({Nx.Serving, serving: serving, name: test, partitions: true})
93 |
94 | text = "query: Cats are cute."
95 |
96 | assert [
97 | %{embedding: %Nx.Tensor{} = embedding1},
98 | %{embedding: %Nx.Tensor{} = embedding2}
99 | ] = Nx.Serving.batched_run(test, [text, text])
100 |
101 | assert_equal(embedding1, embedding2)
102 | end
103 | end
104 |
--------------------------------------------------------------------------------
/test/bumblebee/text/token_classification_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.TokenClassificationTest do
2 | use ExUnit.Case, async: false
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag serving_test_tags()
7 |
8 | test "correctly extracts entities with :same aggregation" do
9 | assert {:ok, model_info} = Bumblebee.load_model({:hf, "dslim/bert-base-NER"})
10 | assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-cased"})
11 |
12 | serving =
13 | Bumblebee.Text.TokenClassification.token_classification(model_info, tokenizer,
14 | aggregation: :same
15 | )
16 |
17 | text = "I went with Jane Doe to Atlanta and we talked to John Smith about Microsoft"
18 |
19 | assert %{entities: [jane, atlanta, john, microsoft]} = Nx.Serving.run(serving, text)
20 |
21 | assert %{
22 | label: "PER",
23 | score: _jane_score,
24 | phrase: "Jane Doe",
25 | start: 12,
26 | end: 20
27 | } = jane
28 |
29 | assert %{
30 | label: "LOC",
31 | score: _atlanta_score,
32 | phrase: "Atlanta",
33 | start: 24,
34 | end: 31
35 | } = atlanta
36 |
37 | assert %{
38 | label: "PER",
39 | score: _john_score,
40 | phrase: "John Smith",
41 | start: 49,
42 | end: 59
43 | } = john
44 |
45 | assert %{
46 | label: "ORG",
47 | score: _microsoft_score,
48 | phrase: "Microsoft",
49 | start: 66,
50 | end: 75
51 | } = microsoft
52 |
53 | # Offsets should be expressed in terms of bytes (note that é is 2 bytes)
54 |
55 | text = "Jane é John"
56 |
57 | assert %{
58 | entities: [%{start: 0, end: 4}, %{start: 8, end: 12}]
59 | } = Nx.Serving.run(serving, text)
60 | end
61 |
62 | for aggregation <- [:word_first, :word_max, :word_average] do
63 | test "correctly extracts entities with :#{aggregation} aggregation" do
64 | assert {:ok, model_info} = Bumblebee.load_model({:hf, "dslim/bert-base-NER"})
65 | assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-cased"})
66 |
67 | serving =
68 | Bumblebee.Text.TokenClassification.token_classification(model_info, tokenizer,
69 | aggregation: unquote(aggregation)
70 | )
71 |
72 | text = "I went with Janine Doe to Atlanta and we talked to John Smith about Microsoft"
73 |
74 | assert %{entities: [jane, atlanta, john, microsoft]} = Nx.Serving.run(serving, text)
75 |
76 | assert %{
77 | label: "PER",
78 | score: _janine_score,
79 | phrase: "Janine Doe",
80 | start: 12,
81 | end: 22
82 | } = jane
83 |
84 | assert %{
85 | label: "LOC",
86 | score: _atlanta_score,
87 | phrase: "Atlanta",
88 | start: 26,
89 | end: 33
90 | } = atlanta
91 |
92 | assert %{
93 | label: "PER",
94 | score: _john_score,
95 | phrase: "John Smith",
96 | start: 51,
97 | end: 61
98 | } = john
99 |
100 | assert %{
101 | label: "ORG",
102 | score: _microsoft_score,
103 | phrase: "Microsoft",
104 | start: 68,
105 | end: 77
106 | } = microsoft
107 | end
108 | end
109 |
110 | test "correctly extracts entities with simple aggregation on batched input" do
111 | assert {:ok, model_info} = Bumblebee.load_model({:hf, "dslim/bert-base-NER"})
112 | assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-cased"})
113 |
114 | serving =
115 | Bumblebee.Text.TokenClassification.token_classification(model_info, tokenizer,
116 | aggregation: :same
117 | )
118 |
119 | texts = [
120 | "I went with Janine Doe to Atlanta and we talked to John Smith about Microsoft",
121 | "John went to Philadelphia"
122 | ]
123 |
124 | assert [_first, %{entities: [john, philadelphia]}] = Nx.Serving.run(serving, texts)
125 |
126 | assert %{label: "PER", phrase: "John"} = john
127 | assert %{label: "LOC", phrase: "Philadelphia"} = philadelphia
128 | end
129 | end
130 |
--------------------------------------------------------------------------------
/test/bumblebee/text/translation_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.TranslationTest do
2 | use ExUnit.Case, async: false
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag serving_test_tags()
7 |
8 | test "generates text with greedy generation" do
9 | {:ok, model_info} = Bumblebee.load_model({:hf, "facebook/nllb-200-distilled-600M"})
10 |
11 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/nllb-200-distilled-600M"})
12 |
13 | {:ok, generation_config} =
14 | Bumblebee.load_generation_config({:hf, "facebook/nllb-200-distilled-600M"})
15 |
16 | serving = Bumblebee.Text.translation(model_info, tokenizer, generation_config)
17 |
18 | text = "The bank of the river is beautiful in spring"
19 |
20 | assert %{
21 | results: [
22 | %{
23 | text: "W wiosnę brzeg rzeki jest piękny",
24 | token_summary: %{input: 11, output: 13, padding: 0}
25 | }
26 | ]
27 | } =
28 | Nx.Serving.run(serving, %{
29 | text: text,
30 | source_language_token: "eng_Latn",
31 | target_language_token: "pol_Latn"
32 | })
33 | end
34 | end
35 |
--------------------------------------------------------------------------------
/test/bumblebee/text/xlm_roberta_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.XlmRobertaTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-XLMRobertaModel"})
11 |
12 | assert %Bumblebee.Text.Roberta{architecture: :base} = spec
13 |
14 | inputs = %{
15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
17 | }
18 |
19 | outputs = Axon.predict(model, params, inputs)
20 |
21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32}
22 |
23 | assert_all_close(
24 | outputs.hidden_state[[.., 1..3, 1..3]],
25 | Nx.tensor([
26 | [[-0.6455, -0.4189, 0.3424], [-0.4303, -0.6731, 0.2534], [-0.5240, 0.0864, -0.5632]]
27 | ])
28 | )
29 | end
30 | end
31 |
--------------------------------------------------------------------------------
/test/bumblebee/text/zero_shot_classification_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Text.ZeroShotClassificationTest do
2 | use ExUnit.Case, async: false
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag serving_test_tags()
7 |
8 | test "correctly classifies labels with one sequence" do
9 | {:ok, model} = Bumblebee.load_model({:hf, "facebook/bart-large-mnli"})
10 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/bart-large-mnli"})
11 | labels = ["cooking", "traveling", "dancing"]
12 |
13 | zero_shot_serving = Bumblebee.Text.zero_shot_classification(model, tokenizer, labels)
14 |
15 | output = Nx.Serving.run(zero_shot_serving, "one day I will see the world")
16 |
17 | assert %{
18 | predictions: [
19 | %{label: "traveling", score: _},
20 | %{label: "dancing", score: _},
21 | %{label: "cooking", score: _}
22 | ]
23 | } = output
24 |
25 | assert %{label: "traveling", score: score} = Enum.max_by(output.predictions, & &1.score)
26 | assert_all_close(score, 0.9874)
27 | end
28 |
29 | test "correctly classifies labels with multiple sequences" do
30 | {:ok, model} = Bumblebee.load_model({:hf, "facebook/bart-large-mnli"})
31 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/bart-large-mnli"})
32 | labels = ["cooking", "traveling", "dancing"]
33 |
34 | zero_shot_serving = Bumblebee.Text.zero_shot_classification(model, tokenizer, labels)
35 |
36 | assert [output1, output2] =
37 | Nx.Serving.run(zero_shot_serving, [
38 | "one day I will see the world",
39 | "one day I will learn to salsa"
40 | ])
41 |
42 | assert %{label: "traveling", score: score1} = Enum.max_by(output1.predictions, & &1.score)
43 | assert_all_close(score1, 0.9874)
44 |
45 | assert %{label: "dancing", score: score2} = Enum.max_by(output2.predictions, & &1.score)
46 | assert_all_close(score2, 0.9585)
47 | end
48 |
49 | test "correctly classifies batch with compilation set to true" do
50 | {:ok, model} = Bumblebee.load_model({:hf, "facebook/bart-large-mnli"})
51 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/bart-large-mnli"})
52 | labels = ["cooking", "traveling", "dancing"]
53 |
54 | zero_shot_serving =
55 | Bumblebee.Text.zero_shot_classification(model, tokenizer, labels,
56 | compile: [batch_size: 2, sequence_length: 32],
57 | defn_options: [compiler: EXLA]
58 | )
59 |
60 | assert [output1, output2] =
61 | Nx.Serving.run(zero_shot_serving, [
62 | "one day I will see the world",
63 | "one day I will learn to salsa"
64 | ])
65 |
66 | assert %{label: "traveling", score: score1} = Enum.max_by(output1.predictions, & &1.score)
67 | assert_all_close(score1, 0.9874)
68 |
69 | assert %{label: "dancing", score: score2} = Enum.max_by(output2.predictions, & &1.score)
70 | assert_all_close(score2, 0.9585)
71 | end
72 | end
73 |
--------------------------------------------------------------------------------
/test/bumblebee/utils/image_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Utils.ImageTest do
2 | use ExUnit.Case, async: true
3 |
4 | doctest Bumblebee.Utils.Image
5 | end
6 |
--------------------------------------------------------------------------------
/test/bumblebee/utils/nx_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Utils.NxTest do
2 | use ExUnit.Case, async: true
3 |
4 | doctest Bumblebee.Utils.Nx
5 | end
6 |
--------------------------------------------------------------------------------
/test/bumblebee/vision/bit_featurizer_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.BitFeaturizerTest do
2 | use ExUnit.Case, async: true
3 |
4 | test "encodes image" do
5 | assert {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "google/bit-50"})
6 |
7 | assert %Bumblebee.Vision.BitFeaturizer{} = featurizer
8 |
9 | image = Nx.tensor([[[50], [100]], [[150], [200]]]) |> Nx.broadcast({2, 2, 3})
10 |
11 | inputs = Bumblebee.apply_featurizer(featurizer, image)
12 |
13 | assert Nx.shape(inputs["pixel_values"]) == {1, 448, 448, 3}
14 | end
15 | end
16 |
--------------------------------------------------------------------------------
/test/bumblebee/vision/blip_featurizer_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.BlipFeaturizerTest do
2 | use ExUnit.Case, async: true
3 |
4 | test "encodes image" do
5 | assert {:ok, featurizer} =
6 | Bumblebee.load_featurizer({:hf, "Salesforce/blip-image-captioning-base"})
7 |
8 | assert %Bumblebee.Vision.BlipFeaturizer{} = featurizer
9 |
10 | image = Nx.tensor([[[50], [100]], [[150], [200]]]) |> Nx.broadcast({2, 2, 3})
11 |
12 | inputs = Bumblebee.apply_featurizer(featurizer, image)
13 |
14 | assert Nx.shape(inputs["pixel_values"]) == {1, 384, 384, 3}
15 | end
16 | end
17 |
--------------------------------------------------------------------------------
/test/bumblebee/vision/blip_vision_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.BlipVisionTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-BlipModel"},
11 | module: Bumblebee.Vision.BlipVision,
12 | architecture: :base
13 | )
14 |
15 | assert %Bumblebee.Vision.BlipVision{architecture: :base} = spec
16 |
17 | inputs = %{
18 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3})
19 | }
20 |
21 | outputs = Axon.predict(model, params, inputs)
22 |
23 | assert Nx.shape(outputs.hidden_state) == {1, 226, 32}
24 | assert Nx.shape(outputs.pooled_state) == {1, 32}
25 |
26 | assert_all_close(
27 | outputs.hidden_state[[.., 1..3, 1..3]] |> Nx.multiply(1_000_000),
28 | Nx.tensor([
29 | [[-0.0272, -0.0129, 0.0174], [0.0069, -0.0429, -0.0334], [0.0428, -0.0797, -0.0353]]
30 | ])
31 | )
32 |
33 | assert_all_close(
34 | outputs.pooled_state[[.., 1..3]] |> Nx.multiply(10_000),
35 | Nx.tensor([[-0.0128, -0.0792, -0.1011]])
36 | )
37 | end
38 | end
39 |
--------------------------------------------------------------------------------
/test/bumblebee/vision/clip_featurizer_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.ClipFeaturizerTest do
2 | use ExUnit.Case, async: true
3 |
4 | test "encodes image" do
5 | assert {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "openai/clip-vit-base-patch32"})
6 |
7 | assert %Bumblebee.Vision.ClipFeaturizer{} = featurizer
8 |
9 | image = Nx.tensor([[[50], [100]], [[150], [200]]]) |> Nx.broadcast({2, 2, 3})
10 |
11 | inputs = Bumblebee.apply_featurizer(featurizer, image)
12 |
13 | assert Nx.shape(inputs["pixel_values"]) == {1, 224, 224, 3}
14 | end
15 | end
16 |
--------------------------------------------------------------------------------
/test/bumblebee/vision/clip_vision_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.ClipVisionTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-CLIPModel"},
11 | module: Bumblebee.Vision.ClipVision,
12 | architecture: :base
13 | )
14 |
15 | assert %Bumblebee.Vision.ClipVision{architecture: :base} = spec
16 |
17 | inputs = %{
18 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3})
19 | }
20 |
21 | outputs = Axon.predict(model, params, inputs)
22 |
23 | assert Nx.shape(outputs.hidden_state) == {1, 226, 32}
24 | assert Nx.shape(outputs.pooled_state) == {1, 32}
25 |
26 | assert_all_close(
27 | outputs.hidden_state[[.., 1..3, 1..3]],
28 | Nx.tensor([
29 | [[0.4483, 0.3736, -0.5581], [0.9376, -0.3424, -0.1002], [0.5782, 0.1069, -0.2953]]
30 | ])
31 | )
32 |
33 | assert_all_close(
34 | outputs.pooled_state[[.., 1..3]],
35 | Nx.tensor([[-0.5059, 0.7391, 0.9252]])
36 | )
37 | end
38 |
39 | test ":for_embedding" do
40 | assert {:ok, %{model: model, params: params, spec: spec}} =
41 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-CLIPModel"},
42 | module: Bumblebee.Vision.ClipVision,
43 | architecture: :for_embedding
44 | )
45 |
46 | assert %Bumblebee.Vision.ClipVision{architecture: :for_embedding} = spec
47 |
48 | inputs = %{
49 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3})
50 | }
51 |
52 | outputs = Axon.predict(model, params, inputs)
53 |
54 | assert Nx.shape(outputs.embedding) == {1, 64}
55 |
56 | assert_all_close(
57 | outputs.embedding[[.., 1..3]],
58 | Nx.tensor([[0.8865, -0.9042, -1.1233]])
59 | )
60 | end
61 | end
62 |
--------------------------------------------------------------------------------
/test/bumblebee/vision/convnext_featurizer_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.ConvNextFeaturizerTest do
2 | use ExUnit.Case, async: true
3 |
4 | test "encodes image" do
5 | assert {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "facebook/convnext-tiny-224"})
6 |
7 | assert %Bumblebee.Vision.ConvNextFeaturizer{} = featurizer
8 |
9 | image = Nx.tensor([[[50], [100]], [[150], [200]]]) |> Nx.broadcast({2, 2, 3})
10 |
11 | inputs = Bumblebee.apply_featurizer(featurizer, image)
12 |
13 | assert Nx.shape(inputs["pixel_values"]) == {1, 224, 224, 3}
14 | end
15 |
16 | test "allows an alpha channel" do
17 | assert {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "facebook/convnext-tiny-224"})
18 |
19 | assert %Bumblebee.Vision.ConvNextFeaturizer{} = featurizer
20 |
21 | image = Nx.tensor([[[50], [100]], [[150], [200]]]) |> Nx.broadcast({2, 2, 4})
22 |
23 | inputs = Bumblebee.apply_featurizer(featurizer, image)
24 |
25 | assert Nx.shape(inputs["pixel_values"]) == {1, 224, 224, 3}
26 | end
27 | end
28 |
--------------------------------------------------------------------------------
/test/bumblebee/vision/convnext_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.ConvNextTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-ConvNextModel"})
11 |
12 | assert %Bumblebee.Vision.ConvNext{architecture: :base} = spec
13 |
14 | inputs = %{
15 | "pixel_values" => Nx.broadcast(0.5, {1, 224, 224, 3})
16 | }
17 |
18 | outputs = Axon.predict(model, params, inputs)
19 |
20 | assert Nx.shape(outputs.hidden_state) == {1, 7, 7, 40}
21 | assert Nx.shape(outputs.pooled_state) == {1, 40}
22 |
23 | assert_all_close(
24 | outputs.hidden_state[[.., 1..2, 1..2, 1..2]],
25 | Nx.tensor([[[[0.3924, -0.2330], [0.3924, -0.2330]], [[0.3924, -0.2330], [0.3924, -0.2330]]]])
26 | )
27 |
28 | assert_all_close(
29 | outputs.pooled_state[[.., 1..3]],
30 | Nx.tensor([[2.2793, -1.3236, -1.0714]]),
31 | atol: 1.0e-3
32 | )
33 | end
34 |
35 | test ":for_image_classification" do
36 | assert {:ok, %{model: model, params: params, spec: spec}} =
37 | Bumblebee.load_model(
38 | {:hf, "hf-internal-testing/tiny-random-ConvNextForImageClassification"}
39 | )
40 |
41 | assert %Bumblebee.Vision.ConvNext{architecture: :for_image_classification} = spec
42 |
43 | inputs = %{
44 | "pixel_values" => Nx.broadcast(0.5, {1, 224, 224, 3})
45 | }
46 |
47 | outputs = Axon.predict(model, params, inputs)
48 |
49 | assert Nx.shape(outputs.logits) == {1, 2}
50 |
51 | assert_all_close(
52 | outputs.logits,
53 | Nx.tensor([[0.0047, -0.1457]])
54 | )
55 | end
56 | end
57 |
--------------------------------------------------------------------------------
/test/bumblebee/vision/deit_featurizer_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.DeitFeaturizerTest do
2 | use ExUnit.Case, async: true
3 |
4 | test "encodes image" do
5 | assert {:ok, featurizer} =
6 | Bumblebee.load_featurizer({:hf, "facebook/deit-base-distilled-patch16-224"})
7 |
8 | assert %Bumblebee.Vision.DeitFeaturizer{} = featurizer
9 |
10 | image = Nx.tensor([[[50], [100]], [[150], [200]]]) |> Nx.broadcast({2, 2, 3})
11 |
12 | inputs = Bumblebee.apply_featurizer(featurizer, image)
13 |
14 | assert Nx.shape(inputs["pixel_values"]) == {1, 224, 224, 3}
15 | end
16 | end
17 |
--------------------------------------------------------------------------------
/test/bumblebee/vision/deit_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.DeitTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-DeiTModel"})
11 |
12 | assert %Bumblebee.Vision.Deit{architecture: :base} = spec
13 |
14 | inputs = %{
15 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3})
16 | }
17 |
18 | outputs = Axon.predict(model, params, inputs)
19 |
20 | assert Nx.shape(outputs.hidden_state) == {1, 227, 32}
21 | assert Nx.shape(outputs.pooled_state) == {1, 32}
22 |
23 | assert_all_close(
24 | outputs.hidden_state[[.., 1..3, 1..3]],
25 | Nx.tensor([
26 | [[-3.0866, 0.2350, 0.2003], [-1.2774, -0.1192, -1.0468], [-1.2774, -0.1192, -1.0468]]
27 | ])
28 | )
29 |
30 | assert_all_close(
31 | outputs.pooled_state[[.., 1..3]],
32 | Nx.tensor([[0.1526, -0.1437, -0.0646]])
33 | )
34 | end
35 |
36 | test ":for_image_classification" do
37 | assert {:ok, %{model: model, params: params, spec: spec}} =
38 | Bumblebee.load_model(
39 | {:hf, "hf-internal-testing/tiny-random-DeiTForImageClassification"}
40 | )
41 |
42 | assert %Bumblebee.Vision.Deit{architecture: :for_image_classification} = spec
43 |
44 | inputs = %{
45 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3})
46 | }
47 |
48 | outputs = Axon.predict(model, params, inputs)
49 |
50 | assert Nx.shape(outputs.logits) == {1, 2}
51 |
52 | assert_all_close(
53 | outputs.logits,
54 | Nx.tensor([[0.0481, 0.1008]])
55 | )
56 | end
57 |
58 | test ":for_image_classification_with_teacher" do
59 | assert {:ok, %{model: model, params: params, spec: spec}} =
60 | Bumblebee.load_model(
61 | {:hf, "hf-internal-testing/tiny-random-DeiTForImageClassificationWithTeacher"}
62 | )
63 |
64 | assert %Bumblebee.Vision.Deit{architecture: :for_image_classification_with_teacher} = spec
65 |
66 | inputs = %{
67 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3})
68 | }
69 |
70 | outputs = Axon.predict(model, params, inputs)
71 |
72 | assert Nx.shape(outputs.logits) == {1, 2}
73 |
74 | assert_all_close(
75 | outputs.logits,
76 | Nx.tensor([[-0.0108, -0.0048]])
77 | )
78 | end
79 |
80 | test ":for_masked_image_modeling" do
81 | assert {:ok, %{model: model, params: params, spec: spec}} =
82 | Bumblebee.load_model(
83 | {:hf, "hf-internal-testing/tiny-random-DeiTForMaskedImageModeling"}
84 | )
85 |
86 | assert %Bumblebee.Vision.Deit{architecture: :for_masked_image_modeling} = spec
87 |
88 | inputs = %{
89 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3})
90 | }
91 |
92 | outputs = Axon.predict(model, params, inputs)
93 |
94 | assert Nx.shape(outputs.pixel_values) == {1, 30, 30, 3}
95 |
96 | assert_all_close(
97 | outputs.pixel_values[[.., 1..2, 1..2, 1..2]],
98 | Nx.tensor([[[[0.1455, 0.1889], [0.0229, 0.0910]], [[-0.0097, -0.1083], [0.0525, -0.0244]]]])
99 | )
100 | end
101 | end
102 |
--------------------------------------------------------------------------------
/test/bumblebee/vision/image_classification_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.ImageClassificationTest do
2 | use ExUnit.Case, async: false
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag serving_test_tags()
7 |
8 | @images_dir Path.expand("../../fixtures/images", __DIR__)
9 |
10 | test "returns top scored labels" do
11 | {:ok, model_info} = Bumblebee.load_model({:hf, "microsoft/resnet-50"})
12 | {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "microsoft/resnet-50"})
13 |
14 | serving = Bumblebee.Vision.ImageClassification.image_classification(model_info, featurizer)
15 |
16 | image = StbImage.read_file!(Path.join(@images_dir, "coco/39769.jpeg"))
17 |
18 | assert %{
19 | predictions: [
20 | %{label: "tiger cat", score: _},
21 | %{label: "tabby, tabby cat", score: _},
22 | %{label: "remote control, remote", score: _},
23 | %{label: "jinrikisha, ricksha, rickshaw", score: _},
24 | %{label: "Egyptian cat", score: _}
25 | ]
26 | } = Nx.Serving.run(serving, image)
27 | end
28 |
29 | test "supports compilation" do
30 | {:ok, model_info} = Bumblebee.load_model({:hf, "microsoft/resnet-50"})
31 | {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "microsoft/resnet-50"})
32 |
33 | serving =
34 | Bumblebee.Vision.ImageClassification.image_classification(model_info, featurizer,
35 | compile: [batch_size: 1],
36 | defn_options: [compiler: EXLA]
37 | )
38 |
39 | image = StbImage.read_file!(Path.join(@images_dir, "coco/39769.jpeg"))
40 |
41 | assert %{
42 | predictions: [
43 | %{label: "tiger cat", score: _},
44 | %{label: "tabby, tabby cat", score: _},
45 | %{label: "remote control, remote", score: _},
46 | %{label: "jinrikisha, ricksha, rickshaw", score: _},
47 | %{label: "Egyptian cat", score: _}
48 | ]
49 | } = Nx.Serving.run(serving, image)
50 | end
51 | end
52 |
--------------------------------------------------------------------------------
/test/bumblebee/vision/image_embedding_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.ImageEmbeddingTest do
2 | use ExUnit.Case, async: false
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag serving_test_tags()
7 | @images_dir Path.expand("../../fixtures/images", __DIR__)
8 |
9 | test "returns embedding for an image" do
10 | {:ok, model_info} =
11 | Bumblebee.load_model({:hf, "openai/clip-vit-base-patch32"},
12 | module: Bumblebee.Vision.ClipVision
13 | )
14 |
15 | {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "openai/clip-vit-base-patch32"})
16 |
17 | serving = Bumblebee.Vision.ImageEmbedding.image_embedding(model_info, featurizer)
18 | image = StbImage.read_file!(Path.join(@images_dir, "coco/39769.jpeg"))
19 |
20 | assert %{embedding: %Nx.Tensor{} = embedding} = Nx.Serving.run(serving, image)
21 | assert Nx.shape(embedding) == {768}
22 |
23 | assert_all_close(
24 | embedding[1..3],
25 | Nx.tensor([0.0978, -0.7233, -0.7707])
26 | )
27 | end
28 |
29 | test "returns normalized embedding for an image" do
30 | {:ok, model_info} =
31 | Bumblebee.load_model({:hf, "openai/clip-vit-base-patch32"},
32 | module: Bumblebee.Vision.ClipVision
33 | )
34 |
35 | {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "openai/clip-vit-base-patch32"})
36 |
37 | options = [
38 | embedding_processor: :l2_norm
39 | ]
40 |
41 | serving = Bumblebee.Vision.ImageEmbedding.image_embedding(model_info, featurizer, options)
42 | image = StbImage.read_file!(Path.join(@images_dir, "coco/39769.jpeg"))
43 |
44 | assert %{embedding: %Nx.Tensor{} = embedding} = Nx.Serving.run(serving, image)
45 | assert Nx.shape(embedding) == {768}
46 |
47 | assert_all_close(
48 | embedding[1..3],
49 | Nx.tensor([0.0036, -0.0269, -0.0286])
50 | )
51 | end
52 | end
53 |
--------------------------------------------------------------------------------
/test/bumblebee/vision/image_to_text_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.ImageToTextTest do
2 | use ExUnit.Case, async: false
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag serving_test_tags()
7 |
8 | @images_dir Path.expand("../../fixtures/images", __DIR__)
9 |
10 | test "generates text describing an image" do
11 | {:ok, blip} = Bumblebee.load_model({:hf, "Salesforce/blip-image-captioning-base"})
12 |
13 | {:ok, featurizer} =
14 | Bumblebee.load_featurizer({:hf, "Salesforce/blip-image-captioning-base"})
15 |
16 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "Salesforce/blip-image-captioning-base"})
17 |
18 | {:ok, generation_config} =
19 | Bumblebee.load_generation_config({:hf, "Salesforce/blip-image-captioning-base"})
20 |
21 | serving =
22 | Bumblebee.Vision.ImageToText.image_to_text(blip, featurizer, tokenizer, generation_config)
23 |
24 | image = StbImage.read_file!(Path.join(@images_dir, "coco/39769.jpeg"))
25 |
26 | assert %{
27 | results: [%{text: "two cats sleeping on a couch"}]
28 | } = Nx.Serving.run(serving, image)
29 | end
30 | end
31 |
--------------------------------------------------------------------------------
/test/bumblebee/vision/resnet_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.ResNetTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-ResNetModel"})
11 |
12 | assert %Bumblebee.Vision.ResNet{architecture: :base} = spec
13 |
14 | inputs = %{
15 | "pixel_values" => Nx.broadcast(0.5, {1, 224, 224, 3})
16 | }
17 |
18 | outputs = Axon.predict(model, params, inputs)
19 |
20 | assert Nx.shape(outputs.hidden_state) == {1, 7, 7, 40}
21 | assert Nx.shape(outputs.pooled_state) == {1, 40}
22 |
23 | assert_all_close(
24 | outputs.hidden_state[[.., 2..3, 2..3, 2..3]],
25 | Nx.tensor([[[[0.0000, 0.9835], [0.0000, 0.9835]], [[0.0000, 0.9835], [0.0000, 0.9835]]]])
26 | )
27 |
28 | assert_all_close(Nx.sum(outputs.hidden_state), Nx.tensor(209.6328))
29 |
30 | assert_all_close(
31 | outputs.pooled_state[[.., 1..3]],
32 | Nx.tensor([[0.0275, 0.0095, 0.8921]])
33 | )
34 |
35 | assert_all_close(Nx.sum(outputs.pooled_state), Nx.tensor(4.2782))
36 | end
37 |
38 | test ":for_image_classification" do
39 | assert {:ok, %{model: model, params: params, spec: spec}} =
40 | Bumblebee.load_model(
41 | {:hf, "hf-internal-testing/tiny-random-ResNetForImageClassification"}
42 | )
43 |
44 | assert %Bumblebee.Vision.ResNet{architecture: :for_image_classification} = spec
45 |
46 | inputs = %{
47 | "pixel_values" => Nx.broadcast(0.5, {1, 224, 224, 3})
48 | }
49 |
50 | outputs = Axon.predict(model, params, inputs)
51 |
52 | assert Nx.shape(outputs.logits) == {1, 3}
53 |
54 | assert_all_close(
55 | outputs.logits,
56 | Nx.tensor([[-0.1053, 0.2160, -0.0331]])
57 | )
58 | end
59 | end
60 |
--------------------------------------------------------------------------------
/test/bumblebee/vision/swin_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.SwinTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-SwinModel"})
11 |
12 | assert %Bumblebee.Vision.Swin{architecture: :base} = spec
13 |
14 | inputs = %{
15 | "pixel_values" => Nx.broadcast(0.5, {1, 32, 32, 3})
16 | }
17 |
18 | outputs = Axon.predict(model, params, inputs)
19 |
20 | assert Nx.shape(outputs.hidden_state) == {1, 16, 64}
21 | assert Nx.shape(outputs.pooled_state) == {1, 64}
22 |
23 | assert_all_close(
24 | outputs.hidden_state[[.., 1..3, 1..3]],
25 | Nx.tensor([
26 | [[-0.4605, 0.9336, -0.5528], [-0.4605, 0.9336, -0.5528], [-0.4605, 0.9336, -0.5528]]
27 | ])
28 | )
29 |
30 | assert_all_close(
31 | outputs.pooled_state[[.., 1..3]],
32 | Nx.tensor([[-0.4605, 0.9336, -0.5528]])
33 | )
34 | end
35 |
36 | test ":for_image_classification" do
37 | assert {:ok, %{model: model, params: params, spec: spec}} =
38 | Bumblebee.load_model(
39 | {:hf, "hf-internal-testing/tiny-random-SwinForImageClassification"}
40 | )
41 |
42 | assert %Bumblebee.Vision.Swin{architecture: :for_image_classification} = spec
43 |
44 | inputs = %{
45 | "pixel_values" => Nx.broadcast(0.5, {1, 32, 32, 3})
46 | }
47 |
48 | outputs = Axon.predict(model, params, inputs)
49 |
50 | assert Nx.shape(outputs.logits) == {1, 2}
51 |
52 | assert_all_close(
53 | outputs.logits,
54 | Nx.tensor([[0.0361, 0.1352]])
55 | )
56 | end
57 | end
58 |
--------------------------------------------------------------------------------
/test/bumblebee/vision/vit_featurizer_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.VitFeaturizerTest do
2 | use ExUnit.Case, async: true
3 |
4 | test "encodes image" do
5 | assert {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "google/vit-base-patch16-224"})
6 |
7 | assert %Bumblebee.Vision.VitFeaturizer{} = featurizer
8 |
9 | image = Nx.tensor([[[50], [100]], [[150], [200]]]) |> Nx.broadcast({2, 2, 3})
10 |
11 | inputs = Bumblebee.apply_featurizer(featurizer, image)
12 |
13 | assert Nx.shape(inputs["pixel_values"]) == {1, 224, 224, 3}
14 | end
15 | end
16 |
--------------------------------------------------------------------------------
/test/bumblebee/vision/vit_test.exs:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.Vision.VitTest do
2 | use ExUnit.Case, async: true
3 |
4 | import Bumblebee.TestHelpers
5 |
6 | @moduletag model_test_tags()
7 |
8 | test ":base" do
9 | assert {:ok, %{model: model, params: params, spec: spec}} =
10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-ViTModel"})
11 |
12 | assert %Bumblebee.Vision.Vit{architecture: :base} = spec
13 |
14 | inputs = %{
15 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3})
16 | }
17 |
18 | outputs = Axon.predict(model, params, inputs)
19 |
20 | assert Nx.shape(outputs.hidden_state) == {1, 226, 32}
21 | assert Nx.shape(outputs.pooled_state) == {1, 32}
22 |
23 | assert_all_close(
24 | outputs.hidden_state[[.., 1..3, 1..3]],
25 | Nx.tensor([
26 | [[-0.2075, 2.7865, 0.2361], [-0.3014, 2.5312, -0.6127], [-0.3460, 2.8741, 0.1988]]
27 | ])
28 | )
29 |
30 | assert_all_close(
31 | outputs.pooled_state[[.., 1..3]],
32 | Nx.tensor([[-0.0244, -0.0515, -0.1584]])
33 | )
34 | end
35 |
36 | test ":for_image_classification" do
37 | assert {:ok, %{model: model, params: params, spec: spec}} =
38 | Bumblebee.load_model(
39 | {:hf, "hf-internal-testing/tiny-random-ViTForImageClassification"}
40 | )
41 |
42 | assert %Bumblebee.Vision.Vit{architecture: :for_image_classification} = spec
43 |
44 | inputs = %{
45 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3})
46 | }
47 |
48 | outputs = Axon.predict(model, params, inputs)
49 |
50 | assert Nx.shape(outputs.logits) == {1, 2}
51 |
52 | assert_all_close(
53 | outputs.logits,
54 | Nx.tensor([[-0.1596, 0.1818]])
55 | )
56 | end
57 |
58 | test ":for_masked_image_modeling" do
59 | assert {:ok, %{model: model, params: params, spec: spec}} =
60 | Bumblebee.load_model(
61 | {:hf, "hf-internal-testing/tiny-random-ViTForMaskedImageModeling"}
62 | )
63 |
64 | assert %Bumblebee.Vision.Vit{architecture: :for_masked_image_modeling} = spec
65 |
66 | inputs = %{
67 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3})
68 | }
69 |
70 | outputs = Axon.predict(model, params, inputs)
71 |
72 | assert Nx.shape(outputs.pixel_values) == {1, 30, 30, 3}
73 |
74 | assert_all_close(
75 | outputs.pixel_values[[.., 1..2, 1..2, 1..2]],
76 | Nx.tensor([[[[0.0752, 0.0548], [-0.0192, -0.0216]], [[-0.0252, 0.0728], [0.0232, -0.1687]]]])
77 | )
78 | end
79 | end
80 |
--------------------------------------------------------------------------------
/test/bumblebee_test.exs:
--------------------------------------------------------------------------------
1 | defmodule BumblebeeTest do
2 | use ExUnit.Case, async: true
3 |
4 | describe "load_model/2" do
5 | test "raises an error on invalid repository type" do
6 | assert_raise ArgumentError,
7 | ~s/expected repository to be either {:hf, repository_id}, {:hf, repository_id, options} or {:local, directory}, got: "repo-id"/,
8 | fn ->
9 | Bumblebee.load_model("repo-id")
10 | end
11 | end
12 |
13 | @tag :capture_log
14 | test "supports sharded params" do
15 | assert {:ok, %{params: params}} =
16 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"})
17 |
18 | # PyTorch format
19 |
20 | assert {:ok, %{params: sharded_params}} =
21 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-GPT2Model-sharded"})
22 |
23 | assert Enum.sort(Map.keys(params)) == Enum.sort(Map.keys(sharded_params))
24 |
25 | # Safetensors
26 |
27 | assert {:ok, %{params: sharded_params}} =
28 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-GPT2Model-sharded"},
29 | params_filename: "model.safetensors"
30 | )
31 |
32 | assert Enum.sort(Map.keys(params)) == Enum.sort(Map.keys(sharded_params))
33 | end
34 |
35 | test "supports .safetensors params" do
36 | assert {:ok, %{params: params}} =
37 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"})
38 |
39 | assert {:ok, %{params: safetensors_params}} =
40 | Bumblebee.load_model(
41 | {:hf, "bumblebee-testing/tiny-random-GPT2Model-safetensors-only"}
42 | )
43 |
44 | assert Enum.sort(Map.keys(params)) == Enum.sort(Map.keys(safetensors_params))
45 | end
46 |
47 | test "supports params variants" do
48 | assert {:ok, %{params: params}} =
49 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-bert-variant"},
50 | params_variant: "v2"
51 | )
52 |
53 | assert {:ok, %{params: sharded_params}} =
54 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-bert-variant-sharded"},
55 | params_variant: "v2"
56 | )
57 |
58 | assert Enum.sort(Map.keys(params)) == Enum.sort(Map.keys(sharded_params))
59 |
60 | assert_raise ArgumentError,
61 | ~s/parameters variant "v3" not found, available variants: "v2"/,
62 | fn ->
63 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-bert-variant"},
64 | params_variant: "v3"
65 | )
66 | end
67 |
68 | assert_raise ArgumentError,
69 | ~s/parameters variant "v3" not found, available variants: "v2"/,
70 | fn ->
71 | Bumblebee.load_model(
72 | {:hf, "hf-internal-testing/tiny-random-bert-variant-sharded"},
73 | params_variant: "v3"
74 | )
75 | end
76 | end
77 |
78 | test "passing :type casts params accordingly" do
79 | assert {:ok, %{params: %Axon.ModelState{data: params}}} =
80 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"},
81 | type: :bf16
82 | )
83 |
84 | assert Nx.type(params["decoder.blocks.0.ffn.output"]["kernel"]) == {:bf, 16}
85 | assert Nx.type(params["decoder.blocks.0.ffn.output"]["bias"]) == {:bf, 16}
86 |
87 | assert {:ok, %{params: %Axon.ModelState{data: params}}} =
88 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"},
89 | type: Axon.MixedPrecision.create_policy(params: :f16)
90 | )
91 |
92 | assert Nx.type(params["decoder.blocks.0.ffn.output"]["kernel"]) == {:f, 16}
93 | assert Nx.type(params["decoder.blocks.0.ffn.output"]["bias"]) == {:f, 16}
94 | end
95 | end
96 | end
97 |
--------------------------------------------------------------------------------
/test/fixtures/audio/common_voice/a6c7706a220eeea7ee3687c1122fe7ac17962d2449d25b6db37cc41cdaace442683e11945b6f581e73941c3083cd4eecfafc938840459cd8c571dae7774ee687.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/audio/common_voice/a6c7706a220eeea7ee3687c1122fe7ac17962d2449d25b6db37cc41cdaace442683e11945b6f581e73941c3083cd4eecfafc938840459cd8c571dae7774ee687.wav
--------------------------------------------------------------------------------
/test/fixtures/audio/common_voice/a6c7706a220eeea7ee3687c1122fe7ac17962d2449d25b6db37cc41cdaace442683e11945b6f581e73941c3083cd4eecfafc938840459cd8c571dae7774ee687_pcm_f32le_16000.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/audio/common_voice/a6c7706a220eeea7ee3687c1122fe7ac17962d2449d25b6db37cc41cdaace442683e11945b6f581e73941c3083cd4eecfafc938840459cd8c571dae7774ee687_pcm_f32le_16000.bin
--------------------------------------------------------------------------------
/test/fixtures/audio/common_voice/info.md:
--------------------------------------------------------------------------------
1 | Source: https://huggingface.co/datasets/common_voice
2 |
--------------------------------------------------------------------------------
/test/fixtures/audio/generate.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd "$(dirname "$0")"
4 |
5 | for source in $(ls **/*.{wav,mp3}); do
6 | name="${source%.*}"
7 | ffmpeg -i $source -ac 1 -ar 16000 -f f32le -hide_banner -loglevel quiet "${name}_pcm_f32le_16000.bin"
8 | done
9 |
--------------------------------------------------------------------------------
/test/fixtures/audio/librivox/46s.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/audio/librivox/46s.mp3
--------------------------------------------------------------------------------
/test/fixtures/audio/librivox/46s_pcm_f32le_16000.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/audio/librivox/46s_pcm_f32le_16000.bin
--------------------------------------------------------------------------------
/test/fixtures/audio/librivox/info.md:
--------------------------------------------------------------------------------
1 | Source: https://librivox.org/the-book-of-irish-poetry-by-various
2 |
--------------------------------------------------------------------------------
/test/fixtures/images/coco/39769.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/images/coco/39769.jpeg
--------------------------------------------------------------------------------
/test/fixtures/images/coco/info.md:
--------------------------------------------------------------------------------
1 | Source: https://cocodataset.org/#explore?id=39769
2 |
--------------------------------------------------------------------------------
/test/fixtures/pytorch/generate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from collections import OrderedDict
5 |
6 |
7 | def save(name, fmt, data):
8 | use_zip = fmt == "zip"
9 | dirname = os.path.dirname(__file__)
10 | path = os.path.join(dirname, f"{name}.{fmt}.pt")
11 | torch.save(data, path, _use_new_zipfile_serialization=use_zip)
12 |
13 |
14 | for fmt in ["zip", "legacy"]:
15 | save("tensors", fmt, [
16 | torch.tensor([-1.0, 1.0], dtype=torch.float64),
17 | torch.tensor([-1.0, 1.0], dtype=torch.float32),
18 | torch.tensor([-1.0, 1.0], dtype=torch.float16),
19 | torch.tensor([-1, 1], dtype=torch.int64),
20 | torch.tensor([-1, 1], dtype=torch.int32),
21 | torch.tensor([-1, 1], dtype=torch.int16),
22 | torch.tensor([-1, 1], dtype=torch.int8),
23 | torch.tensor([0, 1], dtype=torch.uint8),
24 | torch.tensor([0, 1, 0, 1], dtype=torch.bool),
25 | torch.tensor([-1.0, 1.0], dtype=torch.bfloat16),
26 | torch.tensor([1 - 1j, 1 + 1j], dtype=torch.complex128),
27 | torch.tensor([1 - 1j, 1 + 1j], dtype=torch.complex64)
28 | ])
29 |
30 | save("numpy_arrays", fmt, [
31 | np.array([-1.0, 1.0], dtype=np.float64),
32 | np.array([-1.0, 1.0], dtype=np.float32),
33 | np.array([-1.0, 1.0], dtype=np.float16),
34 | np.array([-1, 1], dtype=np.int64),
35 | np.array([-1, 1], dtype=np.int32),
36 | np.array([-1, 1], dtype=np.int16),
37 | np.array([-1, 1], dtype=np.int8),
38 | np.array([0, 1], dtype=np.uint64),
39 | np.array([0, 1], dtype=np.uint32),
40 | np.array([0, 1], dtype=np.uint16),
41 | np.array([0, 1], dtype=np.uint8),
42 | np.array([0, 1], dtype=np.bool_),
43 | np.array([1 - 1j, 1 + 1j], dtype=np.complex128),
44 | np.array([1 - 1j, 1 + 1j], dtype=np.complex64)
45 | ])
46 |
47 | save("ordered_dict", fmt, OrderedDict([("x", 1), ("y", 2)]))
48 |
49 | transposed_tensor = torch.tensor(
50 | [[[1, 1], [2, 2], [3, 3]], [[4, 4], [5, 5], [6, 6]]], dtype=torch.int64).permute(2, 0, 1)
51 | save("noncontiguous_tensor", fmt, transposed_tensor)
52 |
53 | transposed_array = np.transpose(np.array([[1, 2, 3], [4, 5, 6]]))
54 | save("noncontiguous_numpy_array", fmt, transposed_array)
55 |
56 | # Model parameters
57 |
58 | save("state_dict_base", "zip", OrderedDict([
59 | ("conv.weight", torch.ones(2, 3, 2, 2)),
60 | ("conv.bias", torch.zeros(2)),
61 | ]))
62 |
63 | save("state_dict_full", "zip", OrderedDict([
64 | ("base.conv.weight", torch.ones(2, 3, 2, 2)),
65 | ("base.conv.bias", torch.zeros(2)),
66 | # Unexpected shape
67 | ("classifier.layers.0.weight", torch.ones(1, 1)),
68 | ("classifier.layers.0.bias", torch.zeros(1)),
69 | # Missing
70 | # "classifier.layers.1.weight"
71 | # "classifier.layers.1.bias"
72 | # Extra
73 | ("extra.weight", torch.ones(1))
74 | ]))
75 |
--------------------------------------------------------------------------------
/test/fixtures/pytorch/noncontiguous_numpy_array.legacy.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/noncontiguous_numpy_array.legacy.pt
--------------------------------------------------------------------------------
/test/fixtures/pytorch/noncontiguous_numpy_array.zip.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/noncontiguous_numpy_array.zip.pt
--------------------------------------------------------------------------------
/test/fixtures/pytorch/noncontiguous_tensor.legacy.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/noncontiguous_tensor.legacy.pt
--------------------------------------------------------------------------------
/test/fixtures/pytorch/noncontiguous_tensor.zip.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/noncontiguous_tensor.zip.pt
--------------------------------------------------------------------------------
/test/fixtures/pytorch/numpy_arrays.legacy.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/numpy_arrays.legacy.pt
--------------------------------------------------------------------------------
/test/fixtures/pytorch/numpy_arrays.zip.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/numpy_arrays.zip.pt
--------------------------------------------------------------------------------
/test/fixtures/pytorch/ordered_dict.legacy.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/ordered_dict.legacy.pt
--------------------------------------------------------------------------------
/test/fixtures/pytorch/ordered_dict.zip.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/ordered_dict.zip.pt
--------------------------------------------------------------------------------
/test/fixtures/pytorch/state_dict_base.zip.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/state_dict_base.zip.pt
--------------------------------------------------------------------------------
/test/fixtures/pytorch/state_dict_full.zip.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/state_dict_full.zip.pt
--------------------------------------------------------------------------------
/test/fixtures/pytorch/storage_view.legacy.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/storage_view.legacy.pt
--------------------------------------------------------------------------------
/test/fixtures/pytorch/tensors.legacy.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/tensors.legacy.pt
--------------------------------------------------------------------------------
/test/fixtures/pytorch/tensors.zip.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/tensors.zip.pt
--------------------------------------------------------------------------------
/test/support/test_helpers.ex:
--------------------------------------------------------------------------------
1 | defmodule Bumblebee.TestHelpers do
2 | @moduledoc false
3 |
4 | import ExUnit.Assertions
5 |
6 | defmacro assert_equal(left, right) do
7 | # Assert against binary backend tensors to show diff on failure
8 | quote do
9 | left = unquote(left) |> to_binary_backend()
10 | right = unquote(right) |> Nx.as_type(Nx.type(left)) |> to_binary_backend()
11 | assert left == right
12 | end
13 | end
14 |
15 | def to_binary_backend(tensor) do
16 | Nx.backend_copy(tensor, Nx.BinaryBackend)
17 | end
18 |
19 | def assert_all_close(left, right, opts \\ []) do
20 | atol = opts[:atol] || 1.0e-4
21 | rtol = opts[:rtol] || 1.0e-4
22 |
23 | equals =
24 | left
25 | |> Nx.all_close(right, atol: atol, rtol: rtol)
26 | |> Nx.backend_transfer(Nx.BinaryBackend)
27 |
28 | if equals != Nx.tensor(1, type: {:u, 8}, backend: Nx.BinaryBackend) do
29 | flunk("""
30 | expected
31 |
32 | #{inspect(left)}
33 |
34 | to be within tolerance of
35 |
36 | #{inspect(right)}
37 | """)
38 | end
39 | end
40 |
41 | def model_test_tags() do
42 | [model: true, capture_log: true, timeout: 60_000]
43 | end
44 |
45 | def serving_test_tags() do
46 | [serving: true, slow: true, capture_log: true, timeout: 600_000]
47 | end
48 |
49 | def scheduler_loop(scheduler, num_steps) do
50 | sample = dummy_sample()
51 |
52 | {state, timesteps} =
53 | Bumblebee.scheduler_init(scheduler, num_steps, Nx.to_template(sample), Nx.Random.key(0))
54 |
55 | {_state, sample} =
56 | for i <- 0..(Nx.size(timesteps) - 1), reduce: {state, sample} do
57 | {state, sample} ->
58 | prediction = dummy_model(sample, timesteps[i])
59 | Bumblebee.scheduler_step(scheduler, state, sample, prediction)
60 | end
61 |
62 | sample
63 | end
64 |
65 | def scheduler_timesteps(scheduler, num_steps) do
66 | sample = dummy_sample()
67 |
68 | {_state, timesteps} =
69 | Bumblebee.scheduler_init(scheduler, num_steps, Nx.to_template(sample), Nx.Random.key(0))
70 |
71 | timesteps
72 | end
73 |
74 | defp dummy_sample() do
75 | shape = {_height = 8, _width = 8, _channels = 4}
76 | sample = Nx.iota(shape)
77 | Nx.divide(sample, Nx.size(sample))
78 | end
79 |
80 | defp dummy_model(sample, timestep) do
81 | sample
82 | |> Nx.multiply(timestep)
83 | |> Nx.divide(Nx.add(timestep, 1))
84 | end
85 | end
86 |
--------------------------------------------------------------------------------
/test/test_helper.exs:
--------------------------------------------------------------------------------
1 | Application.put_env(:bumblebee, :progress_bar_enabled, false)
2 |
3 | client = EXLA.Client.fetch!(:host)
4 |
5 | exclude_multi_device = if client.device_count > 1, do: [], else: [:multi_device]
6 |
7 | if client.device_count == 1 and System.schedulers_online() > 1 do
8 | IO.puts(
9 | "To run multi-device tests: XLA_FLAGS=--xla_force_host_platform_device_count=2 mix test"
10 | )
11 | end
12 |
13 | Application.put_env(:exla, :clients,
14 | host: [platform: :host],
15 | cuda: [platform: :cuda],
16 | rocm: [platform: :rocm],
17 | tpu: [platform: :tpu],
18 | other_host: [platform: :host, automatic_transfers: false]
19 | )
20 |
21 | Application.put_env(:exla, :preferred_clients, [:tpu, :cuda, :rocm, :other_host, :host])
22 |
23 | Application.put_env(:nx, :default_backend, {EXLA.Backend, client: :host})
24 |
25 | if System.fetch_env("BUMBLEBEE_OFFLINE") == :error do
26 | IO.puts("To run tests without hitting the network: BUMBLEBEE_OFFLINE=true mix test")
27 | end
28 |
29 | ExUnit.start(exclude: [:slow] ++ exclude_multi_device)
30 |
--------------------------------------------------------------------------------