├── .formatter.exs ├── .github ├── images │ ├── background.jpg │ ├── kino_bumblebee_token_classification.png │ └── phx_image_classification.png └── workflows │ ├── nightly.yml │ └── test.yaml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── config └── config.exs ├── examples └── phoenix │ ├── README.md │ ├── image_classification.exs │ ├── speech_to_text.exs │ └── text_classification.exs ├── lib ├── bumblebee.ex └── bumblebee │ ├── application.ex │ ├── audio.ex │ ├── audio │ ├── speech_to_text_whisper.ex │ ├── whisper.ex │ └── whisper_featurizer.ex │ ├── configurable.ex │ ├── conversion │ ├── pytorch_loader.ex │ ├── pytorch_loader │ │ └── file_tensor.ex │ └── pytorch_params.ex │ ├── diffusion │ ├── controlnet.ex │ ├── ddim_scheduler.ex │ ├── layers.ex │ ├── layers │ │ └── unet.ex │ ├── lcm_scheduler.ex │ ├── pndm_scheduler.ex │ ├── scheduler_utils.ex │ ├── stable_diffusion.ex │ ├── stable_diffusion │ │ └── safety_checker.ex │ ├── stable_diffusion_controlnet.ex │ ├── unet_2d_conditional.ex │ └── vae_kl.ex │ ├── featurizer.ex │ ├── huggingface │ ├── hub.ex │ └── transformers │ │ ├── config.ex │ │ ├── model.ex │ │ └── utils.ex │ ├── layers.ex │ ├── layers │ ├── decoder.ex │ └── transformer.ex │ ├── model_spec.ex │ ├── multimodal │ ├── blip.ex │ ├── clip.ex │ └── layout_lm.ex │ ├── scheduler.ex │ ├── shared.ex │ ├── shared │ └── converters.ex │ ├── text.ex │ ├── text │ ├── albert.ex │ ├── bart.ex │ ├── bert.ex │ ├── blenderbot.ex │ ├── blip_text.ex │ ├── clip_text.ex │ ├── distilbert.ex │ ├── fill_mask.ex │ ├── gemma.ex │ ├── generation.ex │ ├── generation │ │ └── logits_processing.ex │ ├── generation_config.ex │ ├── gpt2.ex │ ├── gpt_big_code.ex │ ├── gpt_neo_x.ex │ ├── llama.ex │ ├── m2m100.ex │ ├── mbart.ex │ ├── mistral.ex │ ├── phi.ex │ ├── phi3.ex │ ├── pre_trained_tokenizer.ex │ ├── question_answering.ex │ ├── roberta.ex │ ├── t5.ex │ ├── text_classification.ex │ ├── text_embedding.ex │ ├── text_generation.ex │ ├── token_classification.ex │ ├── translation.ex │ ├── whisper_generation_config.ex │ └── zero_shot_classification.ex │ ├── tokenizer.ex │ ├── utils.ex │ ├── utils │ ├── axon.ex │ ├── http.ex │ ├── image.ex │ ├── model.ex │ └── nx.ex │ ├── vision.ex │ └── vision │ ├── bit_featurizer.ex │ ├── blip_featurizer.ex │ ├── blip_vision.ex │ ├── clip_featurizer.ex │ ├── clip_vision.ex │ ├── convnext.ex │ ├── convnext_featurizer.ex │ ├── deit.ex │ ├── deit_featurizer.ex │ ├── dino_v2.ex │ ├── image_classification.ex │ ├── image_embedding.ex │ ├── image_to_text.ex │ ├── resnet.ex │ ├── swin.ex │ ├── vit.ex │ └── vit_featurizer.ex ├── mix.exs ├── mix.lock ├── notebooks ├── examples.livemd ├── fine_tuning.livemd ├── llms.livemd ├── llms_rag.livemd └── stable_diffusion.livemd └── test ├── bumblebee ├── audio │ ├── speech_to_text_whisper_test.exs │ ├── whisper_featurizer_test.exs │ └── whisper_test.exs ├── conversion │ ├── pytorch_loader_test.exs │ └── pytorch_params_test.exs ├── diffusion │ ├── controlnet_test.exs │ ├── ddim_scheduler_test.exs │ ├── lcm_scheduler_test.exs │ ├── pndm_scheduler_test.exs │ ├── stable_diffusion │ │ └── safety_checker_test.exs │ ├── stable_diffusion_controlnet_test.exs │ ├── stable_diffusion_test.exs │ ├── unet_2d_conditional_test.exs │ └── vae_kl_test.exs ├── huggingface │ └── hub_test.exs ├── multimodal │ ├── blip_test.exs │ ├── clip_test.exs │ └── layout_lm_test.exs ├── shared_test.exs ├── text │ ├── albert_test.exs │ ├── bart_test.exs │ ├── bert_test.exs │ ├── blenderbot_test.exs │ ├── blip_text_test.exs │ ├── camembert_test.exs │ ├── clip_text_test.exs │ ├── distilbert_test.exs │ ├── fill_mask_test.exs │ ├── gemma_test.exs │ ├── generation │ │ └── logits_processing_test.exs │ ├── generation_config_test.exs │ ├── generation_test.exs │ ├── gpt2_test.exs │ ├── gpt_big_code_test.exs │ ├── gpt_neo_x_test.exs │ ├── llama_test.exs │ ├── m2m100_test.exs │ ├── mbart_test.exs │ ├── mistral_test.exs │ ├── nllb_test.exs │ ├── phi3_test.exs │ ├── phi_test.exs │ ├── pre_trained_tokenizer_test.exs │ ├── question_answering_test.exs │ ├── roberta_test.exs │ ├── t5_test.exs │ ├── text_classification_test.exs │ ├── text_embedding_test.exs │ ├── text_generation_test.exs │ ├── token_classification_test.exs │ ├── translation_test.exs │ ├── xlm_roberta_test.exs │ └── zero_shot_classification_test.exs ├── utils │ ├── image_test.exs │ └── nx_test.exs └── vision │ ├── bit_featurizer_test.exs │ ├── blip_featurizer_test.exs │ ├── blip_vision_test.exs │ ├── clip_featurizer_test.exs │ ├── clip_vision_test.exs │ ├── convnext_featurizer_test.exs │ ├── convnext_test.exs │ ├── deit_featurizer_test.exs │ ├── deit_test.exs │ ├── dino_v2_test.exs │ ├── image_classification_test.exs │ ├── image_embedding_test.exs │ ├── image_to_text_test.exs │ ├── resnet_test.exs │ ├── swin_test.exs │ ├── vit_featurizer_test.exs │ └── vit_test.exs ├── bumblebee_test.exs ├── fixtures ├── audio │ ├── common_voice │ │ ├── a6c7706a220eeea7ee3687c1122fe7ac17962d2449d25b6db37cc41cdaace442683e11945b6f581e73941c3083cd4eecfafc938840459cd8c571dae7774ee687.wav │ │ ├── a6c7706a220eeea7ee3687c1122fe7ac17962d2449d25b6db37cc41cdaace442683e11945b6f581e73941c3083cd4eecfafc938840459cd8c571dae7774ee687_pcm_f32le_16000.bin │ │ └── info.md │ ├── generate.sh │ └── librivox │ │ ├── 46s.mp3 │ │ ├── 46s_pcm_f32le_16000.bin │ │ └── info.md ├── images │ └── coco │ │ ├── 39769.jpeg │ │ └── info.md └── pytorch │ ├── generate.py │ ├── noncontiguous_numpy_array.legacy.pt │ ├── noncontiguous_numpy_array.zip.pt │ ├── noncontiguous_tensor.legacy.pt │ ├── noncontiguous_tensor.zip.pt │ ├── numpy_arrays.legacy.pt │ ├── numpy_arrays.zip.pt │ ├── ordered_dict.legacy.pt │ ├── ordered_dict.zip.pt │ ├── state_dict_base.zip.pt │ ├── state_dict_full.zip.pt │ ├── storage_view.legacy.pt │ ├── tensors.legacy.pt │ └── tensors.zip.pt ├── support └── test_helpers.ex └── test_helper.exs /.formatter.exs: -------------------------------------------------------------------------------- 1 | [ 2 | import_deps: [:nx], 3 | inputs: ["{mix,.formatter}.exs", "{config,lib,test,examples}/**/*.{ex,exs}"] 4 | ] 5 | -------------------------------------------------------------------------------- /.github/images/background.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/.github/images/background.jpg -------------------------------------------------------------------------------- /.github/images/kino_bumblebee_token_classification.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/.github/images/kino_bumblebee_token_classification.png -------------------------------------------------------------------------------- /.github/images/phx_image_classification.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/.github/images/phx_image_classification.png -------------------------------------------------------------------------------- /.github/workflows/nightly.yml: -------------------------------------------------------------------------------- 1 | name: Nightly 2 | on: 3 | schedule: 4 | - cron: "0 0 * * *" 5 | workflow_dispatch: 6 | env: 7 | elixir: 1.14.0 8 | otp: 24.0 9 | jobs: 10 | test: 11 | name: Test all 12 | runs-on: ubuntu-latest 13 | env: 14 | MIX_ENV: test 15 | XLA_CACHE_DIR: ${{ github.workspace }}/cache/xla 16 | LIBTORCH_DIR: ${{ github.workspace }}/cache/libtorch 17 | steps: 18 | - uses: actions/checkout@v3 19 | - uses: erlef/setup-beam@v1 20 | with: 21 | otp-version: ${{env.otp}} 22 | elixir-version: ${{env.elixir}} 23 | - uses: actions/cache@v3 24 | with: 25 | path: | 26 | deps 27 | _build 28 | cache 29 | key: ${{ runner.os }}-mix-${{env.elixir}}-${{env.otp}}-${{ hashFiles('**/mix.lock') }} 30 | restore-keys: | 31 | ${{ runner.os }}-mix- 32 | - run: mix deps.get 33 | - run: mix test --include slow 34 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | on: 3 | pull_request: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | main: 10 | name: "main (${{ matrix.pair.elixir }}, ${{ matrix.pair.otp }})" 11 | runs-on: ubuntu-latest 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | include: 16 | - pair: 17 | elixir: "1.15.4" 18 | otp: "26.0.2" 19 | lint: true 20 | slow: true 21 | - pair: 22 | elixir: "1.14.5" 23 | otp: "25.3.2.2" 24 | env: 25 | MIX_ENV: test 26 | XLA_CACHE_DIR: ${{ github.workspace }}/cache/xla 27 | LIBTORCH_DIR: ${{ github.workspace }}/cache/torch 28 | steps: 29 | - uses: actions/checkout@v3 30 | with: 31 | # We need the previous commit for git diff later 32 | fetch-depth: 2 33 | - uses: erlef/setup-beam@v1 34 | with: 35 | otp-version: ${{ matrix.pair.otp }} 36 | elixir-version: ${{ matrix.pair.elixir }} 37 | - uses: actions/cache@v3 38 | with: 39 | path: | 40 | deps 41 | _build 42 | cache 43 | key: ${{ runner.os }}-mix-${{ matrix.pair.elixir }}-${{ matrix.pair.otp }}-${{ hashFiles('**/mix.lock') }} 44 | - run: mix deps.get 45 | - run: mix format --check-formatted 46 | if: ${{ matrix.lint }} 47 | - run: mix deps.unlock --check-unused 48 | if: ${{ matrix.lint }} 49 | - run: mix deps.compile 50 | - run: mix compile --warnings-as-errors 51 | if: ${{ matrix.lint }} 52 | - name: Restore bumblebee cache 53 | id: cache-bumblebee-restore 54 | uses: actions/cache/restore@v3 55 | with: 56 | path: bumblebee_cache 57 | key: ${{ runner.os }}-bumblebee-cache-${{ matrix.pair.elixir }}-${{ matrix.pair.otp }} 58 | - run: mix test 59 | env: 60 | BUMBLEBEE_CACHE_DIR: ${{ github.workspace }}/bumblebee_cache 61 | - name: Save bumblebee cache 62 | id: cache-bumblebee-save 63 | uses: actions/cache/save@v3 64 | with: 65 | path: bumblebee_cache 66 | key: ${{ steps.cache-bumblebee-restore.outputs.cache-primary-key }} 67 | - name: Diff tests 68 | run: | 69 | changed_tests="$(git diff --name-only --diff-filter=AMRC HEAD^1 'test/**/*_test.exs' | tr '\n' ' ')" 70 | echo "Changed test files: $changed_tests" 71 | echo "CHANGED_TESTS=$changed_tests" >> $GITHUB_ENV 72 | - name: Changed slow tests 73 | # mix test exits with a non-zero code if there are no matching tests, 74 | # so we make sure we fail only when the test suite fails 75 | run: mix test test/bumblebee_test.exs --only slow --exit-status 100 ${{ env.CHANGED_TESTS }} || [ $? -ne 100 ] 76 | if: ${{ matrix.slow && env.CHANGED_TESTS != '' }} 77 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # The directory Mix will write compiled artifacts to. 2 | /_build/ 3 | 4 | # If you run "mix test --cover", coverage assets end up here. 5 | /cover/ 6 | 7 | # The directory Mix downloads your dependencies sources to. 8 | /deps/ 9 | 10 | # Where third-party dependencies like ExDoc output generated docs. 11 | /doc/ 12 | 13 | # Ignore .fetch files in case you like to edit your project deps locally. 14 | /.fetch 15 | 16 | # If the VM crashes, it generates a dump, let's ignore it too. 17 | erl_crash.dump 18 | 19 | # Also ignore archive artifacts (built via "mix archive.build"). 20 | *.ez 21 | 22 | # Ignore package tarball (built via "mix hex.build"). 23 | bumblebee-*.tar 24 | 25 | # Temporary files, for example, from tests. 26 | /tmp/ 27 | -------------------------------------------------------------------------------- /config/config.exs: -------------------------------------------------------------------------------- 1 | import Config 2 | 3 | config :exla, :add_backend_on_inspect, config_env() != :test 4 | -------------------------------------------------------------------------------- /examples/phoenix/text_classification.exs: -------------------------------------------------------------------------------- 1 | Mix.install([ 2 | {:phoenix_playground, "~> 0.1.7"}, 3 | {:bumblebee, "~> 0.6.0"}, 4 | {:nx, "~> 0.9.0"}, 5 | {:exla, "~> 0.9.0"} 6 | ]) 7 | 8 | Application.put_env(:nx, :default_backend, EXLA.Backend) 9 | 10 | defmodule DemoLive do 11 | use Phoenix.LiveView 12 | 13 | @impl true 14 | def mount(_params, _session, socket) do 15 | {:ok, assign(socket, text: "", label: nil)} 16 | end 17 | 18 | @impl true 19 | def render(assigns) do 20 | ~H""" 21 | 23 | 24 |
25 |
26 |
27 | 33 | 40 |
41 |
42 | Emotion: 43 | <.async_result :let={label} :if={@label} assign={@label}> 44 | <:loading> 45 | <.spinner /> 46 | 47 | <:failed :let={_reason}> 48 | Oops, something went wrong! 49 | 50 | <%= label %> 51 | 52 |
53 |
54 |
55 | """ 56 | end 57 | 58 | defp spinner(assigns) do 59 | ~H""" 60 | 66 | 70 | 74 | 75 | """ 76 | end 77 | 78 | @impl true 79 | def handle_event("predict", %{"text" => text}, socket) do 80 | socket = 81 | socket 82 | |> assign(:text, text) 83 | # Discard previous label so we show the loading state once more 84 | |> assign(:label, nil) 85 | |> assign_async(:label, fn -> 86 | output = Nx.Serving.batched_run(Demo.Serving, text) 87 | %{predictions: [%{label: label}]} = output 88 | {:ok, %{label: label}} 89 | end) 90 | 91 | {:noreply, socket} 92 | end 93 | end 94 | 95 | # Application startup 96 | 97 | {:ok, model_info} = Bumblebee.load_model({:hf, "finiteautomata/bertweet-base-emotion-analysis"}) 98 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "vinai/bertweet-base"}) 99 | 100 | serving = 101 | Bumblebee.Text.text_classification(model_info, tokenizer, 102 | top_k: 1, 103 | compile: [batch_size: 4, sequence_length: 100], 104 | defn_options: [ 105 | compiler: EXLA, 106 | cache: Path.join(System.tmp_dir!(), "bumblebee_examples/text_classification") 107 | ] 108 | ) 109 | 110 | Nx.Serving.start_link(serving: serving, name: Demo.Serving, batch_timeout: 100) 111 | 112 | PhoenixPlayground.start(live: DemoLive, port: 8080) 113 | -------------------------------------------------------------------------------- /lib/bumblebee/application.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Application do 2 | @moduledoc false 3 | 4 | use Application 5 | 6 | @impl true 7 | def start(_type, _args) do 8 | Bumblebee.Utils.HTTP.start_inets_profile() 9 | 10 | children = [] 11 | opts = [strategy: :one_for_one, name: Bumblebee.Supervisor] 12 | Supervisor.start_link(children, opts) 13 | end 14 | 15 | @impl true 16 | def stop(_state) do 17 | Bumblebee.Utils.HTTP.stop_inets_profile() 18 | end 19 | end 20 | -------------------------------------------------------------------------------- /lib/bumblebee/audio/whisper_featurizer.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Audio.WhisperFeaturizer do 2 | alias Bumblebee.Shared 3 | 4 | import Nx.Defn 5 | 6 | options = [ 7 | feature_size: [ 8 | default: 80, 9 | doc: "the dimension of the extracted features. This corresponds to the number of Mel bins" 10 | ], 11 | sampling_rate: [ 12 | default: 16_000, 13 | doc: "the sampling rate at which the audio files should be digitally expressed in Hertz" 14 | ], 15 | num_seconds: [ 16 | default: 30, 17 | doc: """ 18 | the maximum duration of the audio sequence. This implies that the maximum length of the 19 | input sequence is `:num_seconds` * `:sampling_rate` 20 | """ 21 | ], 22 | hop_length: [ 23 | default: 160, 24 | doc: 25 | "the hop between consecutive overlapping windows for the STFT used to obtain Mel Frequency coefficients" 26 | ], 27 | fft_length: [ 28 | default: 400, 29 | doc: "the size of the fourier transform" 30 | ], 31 | padding_value: [ 32 | default: 0.0, 33 | doc: "the value used to pad the audio. Should correspond to silence" 34 | ] 35 | ] 36 | 37 | @moduledoc """ 38 | Whisper featurizer for audio data. 39 | 40 | ## Configuration 41 | 42 | #{Shared.options_doc(options)} 43 | """ 44 | 45 | defstruct Shared.option_defaults(options) 46 | 47 | @behaviour Bumblebee.Featurizer 48 | @behaviour Bumblebee.Configurable 49 | 50 | @impl true 51 | def config(featurizer, opts) do 52 | Shared.put_config_attrs(featurizer, opts) 53 | end 54 | 55 | @impl true 56 | def process_input(featurizer, raw_samples) do 57 | max_length = featurizer.num_seconds * featurizer.sampling_rate 58 | 59 | samples = 60 | for sample <- List.wrap(raw_samples) do 61 | unless Nx.rank(sample) == 1 do 62 | raise ArgumentError, 63 | "expected sample to be a 1-rank tensor, got: #{Nx.rank(sample)}-rank" 64 | end 65 | 66 | pad_size = max_length - Nx.axis_size(sample, 0) 67 | Nx.pad(sample, featurizer.padding_value, [{0, pad_size, 0}]) 68 | end 69 | 70 | Nx.stack(samples) 71 | end 72 | 73 | @impl true 74 | def batch_template(featurizer, batch_size) do 75 | max_length = featurizer.num_seconds * featurizer.sampling_rate 76 | Nx.template({batch_size, max_length}, :f32) 77 | end 78 | 79 | @impl true 80 | def process_batch(featurizer, samples) do 81 | samples = 82 | samples 83 | |> Nx.vectorize(:batch) 84 | |> extract_fbank_features( 85 | fft_length: featurizer.fft_length, 86 | sampling_rate: featurizer.sampling_rate, 87 | mel_bins: featurizer.feature_size, 88 | hop_length: featurizer.hop_length 89 | ) 90 | |> Nx.devectorize() 91 | 92 | %{"input_features" => samples} 93 | end 94 | 95 | defnp extract_fbank_features(waveform, opts \\ []) do 96 | opts = keyword!(opts, [:fft_length, :sampling_rate, :mel_bins, :hop_length]) 97 | 98 | window = NxSignal.Windows.hann(n: opts[:fft_length], is_periodic: true) 99 | 100 | {stft, _, _} = 101 | NxSignal.stft(waveform, window, 102 | sampling_rate: opts[:sampling_rate], 103 | fft_length: opts[:fft_length], 104 | overlap_length: opts[:fft_length] - opts[:hop_length], 105 | window_padding: :reflect 106 | ) 107 | 108 | stft = stft[0..-2//1] 109 | 110 | # Magic numbers taken from the reference implementation. This yields 111 | # max_mel ~ 3016 112 | frequency_spacing = 200.0 / 3 113 | max_mel = frequency_spacing * 45.245640471924965 114 | 115 | NxSignal.stft_to_mel(stft, opts[:sampling_rate], 116 | fft_length: opts[:fft_length], 117 | mel_bins: opts[:mel_bins], 118 | max_mel: max_mel, 119 | mel_frequency_spacing: frequency_spacing 120 | ) 121 | end 122 | 123 | defimpl Bumblebee.HuggingFace.Transformers.Config do 124 | def load(featurizer, data) do 125 | import Shared.Converters 126 | 127 | opts = 128 | convert!(data, 129 | feature_size: {"feature_size", number()}, 130 | sampling_rate: {"sampling_rate", number()}, 131 | hop_length: {"hop_length", number()}, 132 | num_seconds: {"chunk_length", number()}, 133 | fft_length: {"n_fft", number()}, 134 | padding_value: {"padding_value", number()} 135 | ) 136 | 137 | @for.config(featurizer, opts) 138 | end 139 | end 140 | end 141 | -------------------------------------------------------------------------------- /lib/bumblebee/configurable.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Configurable do 2 | @moduledoc """ 3 | An interface for configurable entities. 4 | 5 | A module implementing this behaviour is expected to define a struct 6 | with configuration. 7 | """ 8 | 9 | @type t :: struct() 10 | 11 | @doc """ 12 | Configures the struct. 13 | """ 14 | @callback config(t(), keyword()) :: t() 15 | end 16 | -------------------------------------------------------------------------------- /lib/bumblebee/conversion/pytorch_loader/file_tensor.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Conversion.PyTorchLoader.FileTensor do 2 | @moduledoc false 3 | 4 | defstruct [:shape, :type, :offset, :strides, :storage] 5 | end 6 | 7 | defimpl Nx.LazyContainer, for: Bumblebee.Conversion.PyTorchLoader.FileTensor do 8 | alias Bumblebee.Conversion.PyTorchLoader 9 | 10 | def traverse(lazy_tensor, acc, fun) do 11 | template = Nx.template(lazy_tensor.shape, lazy_tensor.type) 12 | 13 | load = fn -> 14 | binary = 15 | case lazy_tensor.storage do 16 | {:zip, path, file_name} -> 17 | PyTorchLoader.open_zip!(path, fn unzip -> 18 | PyTorchLoader.read_zip_file(unzip, file_name) 19 | end) 20 | 21 | {:file, path, offset, size} -> 22 | File.open!(path, [:read, :raw], fn file -> 23 | {:ok, binary} = :file.pread(file, offset, size) 24 | binary 25 | end) 26 | end 27 | 28 | %{offset: offset, shape: shape, type: type, strides: strides} = lazy_tensor 29 | 30 | {_, bit_unit} = type 31 | byte_unit = div(bit_unit, 8) 32 | size = Tuple.product(shape) 33 | binary = binary_part(binary, offset * byte_unit, size * byte_unit) 34 | binary |> Nx.from_binary(type) |> to_contiguous(shape, strides) 35 | end 36 | 37 | fun.(template, load, acc) 38 | end 39 | 40 | defp to_contiguous(tensor, shape, strides) do 41 | # PyTorch tensors may not be contiguous in memory, so strides are 42 | # used to indicate jumps necessary when traversing each axis. 43 | # Since Nx doesn't have the notion of strides, we transpose the 44 | # tensor, in a way that makes it contiguous, which is equivalent 45 | # to strides being decreasing 46 | 47 | memory_axes_order = 48 | strides 49 | |> Tuple.to_list() 50 | |> Enum.with_index() 51 | |> Enum.sort_by(&elem(&1, 0), :desc) 52 | |> Enum.map(&elem(&1, 1)) 53 | 54 | if memory_axes_order == Nx.axes(shape) do 55 | Nx.reshape(tensor, shape) 56 | else 57 | memory_shape = 58 | memory_axes_order 59 | |> Enum.map(fn axis -> elem(shape, axis) end) 60 | |> List.to_tuple() 61 | 62 | tensor 63 | |> Nx.reshape(memory_shape) 64 | |> Nx.transpose(axes: inverse_permutation(memory_axes_order)) 65 | end 66 | end 67 | 68 | defp inverse_permutation(list) do 69 | list 70 | |> Enum.with_index() 71 | |> Enum.reduce(List.to_tuple(list), fn {src_idx, dest_idx}, inverse -> 72 | put_elem(inverse, src_idx, dest_idx) 73 | end) 74 | |> Tuple.to_list() 75 | end 76 | end 77 | -------------------------------------------------------------------------------- /lib/bumblebee/diffusion/scheduler_utils.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Diffusion.SchedulerUtils do 2 | @moduledoc false 3 | 4 | import Nx.Defn 5 | 6 | @pi :math.pi() 7 | 8 | @doc """ 9 | Returns a beta schedule of the given type. 10 | 11 | The supported types are: 12 | 13 | * `:linear` - a linear schedule from Ho et al. (https://arxiv.org/pdf/2006.11239.pdf) 14 | 15 | * `:quadratic` - a quadratic schedule specific to the latent diffusion models 16 | 17 | * `:squared_cosine` - a cosine schedule from Nichol et al. (https://arxiv.org/pdf/2102.09672.pdf), 18 | used in OpenAI GLIDE 19 | 20 | ## Options 21 | 22 | * `:start` - start for the linear and quadratic schedules. Defaults to `0.0001` 23 | 24 | * `:end` - end for the linear and quadratic schedules. Defaults to `0.02` 25 | 26 | """ 27 | deftransform beta_schedule(type, num_timesteps, opts \\ []) do 28 | opts = Keyword.validate!(opts, start: 0.0001, end: 0.02) 29 | beta_start = opts[:start] 30 | beta_end = opts[:end] 31 | 32 | case type do 33 | :linear -> 34 | Nx.linspace(beta_start, beta_end, n: num_timesteps) 35 | 36 | :quadratic -> 37 | Nx.linspace(Nx.sqrt(beta_start), Nx.sqrt(beta_end), n: num_timesteps) |> Nx.pow(2) 38 | 39 | :squared_cosine -> 40 | betas_for_alpha_bar(&squared_cosine_alpha_bar/1, num_timesteps: num_timesteps) 41 | end 42 | end 43 | 44 | defnp squared_cosine_alpha_bar(t) do 45 | s = 0.008 46 | Nx.cos((t + s) / (1 + s) * @pi / 2) ** 2 47 | end 48 | 49 | # Creates a beta schedule that discretizes the given alpha_t_bar function, 50 | # which defines the cumulative product of (1 - beta) over time t in [0, 1]. 51 | defnp betas_for_alpha_bar(alpha_t_bar_fun, opts \\ []) do 52 | opts = keyword!(opts, [:num_timesteps, max_beta: 0.999]) 53 | num_timesteps = opts[:num_timesteps] 54 | max_beta = opts[:max_beta] 55 | 56 | i = Nx.iota({num_timesteps}) 57 | t1 = i / num_timesteps 58 | t2 = (i + 1) / num_timesteps 59 | beta = 1 - alpha_t_bar_fun.(t2) / alpha_t_bar_fun.(t1) 60 | min(beta, max_beta) 61 | end 62 | 63 | @doc """ 64 | Returns evenly spaced timesteps as used in the DDIM schedule. 65 | """ 66 | deftransform ddim_timesteps(num_train_steps, num_steps, offset) do 67 | timestep_gap = div(num_train_steps, num_steps) 68 | 69 | Nx.iota({num_steps}) 70 | |> Nx.multiply(timestep_gap) 71 | |> Nx.add(offset) 72 | |> Nx.reverse() 73 | end 74 | end 75 | -------------------------------------------------------------------------------- /lib/bumblebee/featurizer.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Featurizer do 2 | @moduledoc """ 3 | An interface for configuring and applying featurizers. 4 | 5 | A featurizer is used to convert raw data into model input. 6 | 7 | Every module implementing this behaviour is expected to also define 8 | a configuration struct. 9 | """ 10 | 11 | @type t :: Bumblebee.Configurable.t() 12 | 13 | @doc """ 14 | Converts the given input to a batched tensor (or a tensor container). 15 | 16 | Numerical batch processing should be moved to `c:process_batch/2` 17 | whenever possible. 18 | """ 19 | @callback process_input(t(), input :: any()) :: Nx.t() | Nx.Container.t() 20 | 21 | @doc """ 22 | Returns an input template for `c:process_batch/2`. 23 | 24 | The shape is effectively the same as the result of `c:process_input/2`, 25 | except for the batch size. 26 | """ 27 | @callback batch_template(t(), batch_size :: pos_integer()) :: Nx.t() | Nx.Container.t() 28 | 29 | @doc """ 30 | Optional batch processing stage. 31 | 32 | This is a numerical function. It receives the result of `c:process_input/2`, 33 | except the batch size may differ. 34 | 35 | When using featurizer as part of `Nx.Serving`, the batch stage can 36 | be merged with the model computation and compiled together. 37 | """ 38 | @callback process_batch(t(), input :: Nx.t() | Nx.Container.t()) :: Nx.t() | Nx.Container.t() 39 | 40 | @optional_callbacks batch_template: 2, process_batch: 2 41 | 42 | @doc """ 43 | Converts the given input to a batched tensor (or a tensor container). 44 | """ 45 | @spec process_input(t(), any()) :: Nx.t() | Nx.Container.t() 46 | def process_input(%module{} = featurizer, input) do 47 | module.process_input(featurizer, input) 48 | end 49 | 50 | @doc """ 51 | Returns an input template for `process_batch/2`. 52 | 53 | If the featurizer does not define batch processing, `nil` is returned. 54 | """ 55 | @spec batch_template(t(), pos_integer()) :: Nx.t() | Nx.Container.t() | nil 56 | def batch_template(%module{} = featurizer, batch_size) do 57 | if Code.ensure_loaded?(module) and function_exported?(module, :batch_template, 2) do 58 | module.batch_template(featurizer, batch_size) 59 | end 60 | end 61 | 62 | @doc """ 63 | Optional batch processing stage. 64 | 65 | This is a numerical function. It receives the result of `c:process_input/2`, 66 | except the batch size may differ. 67 | 68 | If the featurizer does not define batch processing, the input is 69 | returned as is. 70 | """ 71 | @spec process_batch(t(), Nx.t() | Nx.Container.t()) :: Nx.t() | Nx.Container.t() 72 | def process_batch(%module{} = featurizer, batch) do 73 | if Code.ensure_loaded?(module) and function_exported?(module, :process_batch, 2) do 74 | module.process_batch(featurizer, batch) 75 | else 76 | batch 77 | end 78 | end 79 | end 80 | -------------------------------------------------------------------------------- /lib/bumblebee/huggingface/transformers/config.ex: -------------------------------------------------------------------------------- 1 | defprotocol Bumblebee.HuggingFace.Transformers.Config do 2 | @moduledoc false 3 | 4 | # This protocol defines a bridge between Bumblebee and huggingface/transformers 5 | # configuration. 6 | 7 | @doc """ 8 | Updates configuration based on a parsed JSON data. 9 | """ 10 | @spec load(t(), map()) :: Bumblebee.ModelSpec.t() 11 | def load(config, data) 12 | end 13 | -------------------------------------------------------------------------------- /lib/bumblebee/huggingface/transformers/model.ex: -------------------------------------------------------------------------------- 1 | defprotocol Bumblebee.HuggingFace.Transformers.Model do 2 | @moduledoc false 3 | 4 | # This protocol defines details related to loading Bumblebee model 5 | # from huggingface/transformers model. 6 | 7 | @type params_mapping :: %{layer_name() => params_source()} 8 | 9 | @type params_source :: layer_name() | list(layer_name()) | param_builders() 10 | 11 | @type param_builders :: %{param_name() => param_builder()} 12 | 13 | @type param_builder :: 14 | {list(param_source()), (list(Nx.tensor()) -> Nx.Tensor.t() | Nx.Container.t())} 15 | 16 | @type param_source :: param_ref() | list(param_ref()) 17 | @type param_ref :: {layer_name(), param_name()} 18 | 19 | @type layer_name :: String.t() 20 | @type param_name :: String.t() 21 | 22 | @doc """ 23 | Returns a map describing layers/parameters relationship between an 24 | Axon model and a corresponding huggingface/transformers model. 25 | 26 | ## Mapping format 27 | 28 | The basic mapping format is a map with Axon layer names (target) as 29 | keys and PyTorch layer names (source) as values. For example: 30 | 31 | %{ 32 | "embedder.token_embedding" => "bert.embeddings.word_embeddings", 33 | ... 34 | } 35 | 36 | The mapping should always use the longest names, that is, depending on 37 | the architecture, the PyTorch layer name could be either 38 | `"bert.embeddings.word_embeddings"` or `"embeddings.word_embeddings"`. 39 | The longer version should generally be used. Prefixes are removed/added 40 | as necessary, so loading partial models is supported automatically. 41 | 42 | The layer names may include simple substitutions, useful for lists 43 | of layers: 44 | 45 | %{ 46 | "encoder.blocks.{n}.self_attention.query" => "bert.encoder.layer.{n}.attention.self.query", 47 | ... 48 | } 49 | 50 | Both param names and values for corresponding layers may not match 51 | exactly, so they require further transformations. For example, the 52 | convolution `"kernel"` in Axon corresponds to a transposed `"weight"` 53 | from PyTorch. For most common layers such conversions are handled 54 | automatically. 55 | 56 | In some cases, particularly with model-specific layers/parameters, 57 | we may need more control over the parameter mapping. In such cases, 58 | instead of source layer name, a map with parameter-level transformations 59 | may be specified: 60 | 61 | %{ 62 | "embedder.class_embedding" => %{ 63 | "embedding" => { 64 | [{"vit.embeddings", "cls_token"}], 65 | fn [value] -> Nx.squeeze(value, axes: [0, 1]) end 66 | } 67 | }, 68 | ... 69 | } 70 | 71 | For each parameter, we specify a list of source parameters in the 72 | form of `{source_layer_name, source_param_name}`, then a function 73 | to build our parameter value. Instead of a single tuple, we can 74 | specify a list of those to try one by one. With the explicit 75 | transformation we can handle arbitrary parameter name and value 76 | transformations. 77 | """ 78 | @spec params_mapping(t()) :: params_mapping() 79 | def params_mapping(spec) 80 | end 81 | -------------------------------------------------------------------------------- /lib/bumblebee/huggingface/transformers/utils.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.HuggingFace.Transformers.Utils do 2 | @moduledoc false 3 | 4 | import Bumblebee.Utils.Model, only: [join: 2] 5 | 6 | alias Bumblebee.HuggingFace.Transformers 7 | 8 | @doc """ 9 | Prefixes target and source layer names in the given params mapping. 10 | """ 11 | @spec prefix_params_mapping( 12 | Transformers.Model.params_mapping(), 13 | String.t() | nil, 14 | String.t() | nil 15 | ) :: Transformers.Model.params_mapping() 16 | def prefix_params_mapping(params_mapping, target_prefix, source_prefix) do 17 | Map.new(params_mapping, fn {target_layer_name, params_source} -> 18 | { 19 | join(target_prefix, target_layer_name), 20 | map_params_source_layer_names(params_source, &join(source_prefix, &1)) 21 | } 22 | end) 23 | end 24 | 25 | @doc """ 26 | Maps layer names in a params mapping value. 27 | """ 28 | @spec map_params_source_layer_names( 29 | Transformers.Model.params_source(), 30 | (String.t() -> String.t()) 31 | ) :: Transformers.Model.params_source() 32 | def map_params_source_layer_names(%{} = param_builders, fun) do 33 | Map.new(param_builders, fn {param_name, {sources, builder_fun}} -> 34 | sources = 35 | for ref_or_refs <- sources do 36 | case ref_or_refs do 37 | {layer_name, param_name} -> 38 | {fun.(layer_name), param_name} 39 | 40 | refs -> 41 | for {layer_name, param_name} <- refs, do: {fun.(layer_name), param_name} 42 | end 43 | end 44 | 45 | {param_name, {sources, builder_fun}} 46 | end) 47 | end 48 | 49 | def map_params_source_layer_names(layer_names, fun) when is_list(layer_names) do 50 | Enum.map(layer_names, fun) 51 | end 52 | 53 | def map_params_source_layer_names(layer_name, fun) when is_binary(layer_name) do 54 | fun.(layer_name) 55 | end 56 | end 57 | -------------------------------------------------------------------------------- /lib/bumblebee/model_spec.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.ModelSpec do 2 | @moduledoc """ 3 | An interface for configuring and building models based on the same 4 | architecture. 5 | 6 | Every module implementing this behaviour is expected to also define 7 | a configuration struct. 8 | """ 9 | 10 | @type t :: Bumblebee.Configurable.t() 11 | 12 | @doc """ 13 | Returns the list of supported model architectures. 14 | """ 15 | @callback architectures :: list(atom()) 16 | 17 | @doc """ 18 | Builds a template input for the model. 19 | 20 | The template is used to compile the model when initializing parameters. 21 | """ 22 | @callback input_template(t()) :: map() 23 | 24 | @doc """ 25 | Builds an `Axon` model according to the given configuration. 26 | """ 27 | @callback model(t()) :: Axon.t() 28 | end 29 | -------------------------------------------------------------------------------- /lib/bumblebee/scheduler.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Scheduler do 2 | @moduledoc """ 3 | An interface for configuring and using schedulers. 4 | 5 | A scheduler defines a sampling method, usually used for multi-step 6 | denoising process, as in stable diffusion. 7 | 8 | Every module implementing this behaviour is expected to also define 9 | a configuration struct. 10 | 11 | ## Context 12 | 13 | Imagine a denoising model trained in 1000 steps. During training, 14 | we take some original data and add random noise 1000 times, this 15 | way we obtain 1000 steps with increasing level of noise. Then, the 16 | model learns to predict noise at each timestep, given data at that 17 | step (sample) and the timestep. 18 | 19 | Once such model is trained, we can obtain brand new data (such as 20 | image) by generating random data and denoising it with our model in 21 | 1000 steps. 22 | 23 | Doing 1000 forward passes of the model for a single generation can 24 | be expensive, hence multiple methods have been developed to reduce 25 | the number of steps during denoising, with no changes to the model. 26 | 27 | Each method specifies a subset of the original timesteps, at each 28 | timestep we need to do a forward pass of the model (or possibly a 29 | few), then the method extrapolates the sample to the next selected 30 | timestep, possibly skipping a lot of timesteps in between. 31 | 32 | ## Note on wording 33 | 34 | Throughout the docs and APIs the word "steps" refers to diffusion 35 | steps, whereas "timesteps" is more specific and refers to the exact 36 | values $t$ (points in time). 37 | """ 38 | 39 | @type t :: Bumblebee.Configurable.t() 40 | 41 | @type state :: Nx.Container.t() 42 | 43 | @doc """ 44 | Initializes state for a new scheduler loop. 45 | 46 | Returns a pair of `{state, timesteps}`, where `state` is an opaque 47 | `Nx.Container` and `timesteps` is a tensor with the subsequent 48 | timesteps for model forward pass. 49 | """ 50 | @callback init( 51 | t(), 52 | num_steps :: pos_integer(), 53 | sample_template :: Nx.Tensor.t(), 54 | prng_key :: Nx.Tensor.t() 55 | ) :: {state :: map(), timesteps :: Nx.Tensor.t()} 56 | 57 | @doc """ 58 | Predicts sample at the previous timestep. 59 | 60 | Takes the current `sample` and `prediction` (usually noise) returned 61 | by the model at the current timestep. Returns `{state, prev_sample}`, 62 | where `state` is the updated state and `prev_sample` is the predicted 63 | sample at the previous timestep. 64 | """ 65 | @callback step( 66 | t(), 67 | state(), 68 | sample :: Nx.Tensor.t(), 69 | prediction :: Nx.Tensor.t() 70 | ) :: {state :: map(), prev_sample :: Nx.Tensor.t()} 71 | end 72 | -------------------------------------------------------------------------------- /lib/bumblebee/text/text_classification.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.TextClassification do 2 | @moduledoc false 3 | 4 | alias Bumblebee.Shared 5 | 6 | def text_classification(model_info, tokenizer, opts \\ []) do 7 | %{model: model, params: params, spec: spec} = model_info 8 | Shared.validate_architecture!(spec, :for_sequence_classification) 9 | 10 | opts = 11 | Keyword.validate!(opts, [ 12 | :compile, 13 | top_k: 5, 14 | scores_function: :softmax, 15 | defn_options: [], 16 | preallocate_params: false 17 | ]) 18 | 19 | top_k = opts[:top_k] 20 | scores_function = opts[:scores_function] 21 | preallocate_params = opts[:preallocate_params] 22 | defn_options = opts[:defn_options] 23 | 24 | compile = 25 | if compile = opts[:compile] do 26 | compile 27 | |> Keyword.validate!([:batch_size, :sequence_length]) 28 | |> Shared.require_options!([:batch_size, :sequence_length]) 29 | end 30 | 31 | batch_size = compile[:batch_size] 32 | sequence_length = compile[:sequence_length] 33 | 34 | tokenizer = 35 | Bumblebee.configure(tokenizer, length: sequence_length, return_token_type_ids: false) 36 | 37 | {_init_fun, predict_fun} = Axon.build(model) 38 | 39 | scores_fun = fn params, input -> 40 | outputs = predict_fun.(params, input) 41 | scores = Shared.logits_to_scores(outputs.logits, scores_function) 42 | k = min(top_k, Nx.axis_size(scores, 1)) 43 | {top_scores, top_indices} = Nx.top_k(scores, k: k) 44 | {top_scores, top_indices} 45 | end 46 | 47 | batch_keys = Shared.sequence_batch_keys(sequence_length) 48 | 49 | Nx.Serving.new( 50 | fn batch_key, defn_options -> 51 | params = Shared.maybe_preallocate(params, preallocate_params, defn_options) 52 | 53 | scope = {:scores, batch_key} 54 | 55 | scores_fun = 56 | Shared.compile_or_jit(scores_fun, scope, defn_options, compile != nil, fn -> 57 | {:sequence_length, sequence_length} = batch_key 58 | 59 | inputs = %{ 60 | "input_ids" => Nx.template({batch_size, sequence_length}, :u32), 61 | "attention_mask" => Nx.template({batch_size, sequence_length}, :u32) 62 | } 63 | 64 | [params, inputs] 65 | end) 66 | 67 | fn inputs -> 68 | inputs = Shared.maybe_pad(inputs, batch_size) 69 | scores_fun.(params, inputs) |> Shared.serving_post_computation() 70 | end 71 | end, 72 | defn_options 73 | ) 74 | |> Nx.Serving.batch_size(batch_size) 75 | |> Nx.Serving.process_options(batch_keys: batch_keys) 76 | |> Nx.Serving.client_preprocessing(fn input -> 77 | {texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string_or_pairs/1) 78 | 79 | inputs = 80 | Nx.with_default_backend(Nx.BinaryBackend, fn -> 81 | Bumblebee.apply_tokenizer(tokenizer, texts) 82 | end) 83 | 84 | batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) 85 | batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) 86 | 87 | {batch, multi?} 88 | end) 89 | |> Nx.Serving.client_postprocessing(fn {{top_scores, top_indices}, _metadata}, multi? -> 90 | Enum.zip_with( 91 | Nx.to_list(top_scores), 92 | Nx.to_list(top_indices), 93 | fn top_scores, top_indices -> 94 | predictions = 95 | Enum.zip_with(top_scores, top_indices, fn score, idx -> 96 | label = spec.id_to_label[idx] || "LABEL_#{idx}" 97 | %{score: score, label: label} 98 | end) 99 | 100 | %{predictions: predictions} 101 | end 102 | ) 103 | |> Shared.normalize_output(multi?) 104 | end) 105 | end 106 | end 107 | -------------------------------------------------------------------------------- /lib/bumblebee/text/whisper_generation_config.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.WhisperGenerationConfig do 2 | alias Bumblebee.Shared 3 | 4 | options = [ 5 | no_timestamps_token_id: [ 6 | default: nil, 7 | doc: "the id of the no-timestamps token" 8 | ], 9 | language_to_token_id: [ 10 | default: %{}, 11 | doc: "a map from language code to token id corresponding to that language" 12 | ], 13 | task_to_token_id: [ 14 | default: %{}, 15 | doc: "a map from task to token id corresponding to that task" 16 | ] 17 | ] 18 | 19 | @moduledoc """ 20 | A set of Whisper-specific configuration options controlling text 21 | generation. 22 | 23 | This struct is used in the `Bumblebee.Text.GenerationConfig` struct 24 | under the `:extra_config` attribute. 25 | 26 | ## Configuration 27 | 28 | #{Shared.options_doc(options)} 29 | """ 30 | 31 | defstruct Shared.option_defaults(options) 32 | 33 | @behaviour Bumblebee.Configurable 34 | 35 | @type t :: %__MODULE__{} 36 | 37 | @impl true 38 | def config(config, opts \\ []) do 39 | Shared.put_config_attrs(config, opts) 40 | end 41 | 42 | defimpl Bumblebee.HuggingFace.Transformers.Config do 43 | def load(config, data) do 44 | import Shared.Converters 45 | 46 | language_converter = fn name, value -> 47 | with {:ok, value} <- string().(name, value) do 48 | {:ok, 49 | value 50 | |> String.replace_prefix("<|", "") 51 | |> String.replace_suffix("|>", "")} 52 | end 53 | end 54 | 55 | opts = 56 | convert!(data, 57 | no_timestamps_token_id: {"no_timestamps_token_id", number()}, 58 | language_to_token_id: {"lang_to_id", map(language_converter, number())}, 59 | task_to_token_id: {"task_to_id", map(atom(), number())} 60 | ) 61 | 62 | @for.config(config, opts) 63 | end 64 | end 65 | end 66 | -------------------------------------------------------------------------------- /lib/bumblebee/tokenizer.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Tokenizer do 2 | @moduledoc """ 3 | An interface for configuring and applying tokenizers. 4 | 5 | A tokenizer is used to convert raw text data into model input. 6 | 7 | Every module implementing this behaviour is expected to also define 8 | a configuration struct. 9 | """ 10 | 11 | @type t :: struct() 12 | 13 | @type input :: String.t() | {String.t(), String.t()} 14 | @type token :: String.t() 15 | @type token_id :: non_neg_integer() 16 | 17 | @typedoc """ 18 | A type corresponding to a special token in the vocabulary. 19 | 20 | ## Common types 21 | 22 | * `:bos` - a token representing the beginning of a sentence 23 | 24 | * `:eos` - a token representing the end of a sentence 25 | 26 | * `:unk` - a token representing an out-of-vocabulary token 27 | 28 | * `:sep` - a token separating two different sentences in the same 29 | input 30 | 31 | * `:pad` - a token added when processing a batch of sequences with 32 | different length 33 | 34 | * `:cls` - a token representing the class of the input 35 | 36 | * `:mask` - a token representing a masked token, used for masked 37 | language modeling tasks 38 | 39 | """ 40 | @type special_token_type :: atom() 41 | 42 | @doc """ 43 | Performs tokenization and encoding on the given input. 44 | """ 45 | @callback apply(t(), input() | list(input())) :: any() 46 | 47 | @doc """ 48 | Decodes a list of token ids into a sentence. 49 | """ 50 | @callback decode(t(), list(token_id()) | list(list(token_id()))) :: String.t() 51 | 52 | @doc """ 53 | Converts the given token into the corresponding numeric id. 54 | """ 55 | @callback token_to_id(t(), token()) :: token_id() | nil 56 | 57 | @doc """ 58 | Converts the given token id the corresponding token. 59 | """ 60 | @callback id_to_token(t(), token_id()) :: token() | nil 61 | 62 | @doc """ 63 | Returns a map with special tokens. 64 | """ 65 | @callback special_tokens(t()) :: %{special_token_type() => token()} 66 | 67 | @doc """ 68 | Returns a list with extra special tokens, in addition to the named 69 | `special_tokens/1`. 70 | """ 71 | @callback additional_special_tokens(t()) :: MapSet.t(token()) 72 | 73 | @doc """ 74 | Decodes a list of token ids into a sentence. 75 | """ 76 | @spec decode( 77 | t(), 78 | token() | list(token_id()) | list(list(token_id())) | Nx.Tensor.t() 79 | ) :: String.t() 80 | def decode(%module{} = tokenizer, ids) do 81 | ids = with %Nx.Tensor{} <- ids, do: Nx.to_list(ids) 82 | ids = List.wrap(ids) 83 | module.decode(tokenizer, ids) 84 | end 85 | 86 | @doc """ 87 | Converts the given token into the corresponding numeric id. 88 | """ 89 | @spec token_to_id(t(), token()) :: token_id() | nil 90 | def token_to_id(%module{} = tokenizer, token) do 91 | module.token_to_id(tokenizer, token) 92 | end 93 | 94 | @doc """ 95 | Converts the given token id to the corresponding token. 96 | """ 97 | @spec id_to_token(t(), token_id()) :: token() | nil 98 | def id_to_token(%module{} = tokenizer, id) do 99 | module.id_to_token(tokenizer, id) 100 | end 101 | 102 | @doc """ 103 | Returns a special token by name. 104 | """ 105 | @spec special_token(t(), special_token_type()) :: token() | nil 106 | def special_token(%module{} = tokenizer, type) do 107 | special_tokens = module.special_tokens(tokenizer) 108 | special_tokens[type] 109 | end 110 | 111 | @doc """ 112 | Returns id of a special token by name. 113 | """ 114 | @spec special_token_id(t(), special_token_type()) :: token_id() | nil 115 | def special_token_id(tokenizer, type) do 116 | if token = special_token(tokenizer, type) do 117 | token_to_id(tokenizer, token) 118 | end 119 | end 120 | 121 | @doc """ 122 | Returns all special tokens, including any extra tokens. 123 | """ 124 | @spec all_special_tokens(t()) :: list(token_id()) 125 | def all_special_tokens(%module{} = tokenizer) do 126 | special_tokens = module.special_tokens(tokenizer) 127 | additional_special_tokens = module.additional_special_tokens(tokenizer) 128 | for {_type, token} <- special_tokens, do: token, into: additional_special_tokens 129 | end 130 | end 131 | -------------------------------------------------------------------------------- /lib/bumblebee/utils.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Utils do 2 | @moduledoc false 3 | 4 | @doc """ 5 | Checks if the progress bar is enabled globally. 6 | """ 7 | @spec progress_bar_enabled? :: boolean() 8 | def progress_bar_enabled?() do 9 | Application.get_env(:bumblebee, :progress_bar_enabled, true) 10 | end 11 | end 12 | -------------------------------------------------------------------------------- /lib/bumblebee/utils/image.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Utils.Image do 2 | @moduledoc false 3 | 4 | import Nx.Defn 5 | 6 | @doc """ 7 | Converts the given term to a batch of image. 8 | """ 9 | defn to_batched_tensor(image) do 10 | case Nx.rank(image) do 11 | 3 -> 12 | Nx.new_axis(image, 0, :batch) 13 | 14 | 4 -> 15 | image 16 | 17 | rank -> 18 | raise ArgumentError, 19 | "expected image to be a rank-3 image or a rank-4 batch, got rank: #{rank}" 20 | end 21 | end 22 | 23 | @doc """ 24 | Normalizes an image size to a `{height, width}` tuple. 25 | 26 | Accepts either an existing tuple or a single number used for both 27 | dimensions. 28 | """ 29 | def normalize_size(size) 30 | 31 | def normalize_size({height, width}), do: {height, width} 32 | def normalize_size(size) when is_integer(size), do: {size, size} 33 | 34 | @doc """ 35 | Matches image against the desired number of channels and applies 36 | automatic conversions if applicable. 37 | """ 38 | def normalize_channels(input, channels) do 39 | channel_axis = Nx.axis_index(input, -1) 40 | 41 | case {Nx.axis_size(input, channel_axis), channels} do 42 | {channels, channels} -> 43 | input 44 | 45 | {4, 3} -> 46 | Nx.slice_along_axis(input, 0, 3, axis: channel_axis) 47 | 48 | {1, 3} -> 49 | shape = input |> Nx.shape() |> put_elem(channel_axis, 3) 50 | Nx.broadcast(input, shape) 51 | 52 | {actual, expected} -> 53 | raise ArgumentError, 54 | "expected image with #{expected} channels, but got #{actual} and no automatic conversion applies" 55 | end 56 | end 57 | end 58 | -------------------------------------------------------------------------------- /lib/bumblebee/utils/model.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Utils.Model do 2 | @moduledoc false 3 | 4 | @doc """ 5 | Adds another word to a hierarchical name. 6 | """ 7 | @spec join(String.t() | nil, String.Chars.t()) :: String.t() 8 | def join(name, suffix) 9 | 10 | def join(nil, suffix), do: to_string(suffix) 11 | def join(name, suffix), do: name <> "." <> to_string(suffix) 12 | 13 | @doc """ 14 | Converts a list of inputs to a map with input names as keys. 15 | """ 16 | @spec inputs_to_map(list(Axon.t())) :: %{String.t() => Axon.t()} 17 | def inputs_to_map(inputs) when is_list(inputs) do 18 | for %Axon{output: id, nodes: nodes} = axon <- inputs, into: %{} do 19 | input = nodes[id] 20 | {input.name.(:input, %{}), axon} 21 | end 22 | end 23 | end 24 | -------------------------------------------------------------------------------- /lib/bumblebee/vision/blip_featurizer.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.BlipFeaturizer do 2 | alias Bumblebee.Shared 3 | 4 | options = [ 5 | resize: [ 6 | default: true, 7 | doc: "whether to resize the input to the given `:size`" 8 | ], 9 | size: [ 10 | default: %{height: 384, width: 384}, 11 | doc: """ 12 | the size to resize the input to, given as `%{height: ..., width: ...}`. Only has 13 | an effect if `:resize` is `true` 14 | """ 15 | ], 16 | resize_method: [ 17 | default: :bicubic, 18 | doc: 19 | "the resizing method, either of `:nearest`, `:bilinear`, `:bicubic`, `:lanczos3`, `:lanczos5`" 20 | ], 21 | normalize: [ 22 | default: true, 23 | doc: "whether or not to normalize the input with mean and standard deviation" 24 | ], 25 | image_mean: [ 26 | default: [0.48145466, 0.4578275, 0.40821073], 27 | doc: "the sequence of mean values for each channel, to be used when normalizing images" 28 | ], 29 | image_std: [ 30 | default: [0.26862954, 0.26130258, 0.27577711], 31 | doc: 32 | "the sequence of standard deviations for each channel, to be used when normalizing images" 33 | ] 34 | ] 35 | 36 | @moduledoc """ 37 | BLIP featurizer for image data. 38 | 39 | ## Configuration 40 | 41 | #{Shared.options_doc(options)} 42 | """ 43 | 44 | defstruct Shared.option_defaults(options) 45 | 46 | @behaviour Bumblebee.Featurizer 47 | @behaviour Bumblebee.Configurable 48 | 49 | alias Bumblebee.Utils.Image 50 | 51 | @impl true 52 | def config(featurizer, opts) do 53 | Shared.put_config_attrs(featurizer, opts) 54 | end 55 | 56 | @impl true 57 | def process_input(featurizer, images) do 58 | images = List.wrap(images) 59 | 60 | for image <- images do 61 | image = 62 | image 63 | |> Image.to_batched_tensor() 64 | |> Nx.as_type(:f32) 65 | |> Image.normalize_channels(length(featurizer.image_mean)) 66 | 67 | if featurizer.resize do 68 | %{height: height, width: width} = featurizer.size 69 | NxImage.resize(image, {height, width}, method: featurizer.resize_method) 70 | else 71 | image 72 | end 73 | end 74 | |> Nx.concatenate() 75 | end 76 | 77 | @impl true 78 | def batch_template(featurizer, batch_size) do 79 | %{height: height, width: width} = featurizer.size 80 | num_channels = length(featurizer.image_mean) 81 | Nx.template({batch_size, height, width, num_channels}, :f32) 82 | end 83 | 84 | @impl true 85 | def process_batch(featurizer, images) do 86 | images = NxImage.to_continuous(images, 0, 1) 87 | 88 | images = 89 | if featurizer.normalize do 90 | NxImage.normalize( 91 | images, 92 | Nx.tensor(featurizer.image_mean), 93 | Nx.tensor(featurizer.image_std) 94 | ) 95 | else 96 | images 97 | end 98 | 99 | %{"pixel_values" => images} 100 | end 101 | 102 | defimpl Bumblebee.HuggingFace.Transformers.Config do 103 | def load(featurizer, data) do 104 | import Shared.Converters 105 | 106 | opts = 107 | convert!(data, 108 | resize: {"do_resize", boolean()}, 109 | size: {"size", image_size()}, 110 | resize_method: {"resample", resize_method()}, 111 | normalize: {"do_normalize", boolean()}, 112 | image_mean: {"image_mean", list(number())}, 113 | image_std: {"image_std", list(number())} 114 | ) 115 | 116 | @for.config(featurizer, opts) 117 | end 118 | end 119 | end 120 | -------------------------------------------------------------------------------- /lib/bumblebee/vision/deit_featurizer.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.DeitFeaturizer do 2 | alias Bumblebee.Shared 3 | 4 | options = [ 5 | resize: [ 6 | default: true, 7 | doc: "whether to resize the input to the given `:size`" 8 | ], 9 | size: [ 10 | default: %{height: 256, width: 256}, 11 | doc: """ 12 | the size to resize the input to, given as `%{height: ..., width: ...}`. Only has 13 | an effect if `:resize` is `true` 14 | """ 15 | ], 16 | resize_method: [ 17 | default: :bicubic, 18 | doc: 19 | "the resizing method, either of `:nearest`, `:bilinear`, `:bicubic`, `:lanczos3`, `:lanczos5`" 20 | ], 21 | center_crop: [ 22 | default: true, 23 | doc: """ 24 | whether to crop the input at the center. If the input size is smaller than `:crop_size` along 25 | any edge, the image is padded with zeros and then center cropped 26 | """ 27 | ], 28 | crop_size: [ 29 | default: %{height: 224, width: 224}, 30 | doc: """ 31 | the size to center crop the image to, given as `%{height: ..., width: ...}`. Only has an effect 32 | if `:center_crop` is `true` 33 | """ 34 | ], 35 | normalize: [ 36 | default: true, 37 | doc: "whether or not to normalize the input with mean and standard deviation" 38 | ], 39 | image_mean: [ 40 | default: [0.485, 0.456, 0.406], 41 | doc: "the sequence of mean values for each channel, to be used when normalizing images" 42 | ], 43 | image_std: [ 44 | default: [0.229, 0.224, 0.225], 45 | doc: 46 | "the sequence of standard deviations for each channel, to be used when normalizing images" 47 | ] 48 | ] 49 | 50 | @moduledoc """ 51 | DeiT featurizer for image data. 52 | 53 | ## Configuration 54 | 55 | #{Shared.options_doc(options)} 56 | """ 57 | 58 | defstruct Shared.option_defaults(options) 59 | 60 | @behaviour Bumblebee.Featurizer 61 | @behaviour Bumblebee.Configurable 62 | 63 | alias Bumblebee.Utils.Image 64 | 65 | @impl true 66 | def config(featurizer, opts) do 67 | Shared.put_config_attrs(featurizer, opts) 68 | end 69 | 70 | @impl true 71 | def process_input(featurizer, images) do 72 | images = List.wrap(images) 73 | 74 | for image <- images do 75 | images = 76 | image 77 | |> Image.to_batched_tensor() 78 | |> Nx.as_type(:f32) 79 | |> Image.normalize_channels(length(featurizer.image_mean)) 80 | 81 | if featurizer.resize do 82 | %{height: height, width: width} = featurizer.size 83 | NxImage.resize(images, {height, width}, method: featurizer.resize_method) 84 | else 85 | images 86 | end 87 | end 88 | |> Nx.concatenate() 89 | end 90 | 91 | @impl true 92 | def batch_template(featurizer, batch_size) do 93 | %{height: height, width: width} = featurizer.size 94 | num_channels = length(featurizer.image_mean) 95 | Nx.template({batch_size, height, width, num_channels}, :f32) 96 | end 97 | 98 | @impl true 99 | def process_batch(featurizer, images) do 100 | images = 101 | if featurizer.center_crop do 102 | %{height: height, width: width} = featurizer.crop_size 103 | NxImage.center_crop(images, {height, width}) 104 | else 105 | images 106 | end 107 | 108 | images = NxImage.to_continuous(images, 0, 1) 109 | 110 | images = 111 | if featurizer.normalize do 112 | NxImage.normalize( 113 | images, 114 | Nx.tensor(featurizer.image_mean), 115 | Nx.tensor(featurizer.image_std) 116 | ) 117 | else 118 | images 119 | end 120 | 121 | %{"pixel_values" => images} 122 | end 123 | 124 | defimpl Bumblebee.HuggingFace.Transformers.Config do 125 | def load(featurizer, data) do 126 | import Shared.Converters 127 | 128 | opts = 129 | convert!(data, 130 | resize: {"do_resize", boolean()}, 131 | size: {"size", image_size()}, 132 | resize_method: {"resample", resize_method()}, 133 | center_crop: {"do_center_crop", boolean()}, 134 | crop_size: {"crop_size", image_size()}, 135 | normalize: {"do_normalize", boolean()}, 136 | image_mean: {"image_mean", list(number())}, 137 | image_std: {"image_std", list(number())} 138 | ) 139 | 140 | @for.config(featurizer, opts) 141 | end 142 | end 143 | end 144 | -------------------------------------------------------------------------------- /lib/bumblebee/vision/image_classification.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.ImageClassification do 2 | @moduledoc false 3 | 4 | alias Bumblebee.Shared 5 | 6 | def image_classification(model_info, featurizer, opts \\ []) do 7 | %{model: model, params: params, spec: spec} = model_info 8 | 9 | Shared.validate_architecture!(spec, [ 10 | :for_image_classification, 11 | :for_image_classification_with_teacher 12 | ]) 13 | 14 | opts = 15 | Keyword.validate!(opts, [ 16 | :compile, 17 | top_k: 5, 18 | scores_function: :softmax, 19 | defn_options: [], 20 | preallocate_params: false 21 | ]) 22 | 23 | top_k = opts[:top_k] 24 | scores_function = opts[:scores_function] 25 | preallocate_params = opts[:preallocate_params] 26 | defn_options = opts[:defn_options] 27 | 28 | compile = 29 | if compile = opts[:compile] do 30 | compile 31 | |> Keyword.validate!([:batch_size]) 32 | |> Shared.require_options!([:batch_size]) 33 | end 34 | 35 | batch_size = compile[:batch_size] 36 | 37 | {_init_fun, predict_fun} = Axon.build(model) 38 | 39 | scores_fun = fn params, input -> 40 | input = Bumblebee.Featurizer.process_batch(featurizer, input) 41 | outputs = predict_fun.(params, input) 42 | scores = Shared.logits_to_scores(outputs.logits, scores_function) 43 | k = min(top_k, Nx.axis_size(scores, 1)) 44 | {top_scores, top_indices} = Nx.top_k(scores, k: k) 45 | {top_scores, top_indices} 46 | end 47 | 48 | Nx.Serving.new( 49 | fn defn_options -> 50 | params = Shared.maybe_preallocate(params, preallocate_params, defn_options) 51 | 52 | scores_fun = 53 | Shared.compile_or_jit(scores_fun, :scores, defn_options, compile != nil, fn -> 54 | inputs = Bumblebee.Featurizer.batch_template(featurizer, batch_size) 55 | [params, inputs] 56 | end) 57 | 58 | fn inputs -> 59 | inputs = Shared.maybe_pad(inputs, batch_size) 60 | scores_fun.(params, inputs) |> Shared.serving_post_computation() 61 | end 62 | end, 63 | defn_options 64 | ) 65 | |> Nx.Serving.batch_size(batch_size) 66 | |> Nx.Serving.client_preprocessing(fn input -> 67 | {images, multi?} = Shared.validate_serving_input!(input, &Shared.validate_image/1) 68 | inputs = Bumblebee.Featurizer.process_input(featurizer, images) 69 | {Nx.Batch.concatenate([inputs]), multi?} 70 | end) 71 | |> Nx.Serving.client_postprocessing(fn {{top_scores, top_indices}, _metadata}, multi? -> 72 | Enum.zip_with( 73 | Nx.to_list(top_scores), 74 | Nx.to_list(top_indices), 75 | fn top_scores, top_indices -> 76 | predictions = 77 | Enum.zip_with(top_scores, top_indices, fn score, idx -> 78 | label = spec.id_to_label[idx] || "LABEL_#{idx}" 79 | %{score: score, label: label} 80 | end) 81 | 82 | %{predictions: predictions} 83 | end 84 | ) 85 | |> Shared.normalize_output(multi?) 86 | end) 87 | end 88 | end 89 | -------------------------------------------------------------------------------- /lib/bumblebee/vision/image_embedding.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.ImageEmbedding do 2 | @moduledoc false 3 | 4 | alias Bumblebee.Shared 5 | 6 | def image_embedding(model_info, featurizer, opts \\ []) do 7 | %{model: model, params: params} = model_info 8 | 9 | opts = 10 | Keyword.validate!(opts, [ 11 | :compile, 12 | output_attribute: :pooled_state, 13 | embedding_processor: nil, 14 | defn_options: [], 15 | preallocate_params: false 16 | ]) 17 | 18 | output_attribute = opts[:output_attribute] 19 | embedding_processor = opts[:embedding_processor] 20 | preallocate_params = opts[:preallocate_params] 21 | defn_options = opts[:defn_options] 22 | 23 | compile = 24 | if compile = opts[:compile] do 25 | compile 26 | |> Keyword.validate!([:batch_size]) 27 | |> Shared.require_options!([:batch_size]) 28 | end 29 | 30 | batch_size = compile[:batch_size] 31 | 32 | {_init_fun, encoder} = Axon.build(model) 33 | 34 | embedding_fun = fn params, inputs -> 35 | inputs = Bumblebee.Featurizer.process_batch(featurizer, inputs) 36 | 37 | output = encoder.(params, inputs) 38 | 39 | output = 40 | case output do 41 | %{^output_attribute => output} -> 42 | output 43 | 44 | %{} -> 45 | keys = output |> Map.keys() |> Enum.sort() 46 | 47 | raise ArgumentError, 48 | "key #{inspect(output_attribute)} not found in the output map," <> 49 | " you may want to set :output_attribute to one of the map keys: #{inspect(keys)}" 50 | 51 | _ -> 52 | output 53 | end 54 | 55 | output = 56 | case embedding_processor do 57 | nil -> 58 | output 59 | 60 | :l2_norm -> 61 | Bumblebee.Utils.Nx.normalize(output) 62 | 63 | other -> 64 | raise ArgumentError, 65 | "expected :embedding_processor to be one of nil or :l2_norm, got: #{inspect(other)}" 66 | end 67 | 68 | output 69 | end 70 | 71 | Nx.Serving.new( 72 | fn defn_options -> 73 | params = Shared.maybe_preallocate(params, preallocate_params, defn_options) 74 | 75 | embedding_fun = 76 | Shared.compile_or_jit(embedding_fun, :embedding, defn_options, compile != nil, fn -> 77 | inputs = Bumblebee.Featurizer.batch_template(featurizer, batch_size) 78 | [params, inputs] 79 | end) 80 | 81 | fn inputs -> 82 | inputs = Shared.maybe_pad(inputs, batch_size) 83 | embedding_fun.(params, inputs) |> Shared.serving_post_computation() 84 | end 85 | end, 86 | defn_options 87 | ) 88 | |> Nx.Serving.batch_size(batch_size) 89 | |> Nx.Serving.client_preprocessing(fn input -> 90 | {images, multi?} = Shared.validate_serving_input!(input, &Shared.validate_image/1) 91 | 92 | inputs = Bumblebee.Featurizer.process_input(featurizer, images) 93 | 94 | {Nx.Batch.concatenate([inputs]), multi?} 95 | end) 96 | |> Nx.Serving.client_postprocessing(fn {embeddings, _metadata}, multi? -> 97 | for embedding <- Bumblebee.Utils.Nx.batch_to_list(embeddings) do 98 | %{embedding: embedding} 99 | end 100 | |> Shared.normalize_output(multi?) 101 | end) 102 | end 103 | end 104 | -------------------------------------------------------------------------------- /lib/bumblebee/vision/image_to_text.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.ImageToText do 2 | @moduledoc false 3 | 4 | alias Bumblebee.Shared 5 | alias Bumblebee.Text 6 | 7 | def image_to_text( 8 | model_info, 9 | featurizer, 10 | tokenizer, 11 | %Text.GenerationConfig{} = generation_config, 12 | opts \\ [] 13 | ) do 14 | opts = Keyword.validate!(opts, [:compile, defn_options: [], preallocate_params: false]) 15 | 16 | %{model: model, params: params, spec: spec} = model_info 17 | 18 | Shared.validate_architecture!(spec, [:for_conditional_generation]) 19 | 20 | preallocate_params = opts[:preallocate_params] 21 | defn_options = opts[:defn_options] 22 | 23 | compile = 24 | if compile = opts[:compile] do 25 | compile 26 | |> Keyword.validate!([:batch_size]) 27 | |> Shared.require_options!([:batch_size]) 28 | end 29 | 30 | batch_size = compile[:batch_size] 31 | 32 | generate_fun = Text.Generation.build_generate(model, spec, generation_config) 33 | 34 | generate_fun = fn params, {inputs, seed} -> 35 | inputs = Bumblebee.Featurizer.process_batch(featurizer, inputs) 36 | inputs = Map.put(inputs, "seed", seed) 37 | %{token_ids: token_ids} = generate_fun.(params, inputs) 38 | token_ids 39 | end 40 | 41 | Nx.Serving.new( 42 | fn defn_options -> 43 | params = Shared.maybe_preallocate(params, preallocate_params, defn_options) 44 | 45 | generate_fun = 46 | Shared.compile_or_jit(generate_fun, :generate, defn_options, compile != nil, fn -> 47 | inputs = Bumblebee.Featurizer.batch_template(featurizer, batch_size) 48 | seed = Nx.template({batch_size}, :s64) 49 | [params, {inputs, seed}] 50 | end) 51 | 52 | fn inputs -> 53 | inputs = Shared.maybe_pad(inputs, batch_size) 54 | generate_fun.(params, inputs) |> Shared.serving_post_computation() 55 | end 56 | end, 57 | defn_options 58 | ) 59 | |> Nx.Serving.batch_size(batch_size) 60 | |> Nx.Serving.client_preprocessing(fn input -> 61 | {inputs, multi?} = Shared.validate_serving_input!(input, &validate_input/1) 62 | 63 | images = Enum.map(inputs, & &1.image) 64 | seed = Enum.map(inputs, & &1.seed) |> Nx.tensor(type: :s64, backend: Nx.BinaryBackend) 65 | 66 | inputs = Bumblebee.Featurizer.process_input(featurizer, images) 67 | {Nx.Batch.concatenate([{inputs, seed}]), multi?} 68 | end) 69 | |> Nx.Serving.client_postprocessing(fn {token_ids, _metadata}, multi? -> 70 | decoded = Bumblebee.Tokenizer.decode(tokenizer, token_ids) 71 | 72 | decoded 73 | |> Enum.map(&%{results: [%{text: &1}]}) 74 | |> Shared.normalize_output(multi?) 75 | end) 76 | end 77 | 78 | defp validate_input(%{image: image} = input) do 79 | if Shared.image?(image) do 80 | {:ok, %{image: image, seed: input[:seed] || :erlang.system_time()}} 81 | else 82 | {:error, "expected an image, got: #{inspect(image)}"} 83 | end 84 | end 85 | 86 | defp validate_input(input), do: validate_input(%{image: input}) 87 | end 88 | -------------------------------------------------------------------------------- /lib/bumblebee/vision/vit_featurizer.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.VitFeaturizer do 2 | alias Bumblebee.Shared 3 | 4 | options = [ 5 | resize: [ 6 | default: true, 7 | doc: "whether to resize the input to the given `:size`" 8 | ], 9 | size: [ 10 | default: %{height: 224, width: 224}, 11 | doc: """ 12 | the size to resize the input to, given as `%{height: ..., width: ...}`. Only has 13 | an effect if `:resize` is `true` 14 | """ 15 | ], 16 | resize_method: [ 17 | default: :bilinear, 18 | doc: 19 | "the resizing method, either of `:nearest`, `:bilinear`, `:bicubic`, `:lanczos3`, `:lanczos5`" 20 | ], 21 | normalize: [ 22 | default: true, 23 | doc: "whether or not to normalize the input with mean and standard deviation" 24 | ], 25 | image_mean: [ 26 | default: [0.5, 0.5, 0.5], 27 | doc: "the sequence of mean values for each channel, to be used when normalizing images" 28 | ], 29 | image_std: [ 30 | default: [0.5, 0.5, 0.5], 31 | doc: 32 | "the sequence of standard deviations for each channel, to be used when normalizing images" 33 | ] 34 | ] 35 | 36 | @moduledoc """ 37 | ViT featurizer for image data. 38 | 39 | ## Configuration 40 | 41 | #{Shared.options_doc(options)} 42 | """ 43 | 44 | defstruct Shared.option_defaults(options) 45 | 46 | @behaviour Bumblebee.Featurizer 47 | @behaviour Bumblebee.Configurable 48 | 49 | alias Bumblebee.Utils.Image 50 | 51 | @impl true 52 | def config(featurizer, opts) do 53 | Shared.put_config_attrs(featurizer, opts) 54 | end 55 | 56 | @impl true 57 | def process_input(featurizer, images) do 58 | images = List.wrap(images) 59 | 60 | for image <- images do 61 | image = 62 | image 63 | |> Image.to_batched_tensor() 64 | |> Nx.as_type(:f32) 65 | |> Image.normalize_channels(length(featurizer.image_mean)) 66 | 67 | if featurizer.resize do 68 | %{height: height, width: width} = featurizer.size 69 | NxImage.resize(image, {height, width}, method: featurizer.resize_method) 70 | else 71 | image 72 | end 73 | end 74 | |> Nx.concatenate() 75 | end 76 | 77 | @impl true 78 | def batch_template(featurizer, batch_size) do 79 | %{height: height, width: width} = featurizer.size 80 | num_channels = length(featurizer.image_mean) 81 | Nx.template({batch_size, height, width, num_channels}, :f32) 82 | end 83 | 84 | @impl true 85 | def process_batch(featurizer, images) do 86 | images = NxImage.to_continuous(images, 0, 1) 87 | 88 | images = 89 | if featurizer.normalize do 90 | NxImage.normalize( 91 | images, 92 | Nx.tensor(featurizer.image_mean), 93 | Nx.tensor(featurizer.image_std) 94 | ) 95 | else 96 | images 97 | end 98 | 99 | %{"pixel_values" => images} 100 | end 101 | 102 | defimpl Bumblebee.HuggingFace.Transformers.Config do 103 | def load(featurizer, data) do 104 | import Shared.Converters 105 | 106 | opts = 107 | convert!(data, 108 | resize: {"do_resize", boolean()}, 109 | size: {"size", image_size()}, 110 | resize_method: {"resample", resize_method()}, 111 | normalize: {"do_normalize", boolean()}, 112 | image_mean: {"image_mean", list(number())}, 113 | image_std: {"image_std", list(number())} 114 | ) 115 | 116 | @for.config(featurizer, opts) 117 | end 118 | end 119 | end 120 | -------------------------------------------------------------------------------- /notebooks/stable_diffusion.livemd: -------------------------------------------------------------------------------- 1 | # Stable Diffusion 2 | 3 | ```elixir 4 | Mix.install([ 5 | {:bumblebee, "~> 0.6.0"}, 6 | {:nx, "~> 0.9.0"}, 7 | {:exla, "~> 0.9.0"}, 8 | {:kino, "~> 0.14.0"} 9 | ]) 10 | 11 | Nx.global_default_backend({EXLA.Backend, client: :host}) 12 | ``` 13 | 14 | ## Introduction 15 | 16 | Stable Diffusion is a latent text-to-image diffusion model, primarily used to generate images based on a text prompt. Ever since it [became open-source](https://stability.ai/blog/stable-diffusion-public-release), the research, applications and tooling around it exploded. You can find a ton of resources and examples online, meanwhile let's see how to run Stable Diffusion using Bumblebee! 17 | 18 | 19 | 20 | > **Note:** Stable Diffusion is a very involved model, so the generation can take a long time if you run it on a CPU. Also, running on the GPU currently requires at least 5GiB of VRAM (or 3GiB with lower speed, see below). 21 | 22 | 23 | 24 | ## Text to image 25 | 26 | Stable Diffusion is composed of several separate models and preprocessors, so we will load all of them. 27 | 28 | ```elixir 29 | repo_id = "CompVis/stable-diffusion-v1-4" 30 | opts = [params_variant: "fp16", type: :bf16, backend: EXLA.Backend] 31 | 32 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"}) 33 | {:ok, clip} = Bumblebee.load_model({:hf, repo_id, subdir: "text_encoder"}, opts) 34 | {:ok, unet} = Bumblebee.load_model({:hf, repo_id, subdir: "unet"}, opts) 35 | {:ok, vae} = Bumblebee.load_model({:hf, repo_id, subdir: "vae"}, [architecture: :decoder] ++ opts) 36 | {:ok, scheduler} = Bumblebee.load_scheduler({:hf, repo_id, subdir: "scheduler"}) 37 | {:ok, featurizer} = Bumblebee.load_featurizer({:hf, repo_id, subdir: "feature_extractor"}) 38 | {:ok, safety_checker} = Bumblebee.load_model({:hf, repo_id, subdir: "safety_checker"}, opts) 39 | 40 | :ok 41 | ``` 42 | 43 | > **Note:** some checkpoints, such as [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5), require a license agreement. In those cases, sign up on Hugging Face, accept the license on the repository page, generate an access token in [the settings](https://huggingface.co/settings/tokens) and add it to the repository specification via `:auth_token`. You can use Livebook secrets to pass the token securely. 44 | 45 | 46 | 47 | With all the models loaded, we can now configure a serving implementation of the text-to-image task. 48 | 49 | ```elixir 50 | serving = 51 | Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler, 52 | num_steps: 20, 53 | num_images_per_prompt: 1, 54 | safety_checker: safety_checker, 55 | safety_checker_featurizer: featurizer, 56 | compile: [batch_size: 1, sequence_length: 60], 57 | # Option 1 58 | defn_options: [compiler: EXLA] 59 | # Option 2 (reduces GPU usage, but runs noticeably slower) 60 | # Also remove `backend: EXLA.Backend` from the loading options above 61 | # defn_options: [compiler: EXLA, lazy_transfers: :always] 62 | ) 63 | 64 | Kino.start_child({Nx.Serving, name: StableDiffusion, serving: serving}) 65 | ``` 66 | 67 | ```elixir 68 | prompt_input = 69 | Kino.Input.text("Prompt", default: "numbat, forest, high quality, detailed, digital art") 70 | 71 | negative_prompt_input = Kino.Input.text("Negative Prompt", default: "darkness, rainy, foggy") 72 | 73 | Kino.Layout.grid([prompt_input, negative_prompt_input]) 74 | ``` 75 | 76 | We are ready to generate images! 77 | 78 | ```elixir 79 | prompt = Kino.Input.read(prompt_input) 80 | negative_prompt = Kino.Input.read(negative_prompt_input) 81 | 82 | output = 83 | Nx.Serving.batched_run(StableDiffusion, %{prompt: prompt, negative_prompt: negative_prompt}) 84 | 85 | for result <- output.results do 86 | Kino.Image.new(result.image) 87 | end 88 | |> Kino.Layout.grid(columns: 2) 89 | ``` 90 | 91 | To achieve a better quality you can increase the number of steps and images. 92 | -------------------------------------------------------------------------------- /test/bumblebee/audio/whisper_featurizer_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Audio.WhisperFeaturizerTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | describe "integration" do 7 | test "encodes text" do 8 | assert {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "openai/whisper-tiny"}) 9 | 10 | assert %Bumblebee.Audio.WhisperFeaturizer{} = featurizer 11 | 12 | audio = Nx.sin(Nx.iota({100}, type: :f32)) 13 | 14 | inputs = Bumblebee.apply_featurizer(featurizer, audio, defn_options: [compiler: EXLA]) 15 | 16 | assert_all_close( 17 | inputs["input_features"][[0, 0..3, 0..3]], 18 | Nx.tensor([ 19 | [ 20 | [0.7313, 0.7820, 0.7391, 0.6787], 21 | [0.4332, 0.4861, 0.4412, 0.3497], 22 | [-0.5938, -0.5938, -0.5938, -0.5938], 23 | [-0.5938, -0.5938, -0.5938, -0.5938] 24 | ] 25 | ]) 26 | ) 27 | end 28 | end 29 | end 30 | -------------------------------------------------------------------------------- /test/bumblebee/audio/whisper_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Audio.WhisperTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-WhisperModel"}) 11 | 12 | assert %Bumblebee.Audio.Whisper{architecture: :base} = spec 13 | 14 | inputs = %{ 15 | "input_features" => Nx.sin(Nx.iota({1, 60, 80}, type: :f32)), 16 | "decoder_input_ids" => Nx.tensor([[15, 25, 35, 45, 55, 65, 0, 0]]), 17 | "decoder_attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 0, 0]]) 18 | } 19 | 20 | outputs = Axon.predict(model, params, inputs) 21 | 22 | assert Nx.shape(outputs.hidden_state) == {1, 8, 16} 23 | 24 | assert_all_close( 25 | outputs.hidden_state[[.., 1..3, 1..3]], 26 | Nx.tensor([ 27 | [[-0.3791, -1.6131, -0.6913], [0.1247, -1.3631, 0.0034], [-0.0097, 0.2039, 1.9897]] 28 | ]) 29 | ) 30 | end 31 | 32 | test ":for_conditional_generation" do 33 | assert {:ok, %{model: model, params: params, spec: spec}} = 34 | Bumblebee.load_model( 35 | {:hf, "hf-internal-testing/tiny-random-WhisperForConditionalGeneration"} 36 | ) 37 | 38 | assert %Bumblebee.Audio.Whisper{architecture: :for_conditional_generation} = spec 39 | 40 | inputs = %{ 41 | "input_features" => Nx.sin(Nx.iota({1, 60, 80}, type: :f32)), 42 | "decoder_input_ids" => Nx.tensor([[15, 25, 35, 45, 55, 65, 0, 0]]), 43 | "decoder_attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 0, 0]]) 44 | } 45 | 46 | outputs = Axon.predict(model, params, inputs) 47 | 48 | assert Nx.shape(outputs.logits) == {1, 8, 50257} 49 | 50 | assert_all_close( 51 | outputs.logits[[.., 1..3, 1..3]], 52 | Nx.tensor([ 53 | [[0.0942, 0.1288, 0.0243], [-0.1667, -0.1401, 0.1191], [0.0398, -0.0449, -0.0574]] 54 | ]) 55 | ) 56 | end 57 | end 58 | -------------------------------------------------------------------------------- /test/bumblebee/conversion/pytorch_loader_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Conversion.PyTorchLoaderTest do 2 | use ExUnit.Case, async: true 3 | 4 | alias Bumblebee.Conversion.PyTorchLoader 5 | 6 | setup do 7 | Nx.default_backend(Nx.BinaryBackend) 8 | :ok 9 | end 10 | 11 | @dir Path.expand("../../fixtures/pytorch", __DIR__) 12 | 13 | for format <- ["zip", "legacy"] do 14 | @format format 15 | 16 | describe "#{format} format" do 17 | test "tensors" do 18 | path = Path.join(@dir, "tensors.#{@format}.pt") 19 | 20 | assert path |> PyTorchLoader.load!() |> Enum.map(&Nx.to_tensor/1) == [ 21 | Nx.tensor([-1.0, 1.0], type: :f64), 22 | Nx.tensor([-1.0, 1.0], type: :f32), 23 | Nx.tensor([-1.0, 1.0], type: :f16), 24 | Nx.tensor([-1, 1], type: :s64), 25 | Nx.tensor([-1, 1], type: :s32), 26 | Nx.tensor([-1, 1], type: :s16), 27 | Nx.tensor([-1, 1], type: :s8), 28 | Nx.tensor([0, 1], type: :u8), 29 | Nx.tensor([0, 1, 0, 1], type: :u8), 30 | Nx.tensor([-1.0, 1.0], type: :bf16), 31 | Nx.tensor([Complex.new(1, -1), Complex.new(1, 1)], type: :c128), 32 | Nx.tensor([Complex.new(1, -1), Complex.new(1, 1)], type: :c64) 33 | ] 34 | end 35 | 36 | test "numpy arrays" do 37 | path = Path.join(@dir, "numpy_arrays.#{@format}.pt") 38 | 39 | assert PyTorchLoader.load!(path) == [ 40 | Nx.tensor([-1.0, 1.0], type: :f64), 41 | Nx.tensor([-1.0, 1.0], type: :f32), 42 | Nx.tensor([-1.0, 1.0], type: :f16), 43 | Nx.tensor([-1, 1], type: :s64), 44 | Nx.tensor([-1, 1], type: :s32), 45 | Nx.tensor([-1, 1], type: :s16), 46 | Nx.tensor([-1, 1], type: :s8), 47 | Nx.tensor([0, 1], type: :u64), 48 | Nx.tensor([0, 1], type: :u32), 49 | Nx.tensor([0, 1], type: :u16), 50 | Nx.tensor([0, 1], type: :u8), 51 | Nx.tensor([0, 1], type: :u8), 52 | Nx.tensor([Complex.new(1, -1), Complex.new(1, 1)], type: :c128), 53 | Nx.tensor([Complex.new(1, -1), Complex.new(1, 1)], type: :c64) 54 | ] 55 | end 56 | 57 | test "ordered dict" do 58 | path = Path.join(@dir, "ordered_dict.#{@format}.pt") 59 | 60 | assert PyTorchLoader.load!(path) == %{"x" => 1, "y" => 2} 61 | end 62 | 63 | test "noncontiguous tensor" do 64 | path = Path.join(@dir, "noncontiguous_tensor.#{@format}.pt") 65 | 66 | assert path |> PyTorchLoader.load!() |> Nx.to_tensor() == 67 | Nx.tensor([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], type: :s64) 68 | end 69 | 70 | test "numpy array in Fortran order" do 71 | path = Path.join(@dir, "noncontiguous_numpy_array.#{@format}.pt") 72 | 73 | assert PyTorchLoader.load!(path) == 74 | Nx.tensor([[1, 4], [2, 5], [3, 6]], type: :s64) 75 | end 76 | end 77 | end 78 | 79 | test "legacy format storage view" do 80 | # Note that storage views have been removed in PyTorch v0.4.0, 81 | # this test is based on https://github.com/pytorch/pytorch/blob/v1.11.0/test/test_serialization.py#L554-L575 82 | path = Path.join(@dir, "storage_view.legacy.pt") 83 | 84 | assert { 85 | {:storage, %Unpickler.Global{scope: "torch", name: "FloatStorage"}, storage1}, 86 | {:storage, %Unpickler.Global{scope: "torch", name: "FloatStorage"}, storage2} 87 | } = PyTorchLoader.load!(path) 88 | 89 | assert {:file, path, offset, size} = storage1 90 | assert path |> File.read!() |> binary_part(offset, size) == <<0, 0, 0, 0>> 91 | 92 | assert {:file, path, offset, size} = storage2 93 | assert path |> File.read!() |> binary_part(offset, size) == <<0, 0, 0, 0>> 94 | end 95 | 96 | test "raises if the files does not exist" do 97 | assert_raise File.Error, ~r/no such file or directory/, fn -> 98 | PyTorchLoader.load!("nonexistent") 99 | end 100 | end 101 | end 102 | -------------------------------------------------------------------------------- /test/bumblebee/conversion/pytorch_params_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Conversion.PyTorchParamsTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | alias Bumblebee.Conversion.PyTorchParams 7 | 8 | @dir Path.expand("../../fixtures/pytorch", __DIR__) 9 | 10 | describe "load_params!/3" do 11 | defp base_model() do 12 | Axon.input("input", shape: {nil, 4, 4, 3}) 13 | |> Axon.conv(2, kernel_size: 2, name: "conv") 14 | end 15 | 16 | defp full_model() do 17 | Axon.input("input", shape: {nil, 4, 4, 3}) 18 | |> Axon.conv(2, kernel_size: 2, name: "base.conv") 19 | |> Axon.flatten() 20 | |> Axon.dense(2, name: "classifier.intermediate") 21 | |> Axon.dense(1, name: "classifier.output") 22 | end 23 | 24 | defp input_template() do 25 | Nx.broadcast(1, {1, 4, 4, 3}) 26 | end 27 | 28 | defp params_mapping() do 29 | %{ 30 | "base.conv" => "base.conv", 31 | "classifier.intermediate" => "classifier.layers.0", 32 | "classifier.output" => "classifier.layers.1" 33 | } 34 | end 35 | 36 | test "silently loads parameters if all match" do 37 | model = base_model() 38 | path = Path.join(@dir, "state_dict_base.zip.pt") 39 | 40 | log = 41 | ExUnit.CaptureLog.capture_log(fn -> 42 | %Axon.ModelState{data: params} = 43 | PyTorchParams.load_params!(model, input_template(), path, 44 | params_mapping: params_mapping() 45 | ) 46 | 47 | assert_equal(params["conv"]["kernel"], Nx.broadcast(1.0, {2, 2, 3, 2})) 48 | assert_equal(params["conv"]["bias"], Nx.broadcast(0.0, {2})) 49 | end) 50 | 51 | refute log =~ "parameters" 52 | end 53 | 54 | test "logs parameters diff" do 55 | model = full_model() 56 | path = Path.join(@dir, "state_dict_full.zip.pt") 57 | 58 | log = 59 | ExUnit.CaptureLog.capture_log(fn -> 60 | PyTorchParams.load_params!(model, input_template(), path, 61 | params_mapping: params_mapping() 62 | ) 63 | end) 64 | 65 | assert log =~ """ 66 | the following parameters were missing: 67 | 68 | * classifier.output.kernel 69 | * classifier.output.bias 70 | """ 71 | 72 | assert log =~ """ 73 | the following PyTorch parameters were unused: 74 | 75 | * extra.weight 76 | """ 77 | 78 | assert log =~ """ 79 | the following parameters were ignored, because of non-matching shape: 80 | 81 | * classifier.intermediate.kernel (expected {18, 2}, got: {1, 1}) 82 | * classifier.intermediate.bias (expected {2}, got: {1}) 83 | """ 84 | end 85 | 86 | test "loads parameters without prefix into a specialised model" do 87 | model = base_model() 88 | path = Path.join(@dir, "state_dict_full.zip.pt") 89 | 90 | log = 91 | ExUnit.CaptureLog.capture_log(fn -> 92 | %Axon.ModelState{data: params} = 93 | PyTorchParams.load_params!(model, input_template(), path, 94 | params_mapping: params_mapping() 95 | ) 96 | 97 | assert_equal(params["conv"]["kernel"], Nx.broadcast(1.0, {2, 2, 3, 2})) 98 | assert_equal(params["conv"]["bias"], Nx.broadcast(0.0, {2})) 99 | end) 100 | 101 | refute log =~ "conv" 102 | end 103 | 104 | test "loads parameters with prefix into a base model" do 105 | model = full_model() 106 | path = Path.join(@dir, "state_dict_base.zip.pt") 107 | 108 | log = 109 | ExUnit.CaptureLog.capture_log(fn -> 110 | %Axon.ModelState{data: params} = 111 | PyTorchParams.load_params!(model, input_template(), path, 112 | params_mapping: params_mapping() 113 | ) 114 | 115 | assert_equal(params["base.conv"]["kernel"], Nx.broadcast(1.0, {2, 2, 3, 2})) 116 | assert_equal(params["base.conv"]["bias"], Nx.broadcast(0.0, {2})) 117 | end) 118 | 119 | refute log =~ "conv" 120 | end 121 | end 122 | end 123 | -------------------------------------------------------------------------------- /test/bumblebee/diffusion/controlnet_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Diffusion.ControlNetTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-controlnet"}) 11 | 12 | assert %Bumblebee.Diffusion.ControlNet{architecture: :base} = spec 13 | 14 | inputs = %{ 15 | "sample" => Nx.broadcast(0.5, {1, 32, 32, 4}), 16 | "timestep" => Nx.tensor(1), 17 | "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, 32}), 18 | "conditioning" => Nx.broadcast(0.5, {1, 64, 64, 3}) 19 | } 20 | 21 | outputs = Axon.predict(model, params, inputs) 22 | 23 | assert Nx.shape(outputs.mid_block_state) == {1, 16, 16, 64} 24 | 25 | assert_all_close( 26 | outputs.mid_block_state[[.., 1..3, 1..3, 1..3]], 27 | Nx.tensor([ 28 | [ 29 | [[-0.2818, 1.6207, -0.7002], [0.2391, 1.1387, 0.9682], [-0.6386, 0.7026, -0.4218]], 30 | [[1.0681, 1.8418, -1.0586], [0.9387, 0.5971, 1.2284], [1.2914, 0.4060, -0.9559]], 31 | [[0.5841, 1.2935, 0.0081], [0.7306, 0.2915, 0.7736], [0.0875, 0.9619, 0.4108]] 32 | ] 33 | ]) 34 | ) 35 | 36 | assert tuple_size(outputs.down_block_states) == 6 37 | 38 | first_down_block_state = elem(outputs.down_block_states, 0) 39 | assert Nx.shape(first_down_block_state) == {1, 32, 32, 32} 40 | 41 | assert_all_close( 42 | first_down_block_state[[.., 1..3, 1..3, 1..3]], 43 | Nx.tensor([ 44 | [ 45 | [[-0.1423, 0.2804, -0.0497], [-0.1425, 0.2798, -0.0485], [-0.1426, 0.2794, -0.0488]], 46 | [[-0.1419, 0.2810, -0.0493], [-0.1427, 0.2803, -0.0479], [-0.1427, 0.2800, -0.0486]], 47 | [[-0.1417, 0.2812, -0.0494], [-0.1427, 0.2807, -0.0480], [-0.1426, 0.2804, -0.0486]] 48 | ] 49 | ]) 50 | ) 51 | 52 | last_down_block_state = elem(outputs.down_block_states, 5) 53 | assert Nx.shape(last_down_block_state) == {1, 16, 16, 64} 54 | 55 | assert_all_close( 56 | last_down_block_state[[.., 1..3, 1..3, 1..3]], 57 | Nx.tensor([ 58 | [ 59 | [[-1.1169, 0.8087, 0.1024], [0.4832, 0.0686, 1.0149], [-0.3314, 0.1486, 0.4445]], 60 | [[0.5770, 0.3195, -0.2008], [1.5692, -0.1771, 0.7669], [0.4908, 0.1258, 0.0694]], 61 | [[0.4694, -0.3723, 0.1505], [1.7356, -0.4214, 0.8929], [0.4702, 0.2400, 0.1213]] 62 | ] 63 | ]) 64 | ) 65 | end 66 | end 67 | -------------------------------------------------------------------------------- /test/bumblebee/diffusion/ddim_scheduler_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Diffusion.DdimSchedulerTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | test "timesteps" do 7 | scheduler = Bumblebee.configure(Bumblebee.Diffusion.DdimScheduler) 8 | 9 | timesteps = scheduler_timesteps(scheduler, 10) 10 | 11 | assert_equal(timesteps, Nx.tensor([900, 800, 700, 600, 500, 400, 300, 200, 100, 0])) 12 | end 13 | 14 | test "default configuration" do 15 | scheduler = Bumblebee.configure(Bumblebee.Diffusion.DdimScheduler) 16 | 17 | sample = scheduler_loop(scheduler, 10) 18 | 19 | assert_all_close( 20 | sample[[1..3, 1..3, 1..2]], 21 | Nx.tensor([ 22 | [[0.0648, 0.0666], [0.0718, 0.0736], [0.0788, 0.0806]], 23 | [[0.1209, 0.1226], [0.1279, 0.1296], [0.1349, 0.1367]], 24 | [[0.1770, 0.1787], [0.1840, 0.1857], [0.1910, 0.1927]] 25 | ]) 26 | ) 27 | 28 | assert_all_close(Nx.sum(sample), Nx.tensor(57.1861)) 29 | end 30 | 31 | test ":quadratic beta schedule" do 32 | scheduler = Bumblebee.configure(Bumblebee.Diffusion.DdimScheduler, beta_schedule: :quadratic) 33 | 34 | sample = scheduler_loop(scheduler, 10) 35 | 36 | assert_all_close( 37 | sample[[1..3, 1..3, 1..2]], 38 | Nx.tensor([ 39 | [[0.0675, 0.0693], [0.0747, 0.0766], [0.0820, 0.0839]], 40 | [[0.1258, 0.1276], [0.1331, 0.1349], [0.1404, 0.1422]], 41 | [[0.1841, 0.1860], [0.1914, 0.1932], [0.1987, 0.2005]] 42 | ]) 43 | ) 44 | 45 | assert_all_close(Nx.sum(sample), Nx.tensor(59.5049)) 46 | end 47 | 48 | test ":squared_cosine beta schedule" do 49 | scheduler = 50 | Bumblebee.configure(Bumblebee.Diffusion.DdimScheduler, beta_schedule: :squared_cosine) 51 | 52 | sample = scheduler_loop(scheduler, 10) 53 | 54 | assert_all_close( 55 | sample[[1..3, 1..3, 1..2]], 56 | Nx.tensor([ 57 | [[0.0684, 0.0702], [0.0757, 0.0776], [0.0831, 0.0850]], 58 | [[0.1275, 0.1293], [0.1349, 0.1367], [0.1422, 0.1441]], 59 | [[0.1866, 0.1884], [0.1940, 0.1958], [0.2014, 0.2032]] 60 | ]) 61 | ) 62 | 63 | assert_all_close(Nx.sum(sample), Nx.tensor(60.2983)) 64 | end 65 | 66 | test "beta schedule range" do 67 | scheduler = 68 | Bumblebee.configure(Bumblebee.Diffusion.DdimScheduler, beta_start: 0.001, beta_end: 0.02) 69 | 70 | sample = scheduler_loop(scheduler, 10) 71 | 72 | assert_all_close( 73 | sample[[1..3, 1..3, 1..2]], 74 | Nx.tensor([ 75 | [[0.0653, 0.0670], [0.0723, 0.0741], [0.0794, 0.0812]], 76 | [[0.1217, 0.1235], [0.1288, 0.1306], [0.1359, 0.1376]], 77 | [[0.1782, 0.1800], [0.1853, 0.1870], [0.1923, 0.1941]] 78 | ]) 79 | ) 80 | 81 | assert_all_close(Nx.sum(sample), Nx.tensor(57.5869)) 82 | end 83 | 84 | test ":angular_velocity prediction type" do 85 | scheduler = 86 | Bumblebee.configure(Bumblebee.Diffusion.DdimScheduler, prediction_type: :angular_velocity) 87 | 88 | sample = scheduler_loop(scheduler, 10) 89 | 90 | assert_all_close( 91 | sample[[1..3, 1..3, 1..2]], 92 | Nx.tensor([ 93 | [[0.0198, 0.0203], [0.0219, 0.0225], [0.0241, 0.0246]], 94 | [[0.0369, 0.0375], [0.0391, 0.0396], [0.0412, 0.0417]], 95 | [[0.0540, 0.0546], [0.0562, 0.0567], [0.0583, 0.0589]] 96 | ]) 97 | ) 98 | 99 | assert_all_close(Nx.sum(sample), Nx.tensor(17.4644)) 100 | end 101 | end 102 | -------------------------------------------------------------------------------- /test/bumblebee/diffusion/stable_diffusion_controlnet_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Diffusion.StableDiffusionControlNetTest do 2 | use ExUnit.Case, async: false 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag serving_test_tags() 7 | 8 | describe "text_to_image/6" do 9 | test "generates image for a text prompt with controlnet" do 10 | # Since we don't assert on the result in this case, we use 11 | # a tiny random checkpoint. This test is basically to verify 12 | # the whole generation computation end-to-end 13 | 14 | repository_id = "bumblebee-testing/tiny-stable-diffusion" 15 | 16 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"}) 17 | {:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"}) 18 | 19 | {:ok, unet} = Bumblebee.load_model({:hf, repository_id, subdir: "unet"}) 20 | 21 | {:ok, controlnet} = Bumblebee.load_model({:hf, "bumblebee-testing/tiny-controlnet"}) 22 | 23 | {:ok, vae} = 24 | Bumblebee.load_model({:hf, repository_id, subdir: "vae"}, architecture: :decoder) 25 | 26 | {:ok, scheduler} = Bumblebee.load_scheduler({:hf, repository_id, subdir: "scheduler"}) 27 | 28 | {:ok, featurizer} = 29 | Bumblebee.load_featurizer({:hf, repository_id, subdir: "feature_extractor"}) 30 | 31 | {:ok, safety_checker} = Bumblebee.load_model({:hf, repository_id, subdir: "safety_checker"}) 32 | 33 | conditioning_size = 34 | controlnet.spec.sample_size * 35 | 2 ** (length(controlnet.spec.conditioning_embedding_hidden_sizes) - 1) 36 | 37 | serving = 38 | Bumblebee.Diffusion.StableDiffusionControlNet.text_to_image( 39 | clip, 40 | unet, 41 | vae, 42 | controlnet, 43 | tokenizer, 44 | scheduler, 45 | num_steps: 3, 46 | safety_checker: safety_checker, 47 | safety_checker_featurizer: featurizer 48 | ) 49 | 50 | prompt = "numbat in forest, detailed, digital art" 51 | 52 | conditioning = 53 | Nx.broadcast(Nx.tensor(50, type: :u8), {conditioning_size, conditioning_size, 3}) 54 | 55 | assert %{ 56 | results: [%{image: %Nx.Tensor{}, is_safe: _boolean}] 57 | } = 58 | Nx.Serving.run(serving, %{ 59 | prompt: prompt, 60 | conditioning: conditioning 61 | }) 62 | 63 | # Without safety checker 64 | 65 | serving = 66 | Bumblebee.Diffusion.StableDiffusionControlNet.text_to_image( 67 | clip, 68 | unet, 69 | vae, 70 | controlnet, 71 | tokenizer, 72 | scheduler, 73 | num_steps: 3 74 | ) 75 | 76 | prompt = "numbat in forest, detailed, digital art" 77 | 78 | assert %{results: [%{image: %Nx.Tensor{}}]} = 79 | Nx.Serving.run(serving, %{ 80 | prompt: prompt, 81 | conditioning: conditioning 82 | }) 83 | 84 | # With compilation 85 | 86 | serving = 87 | Bumblebee.Diffusion.StableDiffusionControlNet.text_to_image( 88 | clip, 89 | unet, 90 | vae, 91 | controlnet, 92 | tokenizer, 93 | scheduler, 94 | num_steps: 3, 95 | safety_checker: safety_checker, 96 | safety_checker_featurizer: featurizer, 97 | compile: [batch_size: 1, sequence_length: 60], 98 | defn_options: [compiler: EXLA] 99 | ) 100 | 101 | prompt = "numbat in forest, detailed, digital art" 102 | 103 | assert %{ 104 | results: [%{image: %Nx.Tensor{}, is_safe: _boolean}] 105 | } = 106 | Nx.Serving.run(serving, %{ 107 | prompt: prompt, 108 | conditioning: conditioning 109 | }) 110 | end 111 | end 112 | end 113 | -------------------------------------------------------------------------------- /test/bumblebee/diffusion/stable_diffusion_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Diffusion.StableDiffusionTest do 2 | use ExUnit.Case, async: false 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag serving_test_tags() 7 | 8 | describe "text_to_image/6" do 9 | test "generates image for a text prompt" do 10 | # Since we don't assert on the result in this case, we use 11 | # a tiny random checkpoint. This test is basically to verify 12 | # the whole generation computation end-to-end 13 | 14 | # repository_id = "CompVis/stable-diffusion-v1-4" 15 | repository_id = "bumblebee-testing/tiny-stable-diffusion" 16 | 17 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"}) 18 | {:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"}) 19 | {:ok, unet} = Bumblebee.load_model({:hf, repository_id, subdir: "unet"}) 20 | 21 | {:ok, vae} = 22 | Bumblebee.load_model({:hf, repository_id, subdir: "vae"}, architecture: :decoder) 23 | 24 | {:ok, scheduler} = Bumblebee.load_scheduler({:hf, repository_id, subdir: "scheduler"}) 25 | 26 | {:ok, featurizer} = 27 | Bumblebee.load_featurizer({:hf, repository_id, subdir: "feature_extractor"}) 28 | 29 | {:ok, safety_checker} = Bumblebee.load_model({:hf, repository_id, subdir: "safety_checker"}) 30 | 31 | serving = 32 | Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler, 33 | num_steps: 3, 34 | safety_checker: safety_checker, 35 | safety_checker_featurizer: featurizer 36 | ) 37 | 38 | prompt = "numbat in forest, detailed, digital art" 39 | 40 | assert %{ 41 | results: [%{image: %Nx.Tensor{}, is_safe: _boolean}] 42 | } = Nx.Serving.run(serving, prompt) 43 | 44 | # Without safety checker 45 | 46 | serving = 47 | Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler, 48 | num_steps: 3 49 | ) 50 | 51 | prompt = "numbat in forest, detailed, digital art" 52 | 53 | assert %{results: [%{image: %Nx.Tensor{}}]} = Nx.Serving.run(serving, prompt) 54 | 55 | # With compilation 56 | 57 | serving = 58 | Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler, 59 | num_steps: 3, 60 | safety_checker: safety_checker, 61 | safety_checker_featurizer: featurizer, 62 | compile: [batch_size: 1, sequence_length: 50], 63 | defn_options: [compiler: EXLA] 64 | ) 65 | 66 | prompt = "numbat in forest, detailed, digital art" 67 | 68 | assert %{ 69 | results: [%{image: %Nx.Tensor{}, is_safe: _boolean}] 70 | } = Nx.Serving.run(serving, prompt) 71 | end 72 | end 73 | end 74 | -------------------------------------------------------------------------------- /test/bumblebee/diffusion/unet_2d_conditional_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Diffusion.UNet2DConditionalTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model( 11 | {:hf, "hf-internal-testing/tiny-stable-diffusion-torch", subdir: "unet"} 12 | ) 13 | 14 | assert %Bumblebee.Diffusion.UNet2DConditional{architecture: :base} = spec 15 | 16 | inputs = %{ 17 | "sample" => Nx.broadcast(0.5, {1, 32, 32, 4}), 18 | "timestep" => Nx.tensor(1), 19 | "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, 32}) 20 | } 21 | 22 | outputs = Axon.predict(model, params, inputs) 23 | 24 | assert Nx.shape(outputs.sample) == {1, 32, 32, 4} 25 | 26 | assert_all_close( 27 | outputs.sample[[.., 1..3, 1..3, 1..3]], 28 | Nx.tensor([ 29 | [ 30 | [[-1.0813, -0.2179, 1.3359], [-0.5109, -0.2799, 0.8373], [-0.1545, -1.0922, -0.2392]], 31 | [[-0.8094, -0.9485, 0.9448], [-1.2588, -0.8376, -0.0478], [-0.8355, 0.0843, 0.6881]], 32 | [[-0.9218, -0.9650, -0.0154], [-1.2142, -0.7105, -0.5304], [-0.6982, -0.3920, 0.2081]] 33 | ] 34 | ]) 35 | ) 36 | end 37 | 38 | test ":base with additional states for skip connection" do 39 | tiny = "bumblebee-testing/tiny-stable-diffusion" 40 | 41 | assert {:ok, %{model: model, params: params, spec: spec}} = 42 | Bumblebee.load_model({:hf, tiny, subdir: "unet"}) 43 | 44 | assert %Bumblebee.Diffusion.UNet2DConditional{architecture: :base} = spec 45 | 46 | down_block_states = 47 | [ 48 | {1, 32, 32, 32}, 49 | {1, 32, 32, 32}, 50 | {1, 32, 32, 32}, 51 | {1, 16, 16, 32}, 52 | {1, 16, 16, 64}, 53 | {1, 16, 16, 64} 54 | ] 55 | |> Enum.map(&Nx.broadcast(0.5, &1)) 56 | |> List.to_tuple() 57 | 58 | mid_block_state = Nx.broadcast(0.5, {1, 16, 16, 64}) 59 | 60 | inputs = 61 | %{ 62 | "sample" => Nx.broadcast(0.5, {1, 32, 32, 4}), 63 | "timestep" => Nx.tensor(1), 64 | "encoder_hidden_state" => Nx.broadcast(0.5, {1, 1, 32}), 65 | "additional_down_block_states" => down_block_states, 66 | "additional_mid_block_state" => mid_block_state 67 | } 68 | 69 | outputs = Axon.predict(model, params, inputs) 70 | 71 | assert Nx.shape(outputs.sample) == {1, 32, 32, 4} 72 | 73 | assert_all_close( 74 | outputs.sample[[.., 1..3, 1..3, 1..3]], 75 | Nx.tensor([ 76 | [ 77 | [[-0.9457, -0.2378, 1.4223], [-0.5736, -0.2456, 0.7603], [-0.4346, -1.1370, -0.1988]], 78 | [[-0.5274, -1.0902, 0.5937], [-1.2290, -0.7996, 0.0264], [-0.3006, -0.1181, 0.7059]], 79 | [[-0.8336, -1.1615, -0.1906], [-1.0489, -0.3815, -0.5497], [-0.6255, 0.0863, 0.3285]] 80 | ] 81 | ]) 82 | ) 83 | end 84 | end 85 | -------------------------------------------------------------------------------- /test/bumblebee/diffusion/vae_kl_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Diffusion.VaeKlTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model( 11 | {:hf, "hf-internal-testing/tiny-stable-diffusion-torch", subdir: "vae"} 12 | ) 13 | 14 | assert %Bumblebee.Diffusion.VaeKl{architecture: :base} = spec 15 | 16 | inputs = %{ 17 | "sample" => Nx.broadcast(0.5, {1, 32, 32, 3}) 18 | } 19 | 20 | outputs = Axon.predict(model, params, inputs) 21 | 22 | assert Nx.shape(outputs.sample) == {1, 32, 32, 3} 23 | 24 | assert_all_close( 25 | outputs.sample[[.., 1..3, 1..3, ..]], 26 | Nx.tensor([ 27 | [ 28 | [[0.0164, 0.3587, -0.2398], [-0.1439, 0.4220, 0.2247], [0.4768, 0.1088, -0.2082]], 29 | [[0.3165, 0.4741, -0.1440], [0.0599, 0.4139, 0.0256], [0.1729, 0.6284, -0.0120]], 30 | [[0.1148, 0.4739, -0.0982], [0.5428, 0.1454, -0.3666], [0.6126, 0.3089, -0.1221]] 31 | ] 32 | ]) 33 | ) 34 | end 35 | 36 | test ":decoder" do 37 | assert {:ok, %{model: model, params: params, spec: spec}} = 38 | Bumblebee.load_model( 39 | {:hf, "hf-internal-testing/tiny-stable-diffusion-torch", subdir: "vae"}, 40 | architecture: :decoder 41 | ) 42 | 43 | assert %Bumblebee.Diffusion.VaeKl{architecture: :decoder} = spec 44 | 45 | inputs = %{ 46 | "sample" => Nx.broadcast(0.5, {1, 16, 16, 4}) 47 | } 48 | 49 | outputs = Axon.predict(model, params, inputs) 50 | 51 | assert Nx.shape(outputs.sample) == {1, 32, 32, 3} 52 | 53 | assert_all_close( 54 | outputs.sample[[.., 1..3, 1..3, ..]], 55 | Nx.tensor([ 56 | [ 57 | [[-0.1682, -0.6000, 0.2776], [-0.1015, -0.0538, 0.6985], [-0.4158, -0.6703, 0.2960]], 58 | [[-0.4621, 0.4113, 0.4759], [0.5176, 0.3203, -0.3528], [-0.0999, -0.5005, -0.7306]], 59 | [[-0.0685, 0.2073, 0.5656], [0.7141, -0.1205, -0.5857], [0.3287, 0.2487, -0.2490]] 60 | ] 61 | ]) 62 | ) 63 | end 64 | 65 | test ":encoder" do 66 | assert {:ok, %{model: model, params: params, spec: spec}} = 67 | Bumblebee.load_model( 68 | {:hf, "hf-internal-testing/tiny-stable-diffusion-torch", subdir: "vae"}, 69 | architecture: :encoder 70 | ) 71 | 72 | assert %Bumblebee.Diffusion.VaeKl{architecture: :encoder} = spec 73 | 74 | inputs = %{ 75 | "sample" => Nx.broadcast(0.5, {1, 32, 32, 3}) 76 | } 77 | 78 | outputs = Axon.predict(model, params, inputs) 79 | 80 | assert Nx.shape(outputs.latent_dist.mean) == {1, 16, 16, 4} 81 | assert Nx.shape(outputs.latent_dist.var) == {1, 16, 16, 4} 82 | assert Nx.shape(outputs.latent_dist.logvar) == {1, 16, 16, 4} 83 | assert Nx.shape(outputs.latent_dist.std) == {1, 16, 16, 4} 84 | 85 | assert_all_close( 86 | outputs.latent_dist.mean[[.., 1..3, 1..3, 1..3]], 87 | Nx.tensor([ 88 | [ 89 | [[0.1788, -0.6560, 0.2527], [0.2526, -0.1389, 0.6616], [0.3464, -0.1010, -0.1320]], 90 | [[0.1255, -0.0494, 0.2834], [0.4318, -0.5862, -0.1787], [0.0935, -0.2144, -0.1887]], 91 | [[-0.3859, 0.1139, 0.2339], [-0.1090, -0.5287, 0.6370], [-0.1257, -0.3207, -0.1075]] 92 | ] 93 | ]) 94 | ) 95 | 96 | assert_all_close( 97 | outputs.latent_dist.var[[.., 1..3, 1..3, 1..3]], 98 | Nx.tensor([ 99 | [ 100 | [[0.7926, 0.4830, 1.7315], [0.6405, 1.2762, 1.8338], [0.8108, 1.1277, 1.7099]], 101 | [[0.4721, 0.9835, 1.6843], [0.9543, 1.1715, 1.5880], [0.8660, 1.5034, 1.2972]], 102 | [[0.5069, 1.2341, 1.2979], [0.7749, 1.0105, 1.3841], [0.9574, 1.2950, 1.1591]] 103 | ] 104 | ]) 105 | ) 106 | end 107 | end 108 | -------------------------------------------------------------------------------- /test/bumblebee/multimodal/blip_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Multimodal.BlipTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":for_conditional_generation" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model( 11 | {:hf, "hf-internal-testing/tiny-random-BlipForConditionalGeneration"} 12 | ) 13 | 14 | assert %Bumblebee.Multimodal.Blip{architecture: :for_conditional_generation} = spec 15 | 16 | inputs = %{ 17 | "decoder_input_ids" => Nx.tensor([[15, 25, 35, 45, 55, 65, 0, 0]]), 18 | "decoder_attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 0, 0]]), 19 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3}) 20 | } 21 | 22 | outputs = Axon.predict(model, params, inputs) 23 | 24 | assert Nx.shape(outputs.logits) == {1, 8, 1124} 25 | 26 | assert_all_close( 27 | outputs.logits[[.., 1..3, 1..3]], 28 | Nx.tensor([[[0.1215, 0.0226, -0.1134], [0.1472, 0.1118, 0.1031], [-0.0687, 0.0104, 0.1781]]]) 29 | ) 30 | end 31 | end 32 | -------------------------------------------------------------------------------- /test/bumblebee/multimodal/clip_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Multimodal.ClipTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-CLIPModel"}) 11 | 12 | assert %Bumblebee.Multimodal.Clip{architecture: :base} = spec 13 | 14 | inputs = %{ 15 | "input_ids" => 16 | Nx.tensor([ 17 | [10, 20, 30, 40, 50, 60, 70, 80, 0, 0], 18 | [15, 25, 35, 45, 55, 65, 75, 85, 0, 0] 19 | ]), 20 | "attention_mask" => 21 | Nx.tensor([ 22 | [1, 1, 1, 1, 1, 1, 1, 1, 0, 0], 23 | [1, 1, 1, 1, 1, 1, 1, 1, 0, 0] 24 | ]), 25 | "pixel_values" => 26 | Nx.concatenate([ 27 | Nx.broadcast(0.25, {1, 30, 30, 3}), 28 | Nx.broadcast(0.75, {1, 30, 30, 3}) 29 | ]) 30 | } 31 | 32 | outputs = Axon.predict(model, params, inputs) 33 | 34 | assert Nx.shape(outputs.logits_per_text) == {2, 2} 35 | assert Nx.shape(outputs.logits_per_image) == {2, 2} 36 | 37 | assert_all_close( 38 | outputs.logits_per_text, 39 | Nx.tensor([[0.5381, 0.1981], [0.5212, 0.3291]]) 40 | ) 41 | 42 | assert_all_close( 43 | outputs.logits_per_image, 44 | Nx.tensor([[0.5381, 0.5212], [0.1981, 0.3291]]) 45 | ) 46 | end 47 | end 48 | -------------------------------------------------------------------------------- /test/bumblebee/shared_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.SharedTest do 2 | use ExUnit.Case, async: true 3 | 4 | alias Bumblebee.Shared 5 | 6 | describe "validate_label_options/1" do 7 | test "passes when :id_to_label is empty" do 8 | spec = %{__struct__: TestConfig, num_labels: 3, id_to_label: %{}} 9 | 10 | assert Shared.validate_label_options(spec) == spec 11 | end 12 | 13 | test "passes when :id_to_label is matches :num_labels" do 14 | id_to_label = %{0 => "cat", 1 => "dog", 2 => "squirrel"} 15 | spec = %{__struct__: TestConfig, num_labels: 3, id_to_label: id_to_label} 16 | 17 | assert Shared.validate_label_options(spec) == spec 18 | end 19 | 20 | test "raises an error if mismatched :num_labels and :id_to_label are given" do 21 | id_to_label = %{0 => "cat", 1 => "dog"} 22 | spec = %{__struct__: TestConfig, num_labels: 3, id_to_label: id_to_label} 23 | 24 | assert_raise ArgumentError, 25 | ~s/size mismatch between :num_labels (3) and :id_to_label (%{0 => "cat", 1 => "dog"})/, 26 | fn -> 27 | Shared.validate_label_options(spec) 28 | end 29 | end 30 | end 31 | end 32 | -------------------------------------------------------------------------------- /test/bumblebee/text/bart_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.BartTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-BartModel"}) 11 | 12 | assert %Bumblebee.Text.Bart{architecture: :base} = spec 13 | 14 | inputs = %{ 15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 17 | } 18 | 19 | outputs = Axon.predict(model, params, inputs) 20 | 21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 16} 22 | 23 | assert_all_close( 24 | outputs.hidden_state[[.., 1..3, 1..3]], 25 | Nx.tensor([ 26 | [[0.9984, -0.0751, 0.4176], [0.0095, -0.3245, -0.4237], [-0.8061, -0.3498, 0.9201]] 27 | ]) 28 | ) 29 | end 30 | 31 | test ":for_conditional_generation" do 32 | assert {:ok, %{model: model, params: params, spec: spec}} = 33 | Bumblebee.load_model( 34 | {:hf, "hf-internal-testing/tiny-random-BartForConditionalGeneration"} 35 | ) 36 | 37 | assert %Bumblebee.Text.Bart{architecture: :for_conditional_generation} = spec 38 | 39 | inputs = %{ 40 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 41 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 42 | } 43 | 44 | outputs = Axon.predict(model, params, inputs) 45 | 46 | assert Nx.shape(outputs.logits) == {1, 10, 1024} 47 | 48 | assert_all_close( 49 | outputs.logits[[.., 1..3, 1..3]], 50 | Nx.tensor([ 51 | [[0.0000, -0.0601, -0.0501], [0.0000, 0.0443, 0.0813], [0.0000, -0.1303, 0.0968]] 52 | ]) 53 | ) 54 | end 55 | 56 | test ":for_sequence_classification" do 57 | assert {:ok, %{model: model, params: params, spec: spec}} = 58 | Bumblebee.load_model( 59 | {:hf, "hf-internal-testing/tiny-random-BartForSequenceClassification"} 60 | ) 61 | 62 | assert %Bumblebee.Text.Bart{architecture: :for_sequence_classification} = spec 63 | 64 | inputs = %{ 65 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 2, 0]]), 66 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]) 67 | } 68 | 69 | outputs = Axon.predict(model, params, inputs) 70 | 71 | assert Nx.shape(outputs.logits) == {1, 3} 72 | 73 | assert_all_close( 74 | outputs.logits, 75 | Nx.tensor([[-0.0075, -0.0078, -0.0073]]) 76 | ) 77 | end 78 | 79 | test ":for_question_answering" do 80 | assert {:ok, %{model: model, params: params, spec: spec}} = 81 | Bumblebee.load_model( 82 | {:hf, "hf-internal-testing/tiny-random-BartForQuestionAnswering"} 83 | ) 84 | 85 | assert %Bumblebee.Text.Bart{architecture: :for_question_answering} = spec 86 | 87 | inputs = %{ 88 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 89 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 90 | } 91 | 92 | outputs = Axon.predict(model, params, inputs) 93 | 94 | assert Nx.shape(outputs.start_logits) == {1, 10} 95 | assert Nx.shape(outputs.end_logits) == {1, 10} 96 | 97 | assert_all_close( 98 | outputs.start_logits[[.., 1..3]], 99 | Nx.tensor([[0.0474, -0.0767, 0.0278]]) 100 | ) 101 | 102 | assert_all_close( 103 | outputs.end_logits[[.., 1..3]], 104 | Nx.tensor([[0.1557, -0.1034, -0.1271]]) 105 | ) 106 | end 107 | 108 | test ":for_causal_language_modeling" do 109 | assert {:ok, %{model: model, params: params, spec: spec}} = 110 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-BartForCausalLM"}) 111 | 112 | assert %Bumblebee.Text.Bart{architecture: :for_causal_language_modeling} = spec 113 | 114 | inputs = %{ 115 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 116 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 117 | } 118 | 119 | outputs = Axon.predict(model, params, inputs) 120 | 121 | assert Nx.shape(outputs.logits) == {1, 10, 1024} 122 | 123 | assert_all_close( 124 | outputs.logits[[.., 1..3, 1..3]], 125 | Nx.tensor([ 126 | [[0.0000, -0.2084, -0.0013], [0.0000, -0.0502, 0.0656], [0.0000, -0.1301, -0.1234]] 127 | ]) 128 | ) 129 | end 130 | end 131 | -------------------------------------------------------------------------------- /test/bumblebee/text/blenderbot_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.BlenderbotTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-BlenderbotModel"}) 11 | 12 | assert %Bumblebee.Text.Blenderbot{architecture: :base} = spec 13 | 14 | inputs = %{ 15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), 17 | "decoder_input_ids" => Nx.tensor([[15, 25, 35, 45, 55, 65, 0, 0]]), 18 | "decoder_attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 0, 0]]) 19 | } 20 | 21 | outputs = Axon.predict(model, params, inputs) 22 | 23 | assert Nx.shape(outputs.hidden_state) == {1, 8, 16} 24 | 25 | assert_all_close( 26 | outputs.hidden_state[[.., 1..3, 1..3]], 27 | Nx.tensor([ 28 | [[0.6578, 1.9730, 0.6908], [-1.8067, 0.0553, -0.7491], [0.1820, -0.4390, -0.8273]] 29 | ]) 30 | ) 31 | end 32 | 33 | test ":for_conditional_generation" do 34 | assert {:ok, %{model: model, params: params, spec: spec}} = 35 | Bumblebee.load_model( 36 | {:hf, "hf-internal-testing/tiny-random-BlenderbotForConditionalGeneration"} 37 | ) 38 | 39 | assert %Bumblebee.Text.Blenderbot{architecture: :for_conditional_generation} = spec 40 | 41 | inputs = %{ 42 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 43 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), 44 | "decoder_input_ids" => Nx.tensor([[15, 25, 35, 45, 55, 65, 0, 0]]), 45 | "decoder_attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 0, 0]]) 46 | } 47 | 48 | outputs = Axon.predict(model, params, inputs) 49 | 50 | assert Nx.shape(outputs.logits) == {1, 8, 1024} 51 | 52 | assert_all_close( 53 | outputs.logits[[.., 1..3, 1..3]], 54 | Nx.tensor([ 55 | [[0.0440, -0.0115, -0.0004], [0.0772, 0.0327, -0.0667], [-0.0419, 0.1483, 0.0140]] 56 | ]) 57 | ) 58 | end 59 | end 60 | -------------------------------------------------------------------------------- /test/bumblebee/text/blip_text_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.BlipTextTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model( 11 | {:hf, "hf-internal-testing/tiny-random-BlipModel"}, 12 | module: Bumblebee.Text.BlipText, 13 | architecture: :base 14 | ) 15 | 16 | assert %Bumblebee.Text.BlipText{architecture: :base} = spec 17 | 18 | inputs = %{ 19 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 20 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 21 | } 22 | 23 | outputs = Axon.predict(model, params, inputs) 24 | 25 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32} 26 | 27 | assert_all_close( 28 | outputs.hidden_state[[.., 1..3, 1..3]], 29 | Nx.tensor([ 30 | [[-0.9281, 1.2373, 0.4223], [-1.1549, 2.1187, -0.9194], [0.0237, -0.7517, 0.5720]] 31 | ]) 32 | ) 33 | end 34 | 35 | test ":for_causal_language_modeling" do 36 | assert {:ok, %{model: model, params: params, spec: spec}} = 37 | Bumblebee.load_model( 38 | {:hf, "hf-internal-testing/tiny-random-BlipForConditionalGeneration"}, 39 | module: Bumblebee.Text.BlipText, 40 | architecture: :for_causal_language_modeling 41 | ) 42 | 43 | assert %Bumblebee.Text.BlipText{architecture: :for_causal_language_modeling} = spec 44 | 45 | inputs = %{ 46 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 47 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 48 | } 49 | 50 | outputs = Axon.predict(model, params, inputs) 51 | 52 | assert Nx.shape(outputs.logits) == {1, 10, 1124} 53 | 54 | assert_all_close( 55 | outputs.logits[[.., 1..3, 1..3]], 56 | Nx.tensor([ 57 | [[0.0736, -0.0142, 0.2178], [0.0744, 0.0990, 0.1510], [-0.1186, -0.1449, -0.0643]] 58 | ]) 59 | ) 60 | end 61 | end 62 | -------------------------------------------------------------------------------- /test/bumblebee/text/camembert_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.CamembertTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-camembert"}) 11 | 12 | assert %Bumblebee.Text.Roberta{architecture: :base} = spec 13 | 14 | inputs = %{ 15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 17 | } 18 | 19 | outputs = Axon.predict(model, params, inputs) 20 | 21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32} 22 | 23 | assert_all_close( 24 | outputs.hidden_state[[.., 1..3, 1..3]], 25 | Nx.tensor([ 26 | [[-0.1734, -0.5058, 0.6278], [-0.2506, -0.3877, -0.0394], [-0.4477, 1.9433, -0.7990]] 27 | ]) 28 | ) 29 | end 30 | end 31 | -------------------------------------------------------------------------------- /test/bumblebee/text/clip_text_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.ClipTextTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-CLIPModel"}, 11 | module: Bumblebee.Text.ClipText, 12 | architecture: :base 13 | ) 14 | 15 | assert %Bumblebee.Text.ClipText{architecture: :base} = spec 16 | 17 | inputs = %{ 18 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 19 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 20 | } 21 | 22 | outputs = Axon.predict(model, params, inputs) 23 | 24 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32} 25 | assert Nx.shape(outputs.pooled_state) == {1, 32} 26 | 27 | assert_all_close( 28 | outputs.hidden_state[[.., 1..3, 1..3]], 29 | Nx.tensor([ 30 | [[0.1696, -0.2324, -0.1659], [-0.0525, -0.3103, 0.1557], [-0.2566, -0.4519, 0.6398]] 31 | ]) 32 | ) 33 | 34 | assert_all_close( 35 | outputs.pooled_state[[.., 1..3]], 36 | Nx.tensor([[-0.6903, -1.2524, 1.5328]]) 37 | ) 38 | end 39 | 40 | test ":for_embedding" do 41 | assert {:ok, %{model: model, params: params, spec: spec}} = 42 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-CLIPModel"}, 43 | module: Bumblebee.Text.ClipText, 44 | architecture: :for_embedding 45 | ) 46 | 47 | assert %Bumblebee.Text.ClipText{architecture: :for_embedding} = spec 48 | 49 | inputs = %{ 50 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 51 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 52 | } 53 | 54 | outputs = Axon.predict(model, params, inputs) 55 | 56 | assert Nx.shape(outputs.embedding) == {1, 64} 57 | 58 | assert_all_close( 59 | outputs.embedding[[.., 1..3]], 60 | Nx.tensor([[1.1069, -0.0839, -1.6185]]) 61 | ) 62 | end 63 | end 64 | -------------------------------------------------------------------------------- /test/bumblebee/text/fill_mask_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.FillMaskTest do 2 | use ExUnit.Case, async: false 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag serving_test_tags() 7 | 8 | test "returns top scored tokens" do 9 | {:ok, model_info} = Bumblebee.load_model({:hf, "google-bert/bert-base-uncased"}) 10 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-uncased"}) 11 | 12 | serving = Bumblebee.Text.FillMask.fill_mask(model_info, tokenizer) 13 | 14 | text = "The capital of [MASK] is Paris." 15 | 16 | assert %{ 17 | predictions: [ 18 | %{score: _, token: "france"}, 19 | %{score: _, token: "brittany"}, 20 | %{score: _, token: "algeria"}, 21 | %{score: _, token: "department"}, 22 | %{score: _, token: "reunion"} 23 | ] 24 | } = Nx.Serving.run(serving, text) 25 | end 26 | 27 | test "raises when there isn't exactly one mask token" do 28 | {:ok, model_info} = Bumblebee.load_model({:hf, "google-bert/bert-base-uncased"}) 29 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-uncased"}) 30 | 31 | serving = Bumblebee.Text.FillMask.fill_mask(model_info, tokenizer) 32 | 33 | assert_raise ArgumentError, 34 | ~s/expected exactly one occurrence of [MASK], got: 0 in "The capital of France is Paris."/, 35 | fn -> 36 | Nx.Serving.run(serving, "The capital of France is Paris.") 37 | end 38 | 39 | assert_raise ArgumentError, 40 | ~s/expected exactly one occurrence of [MASK], got: 2 in "The [MASK] of [MASK] is Paris."/, 41 | fn -> 42 | Nx.Serving.run(serving, "The [MASK] of [MASK] is Paris.") 43 | end 44 | end 45 | end 46 | -------------------------------------------------------------------------------- /test/bumblebee/text/gemma_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.GemmaTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-GemmaModel"}) 11 | 12 | assert %Bumblebee.Text.Gemma{architecture: :base} = spec 13 | 14 | inputs = %{ 15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 17 | } 18 | 19 | outputs = Axon.predict(model, params, inputs) 20 | 21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32} 22 | 23 | assert_all_close( 24 | outputs.hidden_state[[.., 1..3, 1..3]], 25 | Nx.tensor([ 26 | [[-2.0724, -1.1056, -0.4123], [0.7144, -1.5150, -1.8728], [0.7222, 0.7930, 0.3218]] 27 | ]) 28 | ) 29 | end 30 | 31 | test ":for_sequence_classification" do 32 | assert {:ok, %{model: model, params: params, spec: spec}} = 33 | Bumblebee.load_model( 34 | {:hf, "bumblebee-testing/tiny-random-GemmaForSequenceClassification"} 35 | ) 36 | 37 | assert %Bumblebee.Text.Gemma{architecture: :for_sequence_classification} = spec 38 | 39 | inputs = %{ 40 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 41 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 42 | } 43 | 44 | outputs = Axon.predict(model, params, inputs) 45 | 46 | assert Nx.shape(outputs.logits) == {1, 2} 47 | 48 | assert_all_close( 49 | outputs.logits, 50 | Nx.tensor([[-0.1422, 0.0613]]) 51 | ) 52 | end 53 | 54 | test ":for_causal_language_modeling" do 55 | assert {:ok, %{model: model, params: params, spec: spec}} = 56 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-GemmaForCausalLM"}) 57 | 58 | assert %Bumblebee.Text.Gemma{architecture: :for_causal_language_modeling} = spec 59 | 60 | inputs = %{ 61 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 62 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 63 | } 64 | 65 | outputs = Axon.predict(model, params, inputs) 66 | 67 | assert Nx.shape(outputs.logits) == {1, 10, 1024} 68 | 69 | assert_all_close( 70 | outputs.logits[[.., 1..3, 1..3]], 71 | Nx.tensor([ 72 | [[0.0924, 0.1602, -0.0448], [-0.0934, -0.0543, -0.1045], [-0.1467, 0.0339, -0.1926]] 73 | ]) 74 | ) 75 | end 76 | end 77 | -------------------------------------------------------------------------------- /test/bumblebee/text/generation_config_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.GenerationConfigTest do 2 | use ExUnit.Case, async: true 3 | 4 | alias Bumblebee.Text.GenerationConfig 5 | 6 | describe "config/2" do 7 | test "ensures either of length and new token options are set" do 8 | assert %GenerationConfig{max_length: 10, max_new_tokens: nil} = 9 | GenerationConfig.config(%GenerationConfig{max_new_tokens: 10}, max_length: 10) 10 | 11 | assert %GenerationConfig{max_length: nil, max_new_tokens: 10} = 12 | GenerationConfig.config(%GenerationConfig{max_length: 10}, max_new_tokens: 10) 13 | 14 | assert %GenerationConfig{min_length: 10, min_new_tokens: nil} = 15 | GenerationConfig.config(%GenerationConfig{min_new_tokens: 10}, min_length: 10) 16 | 17 | assert %GenerationConfig{min_length: nil, min_new_tokens: 10} = 18 | GenerationConfig.config(%GenerationConfig{min_length: 10}, min_new_tokens: 10) 19 | end 20 | 21 | test "raises if both length and new token options are set" do 22 | assert_raise ArgumentError, 23 | "only one of :max_new_tokens or :max_length options must be given, but got both", 24 | fn -> 25 | GenerationConfig.config(%GenerationConfig{}, 26 | max_length: 10, 27 | max_new_tokens: 10 28 | ) 29 | end 30 | 31 | assert_raise ArgumentError, 32 | "only one of :min_new_tokens or :min_length options must be given, but got both", 33 | fn -> 34 | GenerationConfig.config(%GenerationConfig{}, 35 | min_length: 10, 36 | min_new_tokens: 10 37 | ) 38 | end 39 | end 40 | 41 | test "raises on invalid strategy" do 42 | assert_raise ArgumentError, 43 | "expected strategy type to be either :greedy_search or :contrastive_search, got: :invalid", 44 | fn -> 45 | GenerationConfig.config(%GenerationConfig{}, strategy: %{type: :invalid}) 46 | end 47 | 48 | assert_raise ArgumentError, 49 | "expected strategy to have :type, but was not present in %{}", 50 | fn -> 51 | GenerationConfig.config(%GenerationConfig{}, strategy: %{}) 52 | end 53 | 54 | assert_raise ArgumentError, 55 | "expected strategy to be a map, but got: :greedy_search", 56 | fn -> 57 | GenerationConfig.config(%GenerationConfig{}, strategy: :greedy_search) 58 | end 59 | 60 | assert_raise ArgumentError, 61 | "missing keys [:alpha, :top_k] for strategy :contrastive_search", 62 | fn -> 63 | GenerationConfig.config(%GenerationConfig{}, 64 | strategy: %{type: :contrastive_search} 65 | ) 66 | end 67 | 68 | assert_raise ArgumentError, 69 | "unexpected keys [:unexpected] for strategy :contrastive_search", 70 | fn -> 71 | GenerationConfig.config(%GenerationConfig{}, 72 | strategy: %{ 73 | type: :contrastive_search, 74 | top_k: 4, 75 | alpha: 0.6, 76 | unexpected: true 77 | } 78 | ) 79 | end 80 | end 81 | end 82 | end 83 | -------------------------------------------------------------------------------- /test/bumblebee/text/generation_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.GenerationTest do 2 | use ExUnit.Case, async: false 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test "decoder model" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) 11 | 12 | {:ok, generation_config} = 13 | Bumblebee.load_generation_config({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) 14 | 15 | assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec 16 | 17 | inputs = %{ 18 | "input_ids" => Nx.tensor([[0, 0, 10, 20, 30, 40, 50, 60, 70, 80]]), 19 | "attention_mask" => Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]), 20 | "seed" => Nx.tensor([0]) 21 | } 22 | 23 | generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3) 24 | 25 | generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config) 26 | %{token_ids: token_ids} = generate.(params, inputs) 27 | 28 | assert_equal(token_ids, Nx.tensor([[80, 80, 80]])) 29 | end 30 | 31 | test "encoder-decoder model" do 32 | assert {:ok, %{model: model, params: params, spec: spec}} = 33 | Bumblebee.load_model( 34 | {:hf, "hf-internal-testing/tiny-random-BartForConditionalGeneration"} 35 | ) 36 | 37 | {:ok, generation_config} = 38 | Bumblebee.load_generation_config( 39 | {:hf, "hf-internal-testing/tiny-random-BartForConditionalGeneration"} 40 | ) 41 | 42 | assert %Bumblebee.Text.Bart{architecture: :for_conditional_generation} = spec 43 | 44 | inputs = %{ 45 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 46 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), 47 | "seed" => Nx.tensor([0]) 48 | } 49 | 50 | generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3) 51 | 52 | generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config) 53 | %{token_ids: token_ids} = generate.(params, inputs) 54 | 55 | assert_equal(token_ids, Nx.tensor([[988, 988, 988]])) 56 | end 57 | 58 | test "encoder-decoder model and lower precision" do 59 | assert {:ok, %{model: model, params: params, spec: spec}} = 60 | Bumblebee.load_model( 61 | {:hf, "hf-internal-testing/tiny-random-BartForConditionalGeneration"}, 62 | type: :f16 63 | ) 64 | 65 | {:ok, generation_config} = 66 | Bumblebee.load_generation_config( 67 | {:hf, "hf-internal-testing/tiny-random-BartForConditionalGeneration"} 68 | ) 69 | 70 | assert %Bumblebee.Text.Bart{architecture: :for_conditional_generation} = spec 71 | 72 | inputs = %{ 73 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 74 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), 75 | "seed" => Nx.tensor([0]) 76 | } 77 | 78 | generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3) 79 | 80 | generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config) 81 | %{token_ids: token_ids} = generate.(params, inputs) 82 | 83 | assert_equal(token_ids, Nx.tensor([[988, 988, 988]])) 84 | end 85 | 86 | test "multiple end-of-sequence token ids" do 87 | assert {:ok, %{model: model, params: params, spec: spec}} = 88 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) 89 | 90 | {:ok, generation_config} = 91 | Bumblebee.load_generation_config({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) 92 | 93 | assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec 94 | 95 | inputs = %{ 96 | "input_ids" => Nx.tensor([[0, 0, 10, 20, 30, 40, 50, 60, 70, 80]]), 97 | "attention_mask" => Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]), 98 | "seed" => Nx.tensor([0]) 99 | } 100 | 101 | generation_config = 102 | Bumblebee.configure(generation_config, max_new_tokens: 3, eos_token_id: [0, 80]) 103 | 104 | generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config) 105 | %{token_ids: token_ids} = generate.(params, inputs) 106 | 107 | assert_equal(token_ids, Nx.tensor([[80, 1023, 1023]])) 108 | end 109 | end 110 | -------------------------------------------------------------------------------- /test/bumblebee/text/gpt2_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.Gpt2Test do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"}) 11 | 12 | assert %Bumblebee.Text.Gpt2{architecture: :base} = spec 13 | 14 | inputs = %{ 15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 17 | } 18 | 19 | outputs = Axon.predict(model, params, inputs) 20 | 21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32} 22 | 23 | assert_all_close( 24 | outputs.hidden_state[[.., 1..3, 1..3]], 25 | Nx.tensor([ 26 | [[-0.8136, -0.2392, 0.2378], [0.9714, -0.4651, 0.8788], [-0.0980, 0.2294, -1.1416]] 27 | ]) 28 | ) 29 | end 30 | 31 | test ":for_causal_language_modeling" do 32 | assert {:ok, %{model: model, params: params, spec: spec}} = 33 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) 34 | 35 | assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec 36 | 37 | inputs = %{ 38 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 39 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 40 | } 41 | 42 | outputs = Axon.predict(model, params, inputs) 43 | 44 | assert Nx.shape(outputs.logits) == {1, 10, 1024} 45 | 46 | assert_all_close( 47 | outputs.logits[[.., 1..3, 1..3]], 48 | Nx.tensor([[[0.1184, -0.0259, 0.1688], [0.1064, 0.1412, 0.1120], [0.1421, -0.2010, 0.3757]]]) 49 | ) 50 | end 51 | 52 | test ":for_token_classification" do 53 | assert {:ok, %{model: model, params: params, spec: spec}} = 54 | Bumblebee.load_model( 55 | {:hf, "hf-internal-testing/tiny-random-GPT2ForTokenClassification"} 56 | ) 57 | 58 | assert %Bumblebee.Text.Gpt2{architecture: :for_token_classification} = spec 59 | 60 | inputs = %{ 61 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 62 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 63 | } 64 | 65 | outputs = Axon.predict(model, params, inputs) 66 | 67 | assert Nx.shape(outputs.logits) == {1, 10, 2} 68 | 69 | assert_all_close( 70 | outputs.logits[[.., 1..3//1, ..]], 71 | Nx.tensor([[[0.0207, 0.1338], [-0.1582, -0.0384], [-0.2225, -0.0400]]]) 72 | ) 73 | end 74 | 75 | test ":for_sequence_classification" do 76 | assert {:ok, %{model: model, params: params, spec: spec}} = 77 | Bumblebee.load_model( 78 | {:hf, "hf-internal-testing/tiny-random-GPT2ForSequenceClassification"} 79 | ) 80 | 81 | assert %Bumblebee.Text.Gpt2{architecture: :for_sequence_classification} = spec 82 | 83 | inputs = %{ 84 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 1023, 1023]]), 85 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 86 | } 87 | 88 | outputs = Axon.predict(model, params, inputs) 89 | 90 | assert Nx.shape(outputs.logits) == {1, 2} 91 | 92 | assert_all_close( 93 | outputs.logits, 94 | Nx.tensor([[-0.0098, -0.0456]]) 95 | ) 96 | end 97 | end 98 | -------------------------------------------------------------------------------- /test/bumblebee/text/gpt_big_code_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.GptBigCodeTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPTBigCodeModel"}) 11 | 12 | assert %Bumblebee.Text.GptBigCode{architecture: :base} = spec 13 | 14 | inputs = %{ 15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 17 | } 18 | 19 | outputs = Axon.predict(model, params, inputs) 20 | 21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32} 22 | 23 | assert_all_close( 24 | outputs.hidden_state[[.., 1..3, 1..3]], 25 | Nx.tensor([ 26 | [[-0.8193, 0.5945, -0.2915], [0.0150, 0.4736, 0.5148], [-0.4247, -1.8000, -1.6479]] 27 | ]) 28 | ) 29 | end 30 | 31 | test ":base without multi-query attention" do 32 | # We have a separate test to test parameter loading without 33 | # multi-query attention, because the parameters layout differs 34 | 35 | assert {:ok, %{model: model, params: params, spec: spec}} = 36 | Bumblebee.load_model( 37 | {:hf, "bumblebee-testing/tiny-random-GPTBigCodeModel-multi_query-False"} 38 | ) 39 | 40 | assert %Bumblebee.Text.GptBigCode{architecture: :base} = spec 41 | 42 | inputs = %{ 43 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 44 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 45 | } 46 | 47 | outputs = Axon.predict(model, params, inputs) 48 | 49 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32} 50 | 51 | assert_all_close( 52 | outputs.hidden_state[[.., 1..3, 1..3]], 53 | Nx.tensor([ 54 | [[-1.3966, 0.6641, -1.3937], [-0.5489, 0.3397, 0.4567], [-0.6488, -1.6745, -1.1570]] 55 | ]) 56 | ) 57 | end 58 | 59 | test ":for_causal_language_modeling" do 60 | assert {:ok, %{model: model, params: params, spec: spec}} = 61 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPTBigCodeForCausalLM"}) 62 | 63 | assert %Bumblebee.Text.GptBigCode{architecture: :for_causal_language_modeling} = spec 64 | 65 | inputs = %{ 66 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 67 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 68 | } 69 | 70 | outputs = Axon.predict(model, params, inputs) 71 | 72 | assert Nx.shape(outputs.logits) == {1, 10, 1024} 73 | 74 | assert_all_close( 75 | outputs.logits[[.., 1..3, 1..3]], 76 | Nx.tensor([ 77 | [[-0.1509, -0.1751, 0.1848], [-0.0860, -0.2476, 0.3373], [-0.2671, -0.2028, -0.0896]] 78 | ]) 79 | ) 80 | end 81 | 82 | test ":for_token_classification" do 83 | assert {:ok, %{model: model, params: params, spec: spec}} = 84 | Bumblebee.load_model( 85 | {:hf, "hf-internal-testing/tiny-random-GPTBigCodeForTokenClassification"} 86 | ) 87 | 88 | assert %Bumblebee.Text.GptBigCode{architecture: :for_token_classification} = spec 89 | 90 | inputs = %{ 91 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 92 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 93 | } 94 | 95 | outputs = Axon.predict(model, params, inputs) 96 | 97 | assert Nx.shape(outputs.logits) == {1, 10, 2} 98 | 99 | assert_all_close( 100 | outputs.logits[[.., 1..3//1, ..]], 101 | Nx.tensor([[[-0.0775, -0.0276], [0.0634, 0.0396], [-0.0695, 0.1575]]]) 102 | ) 103 | end 104 | 105 | test ":for_sequence_classification" do 106 | assert {:ok, %{model: model, params: params, spec: spec}} = 107 | Bumblebee.load_model( 108 | {:hf, "hf-internal-testing/tiny-random-GPTBigCodeForSequenceClassification"} 109 | ) 110 | 111 | assert %Bumblebee.Text.GptBigCode{architecture: :for_sequence_classification} = spec 112 | 113 | inputs = %{ 114 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 1021, 1021]]), 115 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 116 | } 117 | 118 | outputs = Axon.predict(model, params, inputs) 119 | 120 | assert Nx.shape(outputs.logits) == {1, 2} 121 | 122 | assert_all_close( 123 | outputs.logits, 124 | Nx.tensor([[0.1722, 0.1999]]) 125 | ) 126 | end 127 | end 128 | -------------------------------------------------------------------------------- /test/bumblebee/text/gpt_neo_x_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.GptNeoXTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPTNeoXModel"}) 11 | 12 | assert %Bumblebee.Text.GptNeoX{architecture: :base} = spec 13 | 14 | inputs = %{ 15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 17 | } 18 | 19 | outputs = Axon.predict(model, params, inputs) 20 | 21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32} 22 | 23 | assert_all_close( 24 | outputs.hidden_state[[.., 1..3, 1..3]], 25 | Nx.tensor([ 26 | [[0.4428, 0.3349, -1.1917], [-0.1550, -0.4439, -0.5855], [0.3737, 3.4893, -0.6499]] 27 | ]) 28 | ) 29 | end 30 | 31 | test ":for_sequence_classification" do 32 | assert {:ok, %{model: model, params: params, spec: spec}} = 33 | Bumblebee.load_model( 34 | {:hf, "hf-internal-testing/tiny-random-GPTNeoXForSequenceClassification"} 35 | ) 36 | 37 | assert %Bumblebee.Text.GptNeoX{architecture: :for_sequence_classification} = spec 38 | 39 | inputs = %{ 40 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 41 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 42 | } 43 | 44 | outputs = Axon.predict(model, params, inputs) 45 | 46 | assert Nx.shape(outputs.logits) == {1, 2} 47 | 48 | assert_all_close( 49 | outputs.logits, 50 | Nx.tensor([[0.1089, -0.3733]]) 51 | ) 52 | end 53 | 54 | test ":for_token_classification" do 55 | assert {:ok, %{model: model, params: params, spec: spec}} = 56 | Bumblebee.load_model( 57 | {:hf, "hf-internal-testing/tiny-random-GPTNeoXForTokenClassification"} 58 | ) 59 | 60 | assert %Bumblebee.Text.GptNeoX{architecture: :for_token_classification} = spec 61 | 62 | inputs = %{ 63 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 64 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 65 | } 66 | 67 | outputs = Axon.predict(model, params, inputs) 68 | 69 | assert Nx.shape(outputs.logits) == {1, 10, 2} 70 | 71 | assert_all_close( 72 | outputs.logits[[.., 1..3//1, ..]], 73 | Nx.tensor([[[-0.0900, -0.1853], [0.0567, -0.0443], [-0.0104, -0.1112]]]) 74 | ) 75 | end 76 | 77 | test ":for_causal_language_modeling" do 78 | assert {:ok, %{model: model, params: params, spec: spec}} = 79 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPTNeoXForCausalLM"}) 80 | 81 | assert %Bumblebee.Text.GptNeoX{architecture: :for_causal_language_modeling} = spec 82 | 83 | inputs = %{ 84 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 85 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 86 | } 87 | 88 | outputs = Axon.predict(model, params, inputs) 89 | 90 | assert Nx.shape(outputs.logits) == {1, 10, 1024} 91 | 92 | assert_all_close( 93 | outputs.logits[[.., 1..3, 1..3]], 94 | Nx.tensor([ 95 | [[0.1134, 0.0507, -0.0534], [-0.1113, 0.0035, -0.0319], [0.0019, -0.0273, -0.0151]] 96 | ]) 97 | ) 98 | end 99 | end 100 | -------------------------------------------------------------------------------- /test/bumblebee/text/llama_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.LlamaTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-LlamaModel"}) 11 | 12 | assert %Bumblebee.Text.Llama{architecture: :base} = spec 13 | 14 | inputs = %{ 15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 17 | } 18 | 19 | outputs = Axon.predict(model, params, inputs) 20 | 21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32} 22 | 23 | assert_all_close( 24 | outputs.hidden_state[[.., 1..3, 1..3]], 25 | Nx.tensor([ 26 | [[1.4799, -2.0333, 0.4759], [2.3749, -0.8369, -0.0206], [0.5767, -0.0515, -1.1795]] 27 | ]) 28 | ) 29 | end 30 | 31 | test ":base rotary embedding scaling strategy :llama3" do 32 | assert {:ok, %{model: model, params: params, spec: spec}} = 33 | Bumblebee.load_model( 34 | {:hf, 35 | "bumblebee-testing/tiny-random-LlamaModel-rope_scaling-llama3-original_max_position_embeddings-64"} 36 | ) 37 | 38 | assert %Bumblebee.Text.Llama{architecture: :base} = spec 39 | 40 | inputs = %{ 41 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 42 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 43 | } 44 | 45 | outputs = Axon.predict(model, params, inputs) 46 | 47 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32} 48 | 49 | assert_all_close( 50 | outputs.hidden_state[[.., 1..3, 1..3]], 51 | Nx.tensor([ 52 | [[1.4802, -2.0331, 0.4759], [2.3749, -0.8367, -0.0205], [0.5762, -0.0517, -1.1795]] 53 | ]) 54 | ) 55 | end 56 | 57 | test ":for_sequence_classification" do 58 | assert {:ok, %{model: model, params: params, spec: spec}} = 59 | Bumblebee.load_model( 60 | {:hf, "bumblebee-testing/tiny-random-LlamaForSequenceClassification"} 61 | ) 62 | 63 | assert %Bumblebee.Text.Llama{architecture: :for_sequence_classification} = spec 64 | 65 | inputs = %{ 66 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 67 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 68 | } 69 | 70 | outputs = Axon.predict(model, params, inputs) 71 | 72 | assert Nx.shape(outputs.logits) == {1, 2} 73 | 74 | assert_all_close( 75 | outputs.logits, 76 | Nx.tensor([[-0.1964, -0.1069]]) 77 | ) 78 | end 79 | 80 | test ":for_causal_language_modeling" do 81 | assert {:ok, %{model: model, params: params, spec: spec}} = 82 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-LlamaForCausalLM"}) 83 | 84 | assert %Bumblebee.Text.Llama{architecture: :for_causal_language_modeling} = spec 85 | 86 | inputs = %{ 87 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 88 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 89 | } 90 | 91 | outputs = Axon.predict(model, params, inputs) 92 | 93 | assert Nx.shape(outputs.logits) == {1, 10, 1024} 94 | 95 | assert_all_close( 96 | outputs.logits[[.., 1..3, 1..3]], 97 | Nx.tensor([ 98 | [[0.0469, -0.0751, 0.0349], [0.0617, -0.1357, -0.0204], [-0.1495, 0.0557, -0.0737]] 99 | ]) 100 | ) 101 | end 102 | end 103 | -------------------------------------------------------------------------------- /test/bumblebee/text/m2m100_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.M2m100Test do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-M2M100Model"}, 11 | architecture: :base 12 | ) 13 | 14 | assert %Bumblebee.Text.M2m100{architecture: :base} = spec 15 | 16 | input = %{ 17 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 18 | "decoder_input_ids" => Nx.tensor([[15, 25, 35, 45, 55, 65, 0, 0]]), 19 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 20 | } 21 | 22 | output = Axon.predict(model, params, input) 23 | 24 | assert Nx.shape(output.hidden_state) == {1, 8, 16} 25 | 26 | assert_all_close( 27 | output.hidden_state[[.., 1..3, 1..3]], 28 | Nx.tensor([ 29 | [ 30 | [0.7856, -0.3174, -0.4792], 31 | [0.7265, -0.2752, -0.4823], 32 | [1.0580, -0.3263, -0.7994] 33 | ] 34 | ]) 35 | ) 36 | end 37 | 38 | test ":for_conditional_generation" do 39 | assert {:ok, %{model: model, params: params, spec: spec}} = 40 | Bumblebee.load_model( 41 | {:hf, "hf-internal-testing/tiny-random-M2M100ForConditionalGeneration"}, 42 | architecture: :for_conditional_generation 43 | ) 44 | 45 | assert %Bumblebee.Text.M2m100{architecture: :for_conditional_generation} = spec 46 | 47 | input = %{ 48 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 49 | "decoder_input_ids" => Nx.tensor([[15, 25, 35, 45, 55, 65, 0, 0]]), 50 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 51 | } 52 | 53 | output = Axon.predict(model, params, input) 54 | 55 | assert Nx.shape(output.logits) == {1, 8, 128_112} 56 | 57 | assert_all_close( 58 | output.logits[[.., 1..3, 1..3]], 59 | Nx.tensor([ 60 | [ 61 | [0.0000, -0.0323, 0.0527], 62 | [0.0000, -0.0404, 0.0713], 63 | [0.0000, -0.0660, 0.0758] 64 | ] 65 | ]) 66 | ) 67 | end 68 | end 69 | -------------------------------------------------------------------------------- /test/bumblebee/text/mbart_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.MbartTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MBartModel"}) 11 | 12 | assert %Bumblebee.Text.Mbart{architecture: :base} = spec 13 | 14 | inputs = %{ 15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 17 | } 18 | 19 | outputs = Axon.predict(model, params, inputs) 20 | 21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 16} 22 | 23 | assert_all_close( 24 | outputs.hidden_state[[.., 1..3, 1..3]], 25 | Nx.tensor([ 26 | [[0.8300, -0.4815, 0.4641], [-1.6583, 0.9162, -0.3562], [-0.6983, -0.7699, 1.0282]] 27 | ]) 28 | ) 29 | end 30 | 31 | test ":for_conditional_generation" do 32 | assert {:ok, %{model: model, params: params, spec: spec}} = 33 | Bumblebee.load_model( 34 | {:hf, "hf-internal-testing/tiny-random-MBartForConditionalGeneration"} 35 | ) 36 | 37 | assert %Bumblebee.Text.Mbart{architecture: :for_conditional_generation} = spec 38 | 39 | inputs = %{ 40 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 41 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 42 | } 43 | 44 | outputs = Axon.predict(model, params, inputs) 45 | 46 | assert Nx.shape(outputs.logits) == {1, 10, 250_027} 47 | 48 | assert_all_close( 49 | outputs.logits[[.., 1..3, 1..3]], 50 | Nx.tensor([[[0.0000, 0.0923, 0.0841], [0.0000, 0.1023, -0.0938], [0.0000, 0.0703, 0.1231]]]) 51 | ) 52 | end 53 | 54 | test ":for_sequence_classification" do 55 | assert {:ok, %{model: model, params: params, spec: spec}} = 56 | Bumblebee.load_model( 57 | {:hf, "hf-internal-testing/tiny-random-MBartForSequenceClassification"} 58 | ) 59 | 60 | assert %Bumblebee.Text.Mbart{architecture: :for_sequence_classification} = spec 61 | 62 | inputs = %{ 63 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 2, 0]]), 64 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]) 65 | } 66 | 67 | outputs = Axon.predict(model, params, inputs) 68 | 69 | assert Nx.shape(outputs.logits) == {1, 2} 70 | 71 | assert_all_close( 72 | outputs.logits, 73 | Nx.tensor([[0.0085, 0.0054]]) 74 | ) 75 | end 76 | 77 | test ":for_question_answering" do 78 | assert {:ok, %{model: model, params: params, spec: spec}} = 79 | Bumblebee.load_model( 80 | {:hf, "hf-internal-testing/tiny-random-MBartForQuestionAnswering"} 81 | ) 82 | 83 | assert %Bumblebee.Text.Mbart{architecture: :for_question_answering} = spec 84 | 85 | inputs = %{ 86 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 87 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 88 | } 89 | 90 | outputs = Axon.predict(model, params, inputs) 91 | 92 | assert Nx.shape(outputs.start_logits) == {1, 10} 93 | assert Nx.shape(outputs.end_logits) == {1, 10} 94 | 95 | assert_all_close( 96 | outputs.start_logits[[.., 1..3]], 97 | Nx.tensor([[0.1063, -0.1271, -0.1534]]) 98 | ) 99 | 100 | assert_all_close( 101 | outputs.end_logits[[.., 1..3]], 102 | Nx.tensor([[0.0268, 0.0238, 0.0857]]) 103 | ) 104 | end 105 | 106 | test ":for_causal_language_modeling" do 107 | assert {:ok, %{model: model, params: params, spec: spec}} = 108 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MBartForCausalLM"}) 109 | 110 | assert %Bumblebee.Text.Mbart{architecture: :for_causal_language_modeling} = spec 111 | 112 | inputs = %{ 113 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 114 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 115 | } 116 | 117 | outputs = Axon.predict(model, params, inputs) 118 | 119 | assert Nx.shape(outputs.logits) == {1, 10, 250_027} 120 | 121 | assert_all_close( 122 | outputs.logits[[.., 1..3, 1..3]], 123 | Nx.tensor([ 124 | [[0.0000, -0.0236, -0.0043], [0.0000, -0.0101, 0.0510], [0.0000, 0.0404, 0.0327]] 125 | ]) 126 | ) 127 | end 128 | end 129 | -------------------------------------------------------------------------------- /test/bumblebee/text/mistral_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.MistralTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MistralModel"}) 11 | 12 | assert %Bumblebee.Text.Mistral{architecture: :base} = spec 13 | 14 | inputs = %{ 15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 17 | } 18 | 19 | outputs = Axon.predict(model, params, inputs) 20 | 21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32} 22 | 23 | assert_all_close( 24 | outputs.hidden_state[[.., 1..3, 1..3]], 25 | Nx.tensor([ 26 | [[0.9450, -1.3945, 0.7331], [-2.1118, -1.3091, -0.7834], [-1.7609, -1.3034, 1.0634]] 27 | ]) 28 | ) 29 | end 30 | 31 | test ":base with attention sliding window" do 32 | assert {:ok, spec} = 33 | Bumblebee.load_spec({:hf, "hf-internal-testing/tiny-random-MistralModel"}) 34 | 35 | spec = Bumblebee.configure(spec, attention_window_size: 2) 36 | 37 | assert {:ok, %{model: model, params: params, spec: spec}} = 38 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MistralModel"}, 39 | spec: spec 40 | ) 41 | 42 | assert %Bumblebee.Text.Mistral{architecture: :base} = spec 43 | 44 | inputs = %{ 45 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 46 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 47 | } 48 | 49 | outputs = Axon.predict(model, params, inputs) 50 | 51 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32} 52 | 53 | assert_all_close( 54 | outputs.hidden_state[[.., 1..3, 1..3]], 55 | Nx.tensor([ 56 | [[0.9450, -1.3945, 0.7331], [-2.1118, -1.3091, -0.7834], [-1.3033, -1.3374, 0.8919]] 57 | ]) 58 | ) 59 | end 60 | 61 | test ":for_sequence_classification" do 62 | assert {:ok, %{model: model, params: params, spec: spec}} = 63 | Bumblebee.load_model( 64 | {:hf, "hf-internal-testing/tiny-random-MistralForSequenceClassification"} 65 | ) 66 | 67 | assert %Bumblebee.Text.Mistral{architecture: :for_sequence_classification} = spec 68 | 69 | inputs = %{ 70 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 71 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 72 | } 73 | 74 | outputs = Axon.predict(model, params, inputs) 75 | 76 | assert Nx.shape(outputs.logits) == {1, 2} 77 | 78 | assert_all_close( 79 | outputs.logits, 80 | Nx.tensor([[0.0035, -0.0357]]) 81 | ) 82 | end 83 | 84 | test ":for_causal_language_modeling" do 85 | assert {:ok, %{model: model, params: params, spec: spec}} = 86 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MistralForCausalLM"}) 87 | 88 | assert %Bumblebee.Text.Mistral{architecture: :for_causal_language_modeling} = spec 89 | 90 | inputs = %{ 91 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 92 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 93 | } 94 | 95 | outputs = Axon.predict(model, params, inputs) 96 | 97 | assert Nx.shape(outputs.logits) == {1, 10, 32000} 98 | 99 | assert_all_close( 100 | outputs.logits[[.., 1..3, 1..3]], 101 | Nx.tensor([ 102 | [[-0.1054, 0.0026, 0.0450], [0.1400, 0.1388, 0.0265], [0.0060, -0.1150, -0.1463]] 103 | ]) 104 | ) 105 | end 106 | end 107 | -------------------------------------------------------------------------------- /test/bumblebee/text/nllb_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.NllbTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":for_conditional_generation" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-nllb"}, 11 | module: Bumblebee.Text.M2m100, 12 | architecture: :for_conditional_generation 13 | ) 14 | 15 | assert %Bumblebee.Text.M2m100{architecture: :for_conditional_generation} = spec 16 | 17 | input = %{ 18 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 19 | "decoder_input_ids" => Nx.tensor([[15, 25, 35, 45, 55, 65, 0, 0]]), 20 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 21 | } 22 | 23 | output = Axon.predict(model, params, input) 24 | 25 | assert Nx.shape(output.logits) == {1, 8, 128_112} 26 | 27 | assert_all_close( 28 | output.logits[[.., 1..3, 1..3]], 29 | Nx.tensor([ 30 | [ 31 | [0.0000, 0.0169, -0.0698], 32 | [0.0000, 0.0525, -0.1042], 33 | [0.0000, 0.0667, -0.1078] 34 | ] 35 | ]) 36 | ) 37 | end 38 | end 39 | -------------------------------------------------------------------------------- /test/bumblebee/text/phi3_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.Phi3Test do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-Phi3Model"}) 11 | 12 | assert %Bumblebee.Text.Phi3{architecture: :base} = spec 13 | 14 | inputs = %{ 15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 17 | } 18 | 19 | outputs = Axon.predict(model, params, inputs) 20 | 21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32} 22 | 23 | assert_all_close( 24 | outputs.hidden_state[[.., 1..3, 1..3]], 25 | Nx.tensor([ 26 | [[-1.4514, 0.6000, 0.1565], [-0.2677, 1.9352, 0.5334], [1.1021, -0.1642, 0.5992]] 27 | ]) 28 | ) 29 | end 30 | 31 | test ":base rotary embedding scaling strategy :longrope" do 32 | assert {:ok, %{model: model, params: params, spec: spec}} = 33 | Bumblebee.load_model( 34 | {:hf, 35 | "bumblebee-testing/tiny-random-Phi3Model-rope_scaling-longrope-original_max_position_embeddings-256"} 36 | ) 37 | 38 | assert %Bumblebee.Text.Phi3{architecture: :base} = spec 39 | 40 | inputs = %{ 41 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 42 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 43 | } 44 | 45 | outputs = Axon.predict(model, params, inputs) 46 | 47 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32} 48 | 49 | assert_all_close( 50 | outputs.hidden_state[[.., 1..3, 1..3]], 51 | Nx.tensor([ 52 | [[-1.4528, 0.5995, 0.1573], [-0.2664, 1.9339, 0.5336], [1.1053, -0.1643, 0.5989]] 53 | ]) 54 | ) 55 | end 56 | 57 | test ":for_sequence_classification" do 58 | assert {:ok, %{model: model, params: params, spec: spec}} = 59 | Bumblebee.load_model( 60 | {:hf, "bumblebee-testing/tiny-random-Phi3ForSequenceClassification"} 61 | ) 62 | 63 | assert %Bumblebee.Text.Phi3{architecture: :for_sequence_classification} = spec 64 | 65 | inputs = %{ 66 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 67 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 68 | } 69 | 70 | outputs = Axon.predict(model, params, inputs) 71 | 72 | assert Nx.shape(outputs.logits) == {1, 2} 73 | 74 | assert_all_close( 75 | outputs.logits, 76 | Nx.tensor([[0.1249, 0.1090]]) 77 | ) 78 | end 79 | 80 | test ":for_token_classification" do 81 | assert {:ok, %{model: model, params: params, spec: spec}} = 82 | Bumblebee.load_model( 83 | {:hf, "bumblebee-testing/tiny-random-Phi3ForTokenClassification"} 84 | ) 85 | 86 | assert %Bumblebee.Text.Phi3{architecture: :for_token_classification} = spec 87 | 88 | inputs = %{ 89 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 90 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 91 | } 92 | 93 | outputs = Axon.predict(model, params, inputs) 94 | 95 | assert Nx.shape(outputs.logits) == {1, 10, 2} 96 | 97 | assert_all_close( 98 | outputs.logits[[.., 1..3//1, ..]], 99 | Nx.tensor([[[0.0588, -0.0997], [0.0494, -0.1636], [0.0402, 0.0486]]]) 100 | ) 101 | end 102 | 103 | test ":for_causal_language_modeling" do 104 | assert {:ok, %{model: model, params: params, spec: spec}} = 105 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-Phi3ForCausalLM"}) 106 | 107 | assert %Bumblebee.Text.Phi3{architecture: :for_causal_language_modeling} = spec 108 | 109 | inputs = %{ 110 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 111 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 112 | } 113 | 114 | outputs = Axon.predict(model, params, inputs) 115 | 116 | assert Nx.shape(outputs.logits) == {1, 10, 1024} 117 | 118 | assert_all_close( 119 | outputs.logits[[.., 1..3, 1..3]], 120 | Nx.tensor([ 121 | [[-0.0893, 0.0890, -0.1252], [0.0574, 0.0197, -0.0580], [-0.0302, -0.0644, -0.1228]] 122 | ]) 123 | ) 124 | end 125 | end 126 | -------------------------------------------------------------------------------- /test/bumblebee/text/phi_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.PhiTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-PhiModel"}) 11 | 12 | assert %Bumblebee.Text.Phi{architecture: :base} = spec 13 | 14 | inputs = %{ 15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 17 | } 18 | 19 | outputs = Axon.predict(model, params, inputs) 20 | 21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32} 22 | 23 | assert_all_close( 24 | outputs.hidden_state[[.., 1..3, 1..3]], 25 | Nx.tensor([[[-0.3275, 0.5231, 0.5690], [0.2239, 0.5028, 0.4599], [-0.0979, 1.0183, 0.3350]]]) 26 | ) 27 | end 28 | 29 | test ":for_sequence_classification" do 30 | assert {:ok, %{model: model, params: params, spec: spec}} = 31 | Bumblebee.load_model( 32 | {:hf, "bumblebee-testing/tiny-random-PhiForSequenceClassification"} 33 | ) 34 | 35 | assert %Bumblebee.Text.Phi{architecture: :for_sequence_classification} = spec 36 | 37 | inputs = %{ 38 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 39 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 40 | } 41 | 42 | outputs = Axon.predict(model, params, inputs) 43 | 44 | assert Nx.shape(outputs.logits) == {1, 2} 45 | 46 | assert_all_close( 47 | outputs.logits, 48 | Nx.tensor([[0.1403, -0.1382]]) 49 | ) 50 | end 51 | 52 | test ":for_token_classification" do 53 | assert {:ok, %{model: model, params: params, spec: spec}} = 54 | Bumblebee.load_model( 55 | {:hf, "bumblebee-testing/tiny-random-PhiForTokenClassification"} 56 | ) 57 | 58 | assert %Bumblebee.Text.Phi{architecture: :for_token_classification} = spec 59 | 60 | inputs = %{ 61 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 62 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 63 | } 64 | 65 | outputs = Axon.predict(model, params, inputs) 66 | 67 | assert Nx.shape(outputs.logits) == {1, 10, 2} 68 | 69 | assert_all_close( 70 | outputs.logits[[.., 1..3//1, ..]], 71 | Nx.tensor([[[-0.0364, -0.1207], [0.2520, 0.0755], [0.0243, 0.0269]]]) 72 | ) 73 | end 74 | 75 | test ":for_causal_language_modeling" do 76 | assert {:ok, %{model: model, params: params, spec: spec}} = 77 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-PhiForCausalLM"}) 78 | 79 | assert %Bumblebee.Text.Phi{architecture: :for_causal_language_modeling} = spec 80 | 81 | inputs = %{ 82 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 83 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 84 | } 85 | 86 | outputs = Axon.predict(model, params, inputs) 87 | 88 | assert Nx.shape(outputs.logits) == {1, 10, 1024} 89 | 90 | assert_all_close( 91 | outputs.logits[[.., 1..3, 1..3]], 92 | Nx.tensor([[[0.2541, 0.0827, 0.0526], [0.1901, 0.1289, 0.0758], [0.1051, 0.0658, -0.1167]]]) 93 | ) 94 | end 95 | end 96 | -------------------------------------------------------------------------------- /test/bumblebee/text/question_answering_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.QuestionAnsweringTest do 2 | use ExUnit.Case, async: false 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag serving_test_tags() 7 | 8 | test "returns the most probable answer" do 9 | {:ok, roberta} = Bumblebee.load_model({:hf, "deepset/roberta-base-squad2"}) 10 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "FacebookAI/roberta-base"}) 11 | 12 | serving = Bumblebee.Text.question_answering(roberta, tokenizer) 13 | 14 | input = %{question: "What's my name?", context: "My name is Sarah and I live in London."} 15 | 16 | assert %{ 17 | results: [ 18 | %{ 19 | text: "Sarah", 20 | start: 11, 21 | end: 16, 22 | score: score 23 | } 24 | ] 25 | } = Nx.Serving.run(serving, input) 26 | 27 | assert_all_close(score, 0.8105) 28 | end 29 | 30 | test "supports multiple inputs" do 31 | {:ok, roberta} = Bumblebee.load_model({:hf, "deepset/roberta-base-squad2"}) 32 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "FacebookAI/roberta-base"}) 33 | 34 | serving = Bumblebee.Text.question_answering(roberta, tokenizer) 35 | 36 | inputs = [ 37 | %{question: "What's my name?", context: "My name is Sarah and I live in London."}, 38 | %{question: "Where do I live?", context: "My name is Clara and I live in Berkeley."} 39 | ] 40 | 41 | assert [ 42 | %{results: [%{text: "Sarah", start: 11, end: 16, score: _}]}, 43 | %{results: [%{text: "Berkeley", start: 31, end: 39, score: _}]} 44 | ] = Nx.Serving.run(serving, inputs) 45 | end 46 | end 47 | -------------------------------------------------------------------------------- /test/bumblebee/text/text_classification_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.TextClassificationTest do 2 | use ExUnit.Case, async: false 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag serving_test_tags() 7 | 8 | test "returns top scored labels" do 9 | {:ok, model_info} = Bumblebee.load_model({:hf, "cardiffnlp/twitter-roberta-base-emotion"}) 10 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "FacebookAI/roberta-base"}) 11 | 12 | serving = Bumblebee.Text.TextClassification.text_classification(model_info, tokenizer) 13 | 14 | text = "Cats are cute." 15 | 16 | assert %{ 17 | predictions: [ 18 | %{label: "optimism", score: _}, 19 | %{label: "sadness", score: _}, 20 | %{label: "anger", score: _}, 21 | %{label: "joy", score: _} 22 | ] 23 | } = Nx.Serving.run(serving, text) 24 | end 25 | end 26 | -------------------------------------------------------------------------------- /test/bumblebee/text/text_embedding_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.TextEmbeddingTest do 2 | use ExUnit.Case, async: false 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag serving_test_tags() 7 | 8 | test "returns embedding for a piece of text" do 9 | {:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-small-v2"}) 10 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "intfloat/e5-small-v2"}) 11 | 12 | serving = Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer) 13 | 14 | text = "query: Cats are cute." 15 | 16 | assert %{embedding: %Nx.Tensor{} = embedding} = Nx.Serving.run(serving, text) 17 | 18 | assert Nx.shape(embedding) == {384} 19 | 20 | assert_all_close( 21 | embedding[1..3], 22 | Nx.tensor([0.0420, -0.0188, 0.1115]) 23 | ) 24 | end 25 | 26 | test "returns normalized embedding for a piece of text" do 27 | {:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-small-v2"}) 28 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "intfloat/e5-small-v2"}) 29 | 30 | options = [embedding_processor: :l2_norm] 31 | 32 | serving = Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer, options) 33 | 34 | text = "query: Cats are cute." 35 | 36 | assert %{embedding: %Nx.Tensor{} = embedding} = Nx.Serving.run(serving, text) 37 | 38 | assert Nx.shape(embedding) == {384} 39 | 40 | assert_all_close( 41 | embedding[1..3], 42 | Nx.tensor([0.0433, -0.0194, 0.1151]) 43 | ) 44 | 45 | assert_all_close(Nx.sum(Nx.pow(embedding, 2)), Nx.tensor(1.0), atol: 1.0e-6) 46 | end 47 | 48 | test "supports compilation for single or multiple sequence lengths" do 49 | {:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-small-v2"}) 50 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "intfloat/e5-small-v2"}) 51 | 52 | serving_short = 53 | Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer, 54 | compile: [batch_size: 1, sequence_length: 8] 55 | ) 56 | 57 | serving_long = 58 | Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer, 59 | compile: [batch_size: 1, sequence_length: 16] 60 | ) 61 | 62 | serving_both = 63 | Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer, 64 | compile: [batch_size: 1, sequence_length: [8, 16]] 65 | ) 66 | 67 | short_text = "short text" 68 | long_text = "definitely much longer text that should exceed 16 tokens" 69 | 70 | assert %{embedding: embedding_short} = Nx.Serving.run(serving_short, short_text) 71 | assert %{embedding: embedding_long} = Nx.Serving.run(serving_long, long_text) 72 | 73 | assert %{embedding: embedding_short2} = Nx.Serving.run(serving_both, short_text) 74 | assert %{embedding: embedding_long2} = Nx.Serving.run(serving_both, long_text) 75 | 76 | assert_equal(embedding_short, embedding_short2) 77 | assert_equal(embedding_long, embedding_long2) 78 | end 79 | 80 | @tag :multi_device 81 | test "works with partitioned serving", %{test: test} do 82 | {:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-small-v2"}) 83 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "intfloat/e5-small-v2"}) 84 | 85 | serving = 86 | Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer, 87 | compile: [batch_size: 1, sequence_length: 16], 88 | defn_options: [compiler: EXLA, client: :other_host], 89 | preallocate_params: true 90 | ) 91 | 92 | start_supervised!({Nx.Serving, serving: serving, name: test, partitions: true}) 93 | 94 | text = "query: Cats are cute." 95 | 96 | assert [ 97 | %{embedding: %Nx.Tensor{} = embedding1}, 98 | %{embedding: %Nx.Tensor{} = embedding2} 99 | ] = Nx.Serving.batched_run(test, [text, text]) 100 | 101 | assert_equal(embedding1, embedding2) 102 | end 103 | end 104 | -------------------------------------------------------------------------------- /test/bumblebee/text/token_classification_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.TokenClassificationTest do 2 | use ExUnit.Case, async: false 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag serving_test_tags() 7 | 8 | test "correctly extracts entities with :same aggregation" do 9 | assert {:ok, model_info} = Bumblebee.load_model({:hf, "dslim/bert-base-NER"}) 10 | assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-cased"}) 11 | 12 | serving = 13 | Bumblebee.Text.TokenClassification.token_classification(model_info, tokenizer, 14 | aggregation: :same 15 | ) 16 | 17 | text = "I went with Jane Doe to Atlanta and we talked to John Smith about Microsoft" 18 | 19 | assert %{entities: [jane, atlanta, john, microsoft]} = Nx.Serving.run(serving, text) 20 | 21 | assert %{ 22 | label: "PER", 23 | score: _jane_score, 24 | phrase: "Jane Doe", 25 | start: 12, 26 | end: 20 27 | } = jane 28 | 29 | assert %{ 30 | label: "LOC", 31 | score: _atlanta_score, 32 | phrase: "Atlanta", 33 | start: 24, 34 | end: 31 35 | } = atlanta 36 | 37 | assert %{ 38 | label: "PER", 39 | score: _john_score, 40 | phrase: "John Smith", 41 | start: 49, 42 | end: 59 43 | } = john 44 | 45 | assert %{ 46 | label: "ORG", 47 | score: _microsoft_score, 48 | phrase: "Microsoft", 49 | start: 66, 50 | end: 75 51 | } = microsoft 52 | 53 | # Offsets should be expressed in terms of bytes (note that é is 2 bytes) 54 | 55 | text = "Jane é John" 56 | 57 | assert %{ 58 | entities: [%{start: 0, end: 4}, %{start: 8, end: 12}] 59 | } = Nx.Serving.run(serving, text) 60 | end 61 | 62 | for aggregation <- [:word_first, :word_max, :word_average] do 63 | test "correctly extracts entities with :#{aggregation} aggregation" do 64 | assert {:ok, model_info} = Bumblebee.load_model({:hf, "dslim/bert-base-NER"}) 65 | assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-cased"}) 66 | 67 | serving = 68 | Bumblebee.Text.TokenClassification.token_classification(model_info, tokenizer, 69 | aggregation: unquote(aggregation) 70 | ) 71 | 72 | text = "I went with Janine Doe to Atlanta and we talked to John Smith about Microsoft" 73 | 74 | assert %{entities: [jane, atlanta, john, microsoft]} = Nx.Serving.run(serving, text) 75 | 76 | assert %{ 77 | label: "PER", 78 | score: _janine_score, 79 | phrase: "Janine Doe", 80 | start: 12, 81 | end: 22 82 | } = jane 83 | 84 | assert %{ 85 | label: "LOC", 86 | score: _atlanta_score, 87 | phrase: "Atlanta", 88 | start: 26, 89 | end: 33 90 | } = atlanta 91 | 92 | assert %{ 93 | label: "PER", 94 | score: _john_score, 95 | phrase: "John Smith", 96 | start: 51, 97 | end: 61 98 | } = john 99 | 100 | assert %{ 101 | label: "ORG", 102 | score: _microsoft_score, 103 | phrase: "Microsoft", 104 | start: 68, 105 | end: 77 106 | } = microsoft 107 | end 108 | end 109 | 110 | test "correctly extracts entities with simple aggregation on batched input" do 111 | assert {:ok, model_info} = Bumblebee.load_model({:hf, "dslim/bert-base-NER"}) 112 | assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-cased"}) 113 | 114 | serving = 115 | Bumblebee.Text.TokenClassification.token_classification(model_info, tokenizer, 116 | aggregation: :same 117 | ) 118 | 119 | texts = [ 120 | "I went with Janine Doe to Atlanta and we talked to John Smith about Microsoft", 121 | "John went to Philadelphia" 122 | ] 123 | 124 | assert [_first, %{entities: [john, philadelphia]}] = Nx.Serving.run(serving, texts) 125 | 126 | assert %{label: "PER", phrase: "John"} = john 127 | assert %{label: "LOC", phrase: "Philadelphia"} = philadelphia 128 | end 129 | end 130 | -------------------------------------------------------------------------------- /test/bumblebee/text/translation_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.TranslationTest do 2 | use ExUnit.Case, async: false 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag serving_test_tags() 7 | 8 | test "generates text with greedy generation" do 9 | {:ok, model_info} = Bumblebee.load_model({:hf, "facebook/nllb-200-distilled-600M"}) 10 | 11 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/nllb-200-distilled-600M"}) 12 | 13 | {:ok, generation_config} = 14 | Bumblebee.load_generation_config({:hf, "facebook/nllb-200-distilled-600M"}) 15 | 16 | serving = Bumblebee.Text.translation(model_info, tokenizer, generation_config) 17 | 18 | text = "The bank of the river is beautiful in spring" 19 | 20 | assert %{ 21 | results: [ 22 | %{ 23 | text: "W wiosnę brzeg rzeki jest piękny", 24 | token_summary: %{input: 11, output: 13, padding: 0} 25 | } 26 | ] 27 | } = 28 | Nx.Serving.run(serving, %{ 29 | text: text, 30 | source_language_token: "eng_Latn", 31 | target_language_token: "pol_Latn" 32 | }) 33 | end 34 | end 35 | -------------------------------------------------------------------------------- /test/bumblebee/text/xlm_roberta_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.XlmRobertaTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-XLMRobertaModel"}) 11 | 12 | assert %Bumblebee.Text.Roberta{architecture: :base} = spec 13 | 14 | inputs = %{ 15 | "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), 16 | "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) 17 | } 18 | 19 | outputs = Axon.predict(model, params, inputs) 20 | 21 | assert Nx.shape(outputs.hidden_state) == {1, 10, 32} 22 | 23 | assert_all_close( 24 | outputs.hidden_state[[.., 1..3, 1..3]], 25 | Nx.tensor([ 26 | [[-0.6455, -0.4189, 0.3424], [-0.4303, -0.6731, 0.2534], [-0.5240, 0.0864, -0.5632]] 27 | ]) 28 | ) 29 | end 30 | end 31 | -------------------------------------------------------------------------------- /test/bumblebee/text/zero_shot_classification_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Text.ZeroShotClassificationTest do 2 | use ExUnit.Case, async: false 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag serving_test_tags() 7 | 8 | test "correctly classifies labels with one sequence" do 9 | {:ok, model} = Bumblebee.load_model({:hf, "facebook/bart-large-mnli"}) 10 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/bart-large-mnli"}) 11 | labels = ["cooking", "traveling", "dancing"] 12 | 13 | zero_shot_serving = Bumblebee.Text.zero_shot_classification(model, tokenizer, labels) 14 | 15 | output = Nx.Serving.run(zero_shot_serving, "one day I will see the world") 16 | 17 | assert %{ 18 | predictions: [ 19 | %{label: "traveling", score: _}, 20 | %{label: "dancing", score: _}, 21 | %{label: "cooking", score: _} 22 | ] 23 | } = output 24 | 25 | assert %{label: "traveling", score: score} = Enum.max_by(output.predictions, & &1.score) 26 | assert_all_close(score, 0.9874) 27 | end 28 | 29 | test "correctly classifies labels with multiple sequences" do 30 | {:ok, model} = Bumblebee.load_model({:hf, "facebook/bart-large-mnli"}) 31 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/bart-large-mnli"}) 32 | labels = ["cooking", "traveling", "dancing"] 33 | 34 | zero_shot_serving = Bumblebee.Text.zero_shot_classification(model, tokenizer, labels) 35 | 36 | assert [output1, output2] = 37 | Nx.Serving.run(zero_shot_serving, [ 38 | "one day I will see the world", 39 | "one day I will learn to salsa" 40 | ]) 41 | 42 | assert %{label: "traveling", score: score1} = Enum.max_by(output1.predictions, & &1.score) 43 | assert_all_close(score1, 0.9874) 44 | 45 | assert %{label: "dancing", score: score2} = Enum.max_by(output2.predictions, & &1.score) 46 | assert_all_close(score2, 0.9585) 47 | end 48 | 49 | test "correctly classifies batch with compilation set to true" do 50 | {:ok, model} = Bumblebee.load_model({:hf, "facebook/bart-large-mnli"}) 51 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/bart-large-mnli"}) 52 | labels = ["cooking", "traveling", "dancing"] 53 | 54 | zero_shot_serving = 55 | Bumblebee.Text.zero_shot_classification(model, tokenizer, labels, 56 | compile: [batch_size: 2, sequence_length: 32], 57 | defn_options: [compiler: EXLA] 58 | ) 59 | 60 | assert [output1, output2] = 61 | Nx.Serving.run(zero_shot_serving, [ 62 | "one day I will see the world", 63 | "one day I will learn to salsa" 64 | ]) 65 | 66 | assert %{label: "traveling", score: score1} = Enum.max_by(output1.predictions, & &1.score) 67 | assert_all_close(score1, 0.9874) 68 | 69 | assert %{label: "dancing", score: score2} = Enum.max_by(output2.predictions, & &1.score) 70 | assert_all_close(score2, 0.9585) 71 | end 72 | end 73 | -------------------------------------------------------------------------------- /test/bumblebee/utils/image_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Utils.ImageTest do 2 | use ExUnit.Case, async: true 3 | 4 | doctest Bumblebee.Utils.Image 5 | end 6 | -------------------------------------------------------------------------------- /test/bumblebee/utils/nx_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Utils.NxTest do 2 | use ExUnit.Case, async: true 3 | 4 | doctest Bumblebee.Utils.Nx 5 | end 6 | -------------------------------------------------------------------------------- /test/bumblebee/vision/bit_featurizer_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.BitFeaturizerTest do 2 | use ExUnit.Case, async: true 3 | 4 | test "encodes image" do 5 | assert {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "google/bit-50"}) 6 | 7 | assert %Bumblebee.Vision.BitFeaturizer{} = featurizer 8 | 9 | image = Nx.tensor([[[50], [100]], [[150], [200]]]) |> Nx.broadcast({2, 2, 3}) 10 | 11 | inputs = Bumblebee.apply_featurizer(featurizer, image) 12 | 13 | assert Nx.shape(inputs["pixel_values"]) == {1, 448, 448, 3} 14 | end 15 | end 16 | -------------------------------------------------------------------------------- /test/bumblebee/vision/blip_featurizer_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.BlipFeaturizerTest do 2 | use ExUnit.Case, async: true 3 | 4 | test "encodes image" do 5 | assert {:ok, featurizer} = 6 | Bumblebee.load_featurizer({:hf, "Salesforce/blip-image-captioning-base"}) 7 | 8 | assert %Bumblebee.Vision.BlipFeaturizer{} = featurizer 9 | 10 | image = Nx.tensor([[[50], [100]], [[150], [200]]]) |> Nx.broadcast({2, 2, 3}) 11 | 12 | inputs = Bumblebee.apply_featurizer(featurizer, image) 13 | 14 | assert Nx.shape(inputs["pixel_values"]) == {1, 384, 384, 3} 15 | end 16 | end 17 | -------------------------------------------------------------------------------- /test/bumblebee/vision/blip_vision_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.BlipVisionTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-BlipModel"}, 11 | module: Bumblebee.Vision.BlipVision, 12 | architecture: :base 13 | ) 14 | 15 | assert %Bumblebee.Vision.BlipVision{architecture: :base} = spec 16 | 17 | inputs = %{ 18 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3}) 19 | } 20 | 21 | outputs = Axon.predict(model, params, inputs) 22 | 23 | assert Nx.shape(outputs.hidden_state) == {1, 226, 32} 24 | assert Nx.shape(outputs.pooled_state) == {1, 32} 25 | 26 | assert_all_close( 27 | outputs.hidden_state[[.., 1..3, 1..3]] |> Nx.multiply(1_000_000), 28 | Nx.tensor([ 29 | [[-0.0272, -0.0129, 0.0174], [0.0069, -0.0429, -0.0334], [0.0428, -0.0797, -0.0353]] 30 | ]) 31 | ) 32 | 33 | assert_all_close( 34 | outputs.pooled_state[[.., 1..3]] |> Nx.multiply(10_000), 35 | Nx.tensor([[-0.0128, -0.0792, -0.1011]]) 36 | ) 37 | end 38 | end 39 | -------------------------------------------------------------------------------- /test/bumblebee/vision/clip_featurizer_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.ClipFeaturizerTest do 2 | use ExUnit.Case, async: true 3 | 4 | test "encodes image" do 5 | assert {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "openai/clip-vit-base-patch32"}) 6 | 7 | assert %Bumblebee.Vision.ClipFeaturizer{} = featurizer 8 | 9 | image = Nx.tensor([[[50], [100]], [[150], [200]]]) |> Nx.broadcast({2, 2, 3}) 10 | 11 | inputs = Bumblebee.apply_featurizer(featurizer, image) 12 | 13 | assert Nx.shape(inputs["pixel_values"]) == {1, 224, 224, 3} 14 | end 15 | end 16 | -------------------------------------------------------------------------------- /test/bumblebee/vision/clip_vision_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.ClipVisionTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-CLIPModel"}, 11 | module: Bumblebee.Vision.ClipVision, 12 | architecture: :base 13 | ) 14 | 15 | assert %Bumblebee.Vision.ClipVision{architecture: :base} = spec 16 | 17 | inputs = %{ 18 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3}) 19 | } 20 | 21 | outputs = Axon.predict(model, params, inputs) 22 | 23 | assert Nx.shape(outputs.hidden_state) == {1, 226, 32} 24 | assert Nx.shape(outputs.pooled_state) == {1, 32} 25 | 26 | assert_all_close( 27 | outputs.hidden_state[[.., 1..3, 1..3]], 28 | Nx.tensor([ 29 | [[0.4483, 0.3736, -0.5581], [0.9376, -0.3424, -0.1002], [0.5782, 0.1069, -0.2953]] 30 | ]) 31 | ) 32 | 33 | assert_all_close( 34 | outputs.pooled_state[[.., 1..3]], 35 | Nx.tensor([[-0.5059, 0.7391, 0.9252]]) 36 | ) 37 | end 38 | 39 | test ":for_embedding" do 40 | assert {:ok, %{model: model, params: params, spec: spec}} = 41 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-CLIPModel"}, 42 | module: Bumblebee.Vision.ClipVision, 43 | architecture: :for_embedding 44 | ) 45 | 46 | assert %Bumblebee.Vision.ClipVision{architecture: :for_embedding} = spec 47 | 48 | inputs = %{ 49 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3}) 50 | } 51 | 52 | outputs = Axon.predict(model, params, inputs) 53 | 54 | assert Nx.shape(outputs.embedding) == {1, 64} 55 | 56 | assert_all_close( 57 | outputs.embedding[[.., 1..3]], 58 | Nx.tensor([[0.8865, -0.9042, -1.1233]]) 59 | ) 60 | end 61 | end 62 | -------------------------------------------------------------------------------- /test/bumblebee/vision/convnext_featurizer_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.ConvNextFeaturizerTest do 2 | use ExUnit.Case, async: true 3 | 4 | test "encodes image" do 5 | assert {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "facebook/convnext-tiny-224"}) 6 | 7 | assert %Bumblebee.Vision.ConvNextFeaturizer{} = featurizer 8 | 9 | image = Nx.tensor([[[50], [100]], [[150], [200]]]) |> Nx.broadcast({2, 2, 3}) 10 | 11 | inputs = Bumblebee.apply_featurizer(featurizer, image) 12 | 13 | assert Nx.shape(inputs["pixel_values"]) == {1, 224, 224, 3} 14 | end 15 | 16 | test "allows an alpha channel" do 17 | assert {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "facebook/convnext-tiny-224"}) 18 | 19 | assert %Bumblebee.Vision.ConvNextFeaturizer{} = featurizer 20 | 21 | image = Nx.tensor([[[50], [100]], [[150], [200]]]) |> Nx.broadcast({2, 2, 4}) 22 | 23 | inputs = Bumblebee.apply_featurizer(featurizer, image) 24 | 25 | assert Nx.shape(inputs["pixel_values"]) == {1, 224, 224, 3} 26 | end 27 | end 28 | -------------------------------------------------------------------------------- /test/bumblebee/vision/convnext_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.ConvNextTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-ConvNextModel"}) 11 | 12 | assert %Bumblebee.Vision.ConvNext{architecture: :base} = spec 13 | 14 | inputs = %{ 15 | "pixel_values" => Nx.broadcast(0.5, {1, 224, 224, 3}) 16 | } 17 | 18 | outputs = Axon.predict(model, params, inputs) 19 | 20 | assert Nx.shape(outputs.hidden_state) == {1, 7, 7, 40} 21 | assert Nx.shape(outputs.pooled_state) == {1, 40} 22 | 23 | assert_all_close( 24 | outputs.hidden_state[[.., 1..2, 1..2, 1..2]], 25 | Nx.tensor([[[[0.3924, -0.2330], [0.3924, -0.2330]], [[0.3924, -0.2330], [0.3924, -0.2330]]]]) 26 | ) 27 | 28 | assert_all_close( 29 | outputs.pooled_state[[.., 1..3]], 30 | Nx.tensor([[2.2793, -1.3236, -1.0714]]), 31 | atol: 1.0e-3 32 | ) 33 | end 34 | 35 | test ":for_image_classification" do 36 | assert {:ok, %{model: model, params: params, spec: spec}} = 37 | Bumblebee.load_model( 38 | {:hf, "hf-internal-testing/tiny-random-ConvNextForImageClassification"} 39 | ) 40 | 41 | assert %Bumblebee.Vision.ConvNext{architecture: :for_image_classification} = spec 42 | 43 | inputs = %{ 44 | "pixel_values" => Nx.broadcast(0.5, {1, 224, 224, 3}) 45 | } 46 | 47 | outputs = Axon.predict(model, params, inputs) 48 | 49 | assert Nx.shape(outputs.logits) == {1, 2} 50 | 51 | assert_all_close( 52 | outputs.logits, 53 | Nx.tensor([[0.0047, -0.1457]]) 54 | ) 55 | end 56 | end 57 | -------------------------------------------------------------------------------- /test/bumblebee/vision/deit_featurizer_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.DeitFeaturizerTest do 2 | use ExUnit.Case, async: true 3 | 4 | test "encodes image" do 5 | assert {:ok, featurizer} = 6 | Bumblebee.load_featurizer({:hf, "facebook/deit-base-distilled-patch16-224"}) 7 | 8 | assert %Bumblebee.Vision.DeitFeaturizer{} = featurizer 9 | 10 | image = Nx.tensor([[[50], [100]], [[150], [200]]]) |> Nx.broadcast({2, 2, 3}) 11 | 12 | inputs = Bumblebee.apply_featurizer(featurizer, image) 13 | 14 | assert Nx.shape(inputs["pixel_values"]) == {1, 224, 224, 3} 15 | end 16 | end 17 | -------------------------------------------------------------------------------- /test/bumblebee/vision/deit_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.DeitTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-DeiTModel"}) 11 | 12 | assert %Bumblebee.Vision.Deit{architecture: :base} = spec 13 | 14 | inputs = %{ 15 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3}) 16 | } 17 | 18 | outputs = Axon.predict(model, params, inputs) 19 | 20 | assert Nx.shape(outputs.hidden_state) == {1, 227, 32} 21 | assert Nx.shape(outputs.pooled_state) == {1, 32} 22 | 23 | assert_all_close( 24 | outputs.hidden_state[[.., 1..3, 1..3]], 25 | Nx.tensor([ 26 | [[-3.0866, 0.2350, 0.2003], [-1.2774, -0.1192, -1.0468], [-1.2774, -0.1192, -1.0468]] 27 | ]) 28 | ) 29 | 30 | assert_all_close( 31 | outputs.pooled_state[[.., 1..3]], 32 | Nx.tensor([[0.1526, -0.1437, -0.0646]]) 33 | ) 34 | end 35 | 36 | test ":for_image_classification" do 37 | assert {:ok, %{model: model, params: params, spec: spec}} = 38 | Bumblebee.load_model( 39 | {:hf, "hf-internal-testing/tiny-random-DeiTForImageClassification"} 40 | ) 41 | 42 | assert %Bumblebee.Vision.Deit{architecture: :for_image_classification} = spec 43 | 44 | inputs = %{ 45 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3}) 46 | } 47 | 48 | outputs = Axon.predict(model, params, inputs) 49 | 50 | assert Nx.shape(outputs.logits) == {1, 2} 51 | 52 | assert_all_close( 53 | outputs.logits, 54 | Nx.tensor([[0.0481, 0.1008]]) 55 | ) 56 | end 57 | 58 | test ":for_image_classification_with_teacher" do 59 | assert {:ok, %{model: model, params: params, spec: spec}} = 60 | Bumblebee.load_model( 61 | {:hf, "hf-internal-testing/tiny-random-DeiTForImageClassificationWithTeacher"} 62 | ) 63 | 64 | assert %Bumblebee.Vision.Deit{architecture: :for_image_classification_with_teacher} = spec 65 | 66 | inputs = %{ 67 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3}) 68 | } 69 | 70 | outputs = Axon.predict(model, params, inputs) 71 | 72 | assert Nx.shape(outputs.logits) == {1, 2} 73 | 74 | assert_all_close( 75 | outputs.logits, 76 | Nx.tensor([[-0.0108, -0.0048]]) 77 | ) 78 | end 79 | 80 | test ":for_masked_image_modeling" do 81 | assert {:ok, %{model: model, params: params, spec: spec}} = 82 | Bumblebee.load_model( 83 | {:hf, "hf-internal-testing/tiny-random-DeiTForMaskedImageModeling"} 84 | ) 85 | 86 | assert %Bumblebee.Vision.Deit{architecture: :for_masked_image_modeling} = spec 87 | 88 | inputs = %{ 89 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3}) 90 | } 91 | 92 | outputs = Axon.predict(model, params, inputs) 93 | 94 | assert Nx.shape(outputs.pixel_values) == {1, 30, 30, 3} 95 | 96 | assert_all_close( 97 | outputs.pixel_values[[.., 1..2, 1..2, 1..2]], 98 | Nx.tensor([[[[0.1455, 0.1889], [0.0229, 0.0910]], [[-0.0097, -0.1083], [0.0525, -0.0244]]]]) 99 | ) 100 | end 101 | end 102 | -------------------------------------------------------------------------------- /test/bumblebee/vision/image_classification_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.ImageClassificationTest do 2 | use ExUnit.Case, async: false 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag serving_test_tags() 7 | 8 | @images_dir Path.expand("../../fixtures/images", __DIR__) 9 | 10 | test "returns top scored labels" do 11 | {:ok, model_info} = Bumblebee.load_model({:hf, "microsoft/resnet-50"}) 12 | {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "microsoft/resnet-50"}) 13 | 14 | serving = Bumblebee.Vision.ImageClassification.image_classification(model_info, featurizer) 15 | 16 | image = StbImage.read_file!(Path.join(@images_dir, "coco/39769.jpeg")) 17 | 18 | assert %{ 19 | predictions: [ 20 | %{label: "tiger cat", score: _}, 21 | %{label: "tabby, tabby cat", score: _}, 22 | %{label: "remote control, remote", score: _}, 23 | %{label: "jinrikisha, ricksha, rickshaw", score: _}, 24 | %{label: "Egyptian cat", score: _} 25 | ] 26 | } = Nx.Serving.run(serving, image) 27 | end 28 | 29 | test "supports compilation" do 30 | {:ok, model_info} = Bumblebee.load_model({:hf, "microsoft/resnet-50"}) 31 | {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "microsoft/resnet-50"}) 32 | 33 | serving = 34 | Bumblebee.Vision.ImageClassification.image_classification(model_info, featurizer, 35 | compile: [batch_size: 1], 36 | defn_options: [compiler: EXLA] 37 | ) 38 | 39 | image = StbImage.read_file!(Path.join(@images_dir, "coco/39769.jpeg")) 40 | 41 | assert %{ 42 | predictions: [ 43 | %{label: "tiger cat", score: _}, 44 | %{label: "tabby, tabby cat", score: _}, 45 | %{label: "remote control, remote", score: _}, 46 | %{label: "jinrikisha, ricksha, rickshaw", score: _}, 47 | %{label: "Egyptian cat", score: _} 48 | ] 49 | } = Nx.Serving.run(serving, image) 50 | end 51 | end 52 | -------------------------------------------------------------------------------- /test/bumblebee/vision/image_embedding_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.ImageEmbeddingTest do 2 | use ExUnit.Case, async: false 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag serving_test_tags() 7 | @images_dir Path.expand("../../fixtures/images", __DIR__) 8 | 9 | test "returns embedding for an image" do 10 | {:ok, model_info} = 11 | Bumblebee.load_model({:hf, "openai/clip-vit-base-patch32"}, 12 | module: Bumblebee.Vision.ClipVision 13 | ) 14 | 15 | {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "openai/clip-vit-base-patch32"}) 16 | 17 | serving = Bumblebee.Vision.ImageEmbedding.image_embedding(model_info, featurizer) 18 | image = StbImage.read_file!(Path.join(@images_dir, "coco/39769.jpeg")) 19 | 20 | assert %{embedding: %Nx.Tensor{} = embedding} = Nx.Serving.run(serving, image) 21 | assert Nx.shape(embedding) == {768} 22 | 23 | assert_all_close( 24 | embedding[1..3], 25 | Nx.tensor([0.0978, -0.7233, -0.7707]) 26 | ) 27 | end 28 | 29 | test "returns normalized embedding for an image" do 30 | {:ok, model_info} = 31 | Bumblebee.load_model({:hf, "openai/clip-vit-base-patch32"}, 32 | module: Bumblebee.Vision.ClipVision 33 | ) 34 | 35 | {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "openai/clip-vit-base-patch32"}) 36 | 37 | options = [ 38 | embedding_processor: :l2_norm 39 | ] 40 | 41 | serving = Bumblebee.Vision.ImageEmbedding.image_embedding(model_info, featurizer, options) 42 | image = StbImage.read_file!(Path.join(@images_dir, "coco/39769.jpeg")) 43 | 44 | assert %{embedding: %Nx.Tensor{} = embedding} = Nx.Serving.run(serving, image) 45 | assert Nx.shape(embedding) == {768} 46 | 47 | assert_all_close( 48 | embedding[1..3], 49 | Nx.tensor([0.0036, -0.0269, -0.0286]) 50 | ) 51 | end 52 | end 53 | -------------------------------------------------------------------------------- /test/bumblebee/vision/image_to_text_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.ImageToTextTest do 2 | use ExUnit.Case, async: false 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag serving_test_tags() 7 | 8 | @images_dir Path.expand("../../fixtures/images", __DIR__) 9 | 10 | test "generates text describing an image" do 11 | {:ok, blip} = Bumblebee.load_model({:hf, "Salesforce/blip-image-captioning-base"}) 12 | 13 | {:ok, featurizer} = 14 | Bumblebee.load_featurizer({:hf, "Salesforce/blip-image-captioning-base"}) 15 | 16 | {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "Salesforce/blip-image-captioning-base"}) 17 | 18 | {:ok, generation_config} = 19 | Bumblebee.load_generation_config({:hf, "Salesforce/blip-image-captioning-base"}) 20 | 21 | serving = 22 | Bumblebee.Vision.ImageToText.image_to_text(blip, featurizer, tokenizer, generation_config) 23 | 24 | image = StbImage.read_file!(Path.join(@images_dir, "coco/39769.jpeg")) 25 | 26 | assert %{ 27 | results: [%{text: "two cats sleeping on a couch"}] 28 | } = Nx.Serving.run(serving, image) 29 | end 30 | end 31 | -------------------------------------------------------------------------------- /test/bumblebee/vision/resnet_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.ResNetTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-ResNetModel"}) 11 | 12 | assert %Bumblebee.Vision.ResNet{architecture: :base} = spec 13 | 14 | inputs = %{ 15 | "pixel_values" => Nx.broadcast(0.5, {1, 224, 224, 3}) 16 | } 17 | 18 | outputs = Axon.predict(model, params, inputs) 19 | 20 | assert Nx.shape(outputs.hidden_state) == {1, 7, 7, 40} 21 | assert Nx.shape(outputs.pooled_state) == {1, 40} 22 | 23 | assert_all_close( 24 | outputs.hidden_state[[.., 2..3, 2..3, 2..3]], 25 | Nx.tensor([[[[0.0000, 0.9835], [0.0000, 0.9835]], [[0.0000, 0.9835], [0.0000, 0.9835]]]]) 26 | ) 27 | 28 | assert_all_close(Nx.sum(outputs.hidden_state), Nx.tensor(209.6328)) 29 | 30 | assert_all_close( 31 | outputs.pooled_state[[.., 1..3]], 32 | Nx.tensor([[0.0275, 0.0095, 0.8921]]) 33 | ) 34 | 35 | assert_all_close(Nx.sum(outputs.pooled_state), Nx.tensor(4.2782)) 36 | end 37 | 38 | test ":for_image_classification" do 39 | assert {:ok, %{model: model, params: params, spec: spec}} = 40 | Bumblebee.load_model( 41 | {:hf, "hf-internal-testing/tiny-random-ResNetForImageClassification"} 42 | ) 43 | 44 | assert %Bumblebee.Vision.ResNet{architecture: :for_image_classification} = spec 45 | 46 | inputs = %{ 47 | "pixel_values" => Nx.broadcast(0.5, {1, 224, 224, 3}) 48 | } 49 | 50 | outputs = Axon.predict(model, params, inputs) 51 | 52 | assert Nx.shape(outputs.logits) == {1, 3} 53 | 54 | assert_all_close( 55 | outputs.logits, 56 | Nx.tensor([[-0.1053, 0.2160, -0.0331]]) 57 | ) 58 | end 59 | end 60 | -------------------------------------------------------------------------------- /test/bumblebee/vision/swin_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.SwinTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-SwinModel"}) 11 | 12 | assert %Bumblebee.Vision.Swin{architecture: :base} = spec 13 | 14 | inputs = %{ 15 | "pixel_values" => Nx.broadcast(0.5, {1, 32, 32, 3}) 16 | } 17 | 18 | outputs = Axon.predict(model, params, inputs) 19 | 20 | assert Nx.shape(outputs.hidden_state) == {1, 16, 64} 21 | assert Nx.shape(outputs.pooled_state) == {1, 64} 22 | 23 | assert_all_close( 24 | outputs.hidden_state[[.., 1..3, 1..3]], 25 | Nx.tensor([ 26 | [[-0.4605, 0.9336, -0.5528], [-0.4605, 0.9336, -0.5528], [-0.4605, 0.9336, -0.5528]] 27 | ]) 28 | ) 29 | 30 | assert_all_close( 31 | outputs.pooled_state[[.., 1..3]], 32 | Nx.tensor([[-0.4605, 0.9336, -0.5528]]) 33 | ) 34 | end 35 | 36 | test ":for_image_classification" do 37 | assert {:ok, %{model: model, params: params, spec: spec}} = 38 | Bumblebee.load_model( 39 | {:hf, "hf-internal-testing/tiny-random-SwinForImageClassification"} 40 | ) 41 | 42 | assert %Bumblebee.Vision.Swin{architecture: :for_image_classification} = spec 43 | 44 | inputs = %{ 45 | "pixel_values" => Nx.broadcast(0.5, {1, 32, 32, 3}) 46 | } 47 | 48 | outputs = Axon.predict(model, params, inputs) 49 | 50 | assert Nx.shape(outputs.logits) == {1, 2} 51 | 52 | assert_all_close( 53 | outputs.logits, 54 | Nx.tensor([[0.0361, 0.1352]]) 55 | ) 56 | end 57 | end 58 | -------------------------------------------------------------------------------- /test/bumblebee/vision/vit_featurizer_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.VitFeaturizerTest do 2 | use ExUnit.Case, async: true 3 | 4 | test "encodes image" do 5 | assert {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "google/vit-base-patch16-224"}) 6 | 7 | assert %Bumblebee.Vision.VitFeaturizer{} = featurizer 8 | 9 | image = Nx.tensor([[[50], [100]], [[150], [200]]]) |> Nx.broadcast({2, 2, 3}) 10 | 11 | inputs = Bumblebee.apply_featurizer(featurizer, image) 12 | 13 | assert Nx.shape(inputs["pixel_values"]) == {1, 224, 224, 3} 14 | end 15 | end 16 | -------------------------------------------------------------------------------- /test/bumblebee/vision/vit_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.Vision.VitTest do 2 | use ExUnit.Case, async: true 3 | 4 | import Bumblebee.TestHelpers 5 | 6 | @moduletag model_test_tags() 7 | 8 | test ":base" do 9 | assert {:ok, %{model: model, params: params, spec: spec}} = 10 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-ViTModel"}) 11 | 12 | assert %Bumblebee.Vision.Vit{architecture: :base} = spec 13 | 14 | inputs = %{ 15 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3}) 16 | } 17 | 18 | outputs = Axon.predict(model, params, inputs) 19 | 20 | assert Nx.shape(outputs.hidden_state) == {1, 226, 32} 21 | assert Nx.shape(outputs.pooled_state) == {1, 32} 22 | 23 | assert_all_close( 24 | outputs.hidden_state[[.., 1..3, 1..3]], 25 | Nx.tensor([ 26 | [[-0.2075, 2.7865, 0.2361], [-0.3014, 2.5312, -0.6127], [-0.3460, 2.8741, 0.1988]] 27 | ]) 28 | ) 29 | 30 | assert_all_close( 31 | outputs.pooled_state[[.., 1..3]], 32 | Nx.tensor([[-0.0244, -0.0515, -0.1584]]) 33 | ) 34 | end 35 | 36 | test ":for_image_classification" do 37 | assert {:ok, %{model: model, params: params, spec: spec}} = 38 | Bumblebee.load_model( 39 | {:hf, "hf-internal-testing/tiny-random-ViTForImageClassification"} 40 | ) 41 | 42 | assert %Bumblebee.Vision.Vit{architecture: :for_image_classification} = spec 43 | 44 | inputs = %{ 45 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3}) 46 | } 47 | 48 | outputs = Axon.predict(model, params, inputs) 49 | 50 | assert Nx.shape(outputs.logits) == {1, 2} 51 | 52 | assert_all_close( 53 | outputs.logits, 54 | Nx.tensor([[-0.1596, 0.1818]]) 55 | ) 56 | end 57 | 58 | test ":for_masked_image_modeling" do 59 | assert {:ok, %{model: model, params: params, spec: spec}} = 60 | Bumblebee.load_model( 61 | {:hf, "hf-internal-testing/tiny-random-ViTForMaskedImageModeling"} 62 | ) 63 | 64 | assert %Bumblebee.Vision.Vit{architecture: :for_masked_image_modeling} = spec 65 | 66 | inputs = %{ 67 | "pixel_values" => Nx.broadcast(0.5, {1, 30, 30, 3}) 68 | } 69 | 70 | outputs = Axon.predict(model, params, inputs) 71 | 72 | assert Nx.shape(outputs.pixel_values) == {1, 30, 30, 3} 73 | 74 | assert_all_close( 75 | outputs.pixel_values[[.., 1..2, 1..2, 1..2]], 76 | Nx.tensor([[[[0.0752, 0.0548], [-0.0192, -0.0216]], [[-0.0252, 0.0728], [0.0232, -0.1687]]]]) 77 | ) 78 | end 79 | end 80 | -------------------------------------------------------------------------------- /test/bumblebee_test.exs: -------------------------------------------------------------------------------- 1 | defmodule BumblebeeTest do 2 | use ExUnit.Case, async: true 3 | 4 | describe "load_model/2" do 5 | test "raises an error on invalid repository type" do 6 | assert_raise ArgumentError, 7 | ~s/expected repository to be either {:hf, repository_id}, {:hf, repository_id, options} or {:local, directory}, got: "repo-id"/, 8 | fn -> 9 | Bumblebee.load_model("repo-id") 10 | end 11 | end 12 | 13 | @tag :capture_log 14 | test "supports sharded params" do 15 | assert {:ok, %{params: params}} = 16 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"}) 17 | 18 | # PyTorch format 19 | 20 | assert {:ok, %{params: sharded_params}} = 21 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-GPT2Model-sharded"}) 22 | 23 | assert Enum.sort(Map.keys(params)) == Enum.sort(Map.keys(sharded_params)) 24 | 25 | # Safetensors 26 | 27 | assert {:ok, %{params: sharded_params}} = 28 | Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-GPT2Model-sharded"}, 29 | params_filename: "model.safetensors" 30 | ) 31 | 32 | assert Enum.sort(Map.keys(params)) == Enum.sort(Map.keys(sharded_params)) 33 | end 34 | 35 | test "supports .safetensors params" do 36 | assert {:ok, %{params: params}} = 37 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"}) 38 | 39 | assert {:ok, %{params: safetensors_params}} = 40 | Bumblebee.load_model( 41 | {:hf, "bumblebee-testing/tiny-random-GPT2Model-safetensors-only"} 42 | ) 43 | 44 | assert Enum.sort(Map.keys(params)) == Enum.sort(Map.keys(safetensors_params)) 45 | end 46 | 47 | test "supports params variants" do 48 | assert {:ok, %{params: params}} = 49 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-bert-variant"}, 50 | params_variant: "v2" 51 | ) 52 | 53 | assert {:ok, %{params: sharded_params}} = 54 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-bert-variant-sharded"}, 55 | params_variant: "v2" 56 | ) 57 | 58 | assert Enum.sort(Map.keys(params)) == Enum.sort(Map.keys(sharded_params)) 59 | 60 | assert_raise ArgumentError, 61 | ~s/parameters variant "v3" not found, available variants: "v2"/, 62 | fn -> 63 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-bert-variant"}, 64 | params_variant: "v3" 65 | ) 66 | end 67 | 68 | assert_raise ArgumentError, 69 | ~s/parameters variant "v3" not found, available variants: "v2"/, 70 | fn -> 71 | Bumblebee.load_model( 72 | {:hf, "hf-internal-testing/tiny-random-bert-variant-sharded"}, 73 | params_variant: "v3" 74 | ) 75 | end 76 | end 77 | 78 | test "passing :type casts params accordingly" do 79 | assert {:ok, %{params: %Axon.ModelState{data: params}}} = 80 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"}, 81 | type: :bf16 82 | ) 83 | 84 | assert Nx.type(params["decoder.blocks.0.ffn.output"]["kernel"]) == {:bf, 16} 85 | assert Nx.type(params["decoder.blocks.0.ffn.output"]["bias"]) == {:bf, 16} 86 | 87 | assert {:ok, %{params: %Axon.ModelState{data: params}}} = 88 | Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"}, 89 | type: Axon.MixedPrecision.create_policy(params: :f16) 90 | ) 91 | 92 | assert Nx.type(params["decoder.blocks.0.ffn.output"]["kernel"]) == {:f, 16} 93 | assert Nx.type(params["decoder.blocks.0.ffn.output"]["bias"]) == {:f, 16} 94 | end 95 | end 96 | end 97 | -------------------------------------------------------------------------------- /test/fixtures/audio/common_voice/a6c7706a220eeea7ee3687c1122fe7ac17962d2449d25b6db37cc41cdaace442683e11945b6f581e73941c3083cd4eecfafc938840459cd8c571dae7774ee687.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/audio/common_voice/a6c7706a220eeea7ee3687c1122fe7ac17962d2449d25b6db37cc41cdaace442683e11945b6f581e73941c3083cd4eecfafc938840459cd8c571dae7774ee687.wav -------------------------------------------------------------------------------- /test/fixtures/audio/common_voice/a6c7706a220eeea7ee3687c1122fe7ac17962d2449d25b6db37cc41cdaace442683e11945b6f581e73941c3083cd4eecfafc938840459cd8c571dae7774ee687_pcm_f32le_16000.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/audio/common_voice/a6c7706a220eeea7ee3687c1122fe7ac17962d2449d25b6db37cc41cdaace442683e11945b6f581e73941c3083cd4eecfafc938840459cd8c571dae7774ee687_pcm_f32le_16000.bin -------------------------------------------------------------------------------- /test/fixtures/audio/common_voice/info.md: -------------------------------------------------------------------------------- 1 | Source: https://huggingface.co/datasets/common_voice 2 | -------------------------------------------------------------------------------- /test/fixtures/audio/generate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd "$(dirname "$0")" 4 | 5 | for source in $(ls **/*.{wav,mp3}); do 6 | name="${source%.*}" 7 | ffmpeg -i $source -ac 1 -ar 16000 -f f32le -hide_banner -loglevel quiet "${name}_pcm_f32le_16000.bin" 8 | done 9 | -------------------------------------------------------------------------------- /test/fixtures/audio/librivox/46s.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/audio/librivox/46s.mp3 -------------------------------------------------------------------------------- /test/fixtures/audio/librivox/46s_pcm_f32le_16000.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/audio/librivox/46s_pcm_f32le_16000.bin -------------------------------------------------------------------------------- /test/fixtures/audio/librivox/info.md: -------------------------------------------------------------------------------- 1 | Source: https://librivox.org/the-book-of-irish-poetry-by-various 2 | -------------------------------------------------------------------------------- /test/fixtures/images/coco/39769.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/images/coco/39769.jpeg -------------------------------------------------------------------------------- /test/fixtures/images/coco/info.md: -------------------------------------------------------------------------------- 1 | Source: https://cocodataset.org/#explore?id=39769 2 | -------------------------------------------------------------------------------- /test/fixtures/pytorch/generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from collections import OrderedDict 5 | 6 | 7 | def save(name, fmt, data): 8 | use_zip = fmt == "zip" 9 | dirname = os.path.dirname(__file__) 10 | path = os.path.join(dirname, f"{name}.{fmt}.pt") 11 | torch.save(data, path, _use_new_zipfile_serialization=use_zip) 12 | 13 | 14 | for fmt in ["zip", "legacy"]: 15 | save("tensors", fmt, [ 16 | torch.tensor([-1.0, 1.0], dtype=torch.float64), 17 | torch.tensor([-1.0, 1.0], dtype=torch.float32), 18 | torch.tensor([-1.0, 1.0], dtype=torch.float16), 19 | torch.tensor([-1, 1], dtype=torch.int64), 20 | torch.tensor([-1, 1], dtype=torch.int32), 21 | torch.tensor([-1, 1], dtype=torch.int16), 22 | torch.tensor([-1, 1], dtype=torch.int8), 23 | torch.tensor([0, 1], dtype=torch.uint8), 24 | torch.tensor([0, 1, 0, 1], dtype=torch.bool), 25 | torch.tensor([-1.0, 1.0], dtype=torch.bfloat16), 26 | torch.tensor([1 - 1j, 1 + 1j], dtype=torch.complex128), 27 | torch.tensor([1 - 1j, 1 + 1j], dtype=torch.complex64) 28 | ]) 29 | 30 | save("numpy_arrays", fmt, [ 31 | np.array([-1.0, 1.0], dtype=np.float64), 32 | np.array([-1.0, 1.0], dtype=np.float32), 33 | np.array([-1.0, 1.0], dtype=np.float16), 34 | np.array([-1, 1], dtype=np.int64), 35 | np.array([-1, 1], dtype=np.int32), 36 | np.array([-1, 1], dtype=np.int16), 37 | np.array([-1, 1], dtype=np.int8), 38 | np.array([0, 1], dtype=np.uint64), 39 | np.array([0, 1], dtype=np.uint32), 40 | np.array([0, 1], dtype=np.uint16), 41 | np.array([0, 1], dtype=np.uint8), 42 | np.array([0, 1], dtype=np.bool_), 43 | np.array([1 - 1j, 1 + 1j], dtype=np.complex128), 44 | np.array([1 - 1j, 1 + 1j], dtype=np.complex64) 45 | ]) 46 | 47 | save("ordered_dict", fmt, OrderedDict([("x", 1), ("y", 2)])) 48 | 49 | transposed_tensor = torch.tensor( 50 | [[[1, 1], [2, 2], [3, 3]], [[4, 4], [5, 5], [6, 6]]], dtype=torch.int64).permute(2, 0, 1) 51 | save("noncontiguous_tensor", fmt, transposed_tensor) 52 | 53 | transposed_array = np.transpose(np.array([[1, 2, 3], [4, 5, 6]])) 54 | save("noncontiguous_numpy_array", fmt, transposed_array) 55 | 56 | # Model parameters 57 | 58 | save("state_dict_base", "zip", OrderedDict([ 59 | ("conv.weight", torch.ones(2, 3, 2, 2)), 60 | ("conv.bias", torch.zeros(2)), 61 | ])) 62 | 63 | save("state_dict_full", "zip", OrderedDict([ 64 | ("base.conv.weight", torch.ones(2, 3, 2, 2)), 65 | ("base.conv.bias", torch.zeros(2)), 66 | # Unexpected shape 67 | ("classifier.layers.0.weight", torch.ones(1, 1)), 68 | ("classifier.layers.0.bias", torch.zeros(1)), 69 | # Missing 70 | # "classifier.layers.1.weight" 71 | # "classifier.layers.1.bias" 72 | # Extra 73 | ("extra.weight", torch.ones(1)) 74 | ])) 75 | -------------------------------------------------------------------------------- /test/fixtures/pytorch/noncontiguous_numpy_array.legacy.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/noncontiguous_numpy_array.legacy.pt -------------------------------------------------------------------------------- /test/fixtures/pytorch/noncontiguous_numpy_array.zip.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/noncontiguous_numpy_array.zip.pt -------------------------------------------------------------------------------- /test/fixtures/pytorch/noncontiguous_tensor.legacy.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/noncontiguous_tensor.legacy.pt -------------------------------------------------------------------------------- /test/fixtures/pytorch/noncontiguous_tensor.zip.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/noncontiguous_tensor.zip.pt -------------------------------------------------------------------------------- /test/fixtures/pytorch/numpy_arrays.legacy.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/numpy_arrays.legacy.pt -------------------------------------------------------------------------------- /test/fixtures/pytorch/numpy_arrays.zip.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/numpy_arrays.zip.pt -------------------------------------------------------------------------------- /test/fixtures/pytorch/ordered_dict.legacy.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/ordered_dict.legacy.pt -------------------------------------------------------------------------------- /test/fixtures/pytorch/ordered_dict.zip.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/ordered_dict.zip.pt -------------------------------------------------------------------------------- /test/fixtures/pytorch/state_dict_base.zip.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/state_dict_base.zip.pt -------------------------------------------------------------------------------- /test/fixtures/pytorch/state_dict_full.zip.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/state_dict_full.zip.pt -------------------------------------------------------------------------------- /test/fixtures/pytorch/storage_view.legacy.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/storage_view.legacy.pt -------------------------------------------------------------------------------- /test/fixtures/pytorch/tensors.legacy.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/tensors.legacy.pt -------------------------------------------------------------------------------- /test/fixtures/pytorch/tensors.zip.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/bumblebee/710a645222948f80208c348d3a2589cbd3ab8e7d/test/fixtures/pytorch/tensors.zip.pt -------------------------------------------------------------------------------- /test/support/test_helpers.ex: -------------------------------------------------------------------------------- 1 | defmodule Bumblebee.TestHelpers do 2 | @moduledoc false 3 | 4 | import ExUnit.Assertions 5 | 6 | defmacro assert_equal(left, right) do 7 | # Assert against binary backend tensors to show diff on failure 8 | quote do 9 | left = unquote(left) |> to_binary_backend() 10 | right = unquote(right) |> Nx.as_type(Nx.type(left)) |> to_binary_backend() 11 | assert left == right 12 | end 13 | end 14 | 15 | def to_binary_backend(tensor) do 16 | Nx.backend_copy(tensor, Nx.BinaryBackend) 17 | end 18 | 19 | def assert_all_close(left, right, opts \\ []) do 20 | atol = opts[:atol] || 1.0e-4 21 | rtol = opts[:rtol] || 1.0e-4 22 | 23 | equals = 24 | left 25 | |> Nx.all_close(right, atol: atol, rtol: rtol) 26 | |> Nx.backend_transfer(Nx.BinaryBackend) 27 | 28 | if equals != Nx.tensor(1, type: {:u, 8}, backend: Nx.BinaryBackend) do 29 | flunk(""" 30 | expected 31 | 32 | #{inspect(left)} 33 | 34 | to be within tolerance of 35 | 36 | #{inspect(right)} 37 | """) 38 | end 39 | end 40 | 41 | def model_test_tags() do 42 | [model: true, capture_log: true, timeout: 60_000] 43 | end 44 | 45 | def serving_test_tags() do 46 | [serving: true, slow: true, capture_log: true, timeout: 600_000] 47 | end 48 | 49 | def scheduler_loop(scheduler, num_steps) do 50 | sample = dummy_sample() 51 | 52 | {state, timesteps} = 53 | Bumblebee.scheduler_init(scheduler, num_steps, Nx.to_template(sample), Nx.Random.key(0)) 54 | 55 | {_state, sample} = 56 | for i <- 0..(Nx.size(timesteps) - 1), reduce: {state, sample} do 57 | {state, sample} -> 58 | prediction = dummy_model(sample, timesteps[i]) 59 | Bumblebee.scheduler_step(scheduler, state, sample, prediction) 60 | end 61 | 62 | sample 63 | end 64 | 65 | def scheduler_timesteps(scheduler, num_steps) do 66 | sample = dummy_sample() 67 | 68 | {_state, timesteps} = 69 | Bumblebee.scheduler_init(scheduler, num_steps, Nx.to_template(sample), Nx.Random.key(0)) 70 | 71 | timesteps 72 | end 73 | 74 | defp dummy_sample() do 75 | shape = {_height = 8, _width = 8, _channels = 4} 76 | sample = Nx.iota(shape) 77 | Nx.divide(sample, Nx.size(sample)) 78 | end 79 | 80 | defp dummy_model(sample, timestep) do 81 | sample 82 | |> Nx.multiply(timestep) 83 | |> Nx.divide(Nx.add(timestep, 1)) 84 | end 85 | end 86 | -------------------------------------------------------------------------------- /test/test_helper.exs: -------------------------------------------------------------------------------- 1 | Application.put_env(:bumblebee, :progress_bar_enabled, false) 2 | 3 | client = EXLA.Client.fetch!(:host) 4 | 5 | exclude_multi_device = if client.device_count > 1, do: [], else: [:multi_device] 6 | 7 | if client.device_count == 1 and System.schedulers_online() > 1 do 8 | IO.puts( 9 | "To run multi-device tests: XLA_FLAGS=--xla_force_host_platform_device_count=2 mix test" 10 | ) 11 | end 12 | 13 | Application.put_env(:exla, :clients, 14 | host: [platform: :host], 15 | cuda: [platform: :cuda], 16 | rocm: [platform: :rocm], 17 | tpu: [platform: :tpu], 18 | other_host: [platform: :host, automatic_transfers: false] 19 | ) 20 | 21 | Application.put_env(:exla, :preferred_clients, [:tpu, :cuda, :rocm, :other_host, :host]) 22 | 23 | Application.put_env(:nx, :default_backend, {EXLA.Backend, client: :host}) 24 | 25 | if System.fetch_env("BUMBLEBEE_OFFLINE") == :error do 26 | IO.puts("To run tests without hitting the network: BUMBLEBEE_OFFLINE=true mix test") 27 | end 28 | 29 | ExUnit.start(exclude: [:slow] ++ exclude_multi_device) 30 | --------------------------------------------------------------------------------