├── tests
├── __init__.py
├── run_tests.py
├── models
│ ├── test_base.py
│ ├── test_gemma3.py
│ ├── test_llama.py
│ ├── test_t5.py
│ ├── test_gemma2.py
│ └── test_generation_utils.py
├── functional
│ └── test_attention.py
└── attention
│ ├── test_RoPEMultiHeadAttention.py
│ └── test_multi_head_attention.py
├── .python-version
├── docs
├── .gitignore
├── modules
│ └── index.rst
├── Makefile
├── make.bat
├── index.rst
└── conf.py
├── examples
├── __init__.py
├── t5_inference_example.py
├── gemma2_inference_example.py
├── llama_inference_example.py
├── gemma3_inference_example.py
└── multi_head_attention_example.py
├── assets
└── logo.png
├── .pre-commit-config.yaml
├── jaxgarden
├── attention
│ ├── __init__.py
│ ├── multi_head_attention.py
│ └── rope_multi_head_attention.py
├── functional
│ ├── __init__.py
│ └── attention.py
├── models
│ ├── __init__.py
│ └── base.py
└── __init__.py
├── .github
└── workflows
│ ├── ruff.yml
│ ├── mypy.yml
│ ├── tests.yml
│ └── docs.yml
├── .gitignore
├── LICENSE
├── .devcontainer
├── devcontainer.json
├── Dockerfile
├── README.md
└── verify_jax_cuda.py
├── pyproject.toml
├── .cursor
└── rules
│ ├── guide.mdc
│ ├── baseconfig.mdc
│ ├── rotary_position_embeddings__rope_.mdc
│ ├── generationmixin.mdc
│ ├── tokenizer.mdc
│ ├── llamaforcausallm.mdc
│ └── modernbertformaskedlm.mdc
└── README.md
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
1 | 3.10
2 |
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | _build/
2 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
1 | """Example scripts for the jax-layers library."""
2 |
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-gde/jaxgarden/HEAD/assets/logo.png
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/astral-sh/ruff-pre-commit
3 | rev: v0.9.10
4 | hooks:
5 | - id: ruff
6 | - id: ruff-format
7 |
--------------------------------------------------------------------------------
/jaxgarden/attention/__init__.py:
--------------------------------------------------------------------------------
1 | """Attention modules for JAXgarden."""
2 |
3 | from jaxgarden.attention.multi_head_attention import MultiHeadAttention
4 |
5 | __all__ = [
6 | "MultiHeadAttention",
7 | ]
8 |
--------------------------------------------------------------------------------
/jaxgarden/functional/__init__.py:
--------------------------------------------------------------------------------
1 | """Functional implementations for JAXgarden."""
2 |
3 | from jaxgarden.functional.attention import dot_product_attention
4 |
5 | __all__ = [
6 | "dot_product_attention",
7 | ]
8 |
--------------------------------------------------------------------------------
/.github/workflows/ruff.yml:
--------------------------------------------------------------------------------
1 | name: Ruff
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | jobs:
10 | ruff:
11 | runs-on: ubuntu-latest
12 |
13 | steps:
14 | - uses: actions/checkout@v4
15 |
16 | - name: Set up Python
17 | uses: actions/setup-python@v4
18 | with:
19 | python-version: "3.10"
20 |
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install ruff
25 |
26 | - name: Run Ruff
27 | run: |
28 | ruff check .
29 | ruff format --check .
--------------------------------------------------------------------------------
/docs/modules/index.rst:
--------------------------------------------------------------------------------
1 | API Documentation
2 | ================
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 |
7 | attention
8 | functional
9 | models
10 |
11 | Attention Modules
12 | ---------------
13 |
14 | .. automodule:: jaxgarden.attention
15 | :members:
16 | :undoc-members:
17 | :show-inheritance:
18 |
19 | Functional Interfaces
20 | ------------------
21 |
22 | .. automodule:: jaxgarden.functional
23 | :members:
24 | :undoc-members:
25 | :show-inheritance:
26 |
27 | Models
28 | ------------------
29 |
30 | .. automodule:: jaxgarden.models
31 | :members:
32 | :undoc-members:
33 | :show-inheritance:
--------------------------------------------------------------------------------
/.github/workflows/mypy.yml:
--------------------------------------------------------------------------------
1 | name: Type Checking
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | jobs:
10 | mypy:
11 | runs-on: ubuntu-latest
12 |
13 | steps:
14 | - uses: actions/checkout@v4
15 |
16 | - name: Set up Python
17 | uses: actions/setup-python@v4
18 | with:
19 | python-version: "3.10"
20 |
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install -e ".[dev]"
25 | pip install mypy
26 |
27 | - name: Run mypy
28 | run: |
29 | mypy jaxgarden tests
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 |
3 | # You can set these variables from the command line, and also
4 | # from the environment for the first two.
5 | SPHINXOPTS ?=
6 | SPHINXBUILD ?= sphinx-build
7 | SOURCEDIR = .
8 | BUILDDIR = _build
9 |
10 | # Put it first so that "make" without argument is like "make help".
11 | help:
12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13 |
14 | .PHONY: help Makefile
15 |
16 | # Catch-all target: route all unknown targets to Sphinx using the new
17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18 | %: Makefile
19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/examples/t5_inference_example.py:
--------------------------------------------------------------------------------
1 | from flax import nnx
2 |
3 | from jaxgarden import T5Config, T5ForCausalLM, Tokenizer
4 |
5 | if __name__ == "__main__":
6 | config = T5Config()
7 | model = T5ForCausalLM(config, rngs=nnx.Rngs(0))
8 | model_id = "google-t5/t5-base"
9 |
10 | # download checkpoint from HuggingFace Hub
11 | model.from_hf(model_id, force_download=True)
12 |
13 | tokenizer = Tokenizer.from_pretrained(model_id)
14 |
15 | text = "The meaning of life is"
16 | model_inputs = tokenizer.encode(text)
17 | output = model.generate(**model_inputs, max_length=20, do_sample=True)
18 | output_text = tokenizer.decode(output)
19 |
20 | print(output, output.shape)
21 | print(output_text)
22 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | jobs:
10 | test:
11 | runs-on: ubuntu-latest
12 | strategy:
13 | matrix:
14 | python-version: ["3.10", "3.11", "3.12"]
15 |
16 | steps:
17 | - uses: actions/checkout@v4
18 |
19 | - name: Set up Python ${{ matrix.python-version }}
20 | uses: actions/setup-python@v4
21 | with:
22 | python-version: ${{ matrix.python-version }}
23 |
24 | - name: Install dependencies
25 | run: |
26 | python -m pip install --upgrade pip
27 | pip install -e ".[dev]"
28 |
29 | - name: Run tests
30 | run: |
31 | python tests/run_tests.py
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
--------------------------------------------------------------------------------
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | name: Documentation
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | jobs:
10 | docs:
11 | runs-on: ubuntu-latest
12 |
13 | steps:
14 | - uses: actions/checkout@v4
15 |
16 | - name: Set up Python
17 | uses: actions/setup-python@v4
18 | with:
19 | python-version: "3.10"
20 |
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install -e ".[dev]"
25 |
26 | - name: Build documentation
27 | run: |
28 | cd docs
29 | make html
30 |
31 | - name: Deploy to GitHub Pages
32 | if: github.event_name == 'push' && github.ref == 'refs/heads/main'
33 | uses: peaceiris/actions-gh-pages@v3
34 | with:
35 | github_token: ${{ secrets.GITHUB_TOKEN }}
36 | publish_dir: ./docs/_build/html
--------------------------------------------------------------------------------
/examples/gemma2_inference_example.py:
--------------------------------------------------------------------------------
1 | from flax import nnx
2 |
3 | from jaxgarden import Gemma2Config, Gemma2ForCausalLM, Tokenizer
4 |
5 | # HF repo id of the Gemma variant that you want to use
6 | model_id = "google/gemma-2-2b-it"
7 |
8 | # initialize the Gemma architecture
9 | config = Gemma2Config()
10 | model = Gemma2ForCausalLM(config, rngs=nnx.Rngs(0))
11 |
12 | # This is a one-liner to download HF checkpoint from HuggingFace Hub,
13 | # convert it to jaxgarden format,
14 | # save it in an Orbax checkpoint,
15 | # and then remove the HF checkpoint.
16 | model.from_hf(model_id, force_download=True)
17 |
18 | # this works just like `transformers.AutoTokenizer`,
19 | # but without the dependency of the whole `transformers` library.
20 | # Instead, we simply extend `tokenizers` package and add some cnvenience code for JAX.
21 | tokenizer = Tokenizer.from_pretrained(model_id)
22 |
23 | text = "The meaning of life is"
24 | model_inputs = tokenizer.encode(text)
25 | output = model.generate(**model_inputs, max_length=20, do_sample=True)
26 | output_text = tokenizer.decode(output)
27 | print(output_text)
28 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python-generated files
2 | __pycache__/
3 | *.py[oc]
4 | build/
5 | dist/
6 | wheels/
7 | *.egg-info
8 |
9 | # Virtual environments
10 | .venv
11 | jax_env
12 |
13 | # Byte-compiled / optimized / DLL files
14 | __pycache__/
15 | *.py[cod]
16 | *$py.class
17 |
18 | # C extensions
19 | *.so
20 |
21 | # Distribution / packaging
22 | .Python
23 | build/
24 | develop-eggs/
25 | dist/
26 | downloads/
27 | eggs/
28 | .eggs/
29 | lib/
30 | lib64/
31 | parts/
32 | sdist/
33 | var/
34 | wheels/
35 | *.egg-info/
36 | .installed.cfg
37 | *.egg
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | .hypothesis/
49 | .pytest_cache/
50 |
51 | # Jupyter Notebook
52 | .ipynb_checkpoints
53 |
54 | # Environments
55 | .env
56 | .venv
57 | env/
58 | venv/
59 | ENV/
60 | env.bak/
61 | venv.bak/
62 |
63 | # IDE specific files
64 | .idea/
65 | .vscode/
66 | *.swp
67 | *.swo
68 |
69 | # Library specific
70 | *.npy
71 | *.npz
72 | .ruff_cache/
73 | mypy_cache/
74 | pytest_cache/
75 |
76 | # OS specific
77 | .DS_Store
78 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | Welcome to JAXgarden documentation!
2 | =====================================
3 |
4 | JAXgarden provides high-performance and hackable neural network model implementations in JAX, leveraging optimized kernals and layers like FlashAttention.
5 |
6 | .. toctree::
7 | :maxdepth: 2
8 | :caption: Contents:
9 |
10 | readme
11 | modules/index
12 | contributing
13 | changelog
14 |
15 | Features
16 | --------
17 |
18 | - **MultiHeadAttention**: A Flax NNX-compatible implementation with support for different attention backends.
19 |
20 | Installation
21 | ------------
22 |
23 | .. code-block:: bash
24 |
25 | pip install git+https://github.com/ml-gde/jax-layers.git
26 |
27 | For development installation:
28 |
29 | .. code-block:: bash
30 |
31 | # first, fork the repository to your account.
32 | # Then, clone it to your machine.
33 | git clone https://github.com/yourusername/jax-layers.git
34 | cd jax-layers
35 | pip install -e ".[dev]"
36 |
37 | Indices and tables
38 | ==================
39 |
40 | * :ref:`genindex`
41 | * :ref:`modindex`
42 | * :ref:`search`
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 JAX Layers Contributors
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/tests/run_tests.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Script to run all tests for the jax-layers project."""
3 |
4 | import os
5 | import sys
6 |
7 | import pytest
8 |
9 |
10 | def main():
11 | """Run all tests for the jax-layers project."""
12 | # Get the directory of this script
13 | script_dir = os.path.dirname(os.path.abspath(__file__))
14 |
15 | # Get the root directory of the project
16 | root_dir = os.path.dirname(script_dir)
17 |
18 | # Add the root directory to the Python path
19 | sys.path.insert(0, root_dir)
20 |
21 | # Run the tests
22 | args = [
23 | "-xvs", # -x: exit on first failure, -v: verbose, -s: don't capture stdout
24 | script_dir, # Run all tests in the tests directory
25 | "--cov=jax_layers", # Generate coverage report for jax_layers
26 | "--cov-report=term", # Output coverage report to terminal
27 | ]
28 |
29 | # Add any additional arguments passed to this script
30 | args.extend(sys.argv[1:])
31 |
32 | # Run pytest with the arguments
33 | return pytest.main(args)
34 |
35 |
36 | if __name__ == "__main__":
37 | sys.exit(main())
38 |
--------------------------------------------------------------------------------
/.devcontainer/devcontainer.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "JAX Layers Development",
3 | "build": {
4 | "dockerfile": "Dockerfile",
5 | "context": ".."
6 | },
7 | "runArgs": [
8 | "--gpus=all"
9 | ],
10 | "customizations": {
11 | "vscode": {
12 | "extensions": [
13 | "ms-python.python",
14 | "ms-python.vscode-pylance",
15 | "charliermarsh.ruff",
16 | "matangover.mypy"
17 | ],
18 | "settings": {
19 | "python.defaultInterpreterPath": "/usr/local/bin/python",
20 | "python.linting.enabled": true,
21 | "editor.formatOnSave": true,
22 | "editor.codeActionsOnSave": {
23 | "source.organizeImports": "always",
24 | "source.fixAll": "always"
25 | },
26 | "python.formatting.provider": "none",
27 | "[python]": {
28 | "editor.defaultFormatter": "charliermarsh.ruff"
29 | }
30 | }
31 | }
32 | },
33 | "forwardPorts": [],
34 | "postCreateCommand": "pip install -e '.[dev]'",
35 | "remoteUser": "vscode"
36 | }
--------------------------------------------------------------------------------
/tests/models/test_base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | import jax
5 | import numpy as np
6 | from flax import nnx
7 |
8 | from jaxgarden.models.base import BaseConfig, BaseModel
9 |
10 |
11 | class TwoLayerMLPConfig(BaseConfig):
12 | dim: int = 4
13 | use_bias: bool = False
14 |
15 |
16 | class TwoLayerMLP(BaseModel):
17 | def __init__(self, config, rngs: nnx.Rngs):
18 | self.linear1 = nnx.Linear(config.dim, config.dim, rngs=rngs, use_bias=config.use_bias)
19 | self.linear2 = nnx.Linear(config.dim, config.dim, rngs=rngs, use_bias=config.use_bias)
20 |
21 | def __call__(self, x):
22 | x = self.linear1(x)
23 | return self.linear2(x)
24 |
25 |
26 | def test_save_and_load():
27 | ckpt_dir = "/tmp/jaxgarden_test_ckpt"
28 | if os.path.exists(ckpt_dir):
29 | shutil.rmtree(ckpt_dir)
30 |
31 | config = TwoLayerMLPConfig()
32 | model = TwoLayerMLP(config, rngs=nnx.Rngs(0))
33 | x = jax.random.normal(jax.random.key(42), (3, 4))
34 | assert model(x).shape == (3, 4)
35 | state = model.state_dict
36 | model.save(ckpt_dir)
37 | model_restored = TwoLayerMLP(config, rngs=nnx.Rngs(1)).load(ckpt_dir)
38 | state_restored = model_restored.state_dict
39 | jax.tree.map(np.testing.assert_array_equal, state, state_restored)
40 |
--------------------------------------------------------------------------------
/examples/llama_inference_example.py:
--------------------------------------------------------------------------------
1 | from flax import nnx
2 |
3 | from jaxgarden import LlamaConfig, LlamaForCausalLM, Tokenizer
4 |
5 | if __name__ == "__main__":
6 | # initialize a config object (with defaults for 1B varient)
7 | # other varients to be added.
8 | config = LlamaConfig()
9 | model = LlamaForCausalLM(config, rngs=nnx.Rngs(0))
10 | model_id = "meta-llama/Llama-3.2-1B"
11 |
12 | # this will download HF checkpoint from HuggingFace Hub,
13 | # convert it to jaxgarden format,
14 | # save it in an Orbax checkpoint,
15 | # and then remove the HF checkpoint.
16 | # If you didn't set your HF token globally,
17 | # you may need to pass your token as an argument to this method.
18 | model.from_hf(model_id, force_download=True)
19 |
20 | # this works just like `transformers.AutoTokenizer`,
21 | # but without the dependency of the whole `transformers` library.
22 | # Instead, we simply extend `tokenizers` package and add some cnvenience code for JAX.
23 | tokenizer = Tokenizer.from_pretrained(model_id)
24 |
25 | text = "The meaning of life is"
26 | model_inputs = tokenizer.encode(text)
27 | output = model.generate(**model_inputs, max_length=20, do_sample=True)
28 | output_text = tokenizer.decode(output)
29 | print(output, output.shape)
30 | print(output_text)
31 |
--------------------------------------------------------------------------------
/.devcontainer/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04
2 |
3 | # Avoid prompts from apt
4 | ENV DEBIAN_FRONTEND=noninteractive
5 |
6 | # Install Python and other dependencies
7 | RUN apt-get update && apt-get install -y \
8 | python3.10 \
9 | python3.10-dev \
10 | python3-pip \
11 | python3-venv \
12 | git \
13 | curl \
14 | wget \
15 | build-essential \
16 | && rm -rf /var/lib/apt/lists/*
17 |
18 | # Create symbolic links for python
19 | RUN ln -sf /usr/bin/python3.10 /usr/bin/python && \
20 | ln -sf /usr/bin/python3.10 /usr/bin/python3
21 |
22 | # Create a non-root user
23 | ARG USERNAME=vscode
24 | ARG USER_UID=1000
25 | ARG USER_GID=$USER_UID
26 |
27 | RUN groupadd --gid $USER_GID $USERNAME \
28 | && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \
29 | && apt-get update \
30 | && apt-get install -y sudo \
31 | && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \
32 | && chmod 0440 /etc/sudoers.d/$USERNAME
33 |
34 | # Set up Python environment
35 | RUN python -m pip install --upgrade pip setuptools wheel
36 |
37 | # Install JAX with CUDA support
38 | RUN pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
39 |
40 | # Install matplotlib for benchmarking script
41 | RUN pip install matplotlib
42 |
43 | # Set working directory
44 | WORKDIR /workspace
45 |
46 | # Switch to non-root user
47 | USER $USERNAME
48 |
49 | # Note: The project dependencies will be installed by the postCreateCommand in devcontainer.json
50 | # which runs `pip install -e '.[dev]'` after the container is created
--------------------------------------------------------------------------------
/tests/models/test_gemma3.py:
--------------------------------------------------------------------------------
1 | """Tests for Gemma3 model."""
2 |
3 | import jax.numpy as jnp
4 | import numpy as np
5 | from flax import nnx
6 |
7 | from jaxgarden.models.gemma3 import (
8 | Gemma3Attention,
9 | Gemma3Config,
10 | Gemma3DecoderLayer,
11 | Gemma3ForCausalLM,
12 | Gemma3MLP,
13 | Gemma3RMSNorm,
14 | Gemma3RotaryEmbedding,
15 | )
16 |
17 |
18 | def test_gemma3_config():
19 | """Test Gemma3Config initialization and validation."""
20 | config = Gemma3Config()
21 | assert config.num_attention_heads % config.num_key_value_heads == 0
22 | assert config.head_dim == 256 # Default head_dim
23 |
24 |
25 | def test_gemma3_rms_norm():
26 | """Test RMSNorm layer."""
27 | rng = nnx.Rngs(0)
28 | dim = 32
29 | batch_size = 2
30 | seq_len = 4
31 |
32 | norm = Gemma3RMSNorm(dim, eps=1e-6, rngs=rng)
33 | x = jnp.ones((batch_size, seq_len, dim), dtype=jnp.float32)
34 | out = norm(x)
35 |
36 | assert out.shape == x.shape
37 | # Output should be normalized
38 | variance = jnp.mean(jnp.square(out), axis=-1)
39 | np.testing.assert_allclose(variance, 1.0, rtol=1e-5)
40 |
41 |
42 | def test_gemma3_rotary_embedding():
43 | """Test RotaryEmbedding module."""
44 | dim = 32
45 | batch_size = 2
46 | seq_len = 4
47 | num_heads = 2
48 |
49 | rope = Gemma3RotaryEmbedding(dim=dim, max_position_embeddings=8192)
50 | x = jnp.ones((batch_size, seq_len, num_heads, dim), dtype=jnp.float32)
51 | position_ids = jnp.arange(seq_len, dtype=jnp.int32)[None, :] # [1, seq_len]
52 | position_ids = jnp.broadcast_to(position_ids, (batch_size, seq_len))
53 |
54 | out = rope(x, position_ids)
55 | assert out.shape == (batch_size, seq_len, num_heads, dim)
56 |
--------------------------------------------------------------------------------
/jaxgarden/models/__init__.py:
--------------------------------------------------------------------------------
1 | from jaxgarden.models.base import BaseConfig, BaseModel
2 | from jaxgarden.models.gemma2 import (
3 | Gemma2Attention,
4 | Gemma2Config,
5 | Gemma2ForCausalLM,
6 | Gemma2MLP,
7 | Gemma2RMSNorm,
8 | Gemma2RotaryEmbedding,
9 | )
10 | from jaxgarden.models.gemma3 import (
11 | Gemma3Attention,
12 | Gemma3Config,
13 | Gemma3ForCausalLM,
14 | Gemma3MLP,
15 | Gemma3RMSNorm,
16 | Gemma3RotaryEmbedding,
17 | )
18 | from jaxgarden.models.generation_utils import GenerationMixin
19 | from jaxgarden.models.llama import (
20 | LlamaAttention,
21 | LlamaConfig,
22 | LlamaForCausalLM,
23 | LlamaMLP,
24 | LlamaRMSNorm,
25 | LlamaRotaryEmbedding,
26 | LlamaTransformerBlock,
27 | )
28 | from jaxgarden.models.modernbert import (
29 | ModernBertAttention,
30 | ModernBertEmbeddings,
31 | ModernBERTEncoder,
32 | ModernBERTForMaskedLM,
33 | ModernBertLayer,
34 | ModernBertMLP,
35 | )
36 |
37 | __all__ = [
38 | "BaseConfig",
39 | "BaseModel",
40 | "Gemma2Attention",
41 | "Gemma2Config",
42 | "Gemma2ForCausalLM",
43 | "Gemma2MLP",
44 | "Gemma2RMSNorm",
45 | "Gemma2RotaryEmbedding",
46 | "Gemma3Attention",
47 | "Gemma3Config",
48 | "Gemma3ForCausalLM",
49 | "Gemma3MLP",
50 | "Gemma3RMSNorm",
51 | "Gemma3RotaryEmbedding",
52 | "GenerationMixin",
53 | "LlamaAttention",
54 | "LlamaConfig",
55 | "LlamaForCausalLM",
56 | "LlamaMLP",
57 | "LlamaRMSNorm",
58 | "LlamaRotaryEmbedding",
59 | "LlamaTransformerBlock",
60 | "ModernBERTEncoder",
61 | "ModernBERTForMaskedLM",
62 | "ModernBertAttention",
63 | "ModernBertEmbeddings",
64 | "ModernBertLayer",
65 | "ModernBertMLP",
66 | ]
67 |
--------------------------------------------------------------------------------
/tests/models/test_llama.py:
--------------------------------------------------------------------------------
1 | """Tests for LLama model."""
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | from flax import nnx
6 |
7 | from jaxgarden.models.llama import LlamaConfig, LlamaForCausalLM
8 |
9 |
10 | def test_llama_initialization():
11 | """Test that the LLama model can be properly initialized."""
12 | # Set configuration for a tiny model for testing
13 | config = LlamaConfig(
14 | dim=32,
15 | n_layers=2,
16 | n_heads=4,
17 | n_kv_heads=2,
18 | head_dim=8,
19 | intermediate_size=64,
20 | vocab_size=100,
21 | )
22 |
23 | # Initialize the model
24 | key = jax.random.PRNGKey(0)
25 | rngs = nnx.Rngs(params=key)
26 |
27 | model = LlamaForCausalLM(config, rngs=rngs)
28 |
29 | # Verify the model was created with expected structure
30 | assert len(model.layers) == 2
31 | assert isinstance(model, LlamaForCausalLM)
32 |
33 |
34 | def test_llama_inference():
35 | """Test that the LLama model can run inference end-to-end."""
36 | # Set configuration for a tiny model for testing
37 | config = LlamaConfig(
38 | dim=32,
39 | n_layers=2,
40 | n_heads=4,
41 | n_kv_heads=2,
42 | head_dim=8,
43 | intermediate_size=64,
44 | vocab_size=100,
45 | )
46 |
47 | # Initialize the model
48 | key = jax.random.PRNGKey(0)
49 | rngs = nnx.Rngs(params=key)
50 |
51 | model = LlamaForCausalLM(config, rngs=rngs)
52 |
53 | # Create sample input
54 | batch_size = 1
55 | seq_len = 4
56 | input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
57 | position_ids = jnp.arange(seq_len)[None, :]
58 |
59 | # Run forward pass
60 | logits = model(input_ids, position_ids)
61 |
62 | # Check output shape
63 | assert logits.shape == (batch_size, seq_len, config.vocab_size)
64 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | """Configuration file for the Sphinx documentation builder."""
2 |
3 | import os
4 | import sys
5 |
6 | # Add the project root directory to the Python path
7 | sys.path.insert(0, os.path.abspath(".."))
8 |
9 | # Project information
10 | project = "JAXgarden"
11 | copyright = "2025, JAXgarden Contributors"
12 | author = "JAXgarden Contributors"
13 |
14 | # The full version, including alpha/beta/rc tags
15 | release = "0.1.0"
16 |
17 | # list of Sphinx extension module names
18 | extensions = [
19 | "sphinx.ext.autodoc",
20 | "sphinx.ext.napoleon",
21 | "sphinx.ext.viewcode",
22 | "sphinx.ext.intersphinx",
23 | "myst_parser",
24 | ]
25 |
26 | # Add any paths that contain templates here
27 | templates_path = ["_templates"]
28 |
29 | # List of patterns, relative to source directory, that match files and
30 | # directories to ignore when looking for source files
31 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
32 |
33 | # The theme to use for HTML and HTML Help pages
34 | html_theme = "sphinx_rtd_theme"
35 |
36 | # Add any paths that contain custom static files
37 | html_static_path = ["_static"]
38 |
39 | # Intersphinx configuration
40 | intersphinx_mapping = {
41 | "python": ("https://docs.python.org/3", None),
42 | "jax": ("https://jax.readthedocs.io/en/latest/", None),
43 | "flax": ("https://flax.readthedocs.io/en/latest/", None),
44 | }
45 |
46 | # Napoleon settings
47 | napoleon_google_docstring = True
48 | napoleon_numpy_docstring = True
49 | napoleon_include_init_with_doc = True
50 | napoleon_include_private_with_doc = False
51 | napoleon_include_special_with_doc = True
52 | napoleon_use_admonition_for_examples = False
53 | napoleon_use_admonition_for_notes = False
54 | napoleon_use_admonition_for_references = False
55 | napoleon_use_ivar = False
56 | napoleon_use_param = True
57 | napoleon_use_rtype = True
58 | napoleon_type_aliases = None
59 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=42", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "jaxgarden"
7 | version = "0.2.0"
8 | description = "High-performance model implementations in JAX"
9 | readme = "README.md"
10 | requires-python = ">=3.10"
11 | license = {file = "LICENSE"}
12 | authors = [
13 | {name = "JAXgarden Contributors"}
14 | ]
15 | classifiers = [
16 | "Development Status :: 3 - Alpha",
17 | "Intended Audience :: Developers",
18 | "Intended Audience :: Science/Research",
19 | "License :: OSI Approved :: MIT License",
20 | "Programming Language :: Python :: 3",
21 | "Programming Language :: Python :: 3.10",
22 | "Programming Language :: Python :: 3.11",
23 | "Programming Language :: Python :: 3.12",
24 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
25 | ]
26 | dependencies = [
27 | "jax==0.5.1",
28 | "flax>=0.10.4",
29 | "safetensors>=0.5.3",
30 | "huggingface-hub>=0.30.2",
31 | "orbax>=0.1.9",
32 | "jinja2>=3.1.6",
33 | "tokenizers>=0.21.1",
34 | ]
35 |
36 | [project.optional-dependencies]
37 | dev = [
38 | "pytest>=7.0.0",
39 | "pytest-cov>=4.0.0",
40 | "ruff>=0.1.0",
41 | "mypy>=1.0.0",
42 | "pre-commit",
43 | "sphinx",
44 | "sphinx-rtd-theme",
45 | "myst-parser",
46 | "torch",
47 | "torchvision",
48 | "torchaudio",
49 | "transformers",
50 | ]
51 |
52 | [project.urls]
53 | "Homepage" = "https://github.com/ml-gde/jax-layers"
54 | "Bug Tracker" = "https://github.com/ml-gde/jax-layers/issues"
55 |
56 | [tool.setuptools]
57 | packages = ["jaxgarden"]
58 |
59 | [tool.ruff]
60 | line-length = 100
61 | target-version = "py310"
62 |
63 | [tool.ruff.lint]
64 | select = ["E", "F", "I", "N", "W", "B", "UP", "C4", "PT", "RUF", "SIM", "TID"]
65 | ignore = []
66 |
67 | [tool.mypy]
68 | python_version = "3.10"
69 | warn_return_any = true
70 | warn_unused_configs = true
71 | disallow_untyped_defs = true
72 | disallow_incomplete_defs = true
73 | disable_error_code = "var-annotated"
74 |
75 | [[tool.mypy.overrides]]
76 | module = ["tests.*", "examples.*"]
77 | disallow_untyped_defs = false
78 | disallow_incomplete_defs = false
79 |
--------------------------------------------------------------------------------
/.devcontainer/README.md:
--------------------------------------------------------------------------------
1 | # Development Container for JAX Layers
2 |
3 | This directory contains configuration files for setting up a development container with JAX and CUDA support, which is especially useful for Windows users where JAX doesn't natively support CUDA.
4 |
5 | ## Prerequisites
6 |
7 | To use this development container, you need:
8 |
9 | 1. [Docker Desktop](https://www.docker.com/products/docker-desktop/) installed and configured with WSL 2 backend
10 | 2. [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) installed
11 | 3. [Visual Studio Code](https://code.visualstudio.com/) with the [Remote - Containers](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) extension
12 |
13 | ## GPU Support
14 |
15 | The container is configured to use all available GPUs. Make sure your NVIDIA drivers are up-to-date and that Docker has access to your GPUs.
16 |
17 | ## Usage
18 |
19 | 1. Open the project in Visual Studio Code
20 | 2. Click on the green icon in the bottom-left corner of VS Code
21 | 3. Select "Reopen in Container" from the menu
22 | 4. Wait for the container to build and start (this may take a while the first time)
23 |
24 | Once the container is running, you'll have a fully configured development environment with:
25 |
26 | - Python 3.10
27 | - CUDA 12.2 with cuDNN 9
28 | - JAX with CUDA support
29 | - All dependencies from pyproject.toml
30 |
31 | ## Dependency Management
32 |
33 | The container installs dependencies directly from your project's `pyproject.toml` file using the `pip install -e '.[dev]'` command, ensuring consistency between your development environment and the container.
34 |
35 | ## Customization
36 |
37 | You can customize the container by modifying:
38 |
39 | - `devcontainer.json`: VS Code settings, extensions, and container configuration
40 | - `Dockerfile`: Base image, dependencies, and environment setup
41 |
42 | ## Troubleshooting
43 |
44 | If you encounter issues with GPU access:
45 |
46 | 1. Verify that Docker Desktop is configured to use WSL 2
47 | 2. Check that NVIDIA Container Toolkit is properly installed
48 | 3. Ensure your NVIDIA drivers are up-to-date
49 | 4. Run `nvidia-smi` in WSL to verify GPU access
50 | 5. Check Docker logs for any error messages related to GPU access
51 |
--------------------------------------------------------------------------------
/examples/gemma3_inference_example.py:
--------------------------------------------------------------------------------
1 | """Example script demonstrating Gemma3 model usage."""
2 |
3 | import jax.numpy as jnp
4 | from flax import nnx
5 | from transformers import AutoTokenizer # type: ignore
6 |
7 | from jaxgarden.models.gemma3 import Gemma3Config, Gemma3ForCausalLM
8 |
9 |
10 | def main():
11 | """Run Gemma3 inference example."""
12 | # Initialize model with correct configuration
13 | print("Initializing model...")
14 | config = Gemma3Config(
15 | vocab_size=262_208,
16 | hidden_size=2304,
17 | intermediate_size=9216,
18 | num_hidden_layers=26,
19 | num_attention_heads=8,
20 | num_key_value_heads=4,
21 | head_dim=256,
22 | rope_theta=1_000_000.0,
23 | rope_local_base_freq=10_000.0,
24 | max_position_embeddings=131_072,
25 | sliding_window=4096,
26 | sliding_window_pattern=6,
27 | hidden_activation="gelu_pytorch_tanh",
28 | # Optional RoPE scaling for longer sequences
29 | rope_scaling={
30 | "rope_type": "linear",
31 | "factor": 2.0, # Enables processing sequences twice the original length
32 | },
33 | )
34 | rng = nnx.Rngs(0)
35 | model = Gemma3ForCausalLM(config, rngs=rng)
36 |
37 | # Load tokenizer
38 | print("Loading tokenizer...")
39 | tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
40 |
41 | # Prepare input
42 | prompt = "Write a short story about a robot learning to paint:"
43 | print(f"\nPrompt: {prompt}")
44 |
45 | # Tokenize input
46 | input_ids = tokenizer(prompt, return_tensors="np").input_ids
47 | input_ids = jnp.array(input_ids)
48 |
49 | # Generate text
50 | print("\nGenerating response...")
51 | max_new_tokens = 100
52 | eos_token_id = config.eos_token_id
53 |
54 | # Initialize cache for faster generation
55 | cache = None
56 | generated = input_ids
57 |
58 | for _ in range(max_new_tokens):
59 | # Get logits and updated cache
60 | logits, cache = model(
61 | generated,
62 | use_cache=True, # Enable KV caching
63 | deterministic=True, # No dropout during inference
64 | )
65 | # Get next token (use argmax for simplicity)
66 | next_token = jnp.argmax(logits[:, -1, :], axis=-1)
67 | # Check if we hit the end of sequence
68 | if next_token[0] == eos_token_id:
69 | break
70 | # Append next token
71 | generated = jnp.concatenate([generated, next_token[:, None]], axis=1)
72 |
73 | # Decode generated text
74 | generated_text = tokenizer.decode(generated[0], skip_special_tokens=True)
75 | print(f"\nGenerated text:\n{generated_text}")
76 |
77 |
78 | if __name__ == "__main__":
79 | main()
80 |
--------------------------------------------------------------------------------
/jaxgarden/__init__.py:
--------------------------------------------------------------------------------
1 | """JAXgarden - High-performance neural network layers for JAX."""
2 |
3 | from jaxgarden.attention.multi_head_attention import MultiHeadAttention
4 | from jaxgarden.functional.attention import dot_product_attention
5 | from jaxgarden.models.base import BaseConfig, BaseModel
6 | from jaxgarden.models.gemma2 import (
7 | Gemma2Attention,
8 | Gemma2Config,
9 | Gemma2ForCausalLM,
10 | Gemma2MLP,
11 | Gemma2RMSNorm,
12 | Gemma2RotaryEmbedding,
13 | )
14 | from jaxgarden.models.gemma3 import (
15 | Gemma3Attention,
16 | Gemma3Config,
17 | Gemma3ForCausalLM,
18 | Gemma3MLP,
19 | Gemma3RMSNorm,
20 | Gemma3RotaryEmbedding,
21 | )
22 | from jaxgarden.models.generation_utils import GenerationMixin
23 | from jaxgarden.models.llama import (
24 | LlamaAttention,
25 | LlamaConfig,
26 | LlamaForCausalLM,
27 | LlamaMLP,
28 | LlamaRMSNorm,
29 | LlamaRotaryEmbedding,
30 | LlamaTransformerBlock,
31 | )
32 | from jaxgarden.models.modernbert import (
33 | ModernBertAttention,
34 | ModernBertEmbeddings,
35 | ModernBERTEncoder,
36 | ModernBERTForMaskedLM,
37 | ModernBertLayer,
38 | ModernBertMLP,
39 | )
40 | from jaxgarden.models.t5 import (
41 | T5MLP,
42 | T5Attention,
43 | T5Block,
44 | T5Config,
45 | T5CrossAttention,
46 | T5ForCausalLM,
47 | T5LayerNorm,
48 | T5SelfAttention,
49 | T5Stack,
50 | )
51 | from jaxgarden.tokenization import Tokenizer # type: ignore
52 |
53 | __all__ = [
54 | "T5MLP",
55 | # Base classes
56 | "BaseConfig",
57 | "BaseModel",
58 | # Gemma Models
59 | "Gemma2Attention",
60 | "Gemma2Config",
61 | "Gemma2ForCausalLM",
62 | "Gemma2MLP",
63 | "Gemma2RMSNorm",
64 | "Gemma2RotaryEmbedding",
65 | # Gemma3 Models
66 | "Gemma3Attention",
67 | "Gemma3Config",
68 | "Gemma3ForCausalLM",
69 | "Gemma3MLP",
70 | "Gemma3RMSNorm",
71 | "Gemma3RotaryEmbedding",
72 | # Mixins
73 | "GenerationMixin",
74 | # Llama Models
75 | "LlamaAttention",
76 | "LlamaConfig",
77 | "LlamaForCausalLM",
78 | "LlamaMLP",
79 | "LlamaRMSNorm",
80 | "LlamaRotaryEmbedding",
81 | "LlamaTransformerBlock",
82 | "ModernBERTEncoder",
83 | "ModernBERTForMaskedLM",
84 | "ModernBertAttention",
85 | "ModernBertEmbeddings",
86 | "ModernBertLayer",
87 | "ModernBertMLP",
88 | # Attention modules
89 | "MultiHeadAttention",
90 | # T5 Models
91 | "T5Attention",
92 | "T5Block",
93 | "T5Config",
94 | "T5CrossAttention",
95 | "T5ForCausalLM",
96 | "T5LayerNorm",
97 | "T5SelfAttention",
98 | "T5Stack",
99 | # tokenization
100 | "Tokenizer",
101 | # Functional interfaces
102 | "dot_product_attention",
103 | ]
104 |
105 | __version__ = "0.2.0"
106 |
--------------------------------------------------------------------------------
/.devcontainer/verify_jax_cuda.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Script to verify that JAX with CUDA is working correctly.
4 | Run this script after the container is built to confirm GPU access.
5 | """
6 |
7 | import re
8 | import subprocess
9 | import time
10 |
11 | import jax
12 | import jax.numpy as jnp
13 |
14 |
15 | def get_cuda_version():
16 | """Get the CUDA version from nvcc."""
17 | try:
18 | result = subprocess.run(["nvcc", "--version"], capture_output=True, text=True)
19 | version_match = re.search(r"release (\d+\.\d+)", result.stdout)
20 | if version_match:
21 | return version_match.group(1)
22 | return "Unknown"
23 | except Exception:
24 | return "Unknown"
25 |
26 |
27 | def main():
28 | print("\n" + "=" * 50)
29 | print("JAX CUDA Verification Script")
30 | print("=" * 50)
31 |
32 | # Check JAX version
33 | print(f"JAX version: {jax.__version__}")
34 |
35 | # Check CUDA version
36 | cuda_version = get_cuda_version()
37 | print(f"CUDA version: {cuda_version}")
38 |
39 | # Check available devices
40 | print("\nAvailable devices:")
41 | for i, device in enumerate(jax.devices()):
42 | print(f" Device {i}: {device}")
43 |
44 | # Check if GPU is available
45 | gpu_available = any(d.platform == "gpu" for d in jax.devices())
46 | print(f"\nGPU available: {gpu_available}")
47 |
48 | if not gpu_available:
49 | print("\n⚠️ No GPU devices found! JAX is not using CUDA.")
50 | print("Please check your installation and GPU configuration.")
51 | return
52 |
53 | # Run a simple benchmark
54 | print("\nRunning simple matrix multiplication benchmark...")
55 |
56 | # Create large matrices
57 | n = 5000
58 | print(f"Creating {n}x{n} matrices...")
59 |
60 | # CPU benchmark
61 | with jax.devices("cpu")[0]:
62 | x_cpu = jnp.ones((n, n))
63 | y_cpu = jnp.ones((n, n))
64 |
65 | # Warm-up
66 | _ = jnp.dot(x_cpu, y_cpu)
67 | jax.block_until_ready(_)
68 |
69 | # Benchmark
70 | start = time.time()
71 | result_cpu = jnp.dot(x_cpu, y_cpu)
72 | jax.block_until_ready(result_cpu)
73 | cpu_time = time.time() - start
74 |
75 | # GPU benchmark
76 | with jax.devices("gpu")[0]:
77 | x_gpu = jnp.ones((n, n))
78 | y_gpu = jnp.ones((n, n))
79 |
80 | # Warm-up
81 | _ = jnp.dot(x_gpu, y_gpu)
82 | jax.block_until_ready(_)
83 |
84 | # Benchmark
85 | start = time.time()
86 | result_gpu = jnp.dot(x_gpu, y_gpu)
87 | jax.block_until_ready(result_gpu)
88 | gpu_time = time.time() - start
89 |
90 | # Print results
91 | print(f"\nCPU time: {cpu_time:.4f} seconds")
92 | print(f"GPU time: {gpu_time:.4f} seconds")
93 | print(f"Speedup: {cpu_time / gpu_time:.2f}x")
94 |
95 | if cpu_time > gpu_time:
96 | print("\n✅ GPU is faster than CPU! JAX with CUDA is working correctly.")
97 | else:
98 | print("\n⚠️ GPU is not faster than CPU. Something might be wrong with the CUDA setup.")
99 |
100 | print("\n" + "=" * 50)
101 |
102 |
103 | if __name__ == "__main__":
104 | main()
105 |
--------------------------------------------------------------------------------
/.cursor/rules/guide.mdc:
--------------------------------------------------------------------------------
1 | ---
2 | description: Guidelines for using jaxgarden
3 | globs:
4 | alwaysApply: true
5 | ---
6 | The `jaxgarden` project provides a library for building and utilizing neural network models, primarily focused on **Transformer architectures**, using *JAX* and *Flax NNX*. It establishes a structured framework through abstract base classes: `BaseModel` defines the core model interface, including state management (parameters, mutable states via `nnx.State`), checkpointing (*Serialization*) with Orbax, and a standardized mechanism for importing and converting weights from Hugging Face's *Safetensors* format. `BaseConfig` offers a consistent way to manage model hyperparameters (*Configuration Management*).
7 |
8 | The library implements specific models like `LlamaForCausalLM` (a decoder-only model) and `ModernBERTForMaskedLM` (an encoder model), showcasing modern techniques. Key architectural components are modularized:
9 | - `Attention Mechanism`: Provides multi-head attention, potentially leveraging hardware acceleration like *Flash Attention* via cuDNN backend selection in `dot_product_attention`.
10 | - `Rotary Position Embeddings (RoPE)`: Implements relative position encoding within the attention mechanism itself, used by both Llama and ModernBERT variants.
11 |
12 | Functionality includes:
13 | - **Text Generation**: The `GenerationMixin` adds autoregressive text generation capabilities to causal models like Llama, supporting various sampling strategies (temperature, top-k, top-p, min-p) and efficient implementation via `jax.lax.scan`.
14 | - **Tokenization**: A `Tokenizer` class wraps the Hugging Face `tokenizers` library, providing a JAX-friendly API for encoding/decoding text and applying chat templates.
15 | - **Interoperability**: Facilitates using pretrained models from the Hugging Face Hub.
16 |
17 | Overall, `jaxgarden` aims for high performance and *modularity*, enabling researchers and developers to work with modern NLP models within the JAX ecosystem, promoting *code reuse* and *extensibility* through its base classes and focused components.
18 |
19 |
20 | **Source Repository:** [https://github.com/ml-gde/jaxgarden.git](https://github.com/ml-gde/jaxgarden.git)
21 |
22 | ```mermaid
23 | flowchart TD
24 | A0["BaseModel"]
25 | A1["BaseConfig"]
26 | A2["Tokenizer"]
27 | A3["Attention Mechanism (MultiHeadAttention / dot_product_attention)"]
28 | A4["GenerationMixin"]
29 | A5["LlamaForCausalLM"]
30 | A6["ModernBERTForMaskedLM"]
31 | A7["Rotary Position Embeddings (RoPE)"]
32 | A0 -- "Manages" --> A1
33 | A5 -- "Inherits From" --> A0
34 | A6 -- "Inherits From" --> A0
35 | A5 -- "Uses Config" --> A1
36 | A6 -- "Uses Config" --> A1
37 | A4 -- "Uses Token IDs" --> A2
38 | A5 -- "Uses Attention" --> A3
39 | A6 -- "Uses Attention" --> A3
40 | A5 -- "Inherits From" --> A4
41 | A4 -- "Calls Model" --> A5
42 | A5 -- "Uses RoPE" --> A7
43 | A6 -- "Uses RoPE" --> A7
44 | ```
45 |
46 | ## Chapters
47 |
48 | [Tokenizer](tokenizer.mdc)
49 | [BaseModel](basemodel.mdc)
50 | [BaseConfig](baseconfig.mdc)
51 | [LlamaForCausalLM](llamaforcausallm.mdc)
52 | [GenerationMixin](generationmixin.mdc)
53 | [ModernBERTForMaskedLM](modernbertformaskedlm.mdc)
54 | [Attention Mechanism (MultiHeadAttention / dot_product_attention)](attention_mechanism__multiheadattention___dot_product_attention_.mdc)
55 | [Rotary Position Embeddings (RoPE)](rotary_position_embeddings__rope_.mdc)
56 |
57 |
58 | ---
59 |
60 | Generated by [Rules for AI](https://github.com/altaidevorg/rules-for-ai)
--------------------------------------------------------------------------------
/jaxgarden/functional/attention.py:
--------------------------------------------------------------------------------
1 | """Attention functions for JAX Layers.
2 |
3 | This module provides attention functions that are compatible with both JAX and Flax NNX.
4 | """
5 |
6 | from typing import Literal
7 |
8 | import jax
9 | import jax.numpy as jnp
10 |
11 |
12 | def dot_product_attention(
13 | query: jnp.ndarray,
14 | key: jnp.ndarray,
15 | value: jnp.ndarray,
16 | bias: jnp.ndarray | None = None,
17 | mask: jnp.ndarray | None = None,
18 | broadcast_dropout: bool = True,
19 | dropout_rng: jnp.ndarray | None = None,
20 | dropout_rate: float = 0.0,
21 | deterministic: bool = False,
22 | dtype: jnp.dtype | None = None,
23 | precision: jax.lax.Precision | str | None = None,
24 | implementation: Literal["xla", "cudnn", "flash"] | None = None,
25 | module: object | None = None,
26 | ) -> jnp.ndarray:
27 | """Computes dot-product attention with optional Flash Attention support.
28 |
29 | This function provides a wrapper around JAX's dot_product_attention with the option
30 | to use Flash Attention when available. It follows the Flax NNX interface while
31 | allowing the use of different implementations through the implementation parameter.
32 |
33 | Args:
34 | query: queries for calculating attention with shape of
35 | `[batch..., q_length, num_heads, qk_depth_per_head]`.
36 | key: keys for calculating attention with shape of
37 | `[batch..., kv_length, num_heads, qk_depth_per_head]`.
38 | value: values to be used in attention with shape of
39 | `[batch..., kv_length, num_heads, v_depth_per_head]`.
40 | bias: bias for the attention weights. This should be broadcastable to
41 | the shape [batch..., num_heads, q_length, kv_length].
42 | mask: mask for the attention weights. This should be broadcastable to
43 | the shape [batch..., num_heads, q_length, kv_length].
44 | broadcast_dropout: bool: use a broadcasted dropout along batch dims.
45 | dropout_rng: JAX PRNGKey: to be used for dropout.
46 | dropout_rate: dropout rate.
47 | deterministic: bool, deterministic or not (to apply dropout).
48 | dtype: the dtype of the computation (default: infer from inputs).
49 | precision: numerical precision of the computation.
50 | implementation: which implementation to use. Options are:
51 | - "xla": Use XLA's default implementation
52 | - "cudnn": Use cuDNN's Flash Attention implementation (if available)
53 | - "flash": Alias for "cudnn"
54 | - None: Automatically select the best available implementation
55 | module: the Module that will sow the attention weights.
56 |
57 | Returns:
58 | Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`.
59 | """
60 | # Map "flash" to "cudnn" for clarity
61 | if implementation == "flash":
62 | implementation = "cudnn"
63 |
64 | # Convert mask to bias if needed
65 | if mask is not None:
66 | # In JAX, mask=True means keep the value, while in Flax mask=False means mask out
67 | # So we need to invert the mask and convert to a bias
68 | if bias is None:
69 | bias = jnp.where(mask, 0.0, -1e10)
70 | else:
71 | bias = jnp.where(mask, bias, -1e10)
72 |
73 | # Call JAX's dot_product_attention with the implementation parameter
74 | return jax.nn.dot_product_attention(
75 | query=query,
76 | key=key,
77 | value=value,
78 | bias=bias,
79 | # JAX-specific parameters
80 | scale=None, # Use default scaling
81 | is_causal=False, # We handle causal masking through the bias/mask
82 | query_seq_lengths=None,
83 | key_value_seq_lengths=None,
84 | local_window_size=None,
85 | implementation=implementation,
86 | )
87 |
--------------------------------------------------------------------------------
/tests/functional/test_attention.py:
--------------------------------------------------------------------------------
1 | """Tests for the jax_layers/functional/attention.py implementations."""
2 |
3 | from unittest.mock import MagicMock
4 |
5 | import jax
6 | import jax.numpy as jnp
7 |
8 | from jaxgarden.functional import dot_product_attention
9 |
10 |
11 | def test_dot_product_attention():
12 | """Test that the dot_product_attention function works with different implementations."""
13 | batch_size = 2
14 | seq_len = 16
15 | num_heads = 4
16 | head_dim = 8
17 |
18 | key = jax.random.PRNGKey(0)
19 | key1, key2, key3 = jax.random.split(key, 3)
20 |
21 | query = jax.random.normal(key1, (batch_size, seq_len, num_heads, head_dim))
22 | key_tensor = jax.random.normal(key2, (batch_size, seq_len, num_heads, head_dim))
23 | value = jax.random.normal(key3, (batch_size, seq_len, num_heads, head_dim))
24 |
25 | # Execute using default (XLA) implementation
26 | output_default = dot_product_attention(query, key_tensor, value)
27 | output_xla = dot_product_attention(query, key_tensor, value, implementation="xla")
28 |
29 | # Verify that the output shapes and values are identical
30 | assert output_default.shape == output_xla.shape
31 | assert jnp.allclose(output_default, output_xla, rtol=1e-5, atol=1e-5)
32 |
33 |
34 | def test_dot_product_attention_flash_mapping(monkeypatch):
35 | """Test that when implementation is 'flash', it is remapped to 'cudnn'
36 | before calling jax.nn.dot_product_attention."""
37 |
38 | mock_fn = MagicMock(return_value=jnp.array(0))
39 | monkeypatch.setattr(jax.nn, "dot_product_attention", mock_fn)
40 |
41 | key = jax.random.PRNGKey(42)
42 | key1, key2, key3 = jax.random.split(key, 3)
43 |
44 | query = jax.random.normal(key1, (2, 16, 4, 8))
45 | key_tensor = jax.random.normal(key2, (2, 16, 4, 8))
46 | value = jax.random.normal(key3, (2, 16, 4, 8))
47 |
48 | _ = dot_product_attention(query, key_tensor, value, implementation="flash")
49 |
50 | mock_fn.assert_called_once_with(
51 | query=query,
52 | key=key_tensor,
53 | value=value,
54 | bias=None,
55 | scale=None,
56 | is_causal=False,
57 | query_seq_lengths=None,
58 | key_value_seq_lengths=None,
59 | local_window_size=None,
60 | implementation="cudnn",
61 | )
62 |
63 |
64 | def test_dot_product_attention_mask_without_bias():
65 | """Test the mask branch when bias is None:
66 | bias should be created as jnp.where(mask, 0.0, -1e10)."""
67 | batch_size = 2
68 | seq_len = 16
69 | num_heads = 4
70 | head_dim = 8
71 |
72 | key = jax.random.PRNGKey(1)
73 | key1, key2, key3 = jax.random.split(key, 3)
74 |
75 | query = jax.random.normal(key1, (batch_size, seq_len, num_heads, head_dim))
76 | key_tensor = jax.random.normal(key2, (batch_size, seq_len, num_heads, head_dim))
77 | value = jax.random.normal(key3, (batch_size, seq_len, num_heads, head_dim))
78 |
79 | # Create a causal mask with boolean values.
80 | mask = jnp.tril(jnp.ones((batch_size, 1, seq_len, seq_len), dtype=bool))
81 |
82 | # Compute output using the mask (without providing bias explicitly).
83 | output_with_mask = dot_product_attention(query, key_tensor, value, mask=mask)
84 |
85 | # Manually create the bias as per conversion: if mask is True -> 0.0, else -1e10.
86 | bias_manual = jnp.where(mask, 0.0, -1e10)
87 | output_with_manual_bias = dot_product_attention(query, key_tensor, value, bias=bias_manual)
88 |
89 | assert output_with_mask.shape == output_with_manual_bias.shape
90 | assert jnp.allclose(output_with_mask, output_with_manual_bias, rtol=1e-5, atol=1e-5)
91 |
92 |
93 | def test_dot_product_attention_mask_with_bias():
94 | """Test the mask branch when bias is provided:
95 | bias should be converted using jnp.where(mask, bias, -1e10)."""
96 | batch_size = 2
97 | seq_len = 16
98 | num_heads = 4
99 | head_dim = 8
100 |
101 | key = jax.random.PRNGKey(2)
102 | key1, key2, key3 = jax.random.split(key, 3)
103 |
104 | query = jax.random.normal(key1, (batch_size, seq_len, num_heads, head_dim))
105 | key_tensor = jax.random.normal(key2, (batch_size, seq_len, num_heads, head_dim))
106 | value = jax.random.normal(key3, (batch_size, seq_len, num_heads, head_dim))
107 |
108 | # Create a causal mask with boolean values.
109 | mask = jnp.tril(jnp.ones((batch_size, 1, seq_len, seq_len), dtype=bool))
110 | # Provide a custom bias (e.g., constant value 5.0).
111 | custom_bias = jnp.full(mask.shape, 5.0)
112 |
113 | # Expected bias after conversion: jnp.where(mask, custom_bias, -1e10)
114 | bias_manual = jnp.where(mask, custom_bias, -1e10)
115 |
116 | output_with_mask_bias = dot_product_attention(
117 | query,
118 | key_tensor,
119 | value,
120 | mask=mask,
121 | bias=custom_bias,
122 | )
123 | output_with_manual_bias = dot_product_attention(query, key_tensor, value, bias=bias_manual)
124 |
125 | assert output_with_mask_bias.shape == output_with_manual_bias.shape
126 | assert jnp.allclose(output_with_mask_bias, output_with_manual_bias, rtol=1e-5, atol=1e-5)
127 |
--------------------------------------------------------------------------------
/tests/attention/test_RoPEMultiHeadAttention.py:
--------------------------------------------------------------------------------
1 | """Tests for the RoPEMultiHeadAttention class."""
2 |
3 | import flax.linen as nn
4 | import jax
5 | import jax.numpy as jnp
6 | import pytest
7 |
8 | from jaxgarden.attention.rope_multi_head_attention import (
9 | RoPEMultiHeadAttention,
10 | apply_rotary_pos_emb,
11 | precompute_rotary_embeddings,
12 | rotate_half,
13 | )
14 |
15 |
16 | def test_rotate_half():
17 | """Tests the rotate_half function."""
18 | key = jax.random.PRNGKey(0)
19 | x = jax.random.normal(key, (2, 4, 6, 8)) # batch, seq, heads, dim
20 | rotated_x = rotate_half(x)
21 |
22 | assert rotated_x.shape == x.shape
23 | # Check specific values after rotation
24 | x1 = x[..., ::2]
25 | x2 = x[..., 1::2]
26 | expected = jnp.concatenate((-x2, x1), axis=-1)
27 | assert jnp.allclose(rotated_x, expected)
28 |
29 |
30 | def test_precompute_rotary_embeddings():
31 | """Tests the precompute_rotary_embeddings function."""
32 | seq_len = 16
33 | head_dim = 8
34 | base = 10000.0
35 |
36 | cos_emb, sin_emb = precompute_rotary_embeddings(seq_len, head_dim, base)
37 |
38 | assert cos_emb.shape == (1, seq_len, 1, head_dim)
39 | assert sin_emb.shape == (1, seq_len, 1, head_dim)
40 |
41 | # Check properties - e.g., cos^2 + sin^2 = 1
42 | assert jnp.allclose(cos_emb**2 + sin_emb**2, jnp.ones_like(cos_emb), atol=1e-6)
43 |
44 | # Check different base value
45 | cos_emb_b2, sin_emb_b2 = precompute_rotary_embeddings(seq_len, head_dim, base=500.0)
46 | assert not jnp.allclose(cos_emb, cos_emb_b2)
47 |
48 | # Test with odd head_dim (should raise error)
49 | with pytest.raises(ValueError, match="head_dim must be even"):
50 | precompute_rotary_embeddings(seq_len, head_dim=7)
51 |
52 |
53 | def test_apply_rotary_pos_emb():
54 | """Tests the apply_rotary_pos_emb function."""
55 | key = jax.random.PRNGKey(1)
56 | batch, seq_len, num_heads, head_dim = 2, 16, 4, 8
57 | x = jax.random.normal(key, (batch, seq_len, num_heads, head_dim))
58 |
59 | cos_emb, sin_emb = precompute_rotary_embeddings(seq_len, head_dim)
60 |
61 | rotated_x = apply_rotary_pos_emb(x, cos_emb, sin_emb)
62 |
63 | assert rotated_x.shape == x.shape
64 | # Applying RoPE again should not give the original x (unless pos=0, which isn't the whole seq)
65 | assert not jnp.allclose(rotated_x, x)
66 |
67 |
68 | # --- Test RoPEMultiHeadAttention Module ---
69 |
70 |
71 | @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16])
72 | def test_rope_mha_forward_pass(dtype):
73 | """Tests the forward pass of RoPEMultiHeadAttention."""
74 | key = jax.random.PRNGKey(2)
75 | batch_size = 2
76 | seq_len = 16
77 | num_heads = 4
78 | head_dim = 8
79 | embed_dim = num_heads * head_dim
80 |
81 | x = jax.random.normal(key, (batch_size, seq_len, embed_dim), dtype=dtype)
82 |
83 | rope_mha = RoPEMultiHeadAttention(num_heads=num_heads, head_dim=head_dim, dtype=dtype)
84 | params = rope_mha.init(key, x)["params"]
85 | output = rope_mha.apply({"params": params}, x)
86 |
87 | assert output.shape == (batch_size, seq_len, embed_dim)
88 | assert output.dtype == dtype
89 |
90 |
91 | def test_rope_mha_masking():
92 | """Tests causal masking in RoPEMultiHeadAttention."""
93 | key = jax.random.PRNGKey(3)
94 | batch_size = 1
95 | seq_len = 4
96 | num_heads = 2
97 | head_dim = 4
98 | embed_dim = num_heads * head_dim
99 |
100 | x = jax.random.normal(key, (batch_size, seq_len, embed_dim))
101 | # Create a causal mask (True means masked)
102 | causal_mask = nn.make_causal_mask(x[:, :, 0]) # Gets (batch, seq, seq) or (1, seq, seq)
103 |
104 | rope_mha = RoPEMultiHeadAttention(num_heads=num_heads, head_dim=head_dim)
105 | params = rope_mha.init(key, x, causal_mask)["params"]
106 |
107 | # Apply without mask
108 | output_unmasked = rope_mha.apply({"params": params}, x)
109 |
110 | # Apply with mask
111 | output_masked = rope_mha.apply({"params": params}, x, mask=causal_mask)
112 |
113 | # Basic check: outputs should differ if mask has an effect
114 | assert not jnp.allclose(output_unmasked, output_masked, atol=1e-5)
115 |
116 | # More rigorous check (requires inspecting attention weights, omitted for brevity)
117 |
118 |
119 | def test_rope_mha_errors():
120 | """Tests error conditions for RoPEMultiHeadAttention."""
121 | key = jax.random.PRNGKey(4)
122 | rope_mha_odd_dim = RoPEMultiHeadAttention(num_heads=8, head_dim=7)
123 | x_dummy_odd = jax.random.normal(key, (2, 16, 8 * 7))
124 | # Test with odd head_dim (should raise error during initialization/setup)
125 | with pytest.raises(ValueError, match=r"head_dim \(\d+\) must be even"):
126 | rope_mha_odd_dim.init(key, x_dummy_odd)
127 |
128 | # Test with mismatched embed_dim (should raise error during forward pass / init)
129 | rope_mha = RoPEMultiHeadAttention(num_heads=4, head_dim=8) # Expects embed_dim 32
130 | x_mismatch = jax.random.normal(key, (2, 16, 100)) # Incorrect embed_dim
131 |
132 | with pytest.raises(ValueError, match=r"embed_dim \(\d+\) must equal"):
133 | rope_mha.init(key, x_mismatch)
134 |
--------------------------------------------------------------------------------
/examples/multi_head_attention_example.py:
--------------------------------------------------------------------------------
1 | """Example demonstrating the use of MultiHeadAttention with different implementations."""
2 |
3 | import time
4 |
5 | import flax.nnx as nnx
6 | import jax
7 | import jax.numpy as jnp
8 |
9 | from jaxgarden.attention import MultiHeadAttention
10 |
11 |
12 | def benchmark_attention(implementation=None, batch_size=2, seq_len=1024, num_heads=8, head_dim=64):
13 | """Benchmark MultiHeadAttention with different implementations."""
14 | print(f"\nBenchmarking MultiHeadAttention with implementation={implementation}")
15 | print(f"Input shape (b, s, h, d) = ({batch_size}, {seq_len}, {num_heads}, {head_dim})")
16 |
17 | # Create random input data
18 | key = jax.random.PRNGKey(0)
19 | key1, key2 = jax.random.split(key)
20 | x = jax.random.normal(key1, (batch_size, seq_len, num_heads * head_dim))
21 |
22 | # Create a causal attention mask
23 | mask = jnp.tril(jnp.ones((batch_size, 1, seq_len, seq_len)))
24 |
25 | # Create the MultiHeadAttention module
26 | attention = MultiHeadAttention(
27 | num_heads=num_heads,
28 | in_features=num_heads * head_dim,
29 | implementation=implementation,
30 | rngs=nnx.Rngs(key2),
31 | )
32 |
33 | # Compile the forward pass
34 | @nnx.jit
35 | def forward(x, mask):
36 | return attention(x, mask=mask)
37 |
38 | # Warm-up
39 | _ = forward(x, mask)
40 |
41 | # Benchmark
42 | start_time = time.time()
43 | num_runs = 10
44 | for _ in range(num_runs):
45 | output = forward(x, mask)
46 | output.block_until_ready() # Ensure computation is complete
47 | end_time = time.time()
48 |
49 | avg_time = (end_time - start_time) / num_runs
50 | print(f"Average time per run: {avg_time:.6f} seconds")
51 |
52 | return output, avg_time
53 |
54 |
55 | def compare_implementations():
56 | """Compare different implementations of MultiHeadAttention."""
57 | print("Comparing different implementations of MultiHeadAttention")
58 |
59 | # Parameters for the comparison
60 | batch_size = 2
61 | seq_len = 1024
62 | num_heads = 8
63 | head_dim = 64
64 |
65 | # Benchmark each implementation
66 | implementations = [None, "xla", "cudnn"]
67 | outputs = {}
68 | times = {}
69 |
70 | for impl in implementations:
71 | try:
72 | output, avg_time = benchmark_attention(
73 | implementation=impl,
74 | batch_size=batch_size,
75 | seq_len=seq_len,
76 | num_heads=num_heads,
77 | head_dim=head_dim,
78 | )
79 | outputs[impl] = output
80 | times[impl] = avg_time
81 | except Exception as e:
82 | print(f"Error with implementation {impl}: {e}")
83 |
84 | # Compare outputs for correctness
85 | print("\nComparing outputs for correctness:")
86 | for impl1 in outputs:
87 | for impl2 in outputs:
88 | if impl1 != impl2:
89 | max_diff = jnp.max(jnp.abs(outputs[impl1] - outputs[impl2]))
90 | print(f"Max difference between {impl1} and {impl2}: {max_diff}")
91 |
92 | # Compare performance
93 | if len(times) > 1:
94 | print("\nPerformance comparison:")
95 | baseline = times[None] # Default implementation as baseline
96 | for impl, avg_time in times.items():
97 | if impl is not None:
98 | speedup = baseline / avg_time
99 | print(f"{impl} implementation: {speedup:.2f}x speedup over default")
100 |
101 |
102 | def main():
103 | """Run the example."""
104 | # Enable JAX logging to see which implementation is being used
105 | jax.config.update("jax_log_compiles", True)
106 |
107 | # Compare different implementations
108 | compare_implementations()
109 |
110 | # Show an example of using MultiHeadAttention with a specific implementation
111 | print("\nExample of using MultiHeadAttention with Flash Attention:")
112 | key = jax.random.PRNGKey(0)
113 | key1, key2 = jax.random.split(key)
114 |
115 | # Create input data
116 | batch_size = 2
117 | seq_len = 128
118 | num_heads = 8
119 | head_dim = 64
120 | hidden_dim = num_heads * head_dim
121 |
122 | x = jax.random.normal(key1, (batch_size, seq_len, hidden_dim))
123 |
124 | # Create a causal attention mask
125 | mask = jnp.tril(jnp.ones((batch_size, 1, seq_len, seq_len)))
126 |
127 | # Create the MultiHeadAttention module with Flash Attention
128 | attention = MultiHeadAttention(
129 | num_heads=num_heads,
130 | in_features=hidden_dim,
131 | implementation="flash", # Use Flash Attention (alias for "cudnn")
132 | rngs=nnx.Rngs(key2),
133 | )
134 |
135 | # Apply the attention
136 | output = attention(x, mask=mask)
137 |
138 | print(f"Input shape: {x.shape}")
139 | print(f"Output shape: {output.shape}")
140 | print(f"Output mean: {jnp.mean(output)}")
141 | print(f"Output std: {jnp.std(output)}")
142 |
143 |
144 | if __name__ == "__main__":
145 | main()
146 |
--------------------------------------------------------------------------------
/tests/attention/test_multi_head_attention.py:
--------------------------------------------------------------------------------
1 | """Tests for the MultiHeadAttention class."""
2 |
3 | import flax.nnx as nnx
4 | import jax
5 | import jax.numpy as jnp
6 |
7 | from jaxgarden.attention import MultiHeadAttention
8 |
9 |
10 | def test_multi_head_attention():
11 | """Test that the MultiHeadAttention module works with different implementations."""
12 | # Set up parameters
13 | batch_size = 2
14 | seq_len = 16
15 | num_heads = 4
16 | head_dim = 8
17 | hidden_dim = num_heads * head_dim
18 |
19 | # Create random input data
20 | key = jax.random.PRNGKey(0)
21 | key1, key2, key3 = jax.random.split(key, 3)
22 | x = jax.random.normal(key1, (batch_size, seq_len, hidden_dim))
23 |
24 | # Create the MultiHeadAttention modules with different implementations
25 | attention_default = MultiHeadAttention(
26 | num_heads=num_heads,
27 | in_features=hidden_dim,
28 | decode=False,
29 | rngs=nnx.Rngs(key2),
30 | )
31 |
32 | attention_xla = MultiHeadAttention(
33 | num_heads=num_heads,
34 | in_features=hidden_dim,
35 | decode=False,
36 | implementation="xla",
37 | rngs=nnx.Rngs(key3),
38 | )
39 |
40 | # Apply the attention with different implementations
41 | output_default = attention_default(x, x, x)
42 | output_xla = attention_xla(x, x, x)
43 |
44 | # Check that the outputs are similar (not exactly the same due to different initializations)
45 | assert output_default.shape == output_xla.shape, f"{type(output_xla)}, {len(output_xla)}"
46 |
47 |
48 | def test_multi_head_attention_shape():
49 | """Test that the MultiHeadAttention module produces the expected output shape."""
50 | # Set up parameters
51 | batch_size = 2
52 | seq_len = 16
53 | num_heads = 4
54 | head_dim = 8
55 | hidden_dim = num_heads * head_dim
56 |
57 | # Create random input data
58 | key = jax.random.PRNGKey(0)
59 | key1, key2 = jax.random.split(key)
60 | x = jax.random.normal(key1, (batch_size, seq_len, hidden_dim))
61 |
62 | # Create the MultiHeadAttention module
63 | attention = MultiHeadAttention(
64 | num_heads=num_heads,
65 | in_features=hidden_dim,
66 | decode=False,
67 | rngs=nnx.Rngs(key2),
68 | )
69 |
70 | # Apply the attention
71 | output = attention(x, x, x)
72 |
73 | # Check the output shape
74 | assert output.shape == (batch_size, seq_len, hidden_dim)
75 |
76 |
77 | def test_multi_head_attention_mask():
78 | """Test that the MultiHeadAttention module correctly applies attention masks."""
79 | # Set up parameters
80 | batch_size = 2
81 | seq_len = 16
82 | num_heads = 4
83 | head_dim = 8
84 | hidden_dim = num_heads * head_dim
85 |
86 | # Create random input data
87 | key = jax.random.PRNGKey(0)
88 | key1, key2 = jax.random.split(key)
89 | x = jax.random.normal(key1, (batch_size, seq_len, hidden_dim))
90 |
91 | # Create a causal attention mask
92 | mask = jnp.tril(jnp.ones((batch_size, 1, seq_len, seq_len)))
93 |
94 | # Create the MultiHeadAttention module
95 | attention = MultiHeadAttention(
96 | num_heads=num_heads,
97 | in_features=hidden_dim,
98 | decode=False,
99 | rngs=nnx.Rngs(key2),
100 | )
101 |
102 | # Apply the attention with and without mask
103 | output_with_mask = attention(x, x, x, mask=mask)
104 | output_without_mask = attention(x, x, x)
105 |
106 | # Check that the outputs are different
107 | assert not jnp.allclose(output_with_mask, output_without_mask)
108 |
109 |
110 | def test_multi_head_attention_self_attention():
111 | """Test that the MultiHeadAttention module works as a self-attention module."""
112 | # Set up parameters
113 | batch_size = 2
114 | seq_len = 16
115 | num_heads = 4
116 | head_dim = 8
117 | hidden_dim = num_heads * head_dim
118 |
119 | # Create random input data
120 | key = jax.random.PRNGKey(0)
121 | key1, key2 = jax.random.split(key)
122 | x = jax.random.normal(key1, (batch_size, seq_len, hidden_dim))
123 |
124 | # Create the MultiHeadAttention module
125 | attention = MultiHeadAttention(
126 | num_heads=num_heads,
127 | in_features=hidden_dim,
128 | decode=False,
129 | rngs=nnx.Rngs(key2),
130 | )
131 |
132 | # Apply the attention as self-attention
133 | output = attention(x)
134 |
135 | # Check the output shape
136 | assert output.shape == (batch_size, seq_len, hidden_dim)
137 |
138 |
139 | def test_multi_head_attention_cross_attention():
140 | """Test that the MultiHeadAttention module works as a cross-attention module."""
141 | # Set up parameters
142 | batch_size = 2
143 | q_seq_len = 16
144 | kv_seq_len = 32
145 | num_heads = 4
146 | head_dim = 8
147 | hidden_dim = num_heads * head_dim
148 |
149 | # Create random input data
150 | key = jax.random.PRNGKey(0)
151 | key1, key2, key3, key4 = jax.random.split(key, 4)
152 | q = jax.random.normal(key1, (batch_size, q_seq_len, hidden_dim))
153 | k = jax.random.normal(key2, (batch_size, kv_seq_len, hidden_dim))
154 | v = jax.random.normal(key3, (batch_size, kv_seq_len, hidden_dim))
155 |
156 | # Create the MultiHeadAttention module
157 | attention = MultiHeadAttention(
158 | num_heads=num_heads,
159 | in_features=hidden_dim,
160 | decode=False,
161 | rngs=nnx.Rngs(key4),
162 | )
163 |
164 | # Apply the attention as cross-attention
165 | output = attention(q, k, v)
166 |
167 | # Check the output shape
168 | assert output.shape == (batch_size, q_seq_len, hidden_dim)
169 |
--------------------------------------------------------------------------------
/jaxgarden/attention/multi_head_attention.py:
--------------------------------------------------------------------------------
1 | """MultiHeadAttention implementation with Flash Attention support for JAX Layers."""
2 |
3 | from collections.abc import Callable
4 | from typing import Any, Literal
5 |
6 | import flax.nnx as nnx
7 | import jax
8 | import jax.numpy as jnp
9 | from flax.nnx.nn.linear import default_kernel_init
10 |
11 | from jaxgarden.functional.attention import dot_product_attention
12 |
13 |
14 | class MultiHeadAttention(nnx.MultiHeadAttention):
15 | """Multi-head attention with support for Flash Attention.
16 |
17 | This class extends Flax NNX's MultiHeadAttention to support Flash Attention
18 | through JAX's dot_product_attention implementation parameter.
19 |
20 | Example usage:
21 |
22 | ```python
23 | import jax
24 | import jax.numpy as jnp
25 | import flax.nnx as nnx
26 | from jax_layers.attention import MultiHeadAttention
27 |
28 | # Create a MultiHeadAttention module with Flash Attention support
29 | attention = MultiHeadAttention(
30 | num_heads=8,
31 | in_features=512,
32 | implementation="cudnn", # Use cuDNN's Flash Attention if available
33 | rngs=nnx.Rngs(0),
34 | )
35 |
36 | # Initialize parameters
37 | key = jax.random.PRNGKey(0)
38 | x = jax.random.normal(key, (2, 128, 512)) # (batch, seq_length, hidden_dim)
39 |
40 | # Create a causal attention mask
41 | mask = jnp.tril(jnp.ones((2, 1, 128, 128))) # (batch, 1, q_len, kv_len)
42 |
43 | # Apply the model
44 | output = attention(x, mask=mask)
45 | ```
46 | """
47 |
48 | def __init__(
49 | self,
50 | num_heads: int,
51 | in_features: int,
52 | qkv_features: int | None = None,
53 | out_features: int | None = None,
54 | *,
55 | dtype: jnp.dtype | None = None,
56 | param_dtype: jnp.dtype = jnp.float32,
57 | broadcast_dropout: bool = True,
58 | dropout_rate: float = 0.0,
59 | deterministic: bool | None = None,
60 | precision: jax.lax.Precision | str | None = None,
61 | kernel_init: Callable = default_kernel_init,
62 | out_kernel_init: Callable | None = None,
63 | bias_init: Callable = nnx.initializers.zeros,
64 | out_bias_init: Callable | None = None,
65 | use_bias: bool = True,
66 | attention_fn: Callable | None = None,
67 | decode: bool | None = None,
68 | normalize_qk: bool = False,
69 | qkv_dot_general: Callable | None = None,
70 | out_dot_general: Callable | None = None,
71 | qkv_dot_general_cls: type | None = None,
72 | out_dot_general_cls: type | None = None,
73 | implementation: Literal["xla", "cudnn", "flash"] | None = None,
74 | rngs: nnx.Rngs,
75 | ):
76 | """Initialize the MultiHeadAttention module.
77 |
78 | Args:
79 | num_heads: number of attention heads.
80 | in_features: int or tuple with number of input features.
81 | qkv_features: dimension of the key, query, and value.
82 | out_features: dimension of the last projection.
83 | dtype: the dtype of the computation.
84 | param_dtype: the dtype passed to parameter initializers.
85 | broadcast_dropout: bool: use a broadcasted dropout along batch dims.
86 | dropout_rate: dropout rate.
87 | deterministic: if false, the attention weight is masked randomly using dropout.
88 | precision: numerical precision of the computation.
89 | kernel_init: initializer for the kernel of the Dense layers.
90 | out_kernel_init: initializer for the kernel of the output Dense layer.
91 | bias_init: initializer for the bias of the Dense layers.
92 | out_bias_init: initializer for the bias of the output Dense layer.
93 | use_bias: bool: whether pointwise QKVO dense transforms use bias.
94 | attention_fn: dot_product_attention or compatible function.
95 | decode: whether to prepare and use an autoregressive cache.
96 | normalize_qk: should QK normalization be applied.
97 | qkv_dot_general: dot_general function for QKV projection.
98 | out_dot_general: dot_general function for output projection.
99 | qkv_dot_general_cls: dot_general class for QKV projection.
100 | out_dot_general_cls: dot_general class for output projection.
101 | implementation: which implementation to use for attention. Options are:
102 | - "xla": Use XLA's default implementation
103 | - "cudnn": Use cuDNN's Flash Attention implementation (if available)
104 | - "flash": Alias for "cudnn"
105 | - None: Automatically select the best available implementation
106 | rngs: random number generator keys.
107 | """
108 | # Create a custom attention function that uses our dot_product_attention
109 | # with the specified implementation
110 | if attention_fn is None:
111 |
112 | def custom_attention_fn(
113 | query: jnp.ndarray,
114 | key: jnp.ndarray,
115 | value: jnp.ndarray,
116 | bias: jnp.ndarray | None = None,
117 | mask: jnp.ndarray | None = None,
118 | **kwargs: Any,
119 | ) -> jnp.ndarray:
120 | return dot_product_attention(
121 | query=query,
122 | key=key,
123 | value=value,
124 | bias=bias,
125 | mask=mask,
126 | implementation=implementation,
127 | **kwargs,
128 | )
129 |
130 | attention_fn = custom_attention_fn
131 |
132 | # Initialize the parent class with our custom attention function
133 | super().__init__(
134 | num_heads=num_heads,
135 | in_features=in_features,
136 | qkv_features=qkv_features,
137 | out_features=out_features,
138 | dtype=dtype,
139 | param_dtype=param_dtype,
140 | broadcast_dropout=broadcast_dropout,
141 | dropout_rate=dropout_rate,
142 | deterministic=deterministic,
143 | precision=precision,
144 | kernel_init=kernel_init,
145 | out_kernel_init=out_kernel_init,
146 | bias_init=bias_init,
147 | out_bias_init=out_bias_init,
148 | use_bias=use_bias,
149 | attention_fn=attention_fn,
150 | decode=decode,
151 | normalize_qk=normalize_qk,
152 | qkv_dot_general=qkv_dot_general,
153 | out_dot_general=out_dot_general,
154 | qkv_dot_general_cls=qkv_dot_general_cls,
155 | out_dot_general_cls=out_dot_general_cls,
156 | rngs=rngs,
157 | )
158 |
--------------------------------------------------------------------------------
/tests/models/test_t5.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import pytest
4 | from flax import nnx
5 |
6 | from jaxgarden.models.t5 import (
7 | T5MLP,
8 | T5Attention,
9 | T5Block,
10 | T5Config,
11 | T5CrossAttention,
12 | T5ForCausalLM,
13 | T5LayerNorm,
14 | T5SelfAttention,
15 | T5Stack,
16 | )
17 |
18 |
19 | @pytest.fixture(scope="module")
20 | def tiny_config():
21 | return T5Config(
22 | hidden_size=128,
23 | dim_kv=64,
24 | num_layers=2,
25 | vocab_size=100,
26 | dtype=jnp.float32,
27 | )
28 |
29 |
30 | @pytest.fixture(scope="module")
31 | def dummy_rngs():
32 | return nnx.Rngs(params=jax.random.PRNGKey(0))
33 |
34 |
35 | @pytest.mark.parametrize(("dim", "dtype"), [(1024, jnp.float32), (2048, jnp.float16)])
36 | def test_t5_layer_norm(dim, dtype):
37 | layer_norm = T5LayerNorm(dim=dim, dtype=dtype, rngs=nnx.Rngs(0))
38 | x = jnp.ones((1, dim))
39 | output = layer_norm(x)
40 |
41 | assert output.shape == (1, dim)
42 | assert output.dtype == dtype
43 |
44 |
45 | @pytest.mark.parametrize(
46 | ("dim", "intermediate_dim", "dtype"), [(512, 2048, jnp.float32), (1024, 4096, jnp.float16)]
47 | )
48 | def test_t5_mlp(dim, intermediate_dim, dtype):
49 | # small sequence length for testing
50 | seq_len = 128
51 |
52 | mlp = T5MLP(dim=dim, intermediate_dim=intermediate_dim, dtype=dtype, rngs=nnx.Rngs(0))
53 | x = jnp.ones((1, seq_len, dim))
54 | output = mlp(x)
55 |
56 | assert output.shape == (1, seq_len, dim)
57 | assert output.dtype == dtype
58 |
59 |
60 | @pytest.mark.parametrize(
61 | ("hidden_size", "dim_kv", "dtype", "masked"),
62 | [(768, 64, jnp.float32, True), (1024, 128, jnp.float16, False)],
63 | )
64 | def test_t5_attention(hidden_size, dim_kv, dtype, masked):
65 | # small sequence length for testing
66 | seq_len = 128
67 |
68 | attention = T5Attention(
69 | hidden_size=hidden_size,
70 | dim_kv=dim_kv,
71 | num_heads=hidden_size // dim_kv,
72 | relative_attention_num_buckets=32,
73 | relative_attention_max_distance=128,
74 | dtype=dtype,
75 | rngs=nnx.Rngs(0),
76 | )
77 |
78 | hidden_states = jnp.ones((1, seq_len, hidden_size))
79 | attention_mask = jnp.ones((1, seq_len)) if masked else None
80 | output = attention(hidden_states, attention_mask=attention_mask)
81 |
82 | assert output.shape == (1, seq_len, hidden_size)
83 | assert output.dtype == dtype
84 |
85 |
86 | @pytest.mark.parametrize(
87 | ("hidden_size", "dim_kv", "dtype", "masked"),
88 | [(768, 64, jnp.float32, True), (1024, 128, jnp.float16, False)],
89 | )
90 | def test_t5_self_attention(hidden_size, dim_kv, dtype, masked):
91 | # small sequence length for testing
92 | seq_len = 128
93 |
94 | config = T5Config(hidden_size=hidden_size, dim_kv=dim_kv, dtype=dtype)
95 | self_attention = T5SelfAttention(config=config, rngs=nnx.Rngs(0))
96 |
97 | hidden_states = jnp.ones((1, seq_len, hidden_size))
98 | attention_mask = jnp.ones((1, seq_len)) if masked else None
99 |
100 | output = self_attention(hidden_states, attention_mask=attention_mask)
101 |
102 | assert output.shape == (1, seq_len, hidden_size)
103 | assert output.dtype == dtype
104 |
105 |
106 | @pytest.mark.parametrize(
107 | ("hidden_size", "dim_kv", "dtype", "masked"),
108 | [(768, 64, jnp.float32, True), (1024, 128, jnp.float16, False)],
109 | )
110 | def test_t5_cross_attention(hidden_size, dim_kv, dtype, masked):
111 | # small sequence length for testing
112 | seq_len = 128
113 |
114 | config = T5Config(hidden_size=hidden_size, dim_kv=dim_kv, dtype=dtype)
115 | cross_attention = T5CrossAttention(config=config, rngs=nnx.Rngs(0))
116 |
117 | hidden_states = jnp.ones((1, seq_len, hidden_size))
118 | attention_mask = jnp.ones((1, seq_len)) if masked else None
119 |
120 | output = cross_attention(hidden_states, attention_mask=attention_mask)
121 |
122 | assert output.shape == (1, seq_len, hidden_size)
123 | assert output.dtype == dtype
124 |
125 |
126 | @pytest.mark.parametrize(
127 | ("hidden_size", "dim_kv", "dtype", "masked", "causal"),
128 | [
129 | (768, 64, jnp.float32, True, True),
130 | (768, 64, jnp.float32, False, False),
131 | (1024, 128, jnp.float16, True, False),
132 | (1024, 128, jnp.float16, False, True),
133 | ],
134 | )
135 | def test_t5_block(hidden_size, dim_kv, dtype, masked, causal):
136 | # small sequence length for testing
137 | seq_len = 128
138 |
139 | config = T5Config(hidden_size=hidden_size, dim_kv=dim_kv, dtype=dtype)
140 | block = T5Block(config=config, causal=causal, rngs=nnx.Rngs(0))
141 |
142 | hidden_states = jnp.ones((1, seq_len, hidden_size))
143 | attention_mask = jnp.ones((1, seq_len)) if masked else None
144 |
145 | output = block(hidden_states, attention_mask=attention_mask, deterministic=True)
146 |
147 | assert output.shape == (1, seq_len, hidden_size)
148 | assert output.dtype == dtype
149 |
150 |
151 | @pytest.mark.parametrize("dtype", [jnp.float32, jnp.float16])
152 | @pytest.mark.parametrize("causal", [True, False])
153 | @pytest.mark.parametrize("masked", [True, False])
154 | @pytest.mark.parametrize("with_encoder", [True, False])
155 | def test_t5_stack(dtype, causal, masked, with_encoder):
156 | hidden_size = 768
157 | dim_kv = 64
158 | vocab_size = 100
159 | num_layers = 2
160 | seq_len = 8
161 | batch_size = 2
162 |
163 | config = T5Config(hidden_size=hidden_size, dim_kv=dim_kv, dtype=dtype)
164 | embed_tokens = nnx.Embed(
165 | num_embeddings=vocab_size,
166 | features=hidden_size,
167 | dtype=dtype,
168 | rngs=nnx.Rngs(0),
169 | )
170 | stack = T5Stack(
171 | config=config,
172 | embed_tokens=embed_tokens,
173 | num_layers=num_layers,
174 | causal=causal,
175 | rngs=nnx.Rngs(0),
176 | )
177 |
178 | input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
179 | attention_mask = jnp.ones((batch_size, seq_len), dtype=dtype) if masked else None
180 |
181 | encoder_hidden_states = (
182 | jnp.ones((batch_size, seq_len, hidden_size), dtype=dtype) if with_encoder else None
183 | )
184 | encoder_attention_mask = (
185 | jnp.ones((batch_size, seq_len), dtype=dtype) if (with_encoder and masked) else None
186 | )
187 |
188 | output = stack(
189 | input_ids,
190 | attention_mask=attention_mask,
191 | encoder_hidden_states=encoder_hidden_states,
192 | encoder_attention_mask=encoder_attention_mask,
193 | deterministic=True,
194 | )
195 |
196 | assert output.shape == (batch_size, seq_len, hidden_size)
197 | assert output.dtype == dtype
198 |
199 |
200 | # --- Test Full Model ---
201 | def test_t5_for_causal_lm(tiny_config, dummy_rngs):
202 | # small values for testing
203 | seq_len = 8
204 | batch = 2
205 |
206 | model = T5ForCausalLM(config=tiny_config, rngs=dummy_rngs)
207 |
208 | input_ids = jnp.ones((batch, seq_len), dtype=jnp.int32)
209 | pos_ids = jnp.arange(seq_len)[None, :].repeat(batch, axis=0)
210 | attn_mask = None
211 |
212 | out = model(input_ids, pos_ids, attn_mask)
213 |
214 | assert out.shape == (batch, seq_len, tiny_config.vocab_size)
215 | assert out.dtype == tiny_config.dtype
216 |
--------------------------------------------------------------------------------
/tests/models/test_gemma2.py:
--------------------------------------------------------------------------------
1 | """
2 | Unit tests for Gemma2 model (jaxgarden/models/gemma2.py)
3 | """
4 |
5 | import dataclasses
6 |
7 | import jax
8 | import jax.numpy as jnp
9 | import numpy as np
10 | import pytest
11 | from flax import nnx
12 |
13 | from jaxgarden.models.gemma2 import (
14 | Gemma2Attention,
15 | Gemma2Config,
16 | Gemma2ForCausalLM,
17 | Gemma2MLP,
18 | Gemma2RMSNorm,
19 | Gemma2RotaryEmbedding,
20 | )
21 |
22 |
23 | # Helper: minimal config for fast tests
24 | @pytest.fixture(scope="module") # Use module scope for efficiency
25 | def tiny_config():
26 | return Gemma2Config(
27 | vocab_size=128,
28 | hidden_size=32,
29 | intermediate_size=64, # Must be even for GeGLU
30 | num_hidden_layers=2,
31 | num_attention_heads=4,
32 | num_key_value_heads=2,
33 | head_dim=8,
34 | context_length=16,
35 | param_dtype=jnp.float32,
36 | dtype=jnp.float32,
37 | attn_logits_soft_cap=50.0,
38 | final_logit_soft_cap=30.0,
39 | )
40 |
41 |
42 | @pytest.fixture(scope="module")
43 | def dummy_rngs():
44 | return nnx.Rngs(params=jax.random.PRNGKey(0))
45 |
46 |
47 | # --- Test Core Modules ---
48 |
49 |
50 | def test_rmsnorm_output(dummy_rngs):
51 | dim = 4
52 | norm = Gemma2RMSNorm(dim=dim, eps=1e-6, rngs=dummy_rngs)
53 | # Initialize dummy weights
54 | norm.weight = nnx.Param(jnp.ones((dim,), dtype=jnp.float32)) # Use ones for non-trivial test
55 |
56 | x = jnp.array([[0.1, 0.2, 0.3, 0.4]])
57 | out = norm(x)
58 | expected = jnp.array([[0.7302919, 1.4605838, 2.1908758, 2.9211676]])
59 |
60 | assert out.shape == (1, dim)
61 | assert jnp.allclose(out, expected, atol=1e-5)
62 |
63 | # Test with zero weights (as initialized in the module)
64 | norm.weight = nnx.Param(jnp.zeros((dim,), dtype=jnp.float32))
65 | x_ones = jnp.ones((2, dim))
66 | out_zeros_w = norm(x_ones)
67 | assert out_zeros_w.shape == (2, dim)
68 | # With input=ones and weight=zeros, output should be normalized input
69 | # (close to input / sqrt(mean(square(input))))
70 | # For ones input, mean(square(input)) = 1, sqrt = 1, output = input / 1 = 1
71 | assert jnp.allclose(out_zeros_w, jnp.ones_like(x_ones), atol=1e-6)
72 |
73 |
74 | def test_rope_embedding():
75 | dim = 4
76 | batch = 2
77 | seq = 1
78 | num_heads = 2
79 | head_dim = dim // num_heads
80 | rope = Gemma2RotaryEmbedding(dim=dim // num_heads)
81 |
82 | # Shape [B, L, N, H] - Use head_dim here!
83 | x = jnp.ones((batch, seq, num_heads, head_dim))
84 | positions = jnp.array([[1], [0]]) # Example positions
85 |
86 | out = rope(x, positions)
87 | assert out.shape == x.shape
88 |
89 | # For pos=1, head=0: sin(1/100^(0/2))=sin(1)=0.841, cos(1)=0.540
90 | # RoPE applies rotations like [x0*c - x1*s, x0*s + x1*c]
91 | # Input is all 1s.
92 | # Expected head 0, pos 1: [1*c0-1*s0, 1*s0+1*c0] = [cos(1)-sin(1), cos(1)+sin(1)]
93 | expected_pos1 = jnp.array([-0.30116867, 1.38177329])
94 | # For pos=0, sin(0)=0, cos(0)=1. Output should be unchanged [1, 1]
95 | expected_pos0 = jnp.ones((head_dim,))
96 |
97 | # Check head 0 output
98 | np.testing.assert_allclose(out[0, 0, 0, :], expected_pos1, atol=1e-5)
99 | np.testing.assert_allclose(out[1, 0, 0, :], expected_pos0, atol=1e-5)
100 |
101 |
102 | def test_attention_shape_and_gqa(tiny_config, dummy_rngs):
103 | # Initialize Gemma2Attention layer
104 | attn = Gemma2Attention(
105 | layer_idx=0,
106 | config=tiny_config,
107 | attention_type="global",
108 | rngs=dummy_rngs,
109 | )
110 | batch = 2
111 | seq = 8
112 | x = jnp.ones((batch, seq, tiny_config.hidden_size))
113 | pos_ids = jnp.arange(seq)[None, :].repeat(batch, axis=0)
114 | mask = jnp.tril(jnp.ones((seq, seq), dtype=jnp.bool_))[None, None, :, :]
115 | mask = jnp.where(mask, 0.0, -jnp.inf) # Additive mask
116 |
117 | # Gemma2Attention.__call__ always returns (output, cache)
118 | out, _ = attn(x, pos_ids, attention_mask=mask, cache=None)
119 | assert out.shape == (batch, seq, tiny_config.hidden_size)
120 |
121 | # Test repeat_kv helper
122 | kv_heads = tiny_config.num_key_value_heads
123 | head_dim = tiny_config.head_dim
124 | num_attn_heads = tiny_config.num_attention_heads
125 | n_rep = tiny_config.num_attention_heads // kv_heads
126 | # Input shape should be (batch, kv_heads, seq, head_dim)
127 | hidden_kv = jnp.ones((batch, kv_heads, seq, head_dim))
128 | repeated = attn._repeat_kv(hidden_kv, n_rep)
129 | # Expected output shape: (batch, num_attn_heads, seq, head_dim)
130 | assert repeated.shape == (batch, num_attn_heads, seq, head_dim)
131 |
132 |
133 | def test_attention_soft_cap(tiny_config, dummy_rngs):
134 | cap_value = 50.0
135 | config_with_cap = dataclasses.replace(tiny_config, attn_logits_soft_cap=cap_value)
136 | attn = Gemma2Attention(
137 | layer_idx=0,
138 | config=config_with_cap,
139 | attention_type="global",
140 | rngs=dummy_rngs,
141 | )
142 |
143 | logits = jnp.array([-100.0, -10.0, 0.0, 10.0, 100.0])
144 | capped_logits = attn.apply_soft_cap(logits, cap_value)
145 |
146 | expected = cap_value * jnp.tanh(logits / cap_value)
147 | np.testing.assert_allclose(capped_logits, expected, atol=1e-6)
148 |
149 |
150 | def test_mlp_geglu_shape(tiny_config, dummy_rngs):
151 | mlp = Gemma2MLP(config=tiny_config, rngs=dummy_rngs)
152 | batch = 2
153 | seq = 8
154 | x = jnp.ones((batch, seq, tiny_config.hidden_size))
155 | out = mlp(x)
156 | assert out.shape == (batch, seq, tiny_config.hidden_size)
157 |
158 | # Test static geglu method
159 | intermediate_x = jnp.ones((batch, seq, tiny_config.intermediate_size * 2))
160 | geglu_out = Gemma2MLP.geglu(intermediate_x)
161 | assert geglu_out.shape == (batch, seq, tiny_config.intermediate_size)
162 |
163 |
164 | # --- Test Decoder Layer ---
165 |
166 |
167 | def test_decoder_layer_structure(tiny_config, dummy_rngs):
168 | layer_idx = 0
169 | model = Gemma2ForCausalLM(config=tiny_config, rngs=dummy_rngs)
170 | decoder = model.layers[layer_idx]
171 |
172 | assert isinstance(decoder.pre_attn_norm, Gemma2RMSNorm)
173 | assert isinstance(decoder.attn, Gemma2Attention)
174 | assert isinstance(decoder.post_attn_norm, Gemma2RMSNorm)
175 | assert isinstance(decoder.pre_mlp_norm, Gemma2RMSNorm)
176 | assert isinstance(decoder.mlp, Gemma2MLP)
177 | assert isinstance(decoder.post_mlp_norm, Gemma2RMSNorm)
178 |
179 | # Check attention type alternation
180 | assert decoder.attn.attention_type == "global"
181 | layer_idx = 1
182 | decoder1 = model.layers[layer_idx]
183 | assert decoder1.attn.attention_type == "local"
184 |
185 |
186 | # --- Test Full Model ---
187 |
188 |
189 | def test_gemma2_init(tiny_config, dummy_rngs):
190 | model = Gemma2ForCausalLM(config=tiny_config, rngs=dummy_rngs)
191 | assert isinstance(model, Gemma2ForCausalLM)
192 | assert hasattr(model, "embed_tokens")
193 | assert hasattr(model, "layers")
194 | assert len(model.layers) == tiny_config.num_hidden_layers
195 | assert hasattr(model, "norm")
196 |
197 |
198 | def test_gemma2_forward_shape_dtype(tiny_config, dummy_rngs):
199 | model = Gemma2ForCausalLM(config=tiny_config, rngs=dummy_rngs)
200 | batch = 2
201 | seq = 8
202 | input_ids = jnp.ones((batch, seq), dtype=jnp.int32)
203 | pos_ids = jnp.arange(seq)[None, :].repeat(batch, axis=0)
204 | # Pass None for attention_mask; model should handle creation
205 | attn_mask = None
206 |
207 | out, cache = model(input_ids, pos_ids, attn_mask)
208 |
209 | assert out.shape == (batch, seq, tiny_config.vocab_size)
210 | assert out.dtype == tiny_config.dtype
211 | assert cache is not None # Cache returned in forward pass
212 |
213 |
214 | def test_final_logit_soft_cap(tiny_config, dummy_rngs):
215 | # Test with and without final soft cap
216 | config_no_cap = dataclasses.replace(tiny_config, final_logit_soft_cap=None)
217 | config_with_cap = dataclasses.replace(tiny_config, final_logit_soft_cap=30.0)
218 |
219 | model_no_cap = Gemma2ForCausalLM(config=config_no_cap, rngs=dummy_rngs)
220 | model_with_cap = Gemma2ForCausalLM(config=config_with_cap, rngs=dummy_rngs)
221 |
222 | nnx.update(model_with_cap, nnx.state(model_no_cap, nnx.Param))
223 |
224 | batch = 1
225 | seq = 4
226 | input_ids = jnp.ones((batch, seq), dtype=jnp.int32)
227 | pos_ids = jnp.arange(seq)[None, :].repeat(batch, axis=0)
228 | attn_mask = jnp.ones((batch, seq), dtype=jnp.bool_)
229 |
230 | out, _ = model_with_cap(input_ids, pos_ids, attn_mask)
231 | final_logits = out[:, -1, :]
232 |
233 | assert out.shape == (batch, seq, tiny_config.vocab_size)
234 | assert jnp.max(jnp.abs(final_logits)) <= config_with_cap.final_logit_soft_cap + 1e-6
235 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # JAXgarden
2 |
3 | [](https://github.com/ml-gde/jax-layers/actions/workflows/docs.yml)
4 | [](https://github.com/ml-gde/jax-layers/actions/workflows/tests.yml)
5 |
6 | 
7 |
8 | A reusable collection of high-performance neural network layers and models in JAX, aiming to match and exceed the capabilities available in the PyTorch ecosystem.
9 |
10 | ## Motivation
11 |
12 | JAXgarden was created to provide the JAX ecosystem with a comprehensive library of well-documented, thoroughly tested, and numerically accurate implementations of neural network layers and models. The project aims to:
13 |
14 | - Provide both functional APIs and Flax NNX wrappers for maximum flexibility
15 | - Ensure seamless integration with the broader JAX ecosystem, especially Flax
16 | - Facilitate easy upstreaming of implementations to core libraries
17 | - Maintain rigorous testing and documentation standards
18 | - Match or exceed the performance of equivalent PyTorch implementations
19 |
20 | Initially started within the ML GDE group, the project began with a high-performance MultiHeadAttention implementation supporting various attention backends, with plans to expand to more layers and models.
21 |
22 | ## Features
23 |
24 | - **MultiHeadAttention**: A Flax NNX-compatible implementation with support for different attention backends.
25 | - Supports JAX's native Flash Attention implementation through cuDNN
26 | - Seamlessly integrates with Flax NNX's module system
27 | - Provides a simple interface for switching between attention implementations
28 |
29 | ## Installation
30 |
31 | ```bash
32 | # Install from source
33 | git clone https://github.com/ml-gde/jax-layers.git
34 | cd jax-layers
35 | pip install -e .
36 | ```
37 |
38 | ## Usage
39 |
40 | ### LLaMA inference
41 |
42 | ```python
43 | from jaxgarden import LlamaConfig, LlamaForCausalLM, Tokenizer
44 | from flax import nnx
45 |
46 |
47 | # HF repo id of the LLaMA variant that you want to use
48 | model_id = "meta-llama/Llama-3.2-1B"
49 |
50 | # initialize the LLaMA architecture
51 | config = LlamaConfig()
52 | model = LlamaForCausalLM(config, rngs=nnx.Rngs(0))
53 |
54 | # This is a one-liner to download HF checkpoint from HuggingFace Hub,
55 | # convert it to jaxgarden format,
56 | # save it in an Orbax checkpoint,
57 | # and then remove the HF checkpoint.
58 | model.from_hf(model_id)
59 |
60 | # this works just like `transformers.AutoTokenizer`,
61 | # but without the dependency of the whole `transformers` library.
62 | # Instead, we simply extend `tokenizers` package and add some cnvenience code for JAX.
63 | tokenizer = Tokenizer.from_pretrained(model_id)
64 |
65 | text = "The meaning of life is"
66 | model_inputs = tokenizer.encode(text)
67 | output = model.generate(**model_inputs, max_length=20, do_sample=True)
68 | output_text = tokenizer.decode(output)
69 | print(output_text)
70 | ```
71 |
72 |
73 | ### MultiHeadAttention Module (Flax NNX)
74 |
75 | ```python
76 | import jax
77 | import jax.numpy as jnp
78 | import flax.nnx as nnx
79 | from jaxgarden.attention import MultiHeadAttention
80 |
81 | # Create a MultiHeadAttention module with Flash Attention support
82 | attention = MultiHeadAttention(
83 | num_heads=8,
84 | in_features=512,
85 | implementation="cudnn", # Use cuDNN's Flash Attention if available
86 | rngs=nnx.Rngs(0),
87 | )
88 |
89 | # Create input data
90 | key = jax.random.PRNGKey(0)
91 | x = jax.random.normal(key, (2, 128, 512)) # (batch, seq_length, hidden_dim)
92 |
93 | # Create a causal attention mask
94 | mask = jnp.tril(jnp.ones((2, 1, 128, 128))) # (batch, 1, q_len, kv_len)
95 |
96 | # Apply the model
97 | output = attention(x, mask=mask)
98 | ```
99 |
100 | ### RoPEMultiHeadAttention Module (Flax NNX)
101 |
102 | ```python
103 | import jax
104 | import jax.numpy as jnp
105 | import flax.linen as nn
106 | from jaxgarden.attention.rope_multi_head_attention import RoPEMultiHeadAttention
107 |
108 | # 1. Setup
109 | key = jax.random.PRNGKey(0)
110 | batch_size, seq_len = 2, 16
111 | num_heads, head_dim = 4, 32
112 | embed_dim = num_heads * head_dim
113 | x = jnp.ones((batch_size, seq_len, embed_dim))
114 |
115 | # 2. Instantiate Module
116 | attention = RoPEMultiHeadAttention(num_heads=num_heads, head_dim=head_dim)
117 |
118 | # 3. Initialize Parameters
119 | params = attention.init(key, x)['params']
120 |
121 | # 4. Apply Module (Forward Pass)
122 | output = attention.apply({'params': params}, x)
123 | ```
124 |
125 | ### Functional API
126 |
127 | #### Dot Product Attention with Implementation Selection
128 |
129 | ```python
130 | import jax
131 | import jax.numpy as jnp
132 | from jaxgarden.functional import dot_product_attention
133 |
134 | # Create random query, key, value tensors
135 | key = jax.random.PRNGKey(0)
136 | query = jax.random.normal(key, (2, 128, 8, 64)) # (batch, seq_len, heads, head_dim)
137 | key_tensor = jax.random.normal(key, (2, 128, 8, 64))
138 | value = jax.random.normal(key, (2, 128, 8, 64))
139 |
140 | # Create a causal attention mask
141 | mask = jnp.tril(jnp.ones((2, 1, 128, 128))) # (batch, 1, q_len, kv_len)
142 |
143 | # Apply dot product attention with Flash Attention implementation
144 | output = dot_product_attention(
145 | query=query,
146 | key=key_tensor,
147 | value=value,
148 | mask=mask,
149 | implementation="cudnn", # Use cuDNN's Flash Attention implementation
150 | )
151 | ```
152 |
153 | ## Development
154 |
155 | ### Setup
156 |
157 | 1. Please fork the repository to your account first.
158 | 2. Follow the instructions below.
159 |
160 | ```bash
161 | # Clone the repository
162 | git clone https://github.com/yourusername/jax-layers.git
163 | cd jax-layers
164 |
165 | # Install development dependencies
166 | pip install -e ".[dev]"
167 | ```
168 |
169 | ### Pre-commit
170 |
171 | This project uses pre-commit hooks to ensure code quality and consistency. Pre-commit automatically runs linting and formatting tools (such as ruff) before each commit, helping to catch issues early.
172 |
173 | ```bash
174 | # Install Pre-commit Hooks
175 | pre-commit install
176 |
177 | # Run Pre-commit on All Files
178 | pre-commit run --all-files
179 | ```
180 |
181 | Every time you attempt to commit, pre-commit automatically runs the configured hooks (e.g., ruff). If any issues are detected, the commit will be blocked until they are resolved.
182 |
183 | ### Testing
184 |
185 | The project maintains a comprehensive test suite to ensure correctness and numerical accuracy:
186 |
187 | ```bash
188 | # Run all tests
189 | pytest
190 |
191 | # Run tests with coverage
192 | pytest tests/ --cov=jaxgarden
193 |
194 | # Run specific test file
195 | pytest tests/test_multi_head_attention.py
196 | ```
197 |
198 | ### Code Quality
199 |
200 | We maintain high code quality standards through automated checks:
201 |
202 | ```bash
203 | # Run linting
204 | ruff check .
205 |
206 | # Run type checking
207 | mypy jaxgarden
208 |
209 | # Run tests
210 | pytest
211 | ```
212 |
213 | ### Documentation
214 |
215 | Documentation is automatically generated from docstrings:
216 |
217 | ```bash
218 | # Build documentation
219 | cd docs
220 | make html
221 | ```
222 |
223 | ### Development Container (for Windows users)
224 |
225 | Since JAX doesn't support CUDA on Windows natively, we provide a development container configuration:
226 |
227 | 1. Install [Docker Desktop](https://www.docker.com/products/docker-desktop/) with WSL 2 backend
228 | 2. Install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html)
229 | 3. Install [Visual Studio Code](https://code.visualstudio.com/) with the [Remote - Containers](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) extension
230 | 4. Open the project in VS Code
231 | 5. Click the green icon in the bottom-left corner and select "Reopen in Container"
232 |
233 | The container provides:
234 |
235 | - Python 3.10
236 | - CUDA 12.4 with cuDNN 9
237 | - JAX with CUDA support
238 | - All dependencies from your pyproject.toml
239 |
240 | See [.devcontainer/README.md](.devcontainer/README.md) for more details.
241 |
242 | ## Contributing
243 |
244 | Contributions are more than welcome! Whether it's:
245 |
246 | - Adding new layer implementations
247 | - Improving documentation
248 | - Adding tests
249 | - Reporting bugs
250 | - Suggesting improvements
251 |
252 | Please feel free to open issues and pull requests.
253 |
254 | ## License
255 |
256 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
257 |
258 | ## Acknowledgements
259 |
260 | **Google AI Developer Programs team supported this work by providing Google Cloud Credit.**
261 |
262 | - Thanks to the JAX and Flax teams for their excellent libraries.
263 | - Special thanks to the ML GDE group for initiating this project.
264 |
--------------------------------------------------------------------------------
/jaxgarden/models/base.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import shutil
4 | from collections.abc import Iterator
5 | from dataclasses import dataclass, field
6 | from pathlib import Path
7 | from typing import Any
8 |
9 | import jax
10 | import jax.numpy as jnp
11 | import orbax.checkpoint as ocp # type: ignore
12 | from flax import nnx
13 | from huggingface_hub import snapshot_download # type: ignore
14 | from safetensors import safe_open
15 |
16 | # Set up logging
17 | logger = logging.getLogger(__name__)
18 |
19 | DEFAULT_PARAMS_FILE = "jaxgarden_state"
20 |
21 |
22 | @dataclass
23 | class BaseConfig:
24 | """Base configuration for all the models implemented in the JAXgarden library.
25 |
26 | Each model implemented in JAXgarden should subclass this class for configuration management.
27 | """
28 |
29 | seed: int = 42
30 | log_level: str = "info"
31 | extra: dict[str, Any] = field(default_factory=dict)
32 |
33 | def to_dict(self) -> dict[str, Any]:
34 | return self.__dict__
35 |
36 | def update(self, **kwargs: dict) -> None:
37 | for k, v in kwargs.items():
38 | if hasattr(self, k):
39 | setattr(self, k, v)
40 | else:
41 | self.extra[k] = v
42 |
43 |
44 | class BaseModel(nnx.Module):
45 | """Base class for all the models implemented in the JAXgarden library."""
46 |
47 | def __init__(
48 | self,
49 | config: BaseConfig,
50 | *,
51 | dtype: jnp.dtype | None = None,
52 | param_dtype: jnp.dtype = jnp.float32,
53 | precision: jax.lax.Precision | str | None = None,
54 | rngs: nnx.Rngs,
55 | ):
56 | """Initialize the model.
57 |
58 | Args:
59 | config: config class for this model.
60 | dtype: Data type in which computation is performed.
61 | param_dtype: Data type in which params are stored.
62 | precision: Numerical precision.
63 | rngs: Random number generators for param initialization etc.
64 | """
65 | self.config = config
66 | self.dtype = dtype
67 | self.param_dtype = param_dtype
68 | self.precision = precision
69 | self.rngs = rngs
70 |
71 | @property
72 | def state(self) -> nnx.State:
73 | """Splits state from the graph and returns it"""
74 | return nnx.split(self, nnx.Param, ...)[1] # type: ignore
75 |
76 | @property
77 | def state_dict(self) -> dict[str, jnp.ndarray]:
78 | """Splits state from the graph and returns it as a dictionary.
79 |
80 | It can be used for serialization with orbax."""
81 | state = self.state
82 | pure_dict_state = nnx.to_pure_dict(state)
83 | return pure_dict_state
84 |
85 | def save(self, path: str) -> None:
86 | """Saves the model state to a directory.
87 |
88 | Args:
89 | path: The directory path to save the model state to.
90 | """
91 | state = self.state_dict
92 | checkpointer = ocp.StandardCheckpointer()
93 | checkpointer.save(os.path.join(path, DEFAULT_PARAMS_FILE), state)
94 | checkpointer.wait_until_finished()
95 |
96 | def load(self, path: str) -> nnx.Module:
97 | """Loads the model state from a directory.
98 |
99 | Args:
100 | path: The directory path to load the model state from.
101 | """
102 | checkpointer = ocp.StandardCheckpointer()
103 | restored_pure_dict = checkpointer.restore(os.path.join(path, DEFAULT_PARAMS_FILE))
104 | abstract_model = nnx.eval_shape(lambda: self)
105 | graphdef, abstract_state = nnx.split(abstract_model)
106 | nnx.replace_by_pure_dict(abstract_state, restored_pure_dict)
107 | return nnx.merge(graphdef, abstract_state)
108 |
109 | @staticmethod
110 | def download_from_hf(
111 | repo_id: str, local_dir: str, token: str | None = None, force_download: bool = False
112 | ) -> None:
113 | """Downloads the model from the Hugging Face Hub.
114 |
115 | Args:
116 | repo_id: The repository ID of the model to download.
117 | local_dir: The local directory to save the model to.
118 | token: The hf auth token to download the model with.
119 | - If `True`, the token is read from the HuggingFace config
120 | folder.
121 | - If a string, it's used as the authentication token.
122 | force_download (`bool`, *optional*, defaults to `False`):
123 | Whether the file should be downloaded even if it already exists in the local cache.
124 | """
125 | logger.info(f"Attempting to download {repo_id} from Hugging Face Hub to {local_dir}.")
126 | try:
127 | snapshot_download(
128 | repo_id, local_dir=local_dir, token=token, force_download=force_download
129 | )
130 | logger.info(f"Successfully downloaded {repo_id} to {local_dir}.")
131 | except Exception as e:
132 | logger.error(f"Failed to download {repo_id}: {e}")
133 | raise
134 |
135 | @staticmethod
136 | def iter_safetensors(path_to_model_weights: str) -> Iterator[tuple[Any, Any]]:
137 | """Helper function to lazily load params from safetensors file.
138 |
139 | Use this static method to iterate over weights for conversion tasks.
140 |
141 | Args:
142 | path_to_model_weights: Path to directory containing .safetensors files."""
143 | if not os.path.isdir(path_to_model_weights):
144 | raise ValueError(f"{path_to_model_weights} is not a valid directory.")
145 |
146 | safetensors_files = Path(path_to_model_weights).glob("*.safetensors")
147 |
148 | for file in safetensors_files:
149 | with safe_open(file, framework="jax", device="cpu") as f:
150 | for key in f.keys(): # noqa: SIM118
151 | yield key, f.get_tensor(key)
152 |
153 | def from_hf(
154 | self,
155 | model_repo_or_id: str,
156 | token: str | None = None,
157 | force_download: bool = False,
158 | save_in_orbax: bool = True,
159 | remove_hf_after_conversion: bool = True,
160 | ) -> None:
161 | """Downloads the model from the Hugging Face Hub and returns a new instance of the model.
162 |
163 | It can also save the converted weights in an Orbax checkpoint
164 | and removes the original HF checkpoint after conversion.
165 |
166 | Args:
167 | model_repo_or_id: The repository ID or name of the model to download.
168 | token: The token to use for authentication with the Hugging Face Hub.
169 | force_download: (`bool`, *optional*, defaults to `False`):
170 | Whether the file should be downloaded even if it already exists in the local cache.
171 | save_in_orbax: Whether to save the converted weights in an Orbax checkpoint.
172 | remove_hf_after_conversion: Whether to remove the downloaded HuggingFace checkpoint
173 | after conversion.
174 | """
175 | logger.info(f"Starting from_hf process for model: {model_repo_or_id}")
176 | local_dir = os.path.join(
177 | os.path.expanduser("~"), ".jaxgarden", "hf_models", *model_repo_or_id.split("/")
178 | )
179 | save_dir = local_dir.replace("hf_models", "models")
180 | if os.path.exists(save_dir):
181 | if force_download:
182 | logger.warning(f"Removing {save_dir} because force_download is set to True")
183 | shutil.rmtree(save_dir)
184 | else:
185 | raise RuntimeError(
186 | f"Path {save_dir} already exists."
187 | + " Set force_download to Tru to run conversion again."
188 | )
189 |
190 | logger.debug(f"Local Hugging Face model directory set to: {local_dir}")
191 |
192 | BaseModel.download_from_hf(
193 | model_repo_or_id, local_dir, token=token, force_download=force_download
194 | )
195 | logger.info(f"Initiating weight iteration from safetensors in {local_dir}")
196 | weights = BaseModel.iter_safetensors(local_dir)
197 | state = self.state
198 | logger.info("Running weight conversion...")
199 | self.convert_weights_from_hf(state, weights)
200 | logger.info("Weight conversion finished. Updating model state...")
201 | nnx.update(self, state)
202 | logger.warning("Model state successfully updated with converted weights.")
203 |
204 | if remove_hf_after_conversion:
205 | logger.warning(f"Removing HuggingFace checkpoint from {local_dir}...")
206 | shutil.rmtree(local_dir)
207 |
208 | if save_in_orbax:
209 | logger.warning(f")Saving Orbax checkpoint in {save_dir}.")
210 | self.save(save_dir)
211 |
212 | logger.warning(f"from_hf process completed for {model_repo_or_id}.")
213 |
214 | def convert_weights_from_hf(self, state: nnx.State, weights: Iterator[tuple[Any, Any]]) -> None:
215 | """Convert weights from Hugging Face Hub to the model's state.
216 |
217 | This method should be implemented in downstream classes
218 | to support conversion from HuggingFace format.
219 | """
220 | raise NotImplementedError("This model does not support conversion from HuggingFace yet.")
221 |
--------------------------------------------------------------------------------
/.cursor/rules/baseconfig.mdc:
--------------------------------------------------------------------------------
1 | ---
2 | description: JAXgarden tutorial chapter detailing BaseConfig, the base dataclass for managing model hyperparameters and configurations.
3 | globs:
4 | alwaysApply: false
5 | ---
6 | # Chapter 3: BaseConfig
7 |
8 | In the [previous chapter](basemodel.mdc), we explored `BaseModel`, the foundational class for models in `jaxgarden`. We saw that each `BaseModel` instance is initialized with a configuration object. This chapter dives into the base class for those configuration objects: `BaseConfig`.
9 |
10 | **Motivation:** Neural models are complex systems with numerous hyperparameters (like hidden size, number of layers, dropout rates) and settings (like data types, precision). Managing these configurations consistently across different models is crucial for reproducibility, maintainability, and ease of use. Using simple dictionaries or ad-hoc parameter passing can quickly become messy and error-prone. `BaseConfig` provides a standardized, structured way to define, manage, and serialize these configurations.
11 |
12 | **Central Use Case:** Defining the set of hyperparameters for a new model implementation (e.g., creating a `MyTransformerConfig` that inherits from `BaseConfig`) or inspecting and modifying the configuration of an existing `jaxgarden` model like [LlamaForCausalLM](llamaforcausallm.mdc), which uses its own `LlamaConfig` subclass derived from `BaseConfig`.
13 |
14 | ## Key Concepts
15 |
16 | `BaseConfig` leverages Python's `dataclasses` to provide a simple yet powerful configuration system:
17 |
18 | 1. **Dataclass Structure:** It's defined using the `@dataclass` decorator, offering type hints, default values, and automatic `__init__` generation.
19 | 2. **Inheritance:** Model-specific configurations (e.g., `LlamaConfig`, `ModernBERTConfig`) inherit from `BaseConfig`, adding their unique parameters while retaining the common base attributes.
20 | 3. **Common Attributes:** Provides essential base attributes applicable to most models, such as `seed` for reproducibility and `log_level` for controlling verbosity.
21 | 4. **Extensibility:** An `extra` dictionary allows storing arbitrary additional parameters not explicitly defined in the dataclass fields, offering flexibility.
22 | 5. **Serialization:** The `to_dict()` method converts the configuration object into a standard Python dictionary, useful for logging, saving, or inspection.
23 | 6. **Programmatic Updates:** The `update()` method allows modifying configuration attributes after instantiation using keyword arguments.
24 |
25 | ## Using `BaseConfig`
26 |
27 | You typically interact with subclasses of `BaseConfig` tailored to specific models.
28 |
29 | ### Defining a Custom Configuration
30 |
31 | To define a configuration for a new model, create a class inheriting from `BaseConfig` and use the `@dataclass` decorator. Add fields for your model's specific hyperparameters.
32 |
33 | ```python
34 | from dataclasses import dataclass, field
35 | from jaxgarden.models.base import BaseConfig
36 | from typing import Any
37 |
38 | @dataclass
39 | class MyModelConfig(BaseConfig):
40 | """Configuration for MyCustomModel."""
41 | hidden_size: int = 256
42 | num_layers: int = 4
43 | dropout_rate: float = 0.1
44 | activation_fn: str = "relu"
45 | # Override a base attribute default if needed
46 | seed: int = 123
47 |
48 | # Instantiate the configuration
49 | my_config = MyModelConfig(hidden_size=512) # Override default hidden_size
50 |
51 | print(f"MyModelConfig instance: {my_config}")
52 | # Output: MyModelConfig instance: MyModelConfig(seed=123, log_level='info', extra={}, hidden_size=512, num_layers=4, dropout_rate=0.1, activation_fn='relu')
53 |
54 | print(f"Hidden size: {my_config.hidden_size}")
55 | # Output: Hidden size: 512
56 | print(f"Seed (overridden): {my_config.seed}")
57 | # Output: Seed (overridden): 123
58 | print(f"Log Level (from base): {my_config.log_level}")
59 | # Output: Log Level (from base): info
60 | ```
61 |
62 | **Explanation:**
63 | - We define `MyModelConfig` inheriting from `BaseConfig`.
64 | - `@dataclass` automatically generates methods like `__init__`.
65 | - We add model-specific fields like `hidden_size` and `num_layers` with type hints and default values.
66 | - We can override defaults during instantiation (`hidden_size=512`).
67 | - Base attributes like `seed` and `log_level` are inherited.
68 |
69 | ### Serialization (`to_dict`)
70 |
71 | Convert the configuration object to a dictionary for easy inspection or storage.
72 |
73 | ```python
74 | config_dict = my_config.to_dict()
75 | print(f"Configuration as dictionary: {config_dict}")
76 | # Output: Configuration as dictionary: {'seed': 123, 'log_level': 'info', 'extra': {}, 'hidden_size': 512, 'num_layers': 4, 'dropout_rate': 0.1, 'activation_fn': 'relu'}
77 |
78 | import json
79 | print(f"Configuration as JSON: {json.dumps(config_dict, indent=2)}")
80 | # Output:
81 | # Configuration as JSON: {
82 | # "seed": 123,
83 | # "log_level": "info",
84 | # "extra": {},
85 | # "hidden_size": 512,
86 | # "num_layers": 4,
87 | # "dropout_rate": 0.1,
88 | # "activation_fn": "relu"
89 | # }
90 | ```
91 |
92 | **Explanation:** The `to_dict()` method simply returns the internal `__dict__` of the dataclass instance, making it compatible with standard serialization libraries like `json`.
93 |
94 | ### Programmatic Updates (`update`)
95 |
96 | Modify configuration values after the object has been created.
97 |
98 | ```python
99 | print(f"Original dropout rate: {my_config.dropout_rate}")
100 | # Output: Original dropout rate: 0.1
101 | print(f"Original extra dict: {my_config.extra}")
102 | # Output: Original extra dict: {}
103 |
104 |
105 | # Update existing attributes and add new ones to 'extra'
106 | update_params = {
107 | "dropout_rate": 0.15,
108 | "new_experimental_param": True,
109 | "learning_rate": 1e-4
110 | }
111 | my_config.update(**update_params)
112 |
113 | print(f"Updated dropout rate: {my_config.dropout_rate}")
114 | # Output: Updated dropout rate: 0.15
115 | print(f"Updated extra dict: {my_config.extra}")
116 | # Output: Updated extra dict: {'new_experimental_param': True, 'learning_rate': 0.0001}
117 | ```
118 |
119 | **Explanation:** The `update()` method iterates through the provided keyword arguments. If an argument name matches an existing attribute of the config object, it updates that attribute. If the name doesn't match any defined attribute, it's added to the `extra` dictionary.
120 |
121 | ## Internal Implementation
122 |
123 | `BaseConfig` is fundamentally a Python `dataclass` with a couple of added helper methods.
124 |
125 | * **Location:** `jaxgarden/models/base.py`
126 |
127 | * **Definition:**
128 | ```python
129 | # From jaxgarden/models/base.py
130 | from dataclasses import dataclass, field
131 | from typing import Any
132 |
133 | @dataclass
134 | class BaseConfig:
135 | """Base configuration for models."""
136 | seed: int = 42
137 | log_level: str = "info"
138 | extra: dict[str, Any] = field(default_factory=dict)
139 | # ... (Subclasses add more fields here)
140 |
141 | def to_dict(self) -> dict[str, Any]:
142 | # Simply returns the instance's dictionary
143 | return self.__dict__
144 |
145 | def update(self, **kwargs: dict) -> None:
146 | # Iterates through provided keyword arguments
147 | for k, v in kwargs.items():
148 | if hasattr(self, k):
149 | # If attribute exists, set its value
150 | setattr(self, k, v)
151 | else:
152 | # Otherwise, add it to the 'extra' dictionary
153 | self.extra[k] = v
154 | ```
155 | * **Mechanism:**
156 | 1. `@dataclass`: Handles the `__init__`, `__repr__`, `__eq__`, etc., automatically based on the defined fields and type hints.
157 | 2. `field(default_factory=dict)`: Ensures that each instance gets its own empty `extra` dictionary by default, preventing accidental sharing between instances.
158 | 3. `to_dict()`: Leverages the standard Python object attribute `__dict__` for a direct conversion to a dictionary.
159 | 4. `update()`: Uses Python's reflection capabilities (`hasattr`, `setattr`) to dynamically update attributes or populate the `extra` dictionary.
160 |
161 | ## Relationship with `BaseModel`
162 |
163 | As discussed in [Chapter 2: BaseModel](basemodel.mdc), every `BaseModel` instance requires a configuration object (a subclass of `BaseConfig`) during initialization. This `config` object is stored as an attribute (`self.config`) within the model instance, making the model's hyperparameters readily accessible.
164 |
165 | ```python
166 | # Conceptual reminder from BaseModel chapter
167 | # class MyModel(BaseModel):
168 | # def __init__(self, config: MyModelConfig, *, rngs):
169 | # super().__init__(config, rngs=rngs) # Pass config to BaseModel
170 | # # Access config later: self.config.hidden_size
171 | # self.layer = nnx.Linear(config.hidden_size, ...)
172 |
173 | # my_config = MyModelConfig()
174 | # model = MyModel(config=my_config, ...)
175 | # print(model.config.num_layers) # Accessing config via the model
176 | ```
177 |
178 | ## Conclusion
179 |
180 | `BaseConfig` provides a clean, structured, and extensible foundation for managing model configurations in `jaxgarden`. By using Python dataclasses and offering simple methods for serialization (`to_dict`) and updates (`update`), it promotes consistency and simplifies the process of defining and working with model hyperparameters. All specific model configurations, like the one we'll see next for Llama, build upon this base.
181 |
182 | **Next:** [LlamaForCausalLM](llamaforcausallm.mdc)
183 |
184 |
185 | ---
186 |
187 | Generated by [Rules for AI](https://github.com/altaidevorg/rules-for-ai)
--------------------------------------------------------------------------------
/jaxgarden/attention/rope_multi_head_attention.py:
--------------------------------------------------------------------------------
1 | """
2 | JAX/Flax implementation of Multi-Head Attention with Rotary Positional Embedding (RoPE).
3 |
4 | This code implements the RoPE technique within a standard Multi-Head Attention
5 | framework. RoPE injects relative positional information by rotating pairs of
6 | features in the Query and Key vectors based on their absolute position before
7 | the attention calculation.
8 |
9 | The method was introduced in the paper:
10 | "RoFormer: Enhanced Transformer with Rotary Position Embedding"
11 | by Jianlin Su, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, Yunfeng Liu.
12 | arXiv:2104.09864v5 [cs.CL] (Submitted on 20 Apr 2021)
13 | """
14 |
15 | import flax.linen as nn
16 | import jax
17 | import jax.numpy as jnp
18 |
19 |
20 | def rotate_half(x: jnp.ndarray) -> jnp.ndarray:
21 | """Rotates half the hidden dims of the input tensor."""
22 | x1 = x[..., ::2]
23 | x2 = x[..., 1::2]
24 | # Builds the rotated tensor by concatenating the negated second half
25 | # and the first half along the last dimension.
26 | return jnp.concatenate((-x2, x1), axis=-1)
27 |
28 |
29 | def apply_rotary_pos_emb(x: jnp.ndarray, cos_emb: jnp.ndarray, sin_emb: jnp.ndarray) -> jnp.ndarray:
30 | """Applies Rotary Positional Embedding to the input tensor.
31 |
32 | Args:
33 | x: Input tensor, e.g., query or key (batch, seq_len, num_heads, head_dim)
34 | cos_emb: Cosine component of the positional embedding.
35 | Shape: (1, seq_len, 1, head_dim) or compatible via broadcasting.
36 | sin_emb: Sine component of the positional embedding.
37 | Shape: (1, seq_len, 1, head_dim) or compatible via broadcasting.
38 | Returns:
39 | Tensor with RoPE applied.
40 | """
41 | # Applying the rotation formula:
42 | # x_rotated = x * cos(theta) + rotate_half(x) * sin(theta)
43 | # Ensure shapes are broadcastable: cos_emb and sin_emb should have dimensions
44 | # for sequence length and features, matching the corresponding dimensions in x.
45 | # Typically, precomputed embeddings have shape (seq_len, head_dim)
46 | # or (1, seq_len, 1, head_dim) for easy broadcasting.
47 | return (x * cos_emb) + (rotate_half(x) * sin_emb)
48 |
49 |
50 | def precompute_rotary_embeddings(
51 | seq_len: int, head_dim: int, base: float = 10000.0
52 | ) -> tuple[jnp.ndarray, jnp.ndarray]:
53 | """Precomputes the RoPE cosine and sine embeddings.
54 |
55 | Args:
56 | seq_len: The maximum sequence length.
57 | head_dim: The dimension of each attention head (must be even).
58 | base: The base value for the inverse frequency calculation.
59 |
60 | Returns:
61 | cos_emb: Cosine embeddings (1, seq_len, 1, head_dim)
62 | sin_emb: Sine embeddings (1, seq_len, 1, head_dim)
63 | """
64 | if head_dim % 2 != 0:
65 | raise ValueError(f"head_dim must be even, got {head_dim}")
66 |
67 | # Calculate inverse frequencies (theta_i)
68 | # theta_i = 1 / (base^(2*i / head_dim)) for i in [0, 1, ..., head_dim/2 - 1]
69 | inv_freq = 1.0 / (base ** (jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim))
70 |
71 | # Calculate position indices (m)
72 | pos = jnp.arange(seq_len, dtype=jnp.float32)
73 |
74 | # Calculate angles (m * theta_i)
75 | freqs = jnp.outer(pos, inv_freq) # Shape: (seq_len, head_dim / 2)
76 |
77 | # Duplicate frequencies for the full head dimension (for both elements in pairs)
78 | emb = jnp.concatenate((freqs, freqs), axis=-1) # Shape: (seq_len, head_dim)
79 |
80 | # Calculate cosine and sine embeddings
81 | cos_emb = jnp.cos(emb)[None, :, None, :] # Shape: (1, seq_len, 1, head_dim)
82 | sin_emb = jnp.sin(emb)[None, :, None, :] # Shape: (1, seq_len, 1, head_dim)
83 |
84 | return cos_emb, sin_emb
85 |
86 |
87 | class RoPEMultiHeadAttention(nn.Module):
88 | """Multi-Head Attention with Rotary Positional Embeddings."""
89 |
90 | num_heads: int
91 | head_dim: int
92 | rope_base: float = 10000.0
93 | dtype: jnp.dtype = jnp.float32
94 |
95 | def setup(self) -> None: # Added -> None return type
96 | """Initializes the attention projections."""
97 | # Check head_dim validity early during setup
98 | if self.head_dim % 2 != 0:
99 | raise ValueError(f"head_dim ({self.head_dim}) must be even for RoPE.")
100 |
101 | # Define layers here - they will be initialized when the module is first called
102 | total_head_dim = self.num_heads * self.head_dim
103 | self.query_proj = nn.Dense(
104 | features=total_head_dim, use_bias=False, dtype=self.dtype, name="query_proj"
105 | )
106 | self.key_proj = nn.Dense(
107 | features=total_head_dim, use_bias=False, dtype=self.dtype, name="key_proj"
108 | )
109 | self.value_proj = nn.Dense(
110 | features=total_head_dim, use_bias=False, dtype=self.dtype, name="value_proj"
111 | )
112 | self.output_proj = nn.Dense(
113 | features=self.num_heads * self.head_dim, # Output should match embed_dim
114 | use_bias=False,
115 | dtype=self.dtype,
116 | name="output_proj",
117 | )
118 |
119 | @nn.compact
120 | # Also using Optional for the mask type hint for clarity with None default
121 | def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarray:
122 | """Forward pass for RoPE MHA.
123 |
124 | Args:
125 | x: Input tensor (batch_size, seq_len, embed_dim).
126 | mask: Optional attention mask (batch_size, 1, seq_len, seq_len)
127 | or (batch_size, 1, 1, seq_len) for causal masking.
128 | Mask values should be 0 where attention is allowed, -inf otherwise.
129 | Flax convention often uses boolean masks (True=masked). We'll handle both.
130 |
131 | Returns:
132 | Output tensor (batch_size, seq_len, embed_dim).
133 | """
134 | batch_size, seq_len, embed_dim = x.shape
135 | total_head_dim = self.num_heads * self.head_dim
136 |
137 | if embed_dim != total_head_dim:
138 | raise ValueError(
139 | f"embed_dim ({embed_dim}) must equal num_heads*head_dim ({total_head_dim})"
140 | )
141 | # Note: head_dim even check moved to setup for earlier failure
142 |
143 | # 1. Linear projections for Q, K, V
144 | query = self.query_proj(x)
145 | key = self.key_proj(x)
146 | value = self.value_proj(x)
147 |
148 | # 2. Reshape for multi-head processing
149 | # (batch, seq_len, embed_dim) -> (batch, seq_len, num_heads, head_dim)
150 | query = query.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
151 | key = key.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
152 | value = value.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
153 |
154 | # 3. Precompute RoPE embeddings (cosine and sine)
155 | # We compute them dynamically based on the input sequence length
156 | cos_emb, sin_emb = precompute_rotary_embeddings(seq_len, self.head_dim, base=self.rope_base)
157 | # Ensure RoPE embeddings have correct dtype
158 | cos_emb = cos_emb.astype(self.dtype)
159 | sin_emb = sin_emb.astype(self.dtype)
160 |
161 | # 4. Apply RoPE to Query and Key
162 | query = apply_rotary_pos_emb(query, cos_emb, sin_emb)
163 | key = apply_rotary_pos_emb(key, cos_emb, sin_emb)
164 |
165 | # 5. Transpose for attention calculation: (batch, num_heads, seq_len, head_dim)
166 | query = query.transpose((0, 2, 1, 3))
167 | key = key.transpose((0, 2, 1, 3))
168 | value = value.transpose((0, 2, 1, 3))
169 |
170 | # 6. Scaled Dot-Product Attention
171 | # Attention scores: (batch, num_heads, seq_len, seq_len)
172 | attn_scores = jnp.matmul(query, key.transpose((0, 1, 3, 2))) / jnp.sqrt(
173 | self.head_dim
174 | ).astype(self.dtype) # Ensure sqrt is correct dtype
175 |
176 | # Apply mask (if provided)
177 | if mask is not None:
178 | # Standard Flax causal mask is boolean (True means mask)
179 | # nn.make_causal_mask returns (1, seq_len, seq_len) or (batch, 1, seq_len, seq_len)
180 | # Check if mask needs broadcasting or conversion
181 | if mask.ndim == 2: # Likely (seq_len, seq_len)
182 | mask = mask[None, None, :, :] # -> (1, 1, seq_len, seq_len)
183 | elif mask.ndim == 3 and mask.shape[1] != self.num_heads:
184 | # Likely (batch, seq_len, seq_len) or causal (1, sl, sl)
185 | mask = mask[:, None, :, :]
186 | # Assume (batch, seq_len, seq_len) -> (batch, 1, seq_len, seq_len)
187 |
188 | # Ensure mask is broadcastable to attn_scores shape
189 | mask_shape_expected = (batch_size, self.num_heads, seq_len, seq_len)
190 | if mask.shape != mask_shape_expected:
191 | # Attempt broadcasting common causal mask shapes
192 | if mask.shape == (1, 1, seq_len, seq_len) or mask.shape == (
193 | batch_size,
194 | 1,
195 | seq_len,
196 | seq_len,
197 | ): # Causal mask for all batches/heads
198 | mask = jnp.broadcast_to(mask, mask_shape_expected)
199 | # Add other broadcasting cases if needed
200 | else:
201 | raise ValueError(f"Mask shape {mask.shape} != exp shape {mask_shape_expected}")
202 |
203 | # Apply mask: Use large negative number where mask is True
204 | # (or where mask value is 0 if using 0/-inf convention)
205 | # Assuming boolean mask convention (True = mask) common in Flax examples
206 | # If using 0/-inf mask, the logic would be: attn_scores = attn_scores + mask
207 | attn_scores = jnp.where(mask, jnp.finfo(self.dtype).min, attn_scores)
208 |
209 | # Softmax to get attention weights
210 | attn_weights = jax.nn.softmax(attn_scores, axis=-1).astype(
211 | self.dtype
212 | ) # Shape: (batch, num_heads, seq_len, seq_len)
213 |
214 | # Apply attention weights to Value
215 | # Output per head: (batch, num_heads, seq_len, head_dim)
216 | attn_output = jnp.matmul(attn_weights, value)
217 |
218 | # 7. Concatenate heads and final projection
219 | # Transpose back: (batch, seq_len, num_heads, head_dim)
220 | attn_output = attn_output.transpose((0, 2, 1, 3))
221 | # Reshape to (batch, seq_len, embed_dim)
222 | attn_output = attn_output.reshape(batch_size, seq_len, total_head_dim)
223 |
224 | # Final linear projection
225 | output = self.output_proj(attn_output) # Use self.output_proj defined in setup
226 |
227 | return output
228 |
--------------------------------------------------------------------------------
/.cursor/rules/rotary_position_embeddings__rope_.mdc:
--------------------------------------------------------------------------------
1 | ---
2 | description: JAXgarden tutorial chapter on Rotary Position Embeddings (RoPE) for injecting relative positional information in attention.
3 | globs:
4 | alwaysApply: false
5 | ---
6 | # Chapter 8: Rotary Position Embeddings (RoPE)
7 |
8 | In the [previous chapter](attention_mechanism__multiheadattention___dot_product_attention_.mdc), we explored the core `MultiHeadAttention` mechanism and its efficient backend implementations. However, standard self-attention is permutation-invariant, meaning it doesn't inherently understand the order of tokens in a sequence. To address this, models need positional information. This chapter introduces Rotary Position Embeddings (RoPE), a clever technique for incorporating relative position information directly into the attention mechanism.
9 |
10 | **Motivation:** Traditional methods for adding positional information include adding learned absolute position embeddings or fixed sinusoidal embeddings to the input token embeddings. While effective, these methods add extra parameters (learned embeddings) or modify the input representation before the transformer layers. RoPE offers an alternative by *rotating* the query and key vectors based on their absolute positions *before* the dot-product calculation in the attention mechanism. This rotation is designed such that the dot product between a rotated query and a rotated key naturally depends on their relative positions, effectively injecting relative positional awareness directly into the attention scores without adding separate embedding vectors.
11 |
12 | **Central Use Case:** Applying RoPE within the attention modules of models like [LlamaForCausalLM](llamaforcausallm.mdc) (`LlamaAttention`) or [ModernBERTForMaskedLM](modernbertformaskedlm.mdc) (`ModernBertAttention`). RoPE modifies the query and key vectors just before their dot product, enabling the attention mechanism to weigh interactions based on how far apart tokens are in the sequence.
13 |
14 | ## Key Concepts
15 |
16 | 1. **Core Idea:** RoPE operates by viewing pairs of features in the query and key vectors as complex numbers and rotating them in the complex plane. The angle of rotation depends on the token's absolute position (`m`) and the feature index (`i`). When computing the dot product between a rotated query (at position `m`) and a rotated key (at position `n`), the resulting score implicitly depends on their relative distance (`m - n`).
17 | 2. **Mathematical Basis:** The rotation is achieved using sinusoidal functions derived from the position index (`m`) and the feature dimension (`d`). Specifically, frequencies (`theta_i`) are calculated based on the feature index `i` and a base value (`base`, often 10000 or larger). The rotation involves multiplying elements by `cos(m * theta_i)` and `sin(m * theta_i)`.
18 | 3. **Implementation Variations:** `jaxgarden` provides two main implementations reflecting common practices:
19 | * **`LlamaRotaryEmbedding`:** Used in [LlamaForCausalLM](llamaforcausallm.mdc). It calculates the `cos` and `sin` values *on-the-fly* based on the input `position_ids`. This is flexible for varying sequence lengths during generation.
20 | * **`RoPEPositionalEmbedding`:** Used in [ModernBERTForMaskedLM](modernbertformaskedlm.mdc). It *pre-computes* and caches the `cos` and `sin` values for positions up to a specified `max_position_embeddings` during initialization. This can be more efficient if the maximum sequence length is known and fixed, as it replaces calculation with lookup.
21 | 4. **Application Point:** RoPE is applied to the query and key vectors *after* their initial linear projections (`Wq`, `Wk`) but *before* the dot-product attention score calculation (`QK^T`). This modification is typically encapsulated within a helper function (e.g., `apply_rotary_pos_emb`) called by the attention module.
22 |
23 | ## Using RoPE
24 |
25 | RoPE is generally an internal component of attention modules. You typically don't interact with `LlamaRotaryEmbedding` or `RoPEPositionalEmbedding` directly unless implementing a custom attention layer. Instead, you configure the model (via its `Config` object) which implicitly configures the RoPE within its attention layers.
26 |
27 | **Example 1: RoPE within `LlamaAttention`**
28 |
29 | The `LlamaAttention` module internally instantiates `LlamaRotaryEmbedding` and uses it within its `__call__` method.
30 |
31 | ```python
32 | # Inside LlamaAttention initialization (__init__)
33 | # config values like head_dim and rope_theta are passed
34 | self.rotary_emb = LlamaRotaryEmbedding(
35 | dim=config.head_dim,
36 | base=config.rope_theta,
37 | rngs=rngs
38 | )
39 |
40 | # Inside LlamaAttention forward pass (__call__)
41 | def __call__(self, x, position_ids, attention_mask):
42 | # ... project x to query, key, value ...
43 | query = self.q_proj(x).reshape(...) # [batch, n_heads, seq_len, head_dim]
44 | key = self.k_proj(x).reshape(...) # [batch, n_kv_heads, seq_len, head_dim]
45 | value = self.v_proj(x).reshape(...) # [batch, n_kv_heads, seq_len, head_dim]
46 |
47 | # Calculate cos/sin on the fly (assuming batch_size=1 for simplicity here)
48 | cos, sin = self.rotary_emb(position_ids[0])
49 |
50 | # Apply RoPE rotation *before* attention calculation
51 | query, key = self.apply_rotary_pos_emb(query, key, cos, sin)
52 |
53 | # ... repeat key/value if GQA, compute attention scores ...
54 | # attn_weights = jnp.matmul(query, key.transpose(...)) / scale
55 | # ... compute final output ...
56 | return output
57 | ```
58 | **Explanation:** `LlamaAttention` creates a `LlamaRotaryEmbedding` instance. In the forward pass, it calls this instance with `position_ids` to get the `cos` and `sin` values dynamically. These are then passed to `apply_rotary_pos_emb` (a method within `LlamaAttention`) to rotate `query` and `key`.
59 |
60 | **Example 2: RoPE within `ModernBertAttention`**
61 |
62 | `ModernBertAttention` uses the pre-computed cache from `RoPEPositionalEmbedding`.
63 |
64 | ```python
65 | # Inside ModernBertAttention initialization (__init__)
66 | # config values like head_dim, max_position_embeddings, rope_theta are passed
67 | self.rotary_emb = RoPEPositionalEmbedding(
68 | rngs=rngs,
69 | dim=self.head_dim,
70 | max_position_embeddings=max_pos, # Can depend on local/global
71 | base=rope_theta # Can depend on local/global
72 | )
73 | # The cache is stored in self.rotary_emb.cache
74 |
75 | # Inside ModernBertAttention forward pass (__call__)
76 | # It uses a standalone apply_rotary_pos_emb function defined in modernbert.py
77 | from jaxgarden.models.modernbert import apply_rotary_pos_emb
78 |
79 | def __call__(self, hidden_states, attention_mask, ..., position_ids):
80 | # ... project hidden_states to qkv ...
81 | # Split into query, key, value: [batch, num_heads, seq_len, head_dim]
82 | # Transpose for RoPE application: [batch, seq_len, num_heads, head_dim]
83 | query = query.transpose((0, 2, 1, 3))
84 | key = key.transpose((0, 2, 1, 3))
85 |
86 | # Apply RoPE using the pre-computed cache
87 | query = apply_rotary_pos_emb(query, self.rotary_emb.cache, position_ids)
88 | key = apply_rotary_pos_emb(key, self.rotary_emb.cache, position_ids)
89 |
90 | # Transpose back: [batch, num_heads, seq_len, head_dim]
91 | query = query.transpose((0, 2, 1, 3))
92 | key = key.transpose((0, 2, 1, 3))
93 |
94 | # ... compute attention scores ...
95 | # attn_scores = jnp.matmul(query, key.swapaxes(-2, -1)) / scale
96 | # ... compute final output ...
97 | return output_tuple
98 | ```
99 | **Explanation:** `ModernBertAttention` creates a `RoPEPositionalEmbedding` instance, which pre-computes the cache. In the forward pass, the standalone `apply_rotary_pos_emb` function is called, passing the `query` and `key` vectors along with the *cached* sinusoidal values (`self.rotary_emb.cache`) and optional `position_ids` for lookup.
100 |
101 | ## Internal Implementation
102 |
103 | Let's examine the core logic.
104 |
105 | ### `LlamaRotaryEmbedding` (On-the-fly Calculation)
106 |
107 | * **File:** `jaxgarden/models/llama.py`
108 |
109 | **Walkthrough:**
110 | 1. The `__call__` method receives `position_ids` (typically `[1, seq_len]`).
111 | 2. It calculates the inverse frequencies (`inv_freq`) based on `self.dim` and `self.base`. This determines how quickly the angle changes for different feature dimensions.
112 | 3. It uses `jnp.einsum` to efficiently compute the angles (`freqs`) for all positions and half the dimensions: `angle = position_id * inv_freq`.
113 | 4. It concatenates the angles to cover the full dimension (`emb`).
114 | 5. It computes and returns the cosine and sine of these angles.
115 |
116 | ```python
117 | # Simplified from LlamaRotaryEmbedding.__call__
118 | @nnx.jit()
119 | def __call__(self, position_ids: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
120 | # inv_freq shape: [dim/2]
121 | inv_freq = 1.0 / (self.base ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim))
122 | # Expand dims for einsum: inv_freq -> [1, 1, dim/2]
123 | inv_freq_expanded = jnp.expand_dims(inv_freq, axis=(0, 1))
124 | # position_ids -> [1, seq_len, 1]
125 | position_ids_expanded = jnp.expand_dims(position_ids, axis=(0, 2)).astype(jnp.float32)
126 |
127 | # Calculate angles: freqs shape [1, seq_len, dim/2]
128 | freqs = jnp.einsum("bij,bjk->bik", position_ids_expanded, inv_freq_expanded)
129 |
130 | # Concatenate angles for full dimension: emb shape [1, seq_len, dim]
131 | emb = jnp.concatenate([freqs, freqs], axis=-1)
132 |
133 | # Compute cos and sin: Shapes [seq_len, dim] (after squeeze)
134 | cos = jnp.cos(emb).squeeze(0).astype(jnp.bfloat16) # Assuming batch=1 input
135 | sin = jnp.sin(emb).squeeze(0).astype(jnp.bfloat16)
136 | return cos, sin
137 | ```
138 |
139 | ### `RoPEPositionalEmbedding` (Pre-computed Cache)
140 |
141 | * **File:** `jaxgarden/models/modernbert.py`
142 |
143 | **Walkthrough:**
144 | 1. **`create_sinusoidal_positions` (Helper Function):** Calculates angles `pos * theta` for all positions up to `max_length` and half the dimensions. Returns stacked `(cos, sin)` values. Shape: `[max_length, dim//2, 2]`.
145 | 2. **`RoPEPositionalEmbedding.__init__`:** Calls `create_sinusoidal_positions` and stores the result in `self.cache`.
146 | 3. **`apply_rotary_pos_emb` (Helper Function):**
147 | * Takes an input tensor `x` (query or key), the `cache`, and optional `positions`.
148 | * Determines the sequence length (`seq_len`) from `x`.
149 | * Looks up the `cos`/`sin` values from the `cache`. If `positions` are given, it uses advanced indexing; otherwise, it takes the first `seq_len` entries from the cache.
150 | * Reshapes `x` and the looked-up cache values for broadcasting and rotation.
151 | * Performs the rotation using the complex number multiplication analogy: `x_rotated = x * cos +/- x_flipped * sin`.
152 | * Reshapes the result back to the original shape.
153 |
154 | ```mermaid
155 | sequenceDiagram
156 | participant Attn as Attention Module
157 | participant ApplyRoPE as apply_rotary_pos_emb
158 | participant Cache as RoPEPositionalEmbedding.cache
159 | participant QK as Query/Key Vectors
160 |
161 | Attn->>QK: Original Query/Key
162 | Attn->>+ApplyRoPE: apply_rotary_pos_emb(QK, cache, positions)
163 | ApplyRoPE->>Cache: Lookup cos/sin values using positions/seq_len
164 | Cache-->>ApplyRoPE: cos_sin_values
165 | ApplyRoPE->>ApplyRoPE: Reshape QK and cos_sin_values
166 | ApplyRoPE->>ApplyRoPE: Perform rotation (complex multiplication)
167 | ApplyRoPE-->>-Attn: Rotated Query/Key
168 | Attn->>Attn: Compute Attention Scores (using rotated Q/K)
169 | ```
170 |
171 | ```python
172 | # Simplified from modernbert.py create_sinusoidal_positions
173 | def create_sinusoidal_positions(max_length, dim, base=10000.0):
174 | positions = jnp.arange(max_length, dtype=jnp.float32) # [max_length]
175 | dim_indices = jnp.arange(0, dim, 2, dtype=jnp.float32) # [dim/2]
176 | theta = 1.0 / (base ** (dim_indices / dim)) # [dim/2]
177 | angles = jnp.einsum("i,j->ij", positions, theta) # [max_length, dim/2]
178 | # Stack cos and sin along a new last dimension
179 | return jnp.stack([jnp.cos(angles), jnp.sin(angles)], axis=-1) # [max_length, dim//2, 2]
180 |
181 | # Simplified from RoPEPositionalEmbedding.__init__
182 | class RoPEPositionalEmbedding(nnx.Module):
183 | def __init__(self, ..., dim, max_position_embeddings, base):
184 | # ... super().__init__() ...
185 | # Pre-compute and store the cache
186 | self.cache = create_sinusoidal_positions(
187 | max_position_embeddings, dim, base
188 | )
189 | # Cache shape: [max_position_embeddings, dim//2, 2]
190 |
191 | # Simplified from modernbert.py apply_rotary_pos_emb
192 | def apply_rotary_pos_emb(x, cache, positions=None):
193 | # x shape: [batch, seq, heads, dim]
194 | # cache shape: [max_len, dim//2, 2]
195 | seq_len = x.shape[1]
196 |
197 | # Lookup from cache based on positions or seq_len
198 | if positions is not None:
199 | rope_cache = cache[positions] # [batch, seq, dim//2, 2]
200 | else:
201 | rope_cache = cache[:seq_len] # [seq, dim//2, 2]
202 | rope_cache = jnp.expand_dims(rope_cache, 0) # [1, seq, dim//2, 2]
203 |
204 | # Reshape x for rotation: [batch, seq, heads, dim//2, 2]
205 | x_reshaped = x.reshape(*x.shape[:-1], -1, 2)
206 | # Expand cache dims for broadcasting: [batch/1, seq, 1, dim//2, 2]
207 | rope_cache = jnp.expand_dims(rope_cache, 2)
208 |
209 | # Perform rotation (like complex multiplication)
210 | cos_vals = rope_cache[..., 0]
211 | sin_vals = rope_cache[..., 1]
212 | x_real = x_reshaped[..., 0]
213 | x_imag = x_reshaped[..., 1]
214 | x_rotated_real = x_real * cos_vals - x_imag * sin_vals
215 | x_rotated_imag = x_imag * cos_vals + x_real * sin_vals
216 |
217 | x_out = jnp.stack([x_rotated_real, x_rotated_imag], axis=-1)
218 | # Reshape back to original: [batch, seq, heads, dim]
219 | return x_out.reshape(x.shape)
220 | ```
221 | *Note:* The `apply_rotary_pos_emb` method within `LlamaAttention` performs a similar rotation logic but uses the dynamically calculated `cos` and `sin` values instead of a pre-computed cache. The core rotation math (`x * cos +/- rotate_half(x) * sin`) is equivalent.
222 |
223 | ## Conclusion
224 |
225 | Rotary Position Embeddings (RoPE) provide an elegant and effective way to incorporate relative positional information into transformer attention mechanisms. By rotating query and key vectors based on their absolute positions using sinusoidal functions, RoPE makes the attention scores sensitive to relative distances without requiring additional learned parameters or modifying input embeddings directly. `jaxgarden` offers both on-the-fly (`LlamaRotaryEmbedding`) and pre-computed/cached (`RoPEPositionalEmbedding`) implementations, catering to different model architectures and efficiency considerations, as seen in [LlamaForCausalLM](llamaforcausallm.mdc) and [ModernBERTForMaskedLM](modernbertformaskedlm.mdc) respectively.
226 |
227 | This concludes the core chapters outlined for the `jaxgarden` tutorial structure. Understanding these components provides a solid foundation for using and extending the library.
228 |
229 |
230 | ---
231 |
232 | Generated by [Rules for AI](https://github.com/altaidevorg/rules-for-ai)
--------------------------------------------------------------------------------
/.cursor/rules/generationmixin.mdc:
--------------------------------------------------------------------------------
1 | ---
2 | description: jaxgarden tutorial on GenerationMixin, providing autoregressive text generation with sampling for JAX models.
3 | globs: jaxgarden/models/generation_utils.py
4 | alwaysApply: false
5 | ---
6 | # Chapter 5: GenerationMixin
7 |
8 | In the [previous chapter](llamaforcausallm.mdc), we explored `LlamaForCausalLM`, a causal language model built using `jaxgarden` components. We saw that it inherits text generation capabilities. This chapter focuses on the `GenerationMixin` class, the source of that functionality.
9 |
10 | **Motivation:** Implementing autoregressive text generation involves complex logic: managing the token-by-token loop efficiently, handling various sampling strategies (temperature, top-k, top-p, min-p), managing padding and end-of-sequence tokens, and dealing with JAX's PRNG key management and JIT compilation nuances. Encapsulating this logic in a reusable mixin avoids code duplication across different causal language models and provides a standardized generation interface. `GenerationMixin` solves this by providing a robust `generate` method that can be easily added to any compatible causal LM inheriting from [BaseModel](basemodel.mdc).
11 |
12 | **Central Use Case:** Adding the `generate` method to a custom causal language model (or using it via an existing model like `LlamaForCausalLM`). This allows the model instance to perform text generation based on a prompt, controlling the output's creativity and coherence through sampling parameters, and leveraging JAX optimizations like `lax.scan` and JIT compilation for performance on accelerators.
13 |
14 | ## Key Concepts
15 |
16 | 1. **Mixin Pattern:** `GenerationMixin` is not meant to be used standalone. It's designed to be inherited *alongside* a primary base class (like `BaseModel` or a specific model class like `LlamaForCausalLM`). It "mixes in" the `generate` method and its helpers into the inheriting class.
17 | 2. **Autoregressive Loop (`jax.lax.scan`):** Text generation is sequential. The model predicts the next token based on the previously generated tokens. `GenerationMixin` implements this loop using `jax.lax.scan`, which is highly efficient for iterative computations on JAX accelerators (GPU/TPU) as it unrolls the loop within the compiled computation graph.
18 | 3. **Sampling Strategies:** Controls how the next token is chosen from the model's output probability distribution (logits).
19 | * `temperature`: Scales logits before sampling. Higher values -> more randomness; lower values -> more determinism.
20 | * `top_k`: Restricts sampling to the `k` most likely tokens.
21 | * `top_p` (Nucleus Sampling): Restricts sampling to the smallest set of tokens whose cumulative probability exceeds `p`.
22 | * `min_p`: Restricts sampling to tokens with probability `p * max_probability` or higher.
23 | * `do_sample`: Boolean flag to enable/disable sampling (if `False`, uses greedy decoding - picks the most likely token).
24 | 4. **Helper Functions:** Sampling strategies are implemented via standalone helper functions (`temperature_scale`, `top_k_logits`, `top_p_logits`, `min_p_logits`, `sample_logits`) for clarity and testability.
25 | 5. **State Management:** The generation loop manages the sequence length, detects the End-of-Sequence (EOS) token to stop generation for specific sequences in a batch, handles padding (`pad_token_id`), and correctly splits and passes the JAX PRNG key (`rng`) at each step if sampling is enabled.
26 | 6. **JIT Compilation (`use_jit`):** The `generate` method offers a `use_jit` flag. If `True`, it calls a pre-compiled version of the core generation loop (`_generate_compiled`). This requires specifying `static_argnames` (like `max_length`, `temperature`, `top_k`, etc., and crucially `self`) to `jax.jit`, as these values influence the computation graph structure and cannot be dynamic JAX tracers during compilation.
27 |
28 | ## Using `GenerationMixin`
29 |
30 | You typically don't interact with `GenerationMixin` directly. Instead, you call the `generate` method on a model class that inherits from it, like `LlamaForCausalLM`.
31 |
32 | ```python
33 | import jax
34 | import jax.numpy as jnp
35 | from flax import nnx
36 | # Assume LlamaForCausalLM and Tokenizer are imported and initialized
37 | # from jaxgarden.models.llama import LlamaConfig, LlamaForCausalLM
38 | # from jaxgarden.tokenization import Tokenizer
39 |
40 | # --- Setup (Conceptual) ---
41 | # model_config = LlamaConfig(...) # Load appropriate config
42 | # tokenizer = Tokenizer.from_pretrained(...) # Load matching tokenizer
43 | # rngs = nnx.Rngs(0)
44 | # model = LlamaForCausalLM(model_config, rngs=rngs)
45 | # model.from_hf(...) # Optional: Load pretrained weights
46 |
47 | # --- Generation Call ---
48 | # prompt = "The definition of JAX is"
49 | # inputs = tokenizer.encode(prompt, return_tensors="jax")
50 | # input_ids = inputs['input_ids']
51 | # attention_mask = inputs['attention_mask'] # Important for initial prompt
52 |
53 | # Set generation parameters
54 | # max_new_tokens = 50
55 | # target_max_length = input_ids.shape[1] + max_new_tokens
56 | # generation_rng = jax.random.PRNGKey(42)
57 | # pad_token_id = tokenizer.pad_token_id
58 | # eos_token_id = tokenizer.eos_token_id
59 |
60 | # print(f"Generating text with max_length={target_max_length}, temperature=0.8, top_k=50")
61 |
62 | # --- The core call to the generate method ---
63 | # output_ids = model.generate(
64 | # input_ids=input_ids,
65 | # attention_mask=attention_mask, # Use mask from prompt encoding
66 | # max_length=target_max_length,
67 | # temperature=0.8,
68 | # top_k=50,
69 | # top_p=0.9,
70 | # min_p=None, # Example: min_p not used
71 | # do_sample=True, # Enable sampling
72 | # pad_token_id=pad_token_id,
73 | # eos_token_id=eos_token_id,
74 | # rng=generation_rng,
75 | # use_jit=True # Use the compiled version for speed
76 | # )
77 |
78 | # --- Decoding ---
79 | # generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
80 | # print(f"\nPrompt: {prompt}")
81 | # print(f"Generated: {generated_text}")
82 |
83 | # Example Output (Conceptual - depends heavily on model and sampling):
84 | # Generating text with max_length=58, temperature=0.8, top_k=50
85 | #
86 | # Prompt: The definition of JAX is
87 | # Generated: The definition of JAX is a high-performance numerical computation library
88 | # focused on machine learning research. It combines automatic differentiation
89 | # (autograd) and XLA compilation for speed on accelerators like GPUs and TPUs...
90 | print("Skipping actual generation execution.")
91 | ```
92 |
93 | **Explanation:**
94 | - `input_ids`: The starting token sequence (prompt) as a JAX array `[batch_size, seq_len]`.
95 | - `attention_mask`: A mask indicating which input tokens are real (1) vs padding (0), shape `[batch_size, seq_len]`. Important for the model to ignore padding in the prompt.
96 | - `max_length`: The *total* maximum length of the output sequence (prompt + generated tokens).
97 | - `temperature`, `top_k`, `top_p`, `min_p`: Control the sampling process. Setting `top_k=1` or `do_sample=False` results in greedy decoding.
98 | - `do_sample`: If `True`, performs sampling according to the other parameters; if `False`, performs greedy decoding.
99 | - `pad_token_id`: The ID used for padding shorter sequences in the batch *after* they finish generating (e.g., hit EOS). If not provided, it tries to infer from `model.config.pad_token_id` or defaults to 0.
100 | - `eos_token_id`: The ID that signals the end of a sequence. Generation stops for a sequence once this token is sampled. If not provided, generation continues until `max_length`.
101 | - `rng`: A `jax.random.PRNGKey` required if `do_sample=True`.
102 | - `use_jit`: If `True`, calls the JIT-compiled version of the generation loop. Highly recommended for performance.
103 | - **Output:** The method returns a JAX array `output_ids` of shape `[batch_size, max_length]` containing the prompt and the generated tokens, padded up to `max_length`.
104 |
105 | ## Internal Implementation
106 |
107 | The `generate` method orchestrates the process, while the core token-by-token logic resides in `_generate_scan_logic`, which is wrapped by `jax.lax.scan`.
108 |
109 | 1. **`generate` Method:**
110 | * Performs input validation (shapes, parameter ranges).
111 | * Handles default values for `pad_token_id`, `eos_token_id`, and `rng`.
112 | * Determines the initial sequence length (`initial_seq_len`) and batch size.
113 | * Initializes the `finished_sequences` boolean array based on whether the *last* token of the input is already EOS.
114 | * **Conditional Dispatch:** Based on the `use_jit` flag, it calls either:
115 | * `self._generate_compiled(...)`: A pre-compiled version of `_generate_scan_logic` created using `partial(jax.jit, static_argnames=...)`.
116 | * `self._generate_scan_logic(...)`: The raw Python function (slower, useful for debugging).
117 |
118 | 2. **`_generate_scan_logic` Method (Core Loop Logic):**
119 | * Initializes the output tensor `output_ids` of shape `[batch_size, max_length]` filled with `pad_token_id`, and copies the `initial_input_ids` into the beginning.
120 | * Sets up the initial `carry` dictionary for `lax.scan`, containing `output_ids`, `current_length`, `rng`, and `finished` status.
121 | * Defines the `scan_step` function, which performs one step of generation:
122 | * Takes the current `carry` state and a loop counter (`_`) as input.
123 | * Splits the PRNG key (`rng`) if `do_sample` is true.
124 | * Creates the `attention_mask` for the *current* sequence length within the `max_length` buffer. Note: Causal masking is typically handled *inside* the model's attention mechanism, but this mask handles padding.
125 | * **Model Call:** Calls `self(input_ids=current_output_ids, attention_mask=attention_mask, deterministic=True)`. This executes the forward pass of the model (e.g., `LlamaForCausalLM.__call__`). `deterministic=True` ensures dropout etc. are disabled.
126 | * Extracts the logits for the *next* token prediction (logits at the `current_length - 1` position).
127 | * **Sampling:** Calls `sample_logits` with the extracted logits and sampling parameters (`temperature`, `top_k`, etc.) to get the `next_token`.
128 | * **EOS/Padding:** Checks if the `next_token` is the `eos_token_id`. Updates the `finished` status for the sequence. Determines the token to actually write: `pad_token_id` if already finished, otherwise the sampled `next_token`.
129 | * Updates the `output_ids` tensor at the `current_length` position with the chosen token.
130 | * Updates the `carry` dictionary for the next step (`current_length` incremented, `rng` updated, `finished` status updated).
131 | * Returns the updated `carry` and `None` (as `lax.scan` expects `(carry, scan_output)`).
132 | * Calls `jax.lax.scan(scan_step, initial_carry, None, length=num_steps_to_generate)`.
133 | * Returns the final `output_ids` from the resulting `carry`.
134 |
135 | 3. **`sample_logits` Function:**
136 | * Takes raw logits, RNG key, and all sampling parameters.
137 | * If `do_sample=False`, returns `jnp.argmax(logits)`.
138 | * Applies `temperature_scale`.
139 | * Sequentially applies filtering functions (`min_p_logits`, `top_k_logits`, `top_p_logits`) if the corresponding parameters are set. These functions mask invalid logits by setting them to `-jnp.inf`.
140 | * **Edge Case Handling:** If *all* logits become `-jnp.inf` after filtering (can happen with very restrictive filtering), it falls back to sampling from the *pre-filtered* (but temperature-scaled) logits to prevent errors.
141 | * Uses `jax.random.categorical` to sample from the final filtered (or fallback) logit distribution.
142 |
143 | ```mermaid
144 | sequenceDiagram
145 | participant User
146 | participant Model as Model (e.g., LlamaForCausalLM)
147 | participant GenMixin as GenerationMixin Logic
148 | participant JitCompiled as _generate_compiled (JIT)
149 | participant ScanLogic as _generate_scan_logic
150 | participant LaxScan as jax.lax.scan
151 | participant ScanStep as scan_step (Inner Function)
152 | participant Sampler as sample_logits
153 |
154 | User->>+Model: generate(input_ids, ..., use_jit=True, rng=key)
155 | Model->>+GenMixin: generate(...)
156 | GenMixin->>+JitCompiled: _generate_compiled(input_ids, finished, rng, ..., static_args...)
157 | Note over JitCompiled: Pre-compiled version of ScanLogic
158 | JitCompiled->>+ScanLogic: _generate_scan_logic(...) # Compiled call
159 | ScanLogic->>ScanLogic: Initialize output_ids, initial_carry
160 | ScanLogic->>+LaxScan: scan(scan_step, initial_carry, length=N)
161 | loop N times (Number of tokens to generate)
162 | LaxScan->>+ScanStep: scan_step(current_carry, _)
163 | ScanStep->>Model: self(current_output_ids, mask, deterministic=True)
164 | Model-->>ScanStep: next_token_logits
165 | ScanStep->>+Sampler: sample_logits(logits, rng_step, temp, k, p, ...)
166 | Sampler-->>-ScanStep: next_token
167 | ScanStep->>ScanStep: Update finished status (EOS check)
168 | ScanStep->>ScanStep: Determine output_token (handle padding)
169 | ScanStep->>ScanStep: Update output_ids tensor
170 | ScanStep->>ScanStep: Prepare next_carry (update length, rng, finished)
171 | ScanStep-->>-LaxScan: next_carry, None
172 | end
173 | LaxScan-->>-ScanLogic: final_carry, _
174 | ScanLogic->>ScanLogic: Extract final_output_ids from final_carry
175 | ScanLogic-->>-JitCompiled: final_output_ids
176 | JitCompiled-->>-GenMixin: final_output_ids
177 | GenMixin-->>-User: final_output_ids (Generated Sequence)
178 | ```
179 |
180 | 4. **JIT Compilation (`_generate_compiled`)**:
181 | * Defined using `partial(jax.jit, static_argnames=...)`.
182 | * `static_argnames` must include parameters that affect the *structure* of the computation or are needed as Python values inside the loop logic (not JAX tracers). This includes:
183 | * `self`: Needed to call the model's forward pass (`self.__call__`) inside `scan_step`. `self` contains the model's graph definition.
184 | * `max_length`: Determines the size of the output tensor and loop length.
185 | * `temperature`, `top_k`, `top_p`, `min_p`, `do_sample`: Control conditional logic within `sample_logits` and `scan_step`.
186 | * `pad_token_id`, `eos_token_id`: Used in conditional logic (`jnp.where`).
187 | * `initial_seq_len`: Affects the number of scan steps.
188 | * Arguments *not* listed as static (`initial_input_ids`, `initial_finished_sequences`, `initial_rng`) are treated as dynamic JAX arrays (tracers) during compilation.
189 |
190 | * **Code Location:** `jaxgarden/models/generation_utils.py`
191 |
192 | ## Conclusion
193 |
194 | `GenerationMixin` provides a powerful and reusable mechanism for adding sophisticated autoregressive text generation capabilities to causal language models in `jaxgarden`. By leveraging `jax.lax.scan` for efficiency, offering flexible sampling strategies, and integrating seamlessly with JAX's JIT compilation, it allows developers to easily enable text generation in their models while benefiting from high performance on accelerators. Understanding its parameters and internal workings is key to effectively controlling the generation process.
195 |
196 | **Next:** [ModernBERTForMaskedLM](modernbertformaskedlm.mdc)
197 |
198 |
199 | ---
200 |
201 | Generated by [Rules for AI](https://github.com/altaidevorg/rules-for-ai)
--------------------------------------------------------------------------------
/tests/models/test_generation_utils.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import pytest
4 | from jax import random
5 |
6 | from jaxgarden.models.generation_utils import GenerationMixin
7 |
8 | # --- Mock Model for GenerationMixin Tests ---
9 |
10 |
11 | class MockModel(GenerationMixin):
12 | """A simple mock model inheriting GenerationMixin for testing."""
13 |
14 | def __init__(self, vocab_size, eos_token_id=None, pad_token_id=0):
15 | # Mock config
16 | self.config = {
17 | "vocab_size": vocab_size,
18 | "eos_token_id": eos_token_id,
19 | "pad_token_id": pad_token_id,
20 | }
21 | self.vocab_size = vocab_size
22 | self._eos_token_id = eos_token_id # Store separately if needed
23 |
24 | def __call__(self, input_ids, attention_mask=None, **kwargs):
25 | """Mock call that returns deterministic logits."""
26 | batch_size, seq_len = input_ids.shape
27 | if attention_mask is None:
28 | # Fallback: If no mask, assume the sequence length is the full length
29 | current_length = jnp.array(seq_len)
30 | else:
31 | # Determine current length from attention mask
32 | # Handle 1D [max_length] or 2D [batch, max_length] masks
33 | if attention_mask.ndim == 1:
34 | current_length = jnp.sum(attention_mask)
35 | elif attention_mask.ndim == 2:
36 | # Ensure current_length is scalar if attention_mask is [1, seq_len]
37 | if attention_mask.shape[0] == 1:
38 | current_length = jnp.sum(attention_mask[0])
39 | else: # Handle [batch, seq_len]
40 | current_length = jnp.sum(attention_mask, axis=-1)
41 | else:
42 | raise ValueError(f"Unexpected attention_mask ndim: {attention_mask.ndim}")
43 |
44 | # Ensure length is at least 1 to avoid index -1
45 | valid_length = jnp.maximum(1, current_length)
46 | last_token_index = valid_length - 1 # Index of the last valid token
47 |
48 | # Gather the last valid token for each item in the batch
49 | # Needs to handle scalar or per-batch indices
50 | if isinstance(last_token_index, int) or (
51 | hasattr(last_token_index, "shape") and last_token_index.shape == ()
52 | ):
53 | # If current_length (and thus index) is scalar across batch
54 | last_token = input_ids[:, last_token_index] # Shape [batch_size]
55 | elif hasattr(last_token_index, "shape") and last_token_index.ndim == 1:
56 | # If current_length varies per batch item (shape [batch_size])
57 | # Use gather to select the token at last_token_index for each batch item
58 | last_token_index_expanded = last_token_index[:, None] # Shape [batch_size, 1]
59 | last_token = jnp.take_along_axis(input_ids, last_token_index_expanded, axis=1)[:, 0]
60 | else:
61 | raise TypeError(f"""Unexpected type or shape for last_token_index:
62 | {type(last_token_index)}""")
63 |
64 | # Create deterministic logits: next token is always (last_token + 1) % vocab_size
65 | next_token_logits = (
66 | jax.nn.one_hot(
67 | (last_token + 1) % self.vocab_size, num_classes=self.vocab_size, dtype=jnp.float32
68 | )
69 | * 10.0
70 | ) # Multiply to make it peaky
71 | # Return shape [batch_size, seq_len, vocab_size]
72 | # We only care about the last token logits for generation,
73 | # so we just broadcast it for simplicity in this mock.
74 | # A real model would compute logits based on the full sequence.
75 | return jnp.repeat(next_token_logits[:, None, :], seq_len, axis=1)
76 |
77 |
78 | # --- Tests for GenerationMixin ---
79 |
80 |
81 | @pytest.fixture
82 | def generation_setup():
83 | key = random.PRNGKey(42)
84 | vocab_size = 10
85 | eos_token_id = 9
86 | pad_token_id = 0
87 | model = MockModel(vocab_size, eos_token_id, pad_token_id)
88 | input_ids = jnp.array([[1, 2], [5, 6]], dtype=jnp.int32)
89 | return key, model, input_ids, vocab_size, eos_token_id, pad_token_id
90 |
91 |
92 | def test_generate_greedy(generation_setup):
93 | """Tests greedy generation (do_sample=False)."""
94 | key, model, input_ids, _, eos_token_id, pad_token_id = generation_setup
95 | max_length = 7
96 |
97 | output_ids = model.generate(
98 | input_ids,
99 | max_length=max_length,
100 | do_sample=False,
101 | eos_token_id=eos_token_id,
102 | pad_token_id=pad_token_id,
103 | )
104 |
105 | assert output_ids.shape == (input_ids.shape[0], max_length)
106 |
107 | # Expected sequence based on mock model: next = (last + 1) % vocab_size
108 | # Batch 1: [1, 2] -> 3 -> 4 -> 5 -> 6 -> 7
109 | # Batch 2: [5, 6] -> 7 -> 8 -> 9 (EOS) -> 0 (PAD) -> 0 (PAD)
110 | expected_output = jnp.array([[1, 2, 3, 4, 5, 6, 7], [5, 6, 7, 8, 9, 0, 0]], dtype=jnp.int32)
111 | assert jnp.array_equal(output_ids, expected_output)
112 |
113 |
114 | def test_generate_sampling(generation_setup):
115 | """Tests sampling generation (do_sample=True)."""
116 | key, model, input_ids, vocab_size, eos_token_id, pad_token_id = generation_setup
117 | max_length = 6
118 |
119 | # Use a very low temperature to make it nearly deterministic for testing
120 | key, subkey = random.split(key)
121 | output_ids_low_temp = model.generate(
122 | input_ids,
123 | max_length=max_length,
124 | do_sample=True,
125 | temperature=0.01, # Very low temp
126 | eos_token_id=eos_token_id,
127 | pad_token_id=pad_token_id,
128 | rng=subkey,
129 | )
130 |
131 | # Expected sequence (likely, due to low temp):
132 | # Batch 1: [1, 2] -> 3 -> 4 -> 5 -> 6
133 | # Batch 2: [5, 6] -> 7 -> 8 -> 9 (EOS) -> 0 (PAD)
134 | expected_output_likely = jnp.array([[1, 2, 3, 4, 5, 6], [5, 6, 7, 8, 9, 0]], dtype=jnp.int32)
135 | assert output_ids_low_temp.shape == (input_ids.shape[0], max_length)
136 | # Check it's highly likely the same as greedy due to low temp
137 | assert jnp.array_equal(output_ids_low_temp, expected_output_likely)
138 |
139 | # Test reproducibility with the same key
140 | key, subkey = random.split(key)
141 | output_ids1 = model.generate(
142 | input_ids,
143 | max_length=max_length,
144 | do_sample=True,
145 | temperature=1.0,
146 | rng=subkey,
147 | eos_token_id=eos_token_id,
148 | pad_token_id=pad_token_id,
149 | )
150 | output_ids2 = model.generate(
151 | input_ids,
152 | max_length=max_length,
153 | do_sample=True,
154 | temperature=1.0,
155 | rng=subkey,
156 | eos_token_id=eos_token_id,
157 | pad_token_id=pad_token_id,
158 | )
159 | assert jnp.array_equal(output_ids1, output_ids2)
160 |
161 |
162 | def test_generate_max_length(generation_setup):
163 | """Tests that generation stops at max_length."""
164 | key, model, input_ids, _, eos_token_id, pad_token_id = generation_setup
165 | max_length = 4 # Shorter than needed to reach EOS for batch 1
166 |
167 | output_ids = model.generate(
168 | input_ids,
169 | max_length=max_length,
170 | do_sample=False,
171 | eos_token_id=eos_token_id,
172 | pad_token_id=pad_token_id,
173 | )
174 |
175 | assert output_ids.shape == (input_ids.shape[0], max_length)
176 | # Expected sequence (truncated):
177 | # Batch 1: [1, 2] -> 3 -> 4
178 | # Batch 2: [5, 6] -> 7 -> 8
179 | expected_output = jnp.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=jnp.int32)
180 | assert jnp.array_equal(output_ids, expected_output)
181 |
182 |
183 | def test_generate_eos_handling(generation_setup):
184 | """Tests EOS token handling and subsequent padding."""
185 | key, model, input_ids, _, eos_token_id, pad_token_id = generation_setup
186 | max_length = 8 # Long enough for EOS
187 |
188 | output_ids = model.generate(
189 | input_ids,
190 | max_length=max_length,
191 | do_sample=False,
192 | eos_token_id=eos_token_id,
193 | pad_token_id=pad_token_id,
194 | )
195 |
196 | assert output_ids.shape == (input_ids.shape[0], max_length)
197 | # Expected sequence:
198 | # Batch 1: [1, 2] -> 3 -> 4 -> 5 -> 6 -> 7 -> 8
199 | # Batch 2: [5, 6] -> 7 -> 8 -> 9 (EOS) -> 0 (PAD) -> 0 (PAD) -> 0 (PAD)
200 | expected_output = jnp.array(
201 | [[1, 2, 3, 4, 5, 6, 7, 8], [5, 6, 7, 8, 9, 0, 0, 0]], dtype=jnp.int32
202 | )
203 | assert jnp.array_equal(output_ids, expected_output)
204 |
205 | # Test case without EOS token (should just fill to max_length)
206 | model_no_eos = MockModel(model.vocab_size, eos_token_id=None, pad_token_id=pad_token_id)
207 | output_ids_no_eos = model_no_eos.generate(
208 | input_ids,
209 | max_length=max_length,
210 | do_sample=False,
211 | eos_token_id=None,
212 | pad_token_id=pad_token_id,
213 | )
214 | expected_output_no_eos = jnp.array(
215 | [[1, 2, 3, 4, 5, 6, 7, 8], [5, 6, 7, 8, 9, 0, 1, 2]], # Continues sequence cyclically
216 | dtype=jnp.int32,
217 | )
218 | assert jnp.array_equal(output_ids_no_eos, expected_output_no_eos)
219 |
220 |
221 | def test_generate_padding_default(generation_setup):
222 | """Tests generation using default pad_token_id from config."""
223 | key, model, input_ids, _, eos_token_id, pad_token_id = generation_setup
224 | max_length = 7
225 |
226 | # Don't pass pad_token_id, should use model.config['pad_token_id'] (which is 0)
227 | output_ids = model.generate(
228 | input_ids,
229 | max_length=max_length,
230 | do_sample=False,
231 | eos_token_id=eos_token_id,
232 | # pad_token_id not provided
233 | )
234 | # Expected is same as test_generate_greedy
235 | expected_output = jnp.array([[1, 2, 3, 4, 5, 6, 7], [5, 6, 7, 8, 9, 0, 0]], dtype=jnp.int32)
236 | assert jnp.array_equal(output_ids, expected_output)
237 |
238 |
239 | def test_generate_input_length_equals_max_length(generation_setup):
240 | """Tests generation when input_ids length equals max_length."""
241 | key, model, input_ids, _, eos_token_id, pad_token_id = generation_setup
242 | max_length = input_ids.shape[1] # 2
243 |
244 | output_ids = model.generate(
245 | input_ids,
246 | max_length=max_length,
247 | do_sample=False,
248 | eos_token_id=eos_token_id,
249 | pad_token_id=pad_token_id,
250 | )
251 | # Should return the input_ids unchanged
252 | assert output_ids.shape == (input_ids.shape[0], max_length)
253 | assert jnp.array_equal(output_ids, input_ids)
254 |
255 |
256 | def test_generate_input_length_greater_than_max_length(generation_setup):
257 | """Tests generation when input_ids length exceeds max_length."""
258 | key, model, input_ids, _, eos_token_id, pad_token_id = generation_setup
259 | max_length = 1
260 |
261 | output_ids = model.generate(
262 | input_ids,
263 | max_length=max_length,
264 | do_sample=False,
265 | eos_token_id=eos_token_id,
266 | pad_token_id=pad_token_id,
267 | )
268 | # Should return input_ids truncated to max_length
269 | assert output_ids.shape == (input_ids.shape[0], max_length)
270 | assert jnp.array_equal(output_ids, input_ids[:, :max_length])
271 |
272 |
273 | def test_generate_sampling_with_filtering(generation_setup):
274 | """Tests sampling with filtering parameters (top_k, top_p, min_p)."""
275 | key, model, input_ids, vocab_size, eos_token_id, pad_token_id = generation_setup
276 | max_length = 6
277 |
278 | # Test with top_k
279 | key, subkey = random.split(key)
280 | output_ids_top_k = model.generate(
281 | input_ids,
282 | max_length=max_length,
283 | do_sample=True,
284 | top_k=3,
285 | rng=subkey,
286 | eos_token_id=eos_token_id,
287 | pad_token_id=pad_token_id,
288 | )
289 | assert output_ids_top_k.shape == (input_ids.shape[0], max_length)
290 |
291 | # Test with top_p
292 | key, subkey = random.split(key)
293 | output_ids_top_p = model.generate(
294 | input_ids,
295 | max_length=max_length,
296 | do_sample=True,
297 | top_p=0.8,
298 | rng=subkey,
299 | eos_token_id=eos_token_id,
300 | pad_token_id=pad_token_id,
301 | )
302 | assert output_ids_top_p.shape == (input_ids.shape[0], max_length)
303 |
304 | # Test with min_p
305 | key, subkey = random.split(key)
306 | output_ids_min_p = model.generate(
307 | input_ids,
308 | max_length=max_length,
309 | do_sample=True,
310 | min_p=0.1,
311 | rng=subkey,
312 | eos_token_id=eos_token_id,
313 | pad_token_id=pad_token_id,
314 | )
315 | assert output_ids_min_p.shape == (input_ids.shape[0], max_length)
316 |
317 | # Test with multiple filtering methods
318 | key, subkey = random.split(key)
319 | output_ids_combo = model.generate(
320 | input_ids,
321 | max_length=max_length,
322 | do_sample=True,
323 | top_k=3,
324 | top_p=0.8,
325 | min_p=0.1,
326 | rng=subkey,
327 | eos_token_id=eos_token_id,
328 | pad_token_id=pad_token_id,
329 | )
330 | assert output_ids_combo.shape == (input_ids.shape[0], max_length)
331 |
332 |
333 | def test_generate_rng_none_warning(generation_setup, capsys):
334 | """Tests that a warning is printed when do_sample is True but rng is None."""
335 | key, model, input_ids, _, eos_token_id, pad_token_id = generation_setup
336 | max_length = 5
337 |
338 | model.generate(
339 | input_ids,
340 | max_length=max_length,
341 | do_sample=True,
342 | rng=None, # No RNG key provided
343 | eos_token_id=eos_token_id,
344 | pad_token_id=pad_token_id,
345 | )
346 |
347 | captured = capsys.readouterr()
348 | assert "Warning: No RNG key provided for sampling, using default key 0." in captured.out
349 |
350 |
351 | def test_generate_no_sample_rng_none(generation_setup):
352 | """Tests that do_sample=False (greedy) works without rng."""
353 | key, model, input_ids, _, eos_token_id, pad_token_id = generation_setup
354 | max_length = 5
355 | output_ids = model.generate(
356 | input_ids,
357 | max_length=max_length,
358 | do_sample=False,
359 | rng=None, # No RNG key provided
360 | eos_token_id=eos_token_id,
361 | pad_token_id=pad_token_id,
362 | )
363 |
364 | assert output_ids.shape == (input_ids.shape[0], max_length)
365 |
366 |
367 | def test_generate_jax_compile():
368 | """Tests that jax.jit can be used with generate without error."""
369 | vocab_size = 10
370 | eos_token_id = 9
371 | pad_token_id = 0
372 | model = MockModel(vocab_size, eos_token_id, pad_token_id)
373 | input_ids = jnp.array([[1, 2], [5, 6]], dtype=jnp.int32)
374 | max_length = 7
375 |
376 | output_ids = model.generate(
377 | input_ids,
378 | max_length=max_length,
379 | do_sample=False,
380 | eos_token_id=eos_token_id,
381 | pad_token_id=pad_token_id,
382 | use_jit=True,
383 | )
384 |
385 | assert output_ids.shape == (input_ids.shape[0], max_length)
386 |
387 | # Expected sequence based on mock model: next = (last + 1) % vocab_size
388 | # Batch 1: [1, 2] -> 3 -> 4 -> 5 -> 6 -> 7
389 | # Batch 2: [5, 6] -> 7 -> 8 -> 9 (EOS) -> 0 (PAD) -> 0 (PAD)
390 | expected_output = jnp.array([[1, 2, 3, 4, 5, 6, 7], [5, 6, 7, 8, 9, 0, 0]], dtype=jnp.int32)
391 | assert jnp.array_equal(output_ids, expected_output) # Should match greedy output
392 |
393 | # Test that jax.jit does not affect the output
394 | output_ids2 = model.generate(
395 | input_ids,
396 | max_length=max_length,
397 | do_sample=False,
398 | eos_token_id=eos_token_id,
399 | pad_token_id=pad_token_id,
400 | use_jit=True,
401 | )
402 | assert jnp.array_equal(output_ids, output_ids2)
403 |
--------------------------------------------------------------------------------
/.cursor/rules/tokenizer.mdc:
--------------------------------------------------------------------------------
1 | ---
2 | description: Tutorial chapter for the jaxgarden Tokenizer, detailing text encoding/decoding and chat templating for JAX.
3 | globs:
4 | alwaysApply: false
5 | ---
6 | # Chapter 1: Tokenizer
7 |
8 | Welcome to the `jaxgarden` library tutorial! This first chapter introduces the `Tokenizer` class, a fundamental component for processing text data within JAX-based Natural Language Processing (NLP) models.
9 |
10 | **Motivation:** Deep learning models, especially those built with JAX, operate on numerical tensors. Raw text needs to be converted into a numerical format (token IDs) that models can understand, and conversely, model outputs (token IDs) need to be converted back into human-readable text. Furthermore, different models, particularly instruction-tuned ones, expect conversational inputs to be formatted in specific ways (chat templates). The Hugging Face `tokenizers` library is excellent for this, but its outputs are standard Python lists. `jaxgarden.Tokenizer` wraps this library to provide a seamless experience for JAX users, returning `jax.numpy.ndarray` (jnp arrays) directly and integrating features like chat templating.
11 |
12 | **Central Use Case:** Preparing text input for a JAX-based language model like [LlamaForCausalLM](llamaforcausallm.mdc) and decoding its generated token IDs. For conversational models, formatting user prompts and conversation history according to the model's specific chat template is crucial.
13 |
14 | ## Key Concepts
15 |
16 | The `jaxgarden.Tokenizer` provides several core functionalities:
17 |
18 | 1. **Loading:** Instantiating a tokenizer from pre-trained configurations stored on the Hugging Face Hub or locally.
19 | 2. **Encoding:** Converting text strings into sequences of token IDs, handling padding and truncation, and returning JAX arrays.
20 | 3. **Decoding:** Converting sequences of token IDs back into text strings.
21 | 4. **Special Token Management:** Automatically identifying or allowing specification of crucial tokens like Beginning-of-Sequence (BOS), End-of-Sequence (EOS), and Padding (PAD).
22 | 5. **Chat Templating:** Applying Jinja-based templates to format conversational data for instruction-tuned models.
23 |
24 | ## Using the Tokenizer
25 |
26 | Let's explore how to use the `Tokenizer`.
27 |
28 | ### Loading a Tokenizer
29 |
30 | The primary way to get a `Tokenizer` instance is using the `from_pretrained` class method. You provide a model identifier from the Hugging Face Hub (e.g., `"gpt2"`, `"meta-llama/Llama-2-7b-chat-hf"`) or a path to a local directory containing `tokenizer.json` and optionally `tokenizer_config.json`.
31 |
32 | ```python
33 | # Assuming jaxgarden is installed
34 | from jaxgarden.tokenization import Tokenizer
35 |
36 | # Load from Hugging Face Hub
37 | tokenizer = Tokenizer.from_pretrained("gpt2")
38 |
39 | # Example: Load from a local directory (if you have one)
40 | # tokenizer_local = Tokenizer.from_pretrained("./path/to/local_tokenizer_files")
41 |
42 | print(f"Loaded tokenizer for 'gpt2' with vocab size: {tokenizer.vocab_size}")
43 | print(f"Pad token: {tokenizer.pad_token}, ID: {tokenizer.pad_token_id}")
44 | print(f"BOS token: {tokenizer.bos_token}, ID: {tokenizer.bos_token_id}")
45 | print(f"EOS token: {tokenizer.eos_token}, ID: {tokenizer.eos_token_id}")
46 | ```
47 |
48 | **Explanation:** `from_pretrained` downloads necessary files (`tokenizer.json`, `tokenizer_config.json`) from the Hub or reads them locally. It then instantiates the underlying Hugging Face `tokenizers.Tokenizer` and extracts configuration like special tokens and chat templates (if available in `tokenizer_config.json`). The `jaxgarden.Tokenizer` wrapper uses this information to set its own attributes like `pad_token_id`, `bos_token_id`, etc.
49 |
50 | ### Encoding Text
51 |
52 | The `encode` method converts text into token IDs. It offers options for handling batches, padding, and truncation, returning JAX arrays by default.
53 |
54 | ```python
55 | import jax.numpy as jnp
56 |
57 | text = "Hello, world!"
58 | batch_text = ["First sequence.", "This is a second sequence."]
59 |
60 | # Basic encoding
61 | encoded_single = tokenizer.encode(text)
62 | print("Encoded Single:", encoded_single)
63 | # Output: Encoded Single: {'input_ids': DeviceArray([[50256, 15496, 11, 1917, 25, 50256]], dtype=int32),
64 | # 'attention_mask': DeviceArray([[1, 1, 1, 1, 1, 1]], dtype=int32)}
65 |
66 | # Encoding a batch with padding to the longest sequence
67 | encoded_batch = tokenizer.encode(batch_text, padding=True, add_special_tokens=False)
68 | print("Encoded Batch (padded):", encoded_batch)
69 | # Output: Encoded Batch (padded): {
70 | # 'input_ids': DeviceArray([[ 8285, 16337, 13, 50256, 50256, 50256],
71 | # [ 1212, 318, 257, 1144, 16337, 13]], dtype=int32),
72 | # 'attention_mask': DeviceArray([[1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1]], dtype=int32) }
73 |
74 |
75 | # Encoding with truncation and padding to a max length
76 | encoded_truncated = tokenizer.encode(
77 | batch_text, padding="max_length", truncation=True, max_length=5, add_special_tokens=False
78 | )
79 | print("Encoded Batch (truncated/padded):", encoded_truncated)
80 | # Output: Encoded Batch (truncated/padded): {
81 | # 'input_ids': DeviceArray([[ 8285, 16337, 13, 50256, 50256],
82 | # [ 1212, 318, 257, 1144, 16337]], dtype=int32),
83 | # 'attention_mask': DeviceArray([[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]], dtype=int32)}
84 |
85 | ```
86 |
87 | **Explanation:**
88 | - `text`: The input string or list of strings.
89 | - `add_special_tokens`: Controls whether BOS/EOS tokens are added (based on tokenizer config). Default is `True`.
90 | - `padding`: Can be `False` (no padding), `True` or `'longest'` (pad to the longest sequence in the batch), or `'max_length'` (pad to `max_length`). Requires `max_length` if set to `'max_length'`.
91 | - `truncation`: Boolean. If `True`, truncates sequences longer than `max_length`. Requires `max_length`.
92 | - `max_length`: Integer specifying the target length for padding or truncation.
93 | - `return_tensors`: Set to `"jax"` to get `jnp.ndarray`. Set to `None` to get lists of integers.
94 | The output is a dictionary containing `input_ids` and `attention_mask` as JAX arrays. The attention mask indicates which tokens are real (1) and which are padding (0).
95 |
96 | ### Decoding Token IDs
97 |
98 | The `decode` method converts token IDs back into human-readable text.
99 |
100 | ```python
101 | # Use the 'input_ids' from the previous encoding example
102 | ids_to_decode = encoded_batch['input_ids'] # Example: DeviceArray([[ 8285, 16337, 13, 50256, 50256, 50256], ...])
103 |
104 | # Decode the batch
105 | decoded_text = tokenizer.decode(ids_to_decode, skip_special_tokens=True)
106 | print("Decoded Text:", decoded_text)
107 | # Output: Decoded Text: ['First sequence.', 'This is a second sequence.']
108 |
109 | # Decode a single sequence (e.g., the first one from the batch)
110 | single_decoded = tokenizer.decode(ids_to_decode[0], skip_special_tokens=True)
111 | print("Single Decoded:", single_decoded)
112 | # Output: Single Decoded: First sequence.
113 | ```
114 |
115 | **Explanation:**
116 | - `token_ids`: A list of integers, a list of lists of integers, or a JAX array.
117 | - `skip_special_tokens`: If `True` (default), removes special tokens like BOS, EOS, PAD from the output string(s).
118 |
119 | ### Applying Chat Templates
120 |
121 | For models fine-tuned on conversational data, inputs must be formatted correctly. The `apply_chat_template` method uses a Jinja template (either defined in the tokenizer's config or provided explicitly) to structure conversations.
122 |
123 | ```python
124 | # Load a tokenizer known to have a chat template
125 | # Note: You might need to log in to Hugging Face CLI: `huggingface-cli login`
126 | try:
127 | chat_tokenizer = Tokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") # Example
128 | except Exception as e:
129 | print(f"Skipping chat template example: Could not load Llama-2 tokenizer. Error: {e}")
130 | chat_tokenizer = None # Set to None to skip the rest of the block
131 |
132 | if chat_tokenizer and chat_tokenizer.chat_template:
133 | conversation = [
134 | {"role": "user", "content": "Hello, how are you?"},
135 | {"role": "assistant", "content": "I'm doing well, thank you!"},
136 | {"role": "user", "content": "What is JAX?"}
137 | ]
138 |
139 | # Apply the template to format the conversation
140 | formatted_prompt = chat_tokenizer.apply_chat_template(
141 | conversation,
142 | add_generation_prompt=True # Adds the prompt for the assistant's turn
143 | )
144 | print("\nFormatted Chat Prompt:\n", formatted_prompt)
145 |
146 | # Example of Llama-2 Chat format (output structure may vary slightly):
147 | # Formatted Chat Prompt:
148 | # [INST] Hello, how are you? [/INST] I'm doing well, thank you! [INST] What is JAX? [/INST]
149 | else:
150 | print("\nSkipping chat template example: No chat_tokenizer or template found.")
151 |
152 | ```
153 |
154 | **Explanation:**
155 | - `conversation`: A list of dictionaries, each with a `role` (e.g., 'user', 'assistant', 'system') and `content` (the message text).
156 | - `chat_template`: Optional override for the tokenizer's default template.
157 | - `add_generation_prompt`: A common pattern in chat templates is to append tokens indicating it's the assistant's turn to speak. Setting this to `True` often achieves this, though the exact implementation depends on the template itself.
158 | - `**kwargs`: Allows passing additional variables to the Jinja template.
159 | The method returns the formatted string, ready to be encoded and fed into the model.
160 |
161 | ## Internal Implementation
162 |
163 | Understanding the internals helps in debugging and extending functionality.
164 |
165 | 1. **Initialization (`__init__`)**:
166 | - Stores the Hugging Face `HfTokenizer` instance.
167 | - Tries to infer `bos_token`, `eos_token`, and `pad_token` from the `HfTokenizer`'s known special tokens or common defaults (e.g., ``, ``, `[PAD]`). It logs warnings if defaults or fallbacks are used.
168 | - If `pad_token` isn't explicitly provided or found, it defaults to `eos_token` (with a warning) or raises an error if no suitable token can be determined.
169 | - It fetches the corresponding IDs (`bos_token_id`, etc.) using `hf_tokenizer.token_to_id`.
170 | - Critically, it ensures the `HfTokenizer` has padding configured using the determined `pad_token_id` and `pad_token`, either by enabling it or verifying/correcting existing padding settings.
171 |
172 | ```python
173 | # Simplified __init__ logic for special tokens
174 | def __init__(self, hf_tokenizer, chat_template=None, bos_token=None, /*...*/ pad_token=None):
175 | self.hf_tokenizer = hf_tokenizer
176 | # ... (chat_template assignment) ...
177 |
178 | # Infer BOS/EOS (simplified example)
179 | self.bos_token = bos_token or self._infer_token(hf_tokenizer, ["[BOS]", ""])
180 | self.eos_token = eos_token or self._infer_token(hf_tokenizer, ["[EOS]", ""])
181 |
182 | # Infer/Validate PAD token
183 | if pad_token:
184 | self.pad_token = pad_token
185 | elif hf_tokenizer.padding and hf_tokenizer.padding.get("pad_token"):
186 | self.pad_token = hf_tokenizer.padding["pad_token"]
187 | # ... (fallback logic using EOS or searching for [PAD]/) ...
188 | else:
189 | raise ValueError("Cannot determine padding token.")
190 |
191 | # Get IDs
192 | self.bos_token_id = hf_tokenizer.token_to_id(self.bos_token) # if self.bos_token else None
193 | self.eos_token_id = hf_tokenizer.token_to_id(self.eos_token) # if self.eos_token else None
194 | self.pad_token_id = hf_tokenizer.token_to_id(self.pad_token)
195 | # ... (error if pad_token_id is None) ...
196 |
197 | # Configure HF tokenizer padding
198 | self.hf_tokenizer.enable_padding(pad_id=self.pad_token_id, pad_token=self.pad_token)
199 | # ... (logic to handle pre-existing padding config) ...
200 |
201 | # Helper function (conceptual)
202 | def _infer_token(self, hf_tok, candidates):
203 | for token_str in candidates:
204 | if hf_tok.token_to_id(token_str) is not None:
205 | return token_str
206 | return None
207 | ```
208 |
209 | 2. **Loading (`from_pretrained`)**:
210 | - Checks if the `identifier` is a local path.
211 | - If local: Looks for `tokenizer.json` and `tokenizer_config.json` in that directory.
212 | - If not local (Hub): Uses `huggingface_hub.hf_hub_download` to fetch the files.
213 | - Loads the core tokenizer from `tokenizer.json` using `HfTokenizer.from_file`.
214 | - Loads the configuration (chat template, special tokens) from `tokenizer_config.json` if it exists.
215 | - Prioritizes explicitly passed arguments (`bos_token`, `eos_token`, etc.) over values found in `tokenizer_config.json`.
216 | - Calls the `__init__` method with the loaded `HfTokenizer` and extracted/provided configuration.
217 |
218 | 3. **Encoding (`encode`)**:
219 | - Temporarily configures the `hf_tokenizer` instance based on `padding`, `truncation`, and `max_length` arguments by calling `hf_tokenizer.enable_padding`, `hf_tokenizer.enable_truncation`, or `hf_tokenizer.no_padding`/`no_truncation`.
220 | - Configures the `hf_tokenizer.post_processor` using `TemplateProcessing` if `add_special_tokens=True` and BOS/EOS tokens are available, otherwise sets it to `None`. This handles adding BOS/EOS correctly.
221 | - Calls `hf_tokenizer.encode` (for single text) or `hf_tokenizer.encode_batch` (for list of texts).
222 | - Extracts the `ids` and `attention_mask` from the result(s).
223 | - If `return_tensors="jax"`, converts the lists of integers into `jnp.int32` arrays using `jnp.array`. It handles potential raggedness after encoding if `padding='longest'` was used by manually padding again to the true max length found in the batch *after* encoding.
224 | - Returns the dictionary `{'input_ids': ..., 'attention_mask': ...}`.
225 |
226 | ```mermaid
227 | sequenceDiagram
228 | participant User
229 | participant Tokenizer as JaxTokenizer
230 | participant HfTokenizer as HF Tokenizer
231 | participant JNP as jax.numpy
232 |
233 | User->>+JaxTokenizer: encode(text, padding=True, ...)
234 | JaxTokenizer->>+HfTokenizer: enable_padding(pad_id, pad_token, ...)
235 | JaxTokenizer->>+HfTokenizer: enable_truncation(...) / no_truncation() (based on args)
236 | JaxTokenizer->>+HfTokenizer: set post_processor (for special tokens)
237 | JaxTokenizer->>+HfTokenizer: encode_batch(text) / encode(text)
238 | HfTokenizer-->>-JaxTokenizer: list[Encoding] (contains ids, attention_mask)
239 | JaxTokenizer->>JaxTokenizer: Extract ids and masks into lists
240 | alt return_tensors == "jax"
241 | JaxTokenizer->>JNP: array(ids_list, dtype=int32)
242 | JNP-->>JaxTokenizer: ids_array (jnp.ndarray)
243 | JaxTokenizer->>JNP: array(mask_list, dtype=int32)
244 | JNP-->>JaxTokenizer: mask_array (jnp.ndarray)
245 | JaxTokenizer-->>-User: {'input_ids': ids_array, 'attention_mask': mask_array}
246 | else
247 | JaxTokenizer-->>-User: {'input_ids': ids_list, 'attention_mask': mask_list}
248 | end
249 |
250 | ```
251 |
252 | 4. **Decoding (`decode`)**:
253 | - Converts input `jnp.ndarray` to lists if necessary.
254 | - Determines if the input is a single sequence or a batch.
255 | - Calls the underlying `hf_tokenizer.decode` or `hf_tokenizer.decode_batch` method, passing `skip_special_tokens`.
256 | - Returns the resulting string or list of strings.
257 |
258 | 5. **Chat Templating (`apply_chat_template`)**:
259 | - Selects the Jinja template (either `self.chat_template` or the one passed as an argument). Raises an error if none is available.
260 | - Creates a dictionary of variables to pass to the template, including `messages`, `bos_token`, `eos_token`, and any `**kwargs`.
261 | - Performs basic validation on the template structure (optional, checks for expected variables like `messages`).
262 | - Modifies the template string if `add_generation_prompt=True` (using a common but potentially model-specific pattern).
263 | - Creates a Jinja environment and renders the template with the variables.
264 | - Returns the rendered string.
265 |
266 | ## Conclusion
267 |
268 | The `jaxgarden.Tokenizer` provides a crucial bridge between raw text and JAX-based NLP models. It leverages the power of Hugging Face `tokenizers` while ensuring compatibility with JAX workflows by returning `jnp.ndarray` objects. Key functionalities include easy loading from the Hub/local files, robust encoding/decoding with padding and truncation control, automatic handling of special tokens, and essential chat templating for conversational AI.
269 |
270 | Understanding how to use the `Tokenizer` is the first step in building or using models within `jaxgarden`. The next chapter will introduce the foundational building blocks for models themselves.
271 |
272 | **Next:** [BaseModel](basemodel.mdc)
273 |
274 |
275 | ---
276 |
277 | Generated by [Rules for AI](https://github.com/altaidevorg/rules-for-ai)
--------------------------------------------------------------------------------
/.cursor/rules/llamaforcausallm.mdc:
--------------------------------------------------------------------------------
1 | ---
2 | description: Details the jaxgarden LlamaForCausalLM model for causal language modeling, covering architecture, components, and HF weight conversion.
3 | globs: jaxgarden/models/llama.py
4 | alwaysApply: false
5 | ---
6 | # Chapter 4: LlamaForCausalLM
7 |
8 | In the [previous chapter](baseconfig.mdc), we examined `BaseConfig`, the foundation for configuring models in `jaxgarden`. We saw how model-specific configurations like `LlamaConfig` inherit from it to define hyperparameters. Now, we dive into the complete model implementation that uses this configuration: `LlamaForCausalLM`.
9 |
10 | **Motivation:** Large language models like Meta's Llama have demonstrated remarkable capabilities in natural language understanding and generation. Implementing such complex architectures efficiently within the JAX ecosystem requires careful integration of various components (attention mechanisms, normalization layers, embeddings) and adherence to best practices for performance and state management. `LlamaForCausalLM` provides a faithful and optimized implementation of the Llama architecture, ready for training and inference using JAX and Flax NNX.
11 |
12 | **Central Use Case:** Loading pretrained Llama weights (e.g., from Hugging Face Hub) into a `jaxgarden.LlamaForCausalLM` instance and using it for autoregressive text generation. This involves initializing the model with the correct `LlamaConfig`, leveraging the `from_hf` method inherited from [BaseModel](basemodel.mdc) for weight conversion, and then using the `generate` method from the [GenerationMixin](generationmixin.mdc) for text synthesis.
13 |
14 | ## Key Concepts
15 |
16 | `LlamaForCausalLM` integrates several advanced transformer components into a cohesive causal language model:
17 |
18 | 1. **Core Structure:** It inherits from [BaseModel](basemodel.mdc) for configuration, state management, and HF integration capabilities, and from [GenerationMixin](generationmixin.mdc) for text generation methods.
19 | 2. **Token Embeddings (`nnx.Embed`):** Maps input token IDs from a vocabulary to dense vector representations (hidden states).
20 | 3. **`LlamaTransformerBlock`:** The main building block, repeated multiple times (`n_layers` specified in `LlamaConfig`). Each block contains:
21 | * **`LlamaRMSNorm` (Pre-Normalization):** Applied before the attention and MLP layers for improved training stability. RMSNorm is a simpler and often faster alternative to LayerNorm.
22 | * **`LlamaAttention`:** Implements multi-head self-attention. It incorporates [Rotary Position Embeddings (RoPE)](rotary_position_embeddings__rope_.mdc) via `LlamaRotaryEmbedding` to inject positional information dynamically. It also supports Grouped Query Attention (GQA) where the number of key/value heads (`n_kv_heads`) can be smaller than the number of query heads (`n_heads`) for reduced computational cost and memory footprint, particularly during inference.
23 | * **`LlamaMLP`:** A feed-forward network using the SwiGLU activation function (`silu(gate(x)) * up(x)`), which has shown strong performance in recent models.
24 | 4. **Final Normalization & LM Head:** After the last `LlamaTransformerBlock`, a final `LlamaRMSNorm` is applied. The output is then passed through a linear layer (`lm_head`) that projects the final hidden state back to the vocabulary size, producing logits for the next token prediction.
25 | 5. **Weight Tying:** The weights of the `lm_head` are typically tied to the weights of the `token_embed` layer. This reduces the total number of parameters and can improve performance. This tying is handled during weight initialization and conversion.
26 | 6. **HF Weight Conversion (`convert_weights_from_hf`):** Implements the logic to map parameter names and shapes from Hugging Face Llama checkpoints (stored in Safetensors format) to the specific structure and naming convention of the `jaxgarden` `LlamaForCausalLM` state.
27 | 7. **Text Generation:** Inherits the `generate` method from [GenerationMixin](generationmixin.mdc), enabling autoregressive text generation with sampling strategies like temperature, top-k, and top-p.
28 |
29 | ## Using `LlamaForCausalLM`
30 |
31 | Let's see how to initialize, load weights, and use the model.
32 |
33 | ### Initialization
34 |
35 | First, define the configuration (`LlamaConfig`) and instantiate the model using `nnx.Rngs`.
36 |
37 | ```python
38 | import jax
39 | import jax.numpy as jnp
40 | from flax import nnx
41 | from jaxgarden.models.llama import LlamaConfig, LlamaForCausalLM
42 |
43 | # Example: Configuration for a small Llama model
44 | config = LlamaConfig(
45 | dim=512,
46 | n_layers=4,
47 | n_heads=8,
48 | n_kv_heads=4, # GQA enabled (n_kv_heads < n_heads)
49 | head_dim=64,
50 | intermediate_size=1024,
51 | vocab_size=10000, # Example vocabulary size
52 | norm_eps=1e-5,
53 | rope_theta=10000.0
54 | )
55 |
56 | # Initialize PRNG keys
57 | rngs = nnx.Rngs(params=0) # Or use more sophisticated key management
58 |
59 | # Instantiate the model
60 | model = LlamaForCausalLM(config, rngs=rngs, param_dtype=jnp.bfloat16)
61 |
62 | print(f"Initialized LlamaForCausalLM with {config.n_layers} layers.")
63 | # Output: Initialized LlamaForCausalLM with 4 layers.
64 | ```
65 | **Explanation:** We create a `LlamaConfig` dataclass instance, specifying the architecture details. We then pass this config and an `nnx.Rngs` object to the `LlamaForCausalLM` constructor. `param_dtype=jnp.bfloat16` is often used for large models to save memory.
66 |
67 | ### Loading Pretrained Weights
68 |
69 | To load weights from a Hugging Face checkpoint (assuming you have one compatible, e.g., `"meta-llama/Llama-2-7b-hf"`), use the `from_hf` method.
70 |
71 | ```python
72 | # hf_model_id = "meta-llama/Llama-2-7b-hf" # Example HF model ID
73 | # print(f"Attempting to load weights from {hf_model_id}...")
74 | # try:
75 | # # Ensure config matches the HF model being loaded
76 | # # config_hf = LlamaConfig(...) # Load/define config matching the HF model
77 | # # model_hf = LlamaForCausalLM(config_hf, rngs=nnx.Rngs(1), param_dtype=jnp.bfloat16)
78 | # # model_hf.from_hf(hf_model_id, save_in_orbax=False, remove_hf_after_conversion=True)
79 | # # print("Weights loaded successfully.")
80 | # except Exception as e:
81 | # # print(f"Skipping HF weight loading example due to error: {e}")
82 | # pass
83 | print("Skipping actual HF weight loading execution.")
84 | ```
85 | **Explanation:** Calling `model.from_hf(hf_model_id)` triggers the download (via `BaseModel.download_from_hf`), weight iteration (`BaseModel.iter_safetensors`), and crucially, the conversion logic defined in `LlamaForCausalLM.convert_weights_from_hf`. The model's state is updated in place with the loaded weights. Ensure the `LlamaConfig` used to initialize the model matches the architecture of the Hugging Face checkpoint.
86 |
87 | ### Forward Pass (`__call__`)
88 |
89 | The `__call__` method performs a standard forward pass, taking token IDs and an optional attention mask, and returning logits.
90 |
91 | ```python
92 | # Prepare dummy input (Batch size 1, sequence length 10)
93 | batch_size = 1
94 | seq_len = 10
95 | dummy_input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
96 | dummy_attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) # 1s indicate valid tokens
97 |
98 | # Perform forward pass
99 | # Note: Actual model execution requires JIT compilation or running on device
100 | @jax.jit
101 | def forward_pass(model_state, input_ids, attention_mask):
102 | # Need to split model for JIT
103 | graphdef, params = nnx.split(model, nnx.Param, ...)
104 | logits = graphdef(input_ids=input_ids, attention_mask=attention_mask)
105 | return logits
106 |
107 | # We won't execute the JITted function here, just show the structure
108 | # logits = forward_pass(model.state, dummy_input_ids, dummy_attention_mask)
109 | # print(f"Output logits shape: (batch_size, seq_len, vocab_size) = {logits.shape}")
110 | # Expected output shape if run: (1, 10, 10000)
111 | print(f"Expected output logits shape: ({batch_size}, {seq_len}, {config.vocab_size})")
112 | # Output: Expected output logits shape: (1, 10, 10000)
113 | ```
114 | **Explanation:** The `__call__` method expects `input_ids` (shape `[batch_size, seq_len]`) and an optional `attention_mask` (same shape, 1 for real tokens, 0 for padding). It returns logits (shape `[batch_size, seq_len, vocab_size]`). For performance, the forward pass is typically JIT-compiled. Note that the current implementation asserts `batch_size == 1`.
115 |
116 | ### Text Generation (`generate`)
117 |
118 | Use the inherited `generate` method for autoregressive text generation.
119 |
120 | ```python
121 | from jaxgarden.tokenization import Tokenizer # Assuming Tokenizer is available
122 |
123 | # Assume 'model' is initialized and potentially loaded with weights
124 | # Assume 'tokenizer' is loaded, e.g., Tokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
125 |
126 | prompt = "The capital of France is"
127 | # encoded_prompt = tokenizer.encode(prompt, return_tensors="jax")
128 | # input_ids = encoded_prompt['input_ids']
129 | # attention_mask = encoded_prompt['attention_mask']
130 |
131 | # # Set generation parameters
132 | # max_new_tokens = 10
133 | # target_max_length = input_ids.shape[1] + max_new_tokens
134 |
135 | # # Generate text (requires JIT or device execution)
136 | # generation_rng = jax.random.PRNGKey(42)
137 | # # output_ids = model.generate(
138 | # # input_ids,
139 | # # attention_mask=attention_mask,
140 | # # max_length=target_max_length,
141 | # # temperature=0.7,
142 | # # top_k=50,
143 | # # do_sample=True,
144 | # # pad_token_id=tokenizer.pad_token_id,
145 | # # eos_token_id=tokenizer.eos_token_id,
146 | # # rng=generation_rng,
147 | # # use_jit=True # Recommended for performance
148 | # # )
149 |
150 | # # Decode the output
151 | # # decoded_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
152 | # # print(f"Prompt: {prompt}")
153 | # # print(f"Generated Text: {decoded_output}")
154 | # Expected Output (conceptual):
155 | # Prompt: The capital of France is
156 | # Generated Text: The capital of France is Paris. It is known for...
157 | print("Skipping actual text generation execution.")
158 | ```
159 | **Explanation:** We encode a prompt using the [Tokenizer](tokenizer.mdc). The resulting `input_ids` and `attention_mask` are passed to `model.generate`, along with generation parameters (`max_length`, `temperature`, `top_k`, `do_sample`, etc.) and token IDs for padding and end-of-sequence. The `generate` method (detailed in [GenerationMixin](generationmixin.mdc)) handles the autoregressive sampling loop. The output token IDs are then decoded back into text. Using `use_jit=True` is highly recommended for efficient generation.
160 |
161 | ## Internal Implementation
162 |
163 | Understanding the internal structure helps in debugging and potential extensions.
164 |
165 | * **File:** `jaxgarden/models/llama.py`
166 |
167 | 1. **Initialization (`__init__`)**:
168 | * Calls `super().__init__` from [BaseModel](basemodel.mdc).
169 | * Initializes `self.token_embed` as an `nnx.Embed` layer.
170 | * Creates a list of `LlamaTransformerBlock` instances (`self.layers`), passing the configuration and layer index to each.
171 | * Initializes the final `self.norm` as a `LlamaRMSNorm` layer.
172 | * Initializes `self.lm_head` as an `nnx.Linear` layer. Weight tying with `self.token_embed` happens implicitly during weight loading/conversion (in `convert_weights_from_hf`) or potentially during custom initialization logic not shown here.
173 |
174 | 2. **Forward Pass (`__call__`)**:
175 |
176 | ```mermaid
177 | sequenceDiagram
178 | participant Input
179 | participant LlamaForCausalLM
180 | participant Embed as nnx.Embed
181 | participant Blocks as LlamaTransformerBlock (Loop)
182 | participant Norm as LlamaRMSNorm
183 | participant Head as nnx.Linear (lm_head)
184 | participant Output
185 |
186 | Input->>+LlamaForCausalLM: __call__(input_ids, attention_mask)
187 | LlamaForCausalLM->>LlamaForCausalLM: Calculate position_ids
188 | LlamaForCausalLM->>LlamaForCausalLM: Convert attention_mask to additive mask
189 | LlamaForCausalLM->>+Embed: token_embed(input_ids)
190 | Embed-->>-LlamaForCausalLM: hidden_states (x)
191 | loop N Layers
192 | LlamaForCausalLM->>+Blocks: layer(x, position_ids, additive_mask)
193 | Blocks-->>-LlamaForCausalLM: updated hidden_states (x)
194 | end
195 | LlamaForCausalLM->>+Norm: norm(x)
196 | Norm-->>-LlamaForCausalLM: normalized_states
197 | LlamaForCausalLM->>+Head: lm_head(normalized_states)
198 | Head-->>-LlamaForCausalLM: logits
199 | LlamaForCausalLM-->>-Output: logits
200 | ```
201 | * Calculates `position_ids` based on the sequence length.
202 | * Converts the boolean `attention_mask` into an additive mask (0.0 for valid, -inf for masked) suitable for softmax normalization in attention.
203 | * Passes `input_ids` through `self.token_embed`.
204 | * Iteratively passes the hidden states `x` through each `LlamaTransformerBlock` in `self.layers`, along with `position_ids` and the `attention_mask`.
205 | * Applies the final `self.norm`.
206 | * Applies `self.lm_head` to get the final logits.
207 |
208 | 3. **Hugging Face Weight Conversion (`convert_weights_from_hf`)**:
209 | * This method receives the model's `state` (an `nnx.State` object or dict) and an iterator yielding `(hf_key, hf_tensor)` tuples from the Safetensors files.
210 | * It iterates through the HF weights and maps the HF key names to the corresponding `jaxgarden` state keys.
211 | * Key transformations include:
212 | * Renaming: e.g., `model.layers..self_attn...` -> `state["layers"][]["attention"]...`
213 | * Transposition: Weights from `nn.Linear` layers in HF PyTorch often need to be transposed (`.T`) for Flax `nnx.Linear` kernel shape `[in_features, out_features]`.
214 | * Direct Assignment: Normalization weights (`model.layers..input_layernorm.weight`) are directly assigned (`state["layers"][]["input_layernorm"]["norm_weights"].value = tensor`).
215 | * Weight Tying: The embedding weights (`model.embed_tokens.weight`) are assigned to both `state["token_embed"].embedding` and (after transposition) `state["lm_head"].kernel`.
216 |
217 | ```python
218 | # Snippet from LlamaForCausalLM.convert_weights_from_hf
219 | def convert_weights_from_hf(
220 | self, state: nnx.State | dict[str, jnp.ndarray], weights: Iterator[tuple[Any, Any]]
221 | ) -> None:
222 | for wholekey, tensor in weights:
223 | keys = wholekey.split(".")
224 | if keys[0] == "model": # Strip 'model.' prefix often found in HF checkpoints
225 | keys = keys[1:]
226 |
227 | if keys[0] == "layers":
228 | layer_idx = int(keys[1])
229 | component = keys[2]
230 | param_name = keys[3]
231 | target_state = state["layers"][layer_idx]
232 |
233 | if component == "self_attn": component = "attention" # Rename
234 |
235 | if component in ["attention", "mlp"]:
236 | # Linear layer weights need transpose
237 | target_state[component][param_name]["kernel"].value = tensor.T
238 | elif component in ["input_layernorm", "post_attention_layernorm"]:
239 | # RMSNorm weights
240 | target_state[component]["norm_weights"].value = tensor
241 | elif keys[0] == "embed_tokens":
242 | state["token_embed"].embedding.value = tensor
243 | # Tie weights with lm_head (transpose required)
244 | state["lm_head"].kernel.value = tensor.T
245 | elif keys[0] == "norm":
246 | state["norm"].norm_weights.value = tensor
247 | elif keys[0] == "lm_head":
248 | # This case might occur if weights aren't tied in HF checkpoint
249 | # If already tied via embed_tokens, this might overwrite or be redundant
250 | if not jnp.array_equal(state["lm_head"].kernel.value, tensor.T):
251 | print(f"Warning: Overwriting lm_head from separate checkpoint tensor: {wholekey}")
252 | state["lm_head"].kernel.value = tensor.T
253 | # else: # Handle potential unexpected keys
254 | # print(f"Warning: Unhandled HF key: {wholekey}")
255 |
256 | ```
257 |
258 | ## Conclusion
259 |
260 | `LlamaForCausalLM` provides a comprehensive and efficient implementation of the Llama architecture within the `jaxgarden` framework. By leveraging `BaseModel` for structure and `GenerationMixin` for functionality, and integrating key components like `LlamaRMSNorm`, `LlamaAttention` with RoPE and GQA, and `LlamaMLP` with SwiGLU, it offers a powerful tool for causal language modeling tasks. Its ability to load pretrained weights via `convert_weights_from_hf` makes it easy to utilize existing Llama models directly in JAX.
261 |
262 | The next chapter will delve deeper into the text generation capabilities provided by the mixin class used by this model.
263 |
264 | **Next:** [GenerationMixin](generationmixin.mdc)
265 |
266 |
267 | ---
268 |
269 | Generated by [Rules for AI](https://github.com/altaidevorg/rules-for-ai)
--------------------------------------------------------------------------------
/.cursor/rules/modernbertformaskedlm.mdc:
--------------------------------------------------------------------------------
1 | ---
2 | description: Details the jaxgarden ModernBERTForMaskedLM model, a bidirectional encoder with RoPE, pre-LN, and mixed attention for MLM tasks.
3 | globs: jaxgarden/models/modernbert.py
4 | alwaysApply: false
5 | ---
6 | # Chapter 6: ModernBERTForMaskedLM
7 |
8 | In the [previous chapter](generationmixin.mdc), we explored `GenerationMixin`, focusing on autoregressive text generation for causal language models. Now, we shift our focus to a different type of transformer architecture: a bidirectional encoder designed specifically for Masked Language Modeling (MLM) tasks, incorporating modern optimizations. Welcome to `ModernBERTForMaskedLM`.
9 |
10 | **Motivation:** While causal models excel at text generation, many NLP tasks benefit from understanding context bidirectionally (both left and right). Traditional BERT achieved this but often suffered from computational inefficiency, especially with long sequences. The ModernBERT architecture, as proposed by Answer.AI, aims to create a "Smarter, Better, Faster, Longer" bidirectional encoder by incorporating techniques like Rotary Position Embeddings (RoPE), pre-Layer Normalization, and efficient attention mechanisms, making it suitable for fast, memory-efficient training and inference, particularly for long contexts. `ModernBERTForMaskedLM` implements this architecture within `jaxgarden` for the MLM pre-training objective.
11 |
12 | **Central Use Case:** Pre-training or fine-tuning a language model using the masked language modeling objective, where the model predicts randomly masked tokens in the input sequence. This model can also serve as a powerful feature extractor for downstream NLP tasks requiring bidirectional context understanding, potentially after loading pre-trained weights using the framework provided by [BaseModel](basemodel.mdc).
13 |
14 | ## Key Concepts
15 |
16 | `ModernBERTForMaskedLM` integrates several components and modern architectural choices:
17 |
18 | 1. **Inheritance:** It inherits from [BaseModel](basemodel.mdc), gaining standardized configuration management, state handling (Flax NNX), checkpointing (`save`/`load`), and the interface for Hugging Face weight conversion (`from_hf`).
19 | 2. **`ModernBertEmbeddings`:** Handles the initial input processing. It converts token IDs into dense vectors using an embedding layer and applies Layer Normalization. Crucially, unlike traditional BERT, it does *not* include explicit learned position embeddings; positional information is injected later via RoPE.
20 | 3. **`ModernBERTEncoder`:** The core of the model, consisting of a stack of `ModernBertLayer` instances (`num_hidden_layers` defined in `ModernBERTConfig`).
21 | 4. **`ModernBertLayer`:** Implements a single transformer block using the **Pre-LayerNorm** architecture (LayerNorm applied *before* the attention and MLP sub-layers, followed by residual connections). This structure often leads to more stable training compared to the original Post-LayerNorm BERT. Each layer contains:
22 | * `ModernBertAttention`: Performs multi-head self-attention.
23 | * `ModernBertMLP`: A feed-forward network applied after attention.
24 | 5. **`ModernBertAttention`:**
25 | * Calculates query, key, and value projections from the input.
26 | * Applies **[Rotary Position Embeddings (RoPE)](rotary_position_embeddings__rope_.mdc)** directly to the query and key vectors before computing attention scores. `RoPEPositionalEmbedding` generates the necessary sinusoidal encodings.
27 | * Supports **Mixed Global and Local Attention:** Can operate in standard global attention mode (all tokens attend to all others) or use a sliding window (local attention) where tokens only attend to nearby tokens (`local_attention` parameter). This can be configured to happen only on certain layers (`global_attn_every_n_layers`) to balance global context understanding with computational efficiency.
28 | 6. **`ModernBertMLP`:** Uses GELU activation function within the feed-forward network.
29 | 7. **`ModernBERTMLMHead`:** Added on top of the `ModernBERTEncoder`. It takes the final hidden states, applies Layer Normalization, a dense layer with GELU activation, and finally projects the result to the vocabulary size to produce logits for predicting the original masked tokens. The final projection layer (decoder) is often tied to the token embedding weights.
30 |
31 | ## Using `ModernBERTForMaskedLM`
32 |
33 | Let's see how to instantiate and use the model.
34 |
35 | ### Initialization
36 |
37 | Define a `ModernBERTConfig` and initialize the model. Key parameters in the config control the architecture, including attention behavior.
38 |
39 | ```python
40 | import jax
41 | import jax.numpy as jnp
42 | from flax import nnx
43 | from jaxgarden.models.modernbert import ModernBERTConfig, ModernBERTForMaskedLM
44 |
45 | # Configuration for a smaller ModernBERT model
46 | config = ModernBERTConfig(
47 | vocab_size=30522, # Example BERT vocab size
48 | hidden_size=256,
49 | num_hidden_layers=4,
50 | num_attention_heads=8,
51 | intermediate_size=512,
52 | max_position_embeddings=512, # Max sequence length for RoPE cache
53 | attention_dropout=0.1,
54 | hidden_dropout=0.1,
55 | # Use local attention (window size 128 left, 128 right)
56 | local_attention=(128, 128),
57 | # Apply global attention every 2 layers (layers 0, 2)
58 | global_attn_every_n_layers=2,
59 | pad_token_id=0
60 | )
61 |
62 | # Initialize PRNG keys
63 | rngs = nnx.Rngs(params=0) # Or use more sophisticated key management
64 |
65 | # Instantiate the model
66 | model = ModernBERTForMaskedLM(config, rngs=rngs, param_dtype=jnp.float32)
67 |
68 | print(f"Initialized ModernBERTForMaskedLM with {config.num_hidden_layers} layers.")
69 | print(f"Local attention window: {config.local_attention}")
70 | print(f"Global attention every {config.global_attn_every_n_layers} layers.")
71 | # Output: Initialized ModernBERTForMaskedLM with 4 layers.
72 | # Output: Local attention window: (128, 128)
73 | # Output: Global attention every 2 layers.
74 | ```
75 | **Explanation:** We create a `ModernBERTConfig` instance, customizing parameters like layer count, hidden size, and attention settings (`local_attention`, `global_attn_every_n_layers`). We then instantiate `ModernBERTForMaskedLM` with this config and an `nnx.Rngs` object.
76 |
77 | ### Forward Pass (`__call__`)
78 |
79 | The `__call__` method takes token IDs and an optional attention mask, returning a dictionary containing logits and optionally hidden states and attentions.
80 |
81 | ```python
82 | # Prepare dummy input (Batch size 2, sequence length 64)
83 | batch_size = 2
84 | seq_len = 64
85 | dummy_input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
86 | # Mask indicating valid tokens (1) vs padding (0)
87 | dummy_attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
88 |
89 | # Define forward pass function (typically JIT-compiled)
90 | @jax.jit
91 | def forward_pass(model_state, input_ids, attention_mask):
92 | # Split graphdef/params for JIT
93 | graphdef, params = nnx.split(model, nnx.Param, ...) # simplified split
94 | # Call the model's graph definition
95 | outputs = graphdef(
96 | input_ids=input_ids,
97 | attention_mask=attention_mask,
98 | deterministic=True # Disable dropout for inference
99 | )
100 | return outputs['logits'] # Return only logits for this example
101 |
102 | # --- Execution ---
103 | # We won't execute the JITted function here, just show structure/shapes
104 | # logits = forward_pass(model.state, dummy_input_ids, dummy_attention_mask)
105 | # print(f"Output logits shape: {logits.shape}")
106 | # Expected output shape if run: (2, 64, 30522)
107 | print(f"Expected output logits shape: ({batch_size}, {seq_len}, {config.vocab_size})")
108 | # Output: Expected output logits shape: (2, 64, 30522)
109 | ```
110 | **Explanation:**
111 | - `input_ids`: `[batch_size, seq_len]` tensor of token IDs.
112 | - `attention_mask`: `[batch_size, seq_len]` tensor (1 for real tokens, 0 for padding). This mask is used by the attention mechanism to prevent attending to padding tokens.
113 | - `deterministic`: If `True`, dropout layers are disabled. Typically `True` for evaluation/inference and `False` for training.
114 | - The output is a dictionary. The primary output for MLM is `logits` (`[batch_size, seq_len, vocab_size]`), representing the model's prediction scores for each token position. Optional keys `hidden_states` and `attentions` can be requested via `output_hidden_states=True` and `output_attentions=True`.
115 |
116 | ### Loading Pretrained Weights (Conceptual)
117 |
118 | Like other `BaseModel` subclasses, `ModernBERTForMaskedLM` can potentially load weights from Hugging Face using `from_hf`, provided a `convert_weights_from_hf` method is implemented for this specific architecture (mapping HF ModernBERT checkpoint names to `jaxgarden` state names).
119 |
120 | ```python
121 | # hf_model_id = "AnswerDotAI/modernbert-8k-S-500k" # Example HF model ID
122 | # print(f"Attempting to load weights from {hf_model_id}...")
123 | # try:
124 | # # Ensure config matches the HF model being loaded
125 | # # config_hf = ModernBERTConfig(...) # Load/define config matching HF model
126 | # # model_hf = ModernBERTForMaskedLM(config_hf, rngs=nnx.Rngs(1))
127 | #
128 | # # This call requires a ModernBERT-specific implementation of convert_weights_from_hf
129 | # # model_hf.from_hf(hf_model_id)
130 | #
131 | # # print("Weights loaded successfully.")
132 | # except NotImplementedError:
133 | # print(f"NOTE: {type(model).__name__} does not currently implement HF weight conversion.")
134 | # except Exception as e:
135 | # # print(f"Skipping HF weight loading example due to error: {e}")
136 | # pass
137 | print("Skipping actual HF weight loading execution (requires implementation).")
138 | ```
139 |
140 | ## Internal Implementation
141 |
142 | Understanding the flow of data within the model helps in debugging and customization.
143 |
144 | **High-Level Walkthrough (`__call__`)**:
145 |
146 | 1. **Embeddings:** Input `input_ids` are passed to `self.embeddings` (`ModernBertEmbeddings`), which performs token lookup, scales embeddings, applies Layer Normalization, and optional dropout. Result: `hidden_states`.
147 | 2. **Attention Mask Prep:** The input `attention_mask` (e.g., `[1, 1, 0]`) is converted internally by the attention mechanism into an additive mask (e.g., `[0, 0, -inf]`) suitable for adding to attention scores before softmax. Sliding window masks are generated or used if local attention is active for the current layer.
148 | 3. **Encoder:** `hidden_states` and the prepared `attention_mask` (and optionally `position_ids`) are passed to `self.encoder` (`ModernBERTEncoder`).
149 | * The encoder iterates through its list of `ModernBertLayer`s.
150 | * Each `ModernBertLayer`:
151 | * Applies LayerNorm to the input (`attn_norm`, potentially `Identity` for layer 0).
152 | * Passes the normalized input to `self.attn` (`ModernBertAttention`), which calculates attention output using RoPE and potentially local masking.
153 | * Adds the attention output back to the layer's input (first residual connection).
154 | * Applies LayerNorm to the result (`mlp_norm`).
155 | * Passes the normalized result to `self.mlp` (`ModernBertMLP`).
156 | * Adds the MLP output back (second residual connection).
157 | * The final output of the last layer is passed through a final LayerNorm (`self.encoder.final_norm`). Result: `sequence_output`.
158 | 4. **MLM Head:** `sequence_output` is passed to `self.mlm_head` (`ModernBERTMLMHead`).
159 | * Applies LayerNorm (`self.mlm_head.norm`).
160 | * Applies a dense layer (`self.mlm_head.dense`) followed by GELU activation.
161 | * Applies the final decoder/projection layer (`self.mlm_head.decoder`) to get scores over the vocabulary. Result: `logits`.
162 | 5. **Output:** Returns a dictionary containing `logits` and any requested optional outputs (`hidden_states`, `attentions`).
163 |
164 | ```mermaid
165 | sequenceDiagram
166 | participant Input
167 | participant Model as ModernBERTForMaskedLM
168 | participant Embed as ModernBertEmbeddings
169 | participant Encoder as ModernBERTEncoder
170 | participant Layer as ModernBertLayer (Loop)
171 | participant Attn as ModernBertAttention
172 | participant MLP as ModernBertMLP
173 | participant Head as ModernBERTMLMHead
174 | participant OutputDict
175 |
176 | Input->>+Model: __call__(input_ids, attention_mask)
177 | Model->>+Embed: embeddings(input_ids)
178 | Embed-->>-Model: hidden_states
179 | Model->>+Encoder: encoder(hidden_states, attention_mask)
180 | loop N Layers
181 | Encoder->>+Layer: layer(hidden_states, mask)
182 | Layer->>Layer: Apply attn_norm (pre-norm)
183 | Layer->>+Attn: attn(normed_states, mask) # Applies RoPE internally
184 | Attn-->>-Layer: attention_output
185 | Layer->>Layer: Residual 1 (hidden_states + attention_output)
186 | Layer->>Layer: Apply mlp_norm (pre-norm)
187 | Layer->>+MLP: mlp(normed_states)
188 | MLP-->>-Layer: mlp_output
189 | Layer->>Layer: Residual 2 (hidden_states + mlp_output) -> new hidden_states
190 | Layer-->>-Encoder: updated hidden_states
191 | end
192 | Encoder->>Encoder: Apply final_norm
193 | Encoder-->>-Model: sequence_output
194 | Model->>+Head: mlm_head(sequence_output)
195 | Head-->>-Model: logits
196 | Model->>+OutputDict: Create {'logits': logits, ...}
197 | OutputDict-->>-Model: output_dict
198 | Model-->>-Input: output_dict
199 | ```
200 |
201 | **Code Details (from `jaxgarden/models/modernbert.py`)**:
202 |
203 | * **`ModernBERTForMaskedLM.__init__`**: Initializes the main components (`self.embeddings`, `self.encoder`, `self.mlm_head`) by passing the `config` and `rngs`.
204 | ```python
205 | # Simplified from ModernBERTForMaskedLM.__init__
206 | class ModernBERTForMaskedLM(BaseModel):
207 | def __init__(self, config: ModernBERTConfig, *, rngs: nnx.Rngs, ...):
208 | super().__init__(config, rngs=rngs, ...)
209 |
210 | self.embeddings = ModernBertEmbeddings(rngs=rngs, ...) # Pass relevant config fields
211 | self.encoder = ModernBERTEncoder(rngs=rngs, ...) # Pass relevant config fields
212 | self.mlm_head = ModernBERTMLMHead(rngs=rngs, ...) # Pass relevant config fields
213 | ```
214 |
215 | * **`ModernBERTForMaskedLM.__call__`**: Orchestrates the forward pass.
216 | ```python
217 | # Simplified from ModernBERTForMaskedLM.__call__
218 | def __call__(self, input_ids, attention_mask=None, ..., deterministic=True, ...):
219 | # 1. Get embeddings
220 | hidden_states = self.embeddings(
221 | input_ids=input_ids, deterministic=deterministic, ...
222 | )
223 |
224 | # 2. Apply encoder
225 | encoder_outputs = self.encoder(
226 | hidden_states, attention_mask=attention_mask, deterministic=deterministic, ...
227 | )
228 | sequence_output = encoder_outputs[0]
229 | # (Handle optional hidden_states/attentions from encoder_outputs)
230 |
231 | # 3. Apply MLM head
232 | logits = self.mlm_head(sequence_output)
233 |
234 | # 4. Build output dictionary
235 | outputs = {"logits": logits}
236 | # (Add optional outputs if requested)
237 | return outputs
238 | ```
239 |
240 | * **`ModernBertAttention.__call__`**: Key steps include QKV projection, RoPE application, attention score calculation, mask application (including sliding window if active), softmax, dropout, and output projection.
241 | ```python
242 | # Conceptual steps within ModernBertAttention.__call__
243 | # qkv = self.Wqkv(hidden_states)
244 | # query, key, value = split_heads_and_transpose(qkv)
245 | # query = apply_rotary_pos_emb(query, self.rotary_emb.cache, position_ids)
246 | # key = apply_rotary_pos_emb(key, self.rotary_emb.cache, position_ids)
247 | # attention_scores = compute_scores(query, key)
248 | # if self.local_attention != (-1, -1):
249 | # sliding_window_mask = create_sliding_window_mask(...) or use provided
250 | # attention_scores += sliding_window_mask
251 | # if attention_mask is not None:
252 | # attention_scores += attention_mask # Additive mask
253 | # attention_probs = softmax(attention_scores)
254 | # # Apply dropout if needed
255 | # attention_output = combine_heads(matmul(attention_probs, value))
256 | # attention_output = self.Wo(attention_output)
257 | ```
258 |
259 | ## Conclusion
260 |
261 | `ModernBERTForMaskedLM` provides a `jaxgarden` implementation of a modern, efficient bidirectional transformer encoder tailored for masked language modeling. By incorporating features like RoPE, pre-Layer Normalization, and optional sliding window attention, it aims to improve upon traditional BERT architectures in terms of speed, memory usage, and handling of long sequences. It leverages the `BaseModel` foundation for consistency and integrates specialized components like `ModernBertEmbeddings`, `ModernBERTEncoder`, `ModernBertAttention`, `ModernBertMLP`, and `ModernBERTMLMHead`.
262 |
263 | Understanding the general principles of attention mechanisms is fundamental to grasping how models like ModernBERT work internally. The next chapter delves into the details of attention itself.
264 |
265 | **Next:** [Attention Mechanism (MultiHeadAttention / dot_product_attention)](attention_mechanism__multiheadattention___dot_product_attention_.mdc)
266 |
267 |
268 | ---
269 |
270 | Generated by [Rules for AI](https://github.com/altaidevorg/rules-for-ai)
--------------------------------------------------------------------------------