├── tests ├── __init__.py ├── run_tests.py ├── models │ ├── test_base.py │ ├── test_gemma3.py │ ├── test_llama.py │ ├── test_t5.py │ ├── test_gemma2.py │ └── test_generation_utils.py ├── functional │ └── test_attention.py └── attention │ ├── test_RoPEMultiHeadAttention.py │ └── test_multi_head_attention.py ├── .python-version ├── docs ├── .gitignore ├── modules │ └── index.rst ├── Makefile ├── make.bat ├── index.rst └── conf.py ├── examples ├── __init__.py ├── t5_inference_example.py ├── gemma2_inference_example.py ├── llama_inference_example.py ├── gemma3_inference_example.py └── multi_head_attention_example.py ├── assets └── logo.png ├── .pre-commit-config.yaml ├── jaxgarden ├── attention │ ├── __init__.py │ ├── multi_head_attention.py │ └── rope_multi_head_attention.py ├── functional │ ├── __init__.py │ └── attention.py ├── models │ ├── __init__.py │ └── base.py └── __init__.py ├── .github └── workflows │ ├── ruff.yml │ ├── mypy.yml │ ├── tests.yml │ └── docs.yml ├── .gitignore ├── LICENSE ├── .devcontainer ├── devcontainer.json ├── Dockerfile ├── README.md └── verify_jax_cuda.py ├── pyproject.toml ├── .cursor └── rules │ ├── guide.mdc │ ├── baseconfig.mdc │ ├── rotary_position_embeddings__rope_.mdc │ ├── generationmixin.mdc │ ├── tokenizer.mdc │ ├── llamaforcausallm.mdc │ └── modernbertformaskedlm.mdc └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.10 2 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | _build/ 2 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | """Example scripts for the jax-layers library.""" 2 | -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-gde/jaxgarden/HEAD/assets/logo.png -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.9.10 4 | hooks: 5 | - id: ruff 6 | - id: ruff-format 7 | -------------------------------------------------------------------------------- /jaxgarden/attention/__init__.py: -------------------------------------------------------------------------------- 1 | """Attention modules for JAXgarden.""" 2 | 3 | from jaxgarden.attention.multi_head_attention import MultiHeadAttention 4 | 5 | __all__ = [ 6 | "MultiHeadAttention", 7 | ] 8 | -------------------------------------------------------------------------------- /jaxgarden/functional/__init__.py: -------------------------------------------------------------------------------- 1 | """Functional implementations for JAXgarden.""" 2 | 3 | from jaxgarden.functional.attention import dot_product_attention 4 | 5 | __all__ = [ 6 | "dot_product_attention", 7 | ] 8 | -------------------------------------------------------------------------------- /.github/workflows/ruff.yml: -------------------------------------------------------------------------------- 1 | name: Ruff 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | ruff: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v4 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: "3.10" 20 | 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install ruff 25 | 26 | - name: Run Ruff 27 | run: | 28 | ruff check . 29 | ruff format --check . -------------------------------------------------------------------------------- /docs/modules/index.rst: -------------------------------------------------------------------------------- 1 | API Documentation 2 | ================ 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | attention 8 | functional 9 | models 10 | 11 | Attention Modules 12 | --------------- 13 | 14 | .. automodule:: jaxgarden.attention 15 | :members: 16 | :undoc-members: 17 | :show-inheritance: 18 | 19 | Functional Interfaces 20 | ------------------ 21 | 22 | .. automodule:: jaxgarden.functional 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | 27 | Models 28 | ------------------ 29 | 30 | .. automodule:: jaxgarden.models 31 | :members: 32 | :undoc-members: 33 | :show-inheritance: -------------------------------------------------------------------------------- /.github/workflows/mypy.yml: -------------------------------------------------------------------------------- 1 | name: Type Checking 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | mypy: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v4 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: "3.10" 20 | 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install -e ".[dev]" 25 | pip install mypy 26 | 27 | - name: Run mypy 28 | run: | 29 | mypy jaxgarden tests -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | 3 | # You can set these variables from the command line, and also 4 | # from the environment for the first two. 5 | SPHINXOPTS ?= 6 | SPHINXBUILD ?= sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /examples/t5_inference_example.py: -------------------------------------------------------------------------------- 1 | from flax import nnx 2 | 3 | from jaxgarden import T5Config, T5ForCausalLM, Tokenizer 4 | 5 | if __name__ == "__main__": 6 | config = T5Config() 7 | model = T5ForCausalLM(config, rngs=nnx.Rngs(0)) 8 | model_id = "google-t5/t5-base" 9 | 10 | # download checkpoint from HuggingFace Hub 11 | model.from_hf(model_id, force_download=True) 12 | 13 | tokenizer = Tokenizer.from_pretrained(model_id) 14 | 15 | text = "The meaning of life is" 16 | model_inputs = tokenizer.encode(text) 17 | output = model.generate(**model_inputs, max_length=20, do_sample=True) 18 | output_text = tokenizer.decode(output) 19 | 20 | print(output, output.shape) 21 | print(output_text) 22 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.10", "3.11", "3.12"] 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install -e ".[dev]" 28 | 29 | - name: Run tests 30 | run: | 31 | python tests/run_tests.py -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Documentation 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | docs: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v4 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: "3.10" 20 | 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install -e ".[dev]" 25 | 26 | - name: Build documentation 27 | run: | 28 | cd docs 29 | make html 30 | 31 | - name: Deploy to GitHub Pages 32 | if: github.event_name == 'push' && github.ref == 'refs/heads/main' 33 | uses: peaceiris/actions-gh-pages@v3 34 | with: 35 | github_token: ${{ secrets.GITHUB_TOKEN }} 36 | publish_dir: ./docs/_build/html -------------------------------------------------------------------------------- /examples/gemma2_inference_example.py: -------------------------------------------------------------------------------- 1 | from flax import nnx 2 | 3 | from jaxgarden import Gemma2Config, Gemma2ForCausalLM, Tokenizer 4 | 5 | # HF repo id of the Gemma variant that you want to use 6 | model_id = "google/gemma-2-2b-it" 7 | 8 | # initialize the Gemma architecture 9 | config = Gemma2Config() 10 | model = Gemma2ForCausalLM(config, rngs=nnx.Rngs(0)) 11 | 12 | # This is a one-liner to download HF checkpoint from HuggingFace Hub, 13 | # convert it to jaxgarden format, 14 | # save it in an Orbax checkpoint, 15 | # and then remove the HF checkpoint. 16 | model.from_hf(model_id, force_download=True) 17 | 18 | # this works just like `transformers.AutoTokenizer`, 19 | # but without the dependency of the whole `transformers` library. 20 | # Instead, we simply extend `tokenizers` package and add some cnvenience code for JAX. 21 | tokenizer = Tokenizer.from_pretrained(model_id) 22 | 23 | text = "The meaning of life is" 24 | model_inputs = tokenizer.encode(text) 25 | output = model.generate(**model_inputs, max_length=20, do_sample=True) 26 | output_text = tokenizer.decode(output) 27 | print(output_text) 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python-generated files 2 | __pycache__/ 3 | *.py[oc] 4 | build/ 5 | dist/ 6 | wheels/ 7 | *.egg-info 8 | 9 | # Virtual environments 10 | .venv 11 | jax_env 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Jupyter Notebook 52 | .ipynb_checkpoints 53 | 54 | # Environments 55 | .env 56 | .venv 57 | env/ 58 | venv/ 59 | ENV/ 60 | env.bak/ 61 | venv.bak/ 62 | 63 | # IDE specific files 64 | .idea/ 65 | .vscode/ 66 | *.swp 67 | *.swo 68 | 69 | # Library specific 70 | *.npy 71 | *.npz 72 | .ruff_cache/ 73 | mypy_cache/ 74 | pytest_cache/ 75 | 76 | # OS specific 77 | .DS_Store 78 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to JAXgarden documentation! 2 | ===================================== 3 | 4 | JAXgarden provides high-performance and hackable neural network model implementations in JAX, leveraging optimized kernals and layers like FlashAttention. 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | :caption: Contents: 9 | 10 | readme 11 | modules/index 12 | contributing 13 | changelog 14 | 15 | Features 16 | -------- 17 | 18 | - **MultiHeadAttention**: A Flax NNX-compatible implementation with support for different attention backends. 19 | 20 | Installation 21 | ------------ 22 | 23 | .. code-block:: bash 24 | 25 | pip install git+https://github.com/ml-gde/jax-layers.git 26 | 27 | For development installation: 28 | 29 | .. code-block:: bash 30 | 31 | # first, fork the repository to your account. 32 | # Then, clone it to your machine. 33 | git clone https://github.com/yourusername/jax-layers.git 34 | cd jax-layers 35 | pip install -e ".[dev]" 36 | 37 | Indices and tables 38 | ================== 39 | 40 | * :ref:`genindex` 41 | * :ref:`modindex` 42 | * :ref:`search` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 JAX Layers Contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/run_tests.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Script to run all tests for the jax-layers project.""" 3 | 4 | import os 5 | import sys 6 | 7 | import pytest 8 | 9 | 10 | def main(): 11 | """Run all tests for the jax-layers project.""" 12 | # Get the directory of this script 13 | script_dir = os.path.dirname(os.path.abspath(__file__)) 14 | 15 | # Get the root directory of the project 16 | root_dir = os.path.dirname(script_dir) 17 | 18 | # Add the root directory to the Python path 19 | sys.path.insert(0, root_dir) 20 | 21 | # Run the tests 22 | args = [ 23 | "-xvs", # -x: exit on first failure, -v: verbose, -s: don't capture stdout 24 | script_dir, # Run all tests in the tests directory 25 | "--cov=jax_layers", # Generate coverage report for jax_layers 26 | "--cov-report=term", # Output coverage report to terminal 27 | ] 28 | 29 | # Add any additional arguments passed to this script 30 | args.extend(sys.argv[1:]) 31 | 32 | # Run pytest with the arguments 33 | return pytest.main(args) 34 | 35 | 36 | if __name__ == "__main__": 37 | sys.exit(main()) 38 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "JAX Layers Development", 3 | "build": { 4 | "dockerfile": "Dockerfile", 5 | "context": ".." 6 | }, 7 | "runArgs": [ 8 | "--gpus=all" 9 | ], 10 | "customizations": { 11 | "vscode": { 12 | "extensions": [ 13 | "ms-python.python", 14 | "ms-python.vscode-pylance", 15 | "charliermarsh.ruff", 16 | "matangover.mypy" 17 | ], 18 | "settings": { 19 | "python.defaultInterpreterPath": "/usr/local/bin/python", 20 | "python.linting.enabled": true, 21 | "editor.formatOnSave": true, 22 | "editor.codeActionsOnSave": { 23 | "source.organizeImports": "always", 24 | "source.fixAll": "always" 25 | }, 26 | "python.formatting.provider": "none", 27 | "[python]": { 28 | "editor.defaultFormatter": "charliermarsh.ruff" 29 | } 30 | } 31 | } 32 | }, 33 | "forwardPorts": [], 34 | "postCreateCommand": "pip install -e '.[dev]'", 35 | "remoteUser": "vscode" 36 | } -------------------------------------------------------------------------------- /tests/models/test_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import jax 5 | import numpy as np 6 | from flax import nnx 7 | 8 | from jaxgarden.models.base import BaseConfig, BaseModel 9 | 10 | 11 | class TwoLayerMLPConfig(BaseConfig): 12 | dim: int = 4 13 | use_bias: bool = False 14 | 15 | 16 | class TwoLayerMLP(BaseModel): 17 | def __init__(self, config, rngs: nnx.Rngs): 18 | self.linear1 = nnx.Linear(config.dim, config.dim, rngs=rngs, use_bias=config.use_bias) 19 | self.linear2 = nnx.Linear(config.dim, config.dim, rngs=rngs, use_bias=config.use_bias) 20 | 21 | def __call__(self, x): 22 | x = self.linear1(x) 23 | return self.linear2(x) 24 | 25 | 26 | def test_save_and_load(): 27 | ckpt_dir = "/tmp/jaxgarden_test_ckpt" 28 | if os.path.exists(ckpt_dir): 29 | shutil.rmtree(ckpt_dir) 30 | 31 | config = TwoLayerMLPConfig() 32 | model = TwoLayerMLP(config, rngs=nnx.Rngs(0)) 33 | x = jax.random.normal(jax.random.key(42), (3, 4)) 34 | assert model(x).shape == (3, 4) 35 | state = model.state_dict 36 | model.save(ckpt_dir) 37 | model_restored = TwoLayerMLP(config, rngs=nnx.Rngs(1)).load(ckpt_dir) 38 | state_restored = model_restored.state_dict 39 | jax.tree.map(np.testing.assert_array_equal, state, state_restored) 40 | -------------------------------------------------------------------------------- /examples/llama_inference_example.py: -------------------------------------------------------------------------------- 1 | from flax import nnx 2 | 3 | from jaxgarden import LlamaConfig, LlamaForCausalLM, Tokenizer 4 | 5 | if __name__ == "__main__": 6 | # initialize a config object (with defaults for 1B varient) 7 | # other varients to be added. 8 | config = LlamaConfig() 9 | model = LlamaForCausalLM(config, rngs=nnx.Rngs(0)) 10 | model_id = "meta-llama/Llama-3.2-1B" 11 | 12 | # this will download HF checkpoint from HuggingFace Hub, 13 | # convert it to jaxgarden format, 14 | # save it in an Orbax checkpoint, 15 | # and then remove the HF checkpoint. 16 | # If you didn't set your HF token globally, 17 | # you may need to pass your token as an argument to this method. 18 | model.from_hf(model_id, force_download=True) 19 | 20 | # this works just like `transformers.AutoTokenizer`, 21 | # but without the dependency of the whole `transformers` library. 22 | # Instead, we simply extend `tokenizers` package and add some cnvenience code for JAX. 23 | tokenizer = Tokenizer.from_pretrained(model_id) 24 | 25 | text = "The meaning of life is" 26 | model_inputs = tokenizer.encode(text) 27 | output = model.generate(**model_inputs, max_length=20, do_sample=True) 28 | output_text = tokenizer.decode(output) 29 | print(output, output.shape) 30 | print(output_text) 31 | -------------------------------------------------------------------------------- /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04 2 | 3 | # Avoid prompts from apt 4 | ENV DEBIAN_FRONTEND=noninteractive 5 | 6 | # Install Python and other dependencies 7 | RUN apt-get update && apt-get install -y \ 8 | python3.10 \ 9 | python3.10-dev \ 10 | python3-pip \ 11 | python3-venv \ 12 | git \ 13 | curl \ 14 | wget \ 15 | build-essential \ 16 | && rm -rf /var/lib/apt/lists/* 17 | 18 | # Create symbolic links for python 19 | RUN ln -sf /usr/bin/python3.10 /usr/bin/python && \ 20 | ln -sf /usr/bin/python3.10 /usr/bin/python3 21 | 22 | # Create a non-root user 23 | ARG USERNAME=vscode 24 | ARG USER_UID=1000 25 | ARG USER_GID=$USER_UID 26 | 27 | RUN groupadd --gid $USER_GID $USERNAME \ 28 | && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \ 29 | && apt-get update \ 30 | && apt-get install -y sudo \ 31 | && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \ 32 | && chmod 0440 /etc/sudoers.d/$USERNAME 33 | 34 | # Set up Python environment 35 | RUN python -m pip install --upgrade pip setuptools wheel 36 | 37 | # Install JAX with CUDA support 38 | RUN pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 39 | 40 | # Install matplotlib for benchmarking script 41 | RUN pip install matplotlib 42 | 43 | # Set working directory 44 | WORKDIR /workspace 45 | 46 | # Switch to non-root user 47 | USER $USERNAME 48 | 49 | # Note: The project dependencies will be installed by the postCreateCommand in devcontainer.json 50 | # which runs `pip install -e '.[dev]'` after the container is created -------------------------------------------------------------------------------- /tests/models/test_gemma3.py: -------------------------------------------------------------------------------- 1 | """Tests for Gemma3 model.""" 2 | 3 | import jax.numpy as jnp 4 | import numpy as np 5 | from flax import nnx 6 | 7 | from jaxgarden.models.gemma3 import ( 8 | Gemma3Attention, 9 | Gemma3Config, 10 | Gemma3DecoderLayer, 11 | Gemma3ForCausalLM, 12 | Gemma3MLP, 13 | Gemma3RMSNorm, 14 | Gemma3RotaryEmbedding, 15 | ) 16 | 17 | 18 | def test_gemma3_config(): 19 | """Test Gemma3Config initialization and validation.""" 20 | config = Gemma3Config() 21 | assert config.num_attention_heads % config.num_key_value_heads == 0 22 | assert config.head_dim == 256 # Default head_dim 23 | 24 | 25 | def test_gemma3_rms_norm(): 26 | """Test RMSNorm layer.""" 27 | rng = nnx.Rngs(0) 28 | dim = 32 29 | batch_size = 2 30 | seq_len = 4 31 | 32 | norm = Gemma3RMSNorm(dim, eps=1e-6, rngs=rng) 33 | x = jnp.ones((batch_size, seq_len, dim), dtype=jnp.float32) 34 | out = norm(x) 35 | 36 | assert out.shape == x.shape 37 | # Output should be normalized 38 | variance = jnp.mean(jnp.square(out), axis=-1) 39 | np.testing.assert_allclose(variance, 1.0, rtol=1e-5) 40 | 41 | 42 | def test_gemma3_rotary_embedding(): 43 | """Test RotaryEmbedding module.""" 44 | dim = 32 45 | batch_size = 2 46 | seq_len = 4 47 | num_heads = 2 48 | 49 | rope = Gemma3RotaryEmbedding(dim=dim, max_position_embeddings=8192) 50 | x = jnp.ones((batch_size, seq_len, num_heads, dim), dtype=jnp.float32) 51 | position_ids = jnp.arange(seq_len, dtype=jnp.int32)[None, :] # [1, seq_len] 52 | position_ids = jnp.broadcast_to(position_ids, (batch_size, seq_len)) 53 | 54 | out = rope(x, position_ids) 55 | assert out.shape == (batch_size, seq_len, num_heads, dim) 56 | -------------------------------------------------------------------------------- /jaxgarden/models/__init__.py: -------------------------------------------------------------------------------- 1 | from jaxgarden.models.base import BaseConfig, BaseModel 2 | from jaxgarden.models.gemma2 import ( 3 | Gemma2Attention, 4 | Gemma2Config, 5 | Gemma2ForCausalLM, 6 | Gemma2MLP, 7 | Gemma2RMSNorm, 8 | Gemma2RotaryEmbedding, 9 | ) 10 | from jaxgarden.models.gemma3 import ( 11 | Gemma3Attention, 12 | Gemma3Config, 13 | Gemma3ForCausalLM, 14 | Gemma3MLP, 15 | Gemma3RMSNorm, 16 | Gemma3RotaryEmbedding, 17 | ) 18 | from jaxgarden.models.generation_utils import GenerationMixin 19 | from jaxgarden.models.llama import ( 20 | LlamaAttention, 21 | LlamaConfig, 22 | LlamaForCausalLM, 23 | LlamaMLP, 24 | LlamaRMSNorm, 25 | LlamaRotaryEmbedding, 26 | LlamaTransformerBlock, 27 | ) 28 | from jaxgarden.models.modernbert import ( 29 | ModernBertAttention, 30 | ModernBertEmbeddings, 31 | ModernBERTEncoder, 32 | ModernBERTForMaskedLM, 33 | ModernBertLayer, 34 | ModernBertMLP, 35 | ) 36 | 37 | __all__ = [ 38 | "BaseConfig", 39 | "BaseModel", 40 | "Gemma2Attention", 41 | "Gemma2Config", 42 | "Gemma2ForCausalLM", 43 | "Gemma2MLP", 44 | "Gemma2RMSNorm", 45 | "Gemma2RotaryEmbedding", 46 | "Gemma3Attention", 47 | "Gemma3Config", 48 | "Gemma3ForCausalLM", 49 | "Gemma3MLP", 50 | "Gemma3RMSNorm", 51 | "Gemma3RotaryEmbedding", 52 | "GenerationMixin", 53 | "LlamaAttention", 54 | "LlamaConfig", 55 | "LlamaForCausalLM", 56 | "LlamaMLP", 57 | "LlamaRMSNorm", 58 | "LlamaRotaryEmbedding", 59 | "LlamaTransformerBlock", 60 | "ModernBERTEncoder", 61 | "ModernBERTForMaskedLM", 62 | "ModernBertAttention", 63 | "ModernBertEmbeddings", 64 | "ModernBertLayer", 65 | "ModernBertMLP", 66 | ] 67 | -------------------------------------------------------------------------------- /tests/models/test_llama.py: -------------------------------------------------------------------------------- 1 | """Tests for LLama model.""" 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from flax import nnx 6 | 7 | from jaxgarden.models.llama import LlamaConfig, LlamaForCausalLM 8 | 9 | 10 | def test_llama_initialization(): 11 | """Test that the LLama model can be properly initialized.""" 12 | # Set configuration for a tiny model for testing 13 | config = LlamaConfig( 14 | dim=32, 15 | n_layers=2, 16 | n_heads=4, 17 | n_kv_heads=2, 18 | head_dim=8, 19 | intermediate_size=64, 20 | vocab_size=100, 21 | ) 22 | 23 | # Initialize the model 24 | key = jax.random.PRNGKey(0) 25 | rngs = nnx.Rngs(params=key) 26 | 27 | model = LlamaForCausalLM(config, rngs=rngs) 28 | 29 | # Verify the model was created with expected structure 30 | assert len(model.layers) == 2 31 | assert isinstance(model, LlamaForCausalLM) 32 | 33 | 34 | def test_llama_inference(): 35 | """Test that the LLama model can run inference end-to-end.""" 36 | # Set configuration for a tiny model for testing 37 | config = LlamaConfig( 38 | dim=32, 39 | n_layers=2, 40 | n_heads=4, 41 | n_kv_heads=2, 42 | head_dim=8, 43 | intermediate_size=64, 44 | vocab_size=100, 45 | ) 46 | 47 | # Initialize the model 48 | key = jax.random.PRNGKey(0) 49 | rngs = nnx.Rngs(params=key) 50 | 51 | model = LlamaForCausalLM(config, rngs=rngs) 52 | 53 | # Create sample input 54 | batch_size = 1 55 | seq_len = 4 56 | input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) 57 | position_ids = jnp.arange(seq_len)[None, :] 58 | 59 | # Run forward pass 60 | logits = model(input_ids, position_ids) 61 | 62 | # Check output shape 63 | assert logits.shape == (batch_size, seq_len, config.vocab_size) 64 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | """Configuration file for the Sphinx documentation builder.""" 2 | 3 | import os 4 | import sys 5 | 6 | # Add the project root directory to the Python path 7 | sys.path.insert(0, os.path.abspath("..")) 8 | 9 | # Project information 10 | project = "JAXgarden" 11 | copyright = "2025, JAXgarden Contributors" 12 | author = "JAXgarden Contributors" 13 | 14 | # The full version, including alpha/beta/rc tags 15 | release = "0.1.0" 16 | 17 | # list of Sphinx extension module names 18 | extensions = [ 19 | "sphinx.ext.autodoc", 20 | "sphinx.ext.napoleon", 21 | "sphinx.ext.viewcode", 22 | "sphinx.ext.intersphinx", 23 | "myst_parser", 24 | ] 25 | 26 | # Add any paths that contain templates here 27 | templates_path = ["_templates"] 28 | 29 | # List of patterns, relative to source directory, that match files and 30 | # directories to ignore when looking for source files 31 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 32 | 33 | # The theme to use for HTML and HTML Help pages 34 | html_theme = "sphinx_rtd_theme" 35 | 36 | # Add any paths that contain custom static files 37 | html_static_path = ["_static"] 38 | 39 | # Intersphinx configuration 40 | intersphinx_mapping = { 41 | "python": ("https://docs.python.org/3", None), 42 | "jax": ("https://jax.readthedocs.io/en/latest/", None), 43 | "flax": ("https://flax.readthedocs.io/en/latest/", None), 44 | } 45 | 46 | # Napoleon settings 47 | napoleon_google_docstring = True 48 | napoleon_numpy_docstring = True 49 | napoleon_include_init_with_doc = True 50 | napoleon_include_private_with_doc = False 51 | napoleon_include_special_with_doc = True 52 | napoleon_use_admonition_for_examples = False 53 | napoleon_use_admonition_for_notes = False 54 | napoleon_use_admonition_for_references = False 55 | napoleon_use_ivar = False 56 | napoleon_use_param = True 57 | napoleon_use_rtype = True 58 | napoleon_type_aliases = None 59 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "jaxgarden" 7 | version = "0.2.0" 8 | description = "High-performance model implementations in JAX" 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | license = {file = "LICENSE"} 12 | authors = [ 13 | {name = "JAXgarden Contributors"} 14 | ] 15 | classifiers = [ 16 | "Development Status :: 3 - Alpha", 17 | "Intended Audience :: Developers", 18 | "Intended Audience :: Science/Research", 19 | "License :: OSI Approved :: MIT License", 20 | "Programming Language :: Python :: 3", 21 | "Programming Language :: Python :: 3.10", 22 | "Programming Language :: Python :: 3.11", 23 | "Programming Language :: Python :: 3.12", 24 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 25 | ] 26 | dependencies = [ 27 | "jax==0.5.1", 28 | "flax>=0.10.4", 29 | "safetensors>=0.5.3", 30 | "huggingface-hub>=0.30.2", 31 | "orbax>=0.1.9", 32 | "jinja2>=3.1.6", 33 | "tokenizers>=0.21.1", 34 | ] 35 | 36 | [project.optional-dependencies] 37 | dev = [ 38 | "pytest>=7.0.0", 39 | "pytest-cov>=4.0.0", 40 | "ruff>=0.1.0", 41 | "mypy>=1.0.0", 42 | "pre-commit", 43 | "sphinx", 44 | "sphinx-rtd-theme", 45 | "myst-parser", 46 | "torch", 47 | "torchvision", 48 | "torchaudio", 49 | "transformers", 50 | ] 51 | 52 | [project.urls] 53 | "Homepage" = "https://github.com/ml-gde/jax-layers" 54 | "Bug Tracker" = "https://github.com/ml-gde/jax-layers/issues" 55 | 56 | [tool.setuptools] 57 | packages = ["jaxgarden"] 58 | 59 | [tool.ruff] 60 | line-length = 100 61 | target-version = "py310" 62 | 63 | [tool.ruff.lint] 64 | select = ["E", "F", "I", "N", "W", "B", "UP", "C4", "PT", "RUF", "SIM", "TID"] 65 | ignore = [] 66 | 67 | [tool.mypy] 68 | python_version = "3.10" 69 | warn_return_any = true 70 | warn_unused_configs = true 71 | disallow_untyped_defs = true 72 | disallow_incomplete_defs = true 73 | disable_error_code = "var-annotated" 74 | 75 | [[tool.mypy.overrides]] 76 | module = ["tests.*", "examples.*"] 77 | disallow_untyped_defs = false 78 | disallow_incomplete_defs = false 79 | -------------------------------------------------------------------------------- /.devcontainer/README.md: -------------------------------------------------------------------------------- 1 | # Development Container for JAX Layers 2 | 3 | This directory contains configuration files for setting up a development container with JAX and CUDA support, which is especially useful for Windows users where JAX doesn't natively support CUDA. 4 | 5 | ## Prerequisites 6 | 7 | To use this development container, you need: 8 | 9 | 1. [Docker Desktop](https://www.docker.com/products/docker-desktop/) installed and configured with WSL 2 backend 10 | 2. [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) installed 11 | 3. [Visual Studio Code](https://code.visualstudio.com/) with the [Remote - Containers](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) extension 12 | 13 | ## GPU Support 14 | 15 | The container is configured to use all available GPUs. Make sure your NVIDIA drivers are up-to-date and that Docker has access to your GPUs. 16 | 17 | ## Usage 18 | 19 | 1. Open the project in Visual Studio Code 20 | 2. Click on the green icon in the bottom-left corner of VS Code 21 | 3. Select "Reopen in Container" from the menu 22 | 4. Wait for the container to build and start (this may take a while the first time) 23 | 24 | Once the container is running, you'll have a fully configured development environment with: 25 | 26 | - Python 3.10 27 | - CUDA 12.2 with cuDNN 9 28 | - JAX with CUDA support 29 | - All dependencies from pyproject.toml 30 | 31 | ## Dependency Management 32 | 33 | The container installs dependencies directly from your project's `pyproject.toml` file using the `pip install -e '.[dev]'` command, ensuring consistency between your development environment and the container. 34 | 35 | ## Customization 36 | 37 | You can customize the container by modifying: 38 | 39 | - `devcontainer.json`: VS Code settings, extensions, and container configuration 40 | - `Dockerfile`: Base image, dependencies, and environment setup 41 | 42 | ## Troubleshooting 43 | 44 | If you encounter issues with GPU access: 45 | 46 | 1. Verify that Docker Desktop is configured to use WSL 2 47 | 2. Check that NVIDIA Container Toolkit is properly installed 48 | 3. Ensure your NVIDIA drivers are up-to-date 49 | 4. Run `nvidia-smi` in WSL to verify GPU access 50 | 5. Check Docker logs for any error messages related to GPU access 51 | -------------------------------------------------------------------------------- /examples/gemma3_inference_example.py: -------------------------------------------------------------------------------- 1 | """Example script demonstrating Gemma3 model usage.""" 2 | 3 | import jax.numpy as jnp 4 | from flax import nnx 5 | from transformers import AutoTokenizer # type: ignore 6 | 7 | from jaxgarden.models.gemma3 import Gemma3Config, Gemma3ForCausalLM 8 | 9 | 10 | def main(): 11 | """Run Gemma3 inference example.""" 12 | # Initialize model with correct configuration 13 | print("Initializing model...") 14 | config = Gemma3Config( 15 | vocab_size=262_208, 16 | hidden_size=2304, 17 | intermediate_size=9216, 18 | num_hidden_layers=26, 19 | num_attention_heads=8, 20 | num_key_value_heads=4, 21 | head_dim=256, 22 | rope_theta=1_000_000.0, 23 | rope_local_base_freq=10_000.0, 24 | max_position_embeddings=131_072, 25 | sliding_window=4096, 26 | sliding_window_pattern=6, 27 | hidden_activation="gelu_pytorch_tanh", 28 | # Optional RoPE scaling for longer sequences 29 | rope_scaling={ 30 | "rope_type": "linear", 31 | "factor": 2.0, # Enables processing sequences twice the original length 32 | }, 33 | ) 34 | rng = nnx.Rngs(0) 35 | model = Gemma3ForCausalLM(config, rngs=rng) 36 | 37 | # Load tokenizer 38 | print("Loading tokenizer...") 39 | tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") 40 | 41 | # Prepare input 42 | prompt = "Write a short story about a robot learning to paint:" 43 | print(f"\nPrompt: {prompt}") 44 | 45 | # Tokenize input 46 | input_ids = tokenizer(prompt, return_tensors="np").input_ids 47 | input_ids = jnp.array(input_ids) 48 | 49 | # Generate text 50 | print("\nGenerating response...") 51 | max_new_tokens = 100 52 | eos_token_id = config.eos_token_id 53 | 54 | # Initialize cache for faster generation 55 | cache = None 56 | generated = input_ids 57 | 58 | for _ in range(max_new_tokens): 59 | # Get logits and updated cache 60 | logits, cache = model( 61 | generated, 62 | use_cache=True, # Enable KV caching 63 | deterministic=True, # No dropout during inference 64 | ) 65 | # Get next token (use argmax for simplicity) 66 | next_token = jnp.argmax(logits[:, -1, :], axis=-1) 67 | # Check if we hit the end of sequence 68 | if next_token[0] == eos_token_id: 69 | break 70 | # Append next token 71 | generated = jnp.concatenate([generated, next_token[:, None]], axis=1) 72 | 73 | # Decode generated text 74 | generated_text = tokenizer.decode(generated[0], skip_special_tokens=True) 75 | print(f"\nGenerated text:\n{generated_text}") 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /jaxgarden/__init__.py: -------------------------------------------------------------------------------- 1 | """JAXgarden - High-performance neural network layers for JAX.""" 2 | 3 | from jaxgarden.attention.multi_head_attention import MultiHeadAttention 4 | from jaxgarden.functional.attention import dot_product_attention 5 | from jaxgarden.models.base import BaseConfig, BaseModel 6 | from jaxgarden.models.gemma2 import ( 7 | Gemma2Attention, 8 | Gemma2Config, 9 | Gemma2ForCausalLM, 10 | Gemma2MLP, 11 | Gemma2RMSNorm, 12 | Gemma2RotaryEmbedding, 13 | ) 14 | from jaxgarden.models.gemma3 import ( 15 | Gemma3Attention, 16 | Gemma3Config, 17 | Gemma3ForCausalLM, 18 | Gemma3MLP, 19 | Gemma3RMSNorm, 20 | Gemma3RotaryEmbedding, 21 | ) 22 | from jaxgarden.models.generation_utils import GenerationMixin 23 | from jaxgarden.models.llama import ( 24 | LlamaAttention, 25 | LlamaConfig, 26 | LlamaForCausalLM, 27 | LlamaMLP, 28 | LlamaRMSNorm, 29 | LlamaRotaryEmbedding, 30 | LlamaTransformerBlock, 31 | ) 32 | from jaxgarden.models.modernbert import ( 33 | ModernBertAttention, 34 | ModernBertEmbeddings, 35 | ModernBERTEncoder, 36 | ModernBERTForMaskedLM, 37 | ModernBertLayer, 38 | ModernBertMLP, 39 | ) 40 | from jaxgarden.models.t5 import ( 41 | T5MLP, 42 | T5Attention, 43 | T5Block, 44 | T5Config, 45 | T5CrossAttention, 46 | T5ForCausalLM, 47 | T5LayerNorm, 48 | T5SelfAttention, 49 | T5Stack, 50 | ) 51 | from jaxgarden.tokenization import Tokenizer # type: ignore 52 | 53 | __all__ = [ 54 | "T5MLP", 55 | # Base classes 56 | "BaseConfig", 57 | "BaseModel", 58 | # Gemma Models 59 | "Gemma2Attention", 60 | "Gemma2Config", 61 | "Gemma2ForCausalLM", 62 | "Gemma2MLP", 63 | "Gemma2RMSNorm", 64 | "Gemma2RotaryEmbedding", 65 | # Gemma3 Models 66 | "Gemma3Attention", 67 | "Gemma3Config", 68 | "Gemma3ForCausalLM", 69 | "Gemma3MLP", 70 | "Gemma3RMSNorm", 71 | "Gemma3RotaryEmbedding", 72 | # Mixins 73 | "GenerationMixin", 74 | # Llama Models 75 | "LlamaAttention", 76 | "LlamaConfig", 77 | "LlamaForCausalLM", 78 | "LlamaMLP", 79 | "LlamaRMSNorm", 80 | "LlamaRotaryEmbedding", 81 | "LlamaTransformerBlock", 82 | "ModernBERTEncoder", 83 | "ModernBERTForMaskedLM", 84 | "ModernBertAttention", 85 | "ModernBertEmbeddings", 86 | "ModernBertLayer", 87 | "ModernBertMLP", 88 | # Attention modules 89 | "MultiHeadAttention", 90 | # T5 Models 91 | "T5Attention", 92 | "T5Block", 93 | "T5Config", 94 | "T5CrossAttention", 95 | "T5ForCausalLM", 96 | "T5LayerNorm", 97 | "T5SelfAttention", 98 | "T5Stack", 99 | # tokenization 100 | "Tokenizer", 101 | # Functional interfaces 102 | "dot_product_attention", 103 | ] 104 | 105 | __version__ = "0.2.0" 106 | -------------------------------------------------------------------------------- /.devcontainer/verify_jax_cuda.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Script to verify that JAX with CUDA is working correctly. 4 | Run this script after the container is built to confirm GPU access. 5 | """ 6 | 7 | import re 8 | import subprocess 9 | import time 10 | 11 | import jax 12 | import jax.numpy as jnp 13 | 14 | 15 | def get_cuda_version(): 16 | """Get the CUDA version from nvcc.""" 17 | try: 18 | result = subprocess.run(["nvcc", "--version"], capture_output=True, text=True) 19 | version_match = re.search(r"release (\d+\.\d+)", result.stdout) 20 | if version_match: 21 | return version_match.group(1) 22 | return "Unknown" 23 | except Exception: 24 | return "Unknown" 25 | 26 | 27 | def main(): 28 | print("\n" + "=" * 50) 29 | print("JAX CUDA Verification Script") 30 | print("=" * 50) 31 | 32 | # Check JAX version 33 | print(f"JAX version: {jax.__version__}") 34 | 35 | # Check CUDA version 36 | cuda_version = get_cuda_version() 37 | print(f"CUDA version: {cuda_version}") 38 | 39 | # Check available devices 40 | print("\nAvailable devices:") 41 | for i, device in enumerate(jax.devices()): 42 | print(f" Device {i}: {device}") 43 | 44 | # Check if GPU is available 45 | gpu_available = any(d.platform == "gpu" for d in jax.devices()) 46 | print(f"\nGPU available: {gpu_available}") 47 | 48 | if not gpu_available: 49 | print("\n⚠️ No GPU devices found! JAX is not using CUDA.") 50 | print("Please check your installation and GPU configuration.") 51 | return 52 | 53 | # Run a simple benchmark 54 | print("\nRunning simple matrix multiplication benchmark...") 55 | 56 | # Create large matrices 57 | n = 5000 58 | print(f"Creating {n}x{n} matrices...") 59 | 60 | # CPU benchmark 61 | with jax.devices("cpu")[0]: 62 | x_cpu = jnp.ones((n, n)) 63 | y_cpu = jnp.ones((n, n)) 64 | 65 | # Warm-up 66 | _ = jnp.dot(x_cpu, y_cpu) 67 | jax.block_until_ready(_) 68 | 69 | # Benchmark 70 | start = time.time() 71 | result_cpu = jnp.dot(x_cpu, y_cpu) 72 | jax.block_until_ready(result_cpu) 73 | cpu_time = time.time() - start 74 | 75 | # GPU benchmark 76 | with jax.devices("gpu")[0]: 77 | x_gpu = jnp.ones((n, n)) 78 | y_gpu = jnp.ones((n, n)) 79 | 80 | # Warm-up 81 | _ = jnp.dot(x_gpu, y_gpu) 82 | jax.block_until_ready(_) 83 | 84 | # Benchmark 85 | start = time.time() 86 | result_gpu = jnp.dot(x_gpu, y_gpu) 87 | jax.block_until_ready(result_gpu) 88 | gpu_time = time.time() - start 89 | 90 | # Print results 91 | print(f"\nCPU time: {cpu_time:.4f} seconds") 92 | print(f"GPU time: {gpu_time:.4f} seconds") 93 | print(f"Speedup: {cpu_time / gpu_time:.2f}x") 94 | 95 | if cpu_time > gpu_time: 96 | print("\n✅ GPU is faster than CPU! JAX with CUDA is working correctly.") 97 | else: 98 | print("\n⚠️ GPU is not faster than CPU. Something might be wrong with the CUDA setup.") 99 | 100 | print("\n" + "=" * 50) 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /.cursor/rules/guide.mdc: -------------------------------------------------------------------------------- 1 | --- 2 | description: Guidelines for using jaxgarden 3 | globs: 4 | alwaysApply: true 5 | --- 6 | The `jaxgarden` project provides a library for building and utilizing neural network models, primarily focused on **Transformer architectures**, using *JAX* and *Flax NNX*. It establishes a structured framework through abstract base classes: `BaseModel` defines the core model interface, including state management (parameters, mutable states via `nnx.State`), checkpointing (*Serialization*) with Orbax, and a standardized mechanism for importing and converting weights from Hugging Face's *Safetensors* format. `BaseConfig` offers a consistent way to manage model hyperparameters (*Configuration Management*). 7 | 8 | The library implements specific models like `LlamaForCausalLM` (a decoder-only model) and `ModernBERTForMaskedLM` (an encoder model), showcasing modern techniques. Key architectural components are modularized: 9 | - `Attention Mechanism`: Provides multi-head attention, potentially leveraging hardware acceleration like *Flash Attention* via cuDNN backend selection in `dot_product_attention`. 10 | - `Rotary Position Embeddings (RoPE)`: Implements relative position encoding within the attention mechanism itself, used by both Llama and ModernBERT variants. 11 | 12 | Functionality includes: 13 | - **Text Generation**: The `GenerationMixin` adds autoregressive text generation capabilities to causal models like Llama, supporting various sampling strategies (temperature, top-k, top-p, min-p) and efficient implementation via `jax.lax.scan`. 14 | - **Tokenization**: A `Tokenizer` class wraps the Hugging Face `tokenizers` library, providing a JAX-friendly API for encoding/decoding text and applying chat templates. 15 | - **Interoperability**: Facilitates using pretrained models from the Hugging Face Hub. 16 | 17 | Overall, `jaxgarden` aims for high performance and *modularity*, enabling researchers and developers to work with modern NLP models within the JAX ecosystem, promoting *code reuse* and *extensibility* through its base classes and focused components. 18 | 19 | 20 | **Source Repository:** [https://github.com/ml-gde/jaxgarden.git](https://github.com/ml-gde/jaxgarden.git) 21 | 22 | ```mermaid 23 | flowchart TD 24 | A0["BaseModel"] 25 | A1["BaseConfig"] 26 | A2["Tokenizer"] 27 | A3["Attention Mechanism (MultiHeadAttention / dot_product_attention)"] 28 | A4["GenerationMixin"] 29 | A5["LlamaForCausalLM"] 30 | A6["ModernBERTForMaskedLM"] 31 | A7["Rotary Position Embeddings (RoPE)"] 32 | A0 -- "Manages" --> A1 33 | A5 -- "Inherits From" --> A0 34 | A6 -- "Inherits From" --> A0 35 | A5 -- "Uses Config" --> A1 36 | A6 -- "Uses Config" --> A1 37 | A4 -- "Uses Token IDs" --> A2 38 | A5 -- "Uses Attention" --> A3 39 | A6 -- "Uses Attention" --> A3 40 | A5 -- "Inherits From" --> A4 41 | A4 -- "Calls Model" --> A5 42 | A5 -- "Uses RoPE" --> A7 43 | A6 -- "Uses RoPE" --> A7 44 | ``` 45 | 46 | ## Chapters 47 | 48 | [Tokenizer](tokenizer.mdc) 49 | [BaseModel](basemodel.mdc) 50 | [BaseConfig](baseconfig.mdc) 51 | [LlamaForCausalLM](llamaforcausallm.mdc) 52 | [GenerationMixin](generationmixin.mdc) 53 | [ModernBERTForMaskedLM](modernbertformaskedlm.mdc) 54 | [Attention Mechanism (MultiHeadAttention / dot_product_attention)](attention_mechanism__multiheadattention___dot_product_attention_.mdc) 55 | [Rotary Position Embeddings (RoPE)](rotary_position_embeddings__rope_.mdc) 56 | 57 | 58 | --- 59 | 60 | Generated by [Rules for AI](https://github.com/altaidevorg/rules-for-ai) -------------------------------------------------------------------------------- /jaxgarden/functional/attention.py: -------------------------------------------------------------------------------- 1 | """Attention functions for JAX Layers. 2 | 3 | This module provides attention functions that are compatible with both JAX and Flax NNX. 4 | """ 5 | 6 | from typing import Literal 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | 11 | 12 | def dot_product_attention( 13 | query: jnp.ndarray, 14 | key: jnp.ndarray, 15 | value: jnp.ndarray, 16 | bias: jnp.ndarray | None = None, 17 | mask: jnp.ndarray | None = None, 18 | broadcast_dropout: bool = True, 19 | dropout_rng: jnp.ndarray | None = None, 20 | dropout_rate: float = 0.0, 21 | deterministic: bool = False, 22 | dtype: jnp.dtype | None = None, 23 | precision: jax.lax.Precision | str | None = None, 24 | implementation: Literal["xla", "cudnn", "flash"] | None = None, 25 | module: object | None = None, 26 | ) -> jnp.ndarray: 27 | """Computes dot-product attention with optional Flash Attention support. 28 | 29 | This function provides a wrapper around JAX's dot_product_attention with the option 30 | to use Flash Attention when available. It follows the Flax NNX interface while 31 | allowing the use of different implementations through the implementation parameter. 32 | 33 | Args: 34 | query: queries for calculating attention with shape of 35 | `[batch..., q_length, num_heads, qk_depth_per_head]`. 36 | key: keys for calculating attention with shape of 37 | `[batch..., kv_length, num_heads, qk_depth_per_head]`. 38 | value: values to be used in attention with shape of 39 | `[batch..., kv_length, num_heads, v_depth_per_head]`. 40 | bias: bias for the attention weights. This should be broadcastable to 41 | the shape [batch..., num_heads, q_length, kv_length]. 42 | mask: mask for the attention weights. This should be broadcastable to 43 | the shape [batch..., num_heads, q_length, kv_length]. 44 | broadcast_dropout: bool: use a broadcasted dropout along batch dims. 45 | dropout_rng: JAX PRNGKey: to be used for dropout. 46 | dropout_rate: dropout rate. 47 | deterministic: bool, deterministic or not (to apply dropout). 48 | dtype: the dtype of the computation (default: infer from inputs). 49 | precision: numerical precision of the computation. 50 | implementation: which implementation to use. Options are: 51 | - "xla": Use XLA's default implementation 52 | - "cudnn": Use cuDNN's Flash Attention implementation (if available) 53 | - "flash": Alias for "cudnn" 54 | - None: Automatically select the best available implementation 55 | module: the Module that will sow the attention weights. 56 | 57 | Returns: 58 | Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`. 59 | """ 60 | # Map "flash" to "cudnn" for clarity 61 | if implementation == "flash": 62 | implementation = "cudnn" 63 | 64 | # Convert mask to bias if needed 65 | if mask is not None: 66 | # In JAX, mask=True means keep the value, while in Flax mask=False means mask out 67 | # So we need to invert the mask and convert to a bias 68 | if bias is None: 69 | bias = jnp.where(mask, 0.0, -1e10) 70 | else: 71 | bias = jnp.where(mask, bias, -1e10) 72 | 73 | # Call JAX's dot_product_attention with the implementation parameter 74 | return jax.nn.dot_product_attention( 75 | query=query, 76 | key=key, 77 | value=value, 78 | bias=bias, 79 | # JAX-specific parameters 80 | scale=None, # Use default scaling 81 | is_causal=False, # We handle causal masking through the bias/mask 82 | query_seq_lengths=None, 83 | key_value_seq_lengths=None, 84 | local_window_size=None, 85 | implementation=implementation, 86 | ) 87 | -------------------------------------------------------------------------------- /tests/functional/test_attention.py: -------------------------------------------------------------------------------- 1 | """Tests for the jax_layers/functional/attention.py implementations.""" 2 | 3 | from unittest.mock import MagicMock 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | from jaxgarden.functional import dot_product_attention 9 | 10 | 11 | def test_dot_product_attention(): 12 | """Test that the dot_product_attention function works with different implementations.""" 13 | batch_size = 2 14 | seq_len = 16 15 | num_heads = 4 16 | head_dim = 8 17 | 18 | key = jax.random.PRNGKey(0) 19 | key1, key2, key3 = jax.random.split(key, 3) 20 | 21 | query = jax.random.normal(key1, (batch_size, seq_len, num_heads, head_dim)) 22 | key_tensor = jax.random.normal(key2, (batch_size, seq_len, num_heads, head_dim)) 23 | value = jax.random.normal(key3, (batch_size, seq_len, num_heads, head_dim)) 24 | 25 | # Execute using default (XLA) implementation 26 | output_default = dot_product_attention(query, key_tensor, value) 27 | output_xla = dot_product_attention(query, key_tensor, value, implementation="xla") 28 | 29 | # Verify that the output shapes and values are identical 30 | assert output_default.shape == output_xla.shape 31 | assert jnp.allclose(output_default, output_xla, rtol=1e-5, atol=1e-5) 32 | 33 | 34 | def test_dot_product_attention_flash_mapping(monkeypatch): 35 | """Test that when implementation is 'flash', it is remapped to 'cudnn' 36 | before calling jax.nn.dot_product_attention.""" 37 | 38 | mock_fn = MagicMock(return_value=jnp.array(0)) 39 | monkeypatch.setattr(jax.nn, "dot_product_attention", mock_fn) 40 | 41 | key = jax.random.PRNGKey(42) 42 | key1, key2, key3 = jax.random.split(key, 3) 43 | 44 | query = jax.random.normal(key1, (2, 16, 4, 8)) 45 | key_tensor = jax.random.normal(key2, (2, 16, 4, 8)) 46 | value = jax.random.normal(key3, (2, 16, 4, 8)) 47 | 48 | _ = dot_product_attention(query, key_tensor, value, implementation="flash") 49 | 50 | mock_fn.assert_called_once_with( 51 | query=query, 52 | key=key_tensor, 53 | value=value, 54 | bias=None, 55 | scale=None, 56 | is_causal=False, 57 | query_seq_lengths=None, 58 | key_value_seq_lengths=None, 59 | local_window_size=None, 60 | implementation="cudnn", 61 | ) 62 | 63 | 64 | def test_dot_product_attention_mask_without_bias(): 65 | """Test the mask branch when bias is None: 66 | bias should be created as jnp.where(mask, 0.0, -1e10).""" 67 | batch_size = 2 68 | seq_len = 16 69 | num_heads = 4 70 | head_dim = 8 71 | 72 | key = jax.random.PRNGKey(1) 73 | key1, key2, key3 = jax.random.split(key, 3) 74 | 75 | query = jax.random.normal(key1, (batch_size, seq_len, num_heads, head_dim)) 76 | key_tensor = jax.random.normal(key2, (batch_size, seq_len, num_heads, head_dim)) 77 | value = jax.random.normal(key3, (batch_size, seq_len, num_heads, head_dim)) 78 | 79 | # Create a causal mask with boolean values. 80 | mask = jnp.tril(jnp.ones((batch_size, 1, seq_len, seq_len), dtype=bool)) 81 | 82 | # Compute output using the mask (without providing bias explicitly). 83 | output_with_mask = dot_product_attention(query, key_tensor, value, mask=mask) 84 | 85 | # Manually create the bias as per conversion: if mask is True -> 0.0, else -1e10. 86 | bias_manual = jnp.where(mask, 0.0, -1e10) 87 | output_with_manual_bias = dot_product_attention(query, key_tensor, value, bias=bias_manual) 88 | 89 | assert output_with_mask.shape == output_with_manual_bias.shape 90 | assert jnp.allclose(output_with_mask, output_with_manual_bias, rtol=1e-5, atol=1e-5) 91 | 92 | 93 | def test_dot_product_attention_mask_with_bias(): 94 | """Test the mask branch when bias is provided: 95 | bias should be converted using jnp.where(mask, bias, -1e10).""" 96 | batch_size = 2 97 | seq_len = 16 98 | num_heads = 4 99 | head_dim = 8 100 | 101 | key = jax.random.PRNGKey(2) 102 | key1, key2, key3 = jax.random.split(key, 3) 103 | 104 | query = jax.random.normal(key1, (batch_size, seq_len, num_heads, head_dim)) 105 | key_tensor = jax.random.normal(key2, (batch_size, seq_len, num_heads, head_dim)) 106 | value = jax.random.normal(key3, (batch_size, seq_len, num_heads, head_dim)) 107 | 108 | # Create a causal mask with boolean values. 109 | mask = jnp.tril(jnp.ones((batch_size, 1, seq_len, seq_len), dtype=bool)) 110 | # Provide a custom bias (e.g., constant value 5.0). 111 | custom_bias = jnp.full(mask.shape, 5.0) 112 | 113 | # Expected bias after conversion: jnp.where(mask, custom_bias, -1e10) 114 | bias_manual = jnp.where(mask, custom_bias, -1e10) 115 | 116 | output_with_mask_bias = dot_product_attention( 117 | query, 118 | key_tensor, 119 | value, 120 | mask=mask, 121 | bias=custom_bias, 122 | ) 123 | output_with_manual_bias = dot_product_attention(query, key_tensor, value, bias=bias_manual) 124 | 125 | assert output_with_mask_bias.shape == output_with_manual_bias.shape 126 | assert jnp.allclose(output_with_mask_bias, output_with_manual_bias, rtol=1e-5, atol=1e-5) 127 | -------------------------------------------------------------------------------- /tests/attention/test_RoPEMultiHeadAttention.py: -------------------------------------------------------------------------------- 1 | """Tests for the RoPEMultiHeadAttention class.""" 2 | 3 | import flax.linen as nn 4 | import jax 5 | import jax.numpy as jnp 6 | import pytest 7 | 8 | from jaxgarden.attention.rope_multi_head_attention import ( 9 | RoPEMultiHeadAttention, 10 | apply_rotary_pos_emb, 11 | precompute_rotary_embeddings, 12 | rotate_half, 13 | ) 14 | 15 | 16 | def test_rotate_half(): 17 | """Tests the rotate_half function.""" 18 | key = jax.random.PRNGKey(0) 19 | x = jax.random.normal(key, (2, 4, 6, 8)) # batch, seq, heads, dim 20 | rotated_x = rotate_half(x) 21 | 22 | assert rotated_x.shape == x.shape 23 | # Check specific values after rotation 24 | x1 = x[..., ::2] 25 | x2 = x[..., 1::2] 26 | expected = jnp.concatenate((-x2, x1), axis=-1) 27 | assert jnp.allclose(rotated_x, expected) 28 | 29 | 30 | def test_precompute_rotary_embeddings(): 31 | """Tests the precompute_rotary_embeddings function.""" 32 | seq_len = 16 33 | head_dim = 8 34 | base = 10000.0 35 | 36 | cos_emb, sin_emb = precompute_rotary_embeddings(seq_len, head_dim, base) 37 | 38 | assert cos_emb.shape == (1, seq_len, 1, head_dim) 39 | assert sin_emb.shape == (1, seq_len, 1, head_dim) 40 | 41 | # Check properties - e.g., cos^2 + sin^2 = 1 42 | assert jnp.allclose(cos_emb**2 + sin_emb**2, jnp.ones_like(cos_emb), atol=1e-6) 43 | 44 | # Check different base value 45 | cos_emb_b2, sin_emb_b2 = precompute_rotary_embeddings(seq_len, head_dim, base=500.0) 46 | assert not jnp.allclose(cos_emb, cos_emb_b2) 47 | 48 | # Test with odd head_dim (should raise error) 49 | with pytest.raises(ValueError, match="head_dim must be even"): 50 | precompute_rotary_embeddings(seq_len, head_dim=7) 51 | 52 | 53 | def test_apply_rotary_pos_emb(): 54 | """Tests the apply_rotary_pos_emb function.""" 55 | key = jax.random.PRNGKey(1) 56 | batch, seq_len, num_heads, head_dim = 2, 16, 4, 8 57 | x = jax.random.normal(key, (batch, seq_len, num_heads, head_dim)) 58 | 59 | cos_emb, sin_emb = precompute_rotary_embeddings(seq_len, head_dim) 60 | 61 | rotated_x = apply_rotary_pos_emb(x, cos_emb, sin_emb) 62 | 63 | assert rotated_x.shape == x.shape 64 | # Applying RoPE again should not give the original x (unless pos=0, which isn't the whole seq) 65 | assert not jnp.allclose(rotated_x, x) 66 | 67 | 68 | # --- Test RoPEMultiHeadAttention Module --- 69 | 70 | 71 | @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) 72 | def test_rope_mha_forward_pass(dtype): 73 | """Tests the forward pass of RoPEMultiHeadAttention.""" 74 | key = jax.random.PRNGKey(2) 75 | batch_size = 2 76 | seq_len = 16 77 | num_heads = 4 78 | head_dim = 8 79 | embed_dim = num_heads * head_dim 80 | 81 | x = jax.random.normal(key, (batch_size, seq_len, embed_dim), dtype=dtype) 82 | 83 | rope_mha = RoPEMultiHeadAttention(num_heads=num_heads, head_dim=head_dim, dtype=dtype) 84 | params = rope_mha.init(key, x)["params"] 85 | output = rope_mha.apply({"params": params}, x) 86 | 87 | assert output.shape == (batch_size, seq_len, embed_dim) 88 | assert output.dtype == dtype 89 | 90 | 91 | def test_rope_mha_masking(): 92 | """Tests causal masking in RoPEMultiHeadAttention.""" 93 | key = jax.random.PRNGKey(3) 94 | batch_size = 1 95 | seq_len = 4 96 | num_heads = 2 97 | head_dim = 4 98 | embed_dim = num_heads * head_dim 99 | 100 | x = jax.random.normal(key, (batch_size, seq_len, embed_dim)) 101 | # Create a causal mask (True means masked) 102 | causal_mask = nn.make_causal_mask(x[:, :, 0]) # Gets (batch, seq, seq) or (1, seq, seq) 103 | 104 | rope_mha = RoPEMultiHeadAttention(num_heads=num_heads, head_dim=head_dim) 105 | params = rope_mha.init(key, x, causal_mask)["params"] 106 | 107 | # Apply without mask 108 | output_unmasked = rope_mha.apply({"params": params}, x) 109 | 110 | # Apply with mask 111 | output_masked = rope_mha.apply({"params": params}, x, mask=causal_mask) 112 | 113 | # Basic check: outputs should differ if mask has an effect 114 | assert not jnp.allclose(output_unmasked, output_masked, atol=1e-5) 115 | 116 | # More rigorous check (requires inspecting attention weights, omitted for brevity) 117 | 118 | 119 | def test_rope_mha_errors(): 120 | """Tests error conditions for RoPEMultiHeadAttention.""" 121 | key = jax.random.PRNGKey(4) 122 | rope_mha_odd_dim = RoPEMultiHeadAttention(num_heads=8, head_dim=7) 123 | x_dummy_odd = jax.random.normal(key, (2, 16, 8 * 7)) 124 | # Test with odd head_dim (should raise error during initialization/setup) 125 | with pytest.raises(ValueError, match=r"head_dim \(\d+\) must be even"): 126 | rope_mha_odd_dim.init(key, x_dummy_odd) 127 | 128 | # Test with mismatched embed_dim (should raise error during forward pass / init) 129 | rope_mha = RoPEMultiHeadAttention(num_heads=4, head_dim=8) # Expects embed_dim 32 130 | x_mismatch = jax.random.normal(key, (2, 16, 100)) # Incorrect embed_dim 131 | 132 | with pytest.raises(ValueError, match=r"embed_dim \(\d+\) must equal"): 133 | rope_mha.init(key, x_mismatch) 134 | -------------------------------------------------------------------------------- /examples/multi_head_attention_example.py: -------------------------------------------------------------------------------- 1 | """Example demonstrating the use of MultiHeadAttention with different implementations.""" 2 | 3 | import time 4 | 5 | import flax.nnx as nnx 6 | import jax 7 | import jax.numpy as jnp 8 | 9 | from jaxgarden.attention import MultiHeadAttention 10 | 11 | 12 | def benchmark_attention(implementation=None, batch_size=2, seq_len=1024, num_heads=8, head_dim=64): 13 | """Benchmark MultiHeadAttention with different implementations.""" 14 | print(f"\nBenchmarking MultiHeadAttention with implementation={implementation}") 15 | print(f"Input shape (b, s, h, d) = ({batch_size}, {seq_len}, {num_heads}, {head_dim})") 16 | 17 | # Create random input data 18 | key = jax.random.PRNGKey(0) 19 | key1, key2 = jax.random.split(key) 20 | x = jax.random.normal(key1, (batch_size, seq_len, num_heads * head_dim)) 21 | 22 | # Create a causal attention mask 23 | mask = jnp.tril(jnp.ones((batch_size, 1, seq_len, seq_len))) 24 | 25 | # Create the MultiHeadAttention module 26 | attention = MultiHeadAttention( 27 | num_heads=num_heads, 28 | in_features=num_heads * head_dim, 29 | implementation=implementation, 30 | rngs=nnx.Rngs(key2), 31 | ) 32 | 33 | # Compile the forward pass 34 | @nnx.jit 35 | def forward(x, mask): 36 | return attention(x, mask=mask) 37 | 38 | # Warm-up 39 | _ = forward(x, mask) 40 | 41 | # Benchmark 42 | start_time = time.time() 43 | num_runs = 10 44 | for _ in range(num_runs): 45 | output = forward(x, mask) 46 | output.block_until_ready() # Ensure computation is complete 47 | end_time = time.time() 48 | 49 | avg_time = (end_time - start_time) / num_runs 50 | print(f"Average time per run: {avg_time:.6f} seconds") 51 | 52 | return output, avg_time 53 | 54 | 55 | def compare_implementations(): 56 | """Compare different implementations of MultiHeadAttention.""" 57 | print("Comparing different implementations of MultiHeadAttention") 58 | 59 | # Parameters for the comparison 60 | batch_size = 2 61 | seq_len = 1024 62 | num_heads = 8 63 | head_dim = 64 64 | 65 | # Benchmark each implementation 66 | implementations = [None, "xla", "cudnn"] 67 | outputs = {} 68 | times = {} 69 | 70 | for impl in implementations: 71 | try: 72 | output, avg_time = benchmark_attention( 73 | implementation=impl, 74 | batch_size=batch_size, 75 | seq_len=seq_len, 76 | num_heads=num_heads, 77 | head_dim=head_dim, 78 | ) 79 | outputs[impl] = output 80 | times[impl] = avg_time 81 | except Exception as e: 82 | print(f"Error with implementation {impl}: {e}") 83 | 84 | # Compare outputs for correctness 85 | print("\nComparing outputs for correctness:") 86 | for impl1 in outputs: 87 | for impl2 in outputs: 88 | if impl1 != impl2: 89 | max_diff = jnp.max(jnp.abs(outputs[impl1] - outputs[impl2])) 90 | print(f"Max difference between {impl1} and {impl2}: {max_diff}") 91 | 92 | # Compare performance 93 | if len(times) > 1: 94 | print("\nPerformance comparison:") 95 | baseline = times[None] # Default implementation as baseline 96 | for impl, avg_time in times.items(): 97 | if impl is not None: 98 | speedup = baseline / avg_time 99 | print(f"{impl} implementation: {speedup:.2f}x speedup over default") 100 | 101 | 102 | def main(): 103 | """Run the example.""" 104 | # Enable JAX logging to see which implementation is being used 105 | jax.config.update("jax_log_compiles", True) 106 | 107 | # Compare different implementations 108 | compare_implementations() 109 | 110 | # Show an example of using MultiHeadAttention with a specific implementation 111 | print("\nExample of using MultiHeadAttention with Flash Attention:") 112 | key = jax.random.PRNGKey(0) 113 | key1, key2 = jax.random.split(key) 114 | 115 | # Create input data 116 | batch_size = 2 117 | seq_len = 128 118 | num_heads = 8 119 | head_dim = 64 120 | hidden_dim = num_heads * head_dim 121 | 122 | x = jax.random.normal(key1, (batch_size, seq_len, hidden_dim)) 123 | 124 | # Create a causal attention mask 125 | mask = jnp.tril(jnp.ones((batch_size, 1, seq_len, seq_len))) 126 | 127 | # Create the MultiHeadAttention module with Flash Attention 128 | attention = MultiHeadAttention( 129 | num_heads=num_heads, 130 | in_features=hidden_dim, 131 | implementation="flash", # Use Flash Attention (alias for "cudnn") 132 | rngs=nnx.Rngs(key2), 133 | ) 134 | 135 | # Apply the attention 136 | output = attention(x, mask=mask) 137 | 138 | print(f"Input shape: {x.shape}") 139 | print(f"Output shape: {output.shape}") 140 | print(f"Output mean: {jnp.mean(output)}") 141 | print(f"Output std: {jnp.std(output)}") 142 | 143 | 144 | if __name__ == "__main__": 145 | main() 146 | -------------------------------------------------------------------------------- /tests/attention/test_multi_head_attention.py: -------------------------------------------------------------------------------- 1 | """Tests for the MultiHeadAttention class.""" 2 | 3 | import flax.nnx as nnx 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | from jaxgarden.attention import MultiHeadAttention 8 | 9 | 10 | def test_multi_head_attention(): 11 | """Test that the MultiHeadAttention module works with different implementations.""" 12 | # Set up parameters 13 | batch_size = 2 14 | seq_len = 16 15 | num_heads = 4 16 | head_dim = 8 17 | hidden_dim = num_heads * head_dim 18 | 19 | # Create random input data 20 | key = jax.random.PRNGKey(0) 21 | key1, key2, key3 = jax.random.split(key, 3) 22 | x = jax.random.normal(key1, (batch_size, seq_len, hidden_dim)) 23 | 24 | # Create the MultiHeadAttention modules with different implementations 25 | attention_default = MultiHeadAttention( 26 | num_heads=num_heads, 27 | in_features=hidden_dim, 28 | decode=False, 29 | rngs=nnx.Rngs(key2), 30 | ) 31 | 32 | attention_xla = MultiHeadAttention( 33 | num_heads=num_heads, 34 | in_features=hidden_dim, 35 | decode=False, 36 | implementation="xla", 37 | rngs=nnx.Rngs(key3), 38 | ) 39 | 40 | # Apply the attention with different implementations 41 | output_default = attention_default(x, x, x) 42 | output_xla = attention_xla(x, x, x) 43 | 44 | # Check that the outputs are similar (not exactly the same due to different initializations) 45 | assert output_default.shape == output_xla.shape, f"{type(output_xla)}, {len(output_xla)}" 46 | 47 | 48 | def test_multi_head_attention_shape(): 49 | """Test that the MultiHeadAttention module produces the expected output shape.""" 50 | # Set up parameters 51 | batch_size = 2 52 | seq_len = 16 53 | num_heads = 4 54 | head_dim = 8 55 | hidden_dim = num_heads * head_dim 56 | 57 | # Create random input data 58 | key = jax.random.PRNGKey(0) 59 | key1, key2 = jax.random.split(key) 60 | x = jax.random.normal(key1, (batch_size, seq_len, hidden_dim)) 61 | 62 | # Create the MultiHeadAttention module 63 | attention = MultiHeadAttention( 64 | num_heads=num_heads, 65 | in_features=hidden_dim, 66 | decode=False, 67 | rngs=nnx.Rngs(key2), 68 | ) 69 | 70 | # Apply the attention 71 | output = attention(x, x, x) 72 | 73 | # Check the output shape 74 | assert output.shape == (batch_size, seq_len, hidden_dim) 75 | 76 | 77 | def test_multi_head_attention_mask(): 78 | """Test that the MultiHeadAttention module correctly applies attention masks.""" 79 | # Set up parameters 80 | batch_size = 2 81 | seq_len = 16 82 | num_heads = 4 83 | head_dim = 8 84 | hidden_dim = num_heads * head_dim 85 | 86 | # Create random input data 87 | key = jax.random.PRNGKey(0) 88 | key1, key2 = jax.random.split(key) 89 | x = jax.random.normal(key1, (batch_size, seq_len, hidden_dim)) 90 | 91 | # Create a causal attention mask 92 | mask = jnp.tril(jnp.ones((batch_size, 1, seq_len, seq_len))) 93 | 94 | # Create the MultiHeadAttention module 95 | attention = MultiHeadAttention( 96 | num_heads=num_heads, 97 | in_features=hidden_dim, 98 | decode=False, 99 | rngs=nnx.Rngs(key2), 100 | ) 101 | 102 | # Apply the attention with and without mask 103 | output_with_mask = attention(x, x, x, mask=mask) 104 | output_without_mask = attention(x, x, x) 105 | 106 | # Check that the outputs are different 107 | assert not jnp.allclose(output_with_mask, output_without_mask) 108 | 109 | 110 | def test_multi_head_attention_self_attention(): 111 | """Test that the MultiHeadAttention module works as a self-attention module.""" 112 | # Set up parameters 113 | batch_size = 2 114 | seq_len = 16 115 | num_heads = 4 116 | head_dim = 8 117 | hidden_dim = num_heads * head_dim 118 | 119 | # Create random input data 120 | key = jax.random.PRNGKey(0) 121 | key1, key2 = jax.random.split(key) 122 | x = jax.random.normal(key1, (batch_size, seq_len, hidden_dim)) 123 | 124 | # Create the MultiHeadAttention module 125 | attention = MultiHeadAttention( 126 | num_heads=num_heads, 127 | in_features=hidden_dim, 128 | decode=False, 129 | rngs=nnx.Rngs(key2), 130 | ) 131 | 132 | # Apply the attention as self-attention 133 | output = attention(x) 134 | 135 | # Check the output shape 136 | assert output.shape == (batch_size, seq_len, hidden_dim) 137 | 138 | 139 | def test_multi_head_attention_cross_attention(): 140 | """Test that the MultiHeadAttention module works as a cross-attention module.""" 141 | # Set up parameters 142 | batch_size = 2 143 | q_seq_len = 16 144 | kv_seq_len = 32 145 | num_heads = 4 146 | head_dim = 8 147 | hidden_dim = num_heads * head_dim 148 | 149 | # Create random input data 150 | key = jax.random.PRNGKey(0) 151 | key1, key2, key3, key4 = jax.random.split(key, 4) 152 | q = jax.random.normal(key1, (batch_size, q_seq_len, hidden_dim)) 153 | k = jax.random.normal(key2, (batch_size, kv_seq_len, hidden_dim)) 154 | v = jax.random.normal(key3, (batch_size, kv_seq_len, hidden_dim)) 155 | 156 | # Create the MultiHeadAttention module 157 | attention = MultiHeadAttention( 158 | num_heads=num_heads, 159 | in_features=hidden_dim, 160 | decode=False, 161 | rngs=nnx.Rngs(key4), 162 | ) 163 | 164 | # Apply the attention as cross-attention 165 | output = attention(q, k, v) 166 | 167 | # Check the output shape 168 | assert output.shape == (batch_size, q_seq_len, hidden_dim) 169 | -------------------------------------------------------------------------------- /jaxgarden/attention/multi_head_attention.py: -------------------------------------------------------------------------------- 1 | """MultiHeadAttention implementation with Flash Attention support for JAX Layers.""" 2 | 3 | from collections.abc import Callable 4 | from typing import Any, Literal 5 | 6 | import flax.nnx as nnx 7 | import jax 8 | import jax.numpy as jnp 9 | from flax.nnx.nn.linear import default_kernel_init 10 | 11 | from jaxgarden.functional.attention import dot_product_attention 12 | 13 | 14 | class MultiHeadAttention(nnx.MultiHeadAttention): 15 | """Multi-head attention with support for Flash Attention. 16 | 17 | This class extends Flax NNX's MultiHeadAttention to support Flash Attention 18 | through JAX's dot_product_attention implementation parameter. 19 | 20 | Example usage: 21 | 22 | ```python 23 | import jax 24 | import jax.numpy as jnp 25 | import flax.nnx as nnx 26 | from jax_layers.attention import MultiHeadAttention 27 | 28 | # Create a MultiHeadAttention module with Flash Attention support 29 | attention = MultiHeadAttention( 30 | num_heads=8, 31 | in_features=512, 32 | implementation="cudnn", # Use cuDNN's Flash Attention if available 33 | rngs=nnx.Rngs(0), 34 | ) 35 | 36 | # Initialize parameters 37 | key = jax.random.PRNGKey(0) 38 | x = jax.random.normal(key, (2, 128, 512)) # (batch, seq_length, hidden_dim) 39 | 40 | # Create a causal attention mask 41 | mask = jnp.tril(jnp.ones((2, 1, 128, 128))) # (batch, 1, q_len, kv_len) 42 | 43 | # Apply the model 44 | output = attention(x, mask=mask) 45 | ``` 46 | """ 47 | 48 | def __init__( 49 | self, 50 | num_heads: int, 51 | in_features: int, 52 | qkv_features: int | None = None, 53 | out_features: int | None = None, 54 | *, 55 | dtype: jnp.dtype | None = None, 56 | param_dtype: jnp.dtype = jnp.float32, 57 | broadcast_dropout: bool = True, 58 | dropout_rate: float = 0.0, 59 | deterministic: bool | None = None, 60 | precision: jax.lax.Precision | str | None = None, 61 | kernel_init: Callable = default_kernel_init, 62 | out_kernel_init: Callable | None = None, 63 | bias_init: Callable = nnx.initializers.zeros, 64 | out_bias_init: Callable | None = None, 65 | use_bias: bool = True, 66 | attention_fn: Callable | None = None, 67 | decode: bool | None = None, 68 | normalize_qk: bool = False, 69 | qkv_dot_general: Callable | None = None, 70 | out_dot_general: Callable | None = None, 71 | qkv_dot_general_cls: type | None = None, 72 | out_dot_general_cls: type | None = None, 73 | implementation: Literal["xla", "cudnn", "flash"] | None = None, 74 | rngs: nnx.Rngs, 75 | ): 76 | """Initialize the MultiHeadAttention module. 77 | 78 | Args: 79 | num_heads: number of attention heads. 80 | in_features: int or tuple with number of input features. 81 | qkv_features: dimension of the key, query, and value. 82 | out_features: dimension of the last projection. 83 | dtype: the dtype of the computation. 84 | param_dtype: the dtype passed to parameter initializers. 85 | broadcast_dropout: bool: use a broadcasted dropout along batch dims. 86 | dropout_rate: dropout rate. 87 | deterministic: if false, the attention weight is masked randomly using dropout. 88 | precision: numerical precision of the computation. 89 | kernel_init: initializer for the kernel of the Dense layers. 90 | out_kernel_init: initializer for the kernel of the output Dense layer. 91 | bias_init: initializer for the bias of the Dense layers. 92 | out_bias_init: initializer for the bias of the output Dense layer. 93 | use_bias: bool: whether pointwise QKVO dense transforms use bias. 94 | attention_fn: dot_product_attention or compatible function. 95 | decode: whether to prepare and use an autoregressive cache. 96 | normalize_qk: should QK normalization be applied. 97 | qkv_dot_general: dot_general function for QKV projection. 98 | out_dot_general: dot_general function for output projection. 99 | qkv_dot_general_cls: dot_general class for QKV projection. 100 | out_dot_general_cls: dot_general class for output projection. 101 | implementation: which implementation to use for attention. Options are: 102 | - "xla": Use XLA's default implementation 103 | - "cudnn": Use cuDNN's Flash Attention implementation (if available) 104 | - "flash": Alias for "cudnn" 105 | - None: Automatically select the best available implementation 106 | rngs: random number generator keys. 107 | """ 108 | # Create a custom attention function that uses our dot_product_attention 109 | # with the specified implementation 110 | if attention_fn is None: 111 | 112 | def custom_attention_fn( 113 | query: jnp.ndarray, 114 | key: jnp.ndarray, 115 | value: jnp.ndarray, 116 | bias: jnp.ndarray | None = None, 117 | mask: jnp.ndarray | None = None, 118 | **kwargs: Any, 119 | ) -> jnp.ndarray: 120 | return dot_product_attention( 121 | query=query, 122 | key=key, 123 | value=value, 124 | bias=bias, 125 | mask=mask, 126 | implementation=implementation, 127 | **kwargs, 128 | ) 129 | 130 | attention_fn = custom_attention_fn 131 | 132 | # Initialize the parent class with our custom attention function 133 | super().__init__( 134 | num_heads=num_heads, 135 | in_features=in_features, 136 | qkv_features=qkv_features, 137 | out_features=out_features, 138 | dtype=dtype, 139 | param_dtype=param_dtype, 140 | broadcast_dropout=broadcast_dropout, 141 | dropout_rate=dropout_rate, 142 | deterministic=deterministic, 143 | precision=precision, 144 | kernel_init=kernel_init, 145 | out_kernel_init=out_kernel_init, 146 | bias_init=bias_init, 147 | out_bias_init=out_bias_init, 148 | use_bias=use_bias, 149 | attention_fn=attention_fn, 150 | decode=decode, 151 | normalize_qk=normalize_qk, 152 | qkv_dot_general=qkv_dot_general, 153 | out_dot_general=out_dot_general, 154 | qkv_dot_general_cls=qkv_dot_general_cls, 155 | out_dot_general_cls=out_dot_general_cls, 156 | rngs=rngs, 157 | ) 158 | -------------------------------------------------------------------------------- /tests/models/test_t5.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import pytest 4 | from flax import nnx 5 | 6 | from jaxgarden.models.t5 import ( 7 | T5MLP, 8 | T5Attention, 9 | T5Block, 10 | T5Config, 11 | T5CrossAttention, 12 | T5ForCausalLM, 13 | T5LayerNorm, 14 | T5SelfAttention, 15 | T5Stack, 16 | ) 17 | 18 | 19 | @pytest.fixture(scope="module") 20 | def tiny_config(): 21 | return T5Config( 22 | hidden_size=128, 23 | dim_kv=64, 24 | num_layers=2, 25 | vocab_size=100, 26 | dtype=jnp.float32, 27 | ) 28 | 29 | 30 | @pytest.fixture(scope="module") 31 | def dummy_rngs(): 32 | return nnx.Rngs(params=jax.random.PRNGKey(0)) 33 | 34 | 35 | @pytest.mark.parametrize(("dim", "dtype"), [(1024, jnp.float32), (2048, jnp.float16)]) 36 | def test_t5_layer_norm(dim, dtype): 37 | layer_norm = T5LayerNorm(dim=dim, dtype=dtype, rngs=nnx.Rngs(0)) 38 | x = jnp.ones((1, dim)) 39 | output = layer_norm(x) 40 | 41 | assert output.shape == (1, dim) 42 | assert output.dtype == dtype 43 | 44 | 45 | @pytest.mark.parametrize( 46 | ("dim", "intermediate_dim", "dtype"), [(512, 2048, jnp.float32), (1024, 4096, jnp.float16)] 47 | ) 48 | def test_t5_mlp(dim, intermediate_dim, dtype): 49 | # small sequence length for testing 50 | seq_len = 128 51 | 52 | mlp = T5MLP(dim=dim, intermediate_dim=intermediate_dim, dtype=dtype, rngs=nnx.Rngs(0)) 53 | x = jnp.ones((1, seq_len, dim)) 54 | output = mlp(x) 55 | 56 | assert output.shape == (1, seq_len, dim) 57 | assert output.dtype == dtype 58 | 59 | 60 | @pytest.mark.parametrize( 61 | ("hidden_size", "dim_kv", "dtype", "masked"), 62 | [(768, 64, jnp.float32, True), (1024, 128, jnp.float16, False)], 63 | ) 64 | def test_t5_attention(hidden_size, dim_kv, dtype, masked): 65 | # small sequence length for testing 66 | seq_len = 128 67 | 68 | attention = T5Attention( 69 | hidden_size=hidden_size, 70 | dim_kv=dim_kv, 71 | num_heads=hidden_size // dim_kv, 72 | relative_attention_num_buckets=32, 73 | relative_attention_max_distance=128, 74 | dtype=dtype, 75 | rngs=nnx.Rngs(0), 76 | ) 77 | 78 | hidden_states = jnp.ones((1, seq_len, hidden_size)) 79 | attention_mask = jnp.ones((1, seq_len)) if masked else None 80 | output = attention(hidden_states, attention_mask=attention_mask) 81 | 82 | assert output.shape == (1, seq_len, hidden_size) 83 | assert output.dtype == dtype 84 | 85 | 86 | @pytest.mark.parametrize( 87 | ("hidden_size", "dim_kv", "dtype", "masked"), 88 | [(768, 64, jnp.float32, True), (1024, 128, jnp.float16, False)], 89 | ) 90 | def test_t5_self_attention(hidden_size, dim_kv, dtype, masked): 91 | # small sequence length for testing 92 | seq_len = 128 93 | 94 | config = T5Config(hidden_size=hidden_size, dim_kv=dim_kv, dtype=dtype) 95 | self_attention = T5SelfAttention(config=config, rngs=nnx.Rngs(0)) 96 | 97 | hidden_states = jnp.ones((1, seq_len, hidden_size)) 98 | attention_mask = jnp.ones((1, seq_len)) if masked else None 99 | 100 | output = self_attention(hidden_states, attention_mask=attention_mask) 101 | 102 | assert output.shape == (1, seq_len, hidden_size) 103 | assert output.dtype == dtype 104 | 105 | 106 | @pytest.mark.parametrize( 107 | ("hidden_size", "dim_kv", "dtype", "masked"), 108 | [(768, 64, jnp.float32, True), (1024, 128, jnp.float16, False)], 109 | ) 110 | def test_t5_cross_attention(hidden_size, dim_kv, dtype, masked): 111 | # small sequence length for testing 112 | seq_len = 128 113 | 114 | config = T5Config(hidden_size=hidden_size, dim_kv=dim_kv, dtype=dtype) 115 | cross_attention = T5CrossAttention(config=config, rngs=nnx.Rngs(0)) 116 | 117 | hidden_states = jnp.ones((1, seq_len, hidden_size)) 118 | attention_mask = jnp.ones((1, seq_len)) if masked else None 119 | 120 | output = cross_attention(hidden_states, attention_mask=attention_mask) 121 | 122 | assert output.shape == (1, seq_len, hidden_size) 123 | assert output.dtype == dtype 124 | 125 | 126 | @pytest.mark.parametrize( 127 | ("hidden_size", "dim_kv", "dtype", "masked", "causal"), 128 | [ 129 | (768, 64, jnp.float32, True, True), 130 | (768, 64, jnp.float32, False, False), 131 | (1024, 128, jnp.float16, True, False), 132 | (1024, 128, jnp.float16, False, True), 133 | ], 134 | ) 135 | def test_t5_block(hidden_size, dim_kv, dtype, masked, causal): 136 | # small sequence length for testing 137 | seq_len = 128 138 | 139 | config = T5Config(hidden_size=hidden_size, dim_kv=dim_kv, dtype=dtype) 140 | block = T5Block(config=config, causal=causal, rngs=nnx.Rngs(0)) 141 | 142 | hidden_states = jnp.ones((1, seq_len, hidden_size)) 143 | attention_mask = jnp.ones((1, seq_len)) if masked else None 144 | 145 | output = block(hidden_states, attention_mask=attention_mask, deterministic=True) 146 | 147 | assert output.shape == (1, seq_len, hidden_size) 148 | assert output.dtype == dtype 149 | 150 | 151 | @pytest.mark.parametrize("dtype", [jnp.float32, jnp.float16]) 152 | @pytest.mark.parametrize("causal", [True, False]) 153 | @pytest.mark.parametrize("masked", [True, False]) 154 | @pytest.mark.parametrize("with_encoder", [True, False]) 155 | def test_t5_stack(dtype, causal, masked, with_encoder): 156 | hidden_size = 768 157 | dim_kv = 64 158 | vocab_size = 100 159 | num_layers = 2 160 | seq_len = 8 161 | batch_size = 2 162 | 163 | config = T5Config(hidden_size=hidden_size, dim_kv=dim_kv, dtype=dtype) 164 | embed_tokens = nnx.Embed( 165 | num_embeddings=vocab_size, 166 | features=hidden_size, 167 | dtype=dtype, 168 | rngs=nnx.Rngs(0), 169 | ) 170 | stack = T5Stack( 171 | config=config, 172 | embed_tokens=embed_tokens, 173 | num_layers=num_layers, 174 | causal=causal, 175 | rngs=nnx.Rngs(0), 176 | ) 177 | 178 | input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) 179 | attention_mask = jnp.ones((batch_size, seq_len), dtype=dtype) if masked else None 180 | 181 | encoder_hidden_states = ( 182 | jnp.ones((batch_size, seq_len, hidden_size), dtype=dtype) if with_encoder else None 183 | ) 184 | encoder_attention_mask = ( 185 | jnp.ones((batch_size, seq_len), dtype=dtype) if (with_encoder and masked) else None 186 | ) 187 | 188 | output = stack( 189 | input_ids, 190 | attention_mask=attention_mask, 191 | encoder_hidden_states=encoder_hidden_states, 192 | encoder_attention_mask=encoder_attention_mask, 193 | deterministic=True, 194 | ) 195 | 196 | assert output.shape == (batch_size, seq_len, hidden_size) 197 | assert output.dtype == dtype 198 | 199 | 200 | # --- Test Full Model --- 201 | def test_t5_for_causal_lm(tiny_config, dummy_rngs): 202 | # small values for testing 203 | seq_len = 8 204 | batch = 2 205 | 206 | model = T5ForCausalLM(config=tiny_config, rngs=dummy_rngs) 207 | 208 | input_ids = jnp.ones((batch, seq_len), dtype=jnp.int32) 209 | pos_ids = jnp.arange(seq_len)[None, :].repeat(batch, axis=0) 210 | attn_mask = None 211 | 212 | out = model(input_ids, pos_ids, attn_mask) 213 | 214 | assert out.shape == (batch, seq_len, tiny_config.vocab_size) 215 | assert out.dtype == tiny_config.dtype 216 | -------------------------------------------------------------------------------- /tests/models/test_gemma2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for Gemma2 model (jaxgarden/models/gemma2.py) 3 | """ 4 | 5 | import dataclasses 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | import numpy as np 10 | import pytest 11 | from flax import nnx 12 | 13 | from jaxgarden.models.gemma2 import ( 14 | Gemma2Attention, 15 | Gemma2Config, 16 | Gemma2ForCausalLM, 17 | Gemma2MLP, 18 | Gemma2RMSNorm, 19 | Gemma2RotaryEmbedding, 20 | ) 21 | 22 | 23 | # Helper: minimal config for fast tests 24 | @pytest.fixture(scope="module") # Use module scope for efficiency 25 | def tiny_config(): 26 | return Gemma2Config( 27 | vocab_size=128, 28 | hidden_size=32, 29 | intermediate_size=64, # Must be even for GeGLU 30 | num_hidden_layers=2, 31 | num_attention_heads=4, 32 | num_key_value_heads=2, 33 | head_dim=8, 34 | context_length=16, 35 | param_dtype=jnp.float32, 36 | dtype=jnp.float32, 37 | attn_logits_soft_cap=50.0, 38 | final_logit_soft_cap=30.0, 39 | ) 40 | 41 | 42 | @pytest.fixture(scope="module") 43 | def dummy_rngs(): 44 | return nnx.Rngs(params=jax.random.PRNGKey(0)) 45 | 46 | 47 | # --- Test Core Modules --- 48 | 49 | 50 | def test_rmsnorm_output(dummy_rngs): 51 | dim = 4 52 | norm = Gemma2RMSNorm(dim=dim, eps=1e-6, rngs=dummy_rngs) 53 | # Initialize dummy weights 54 | norm.weight = nnx.Param(jnp.ones((dim,), dtype=jnp.float32)) # Use ones for non-trivial test 55 | 56 | x = jnp.array([[0.1, 0.2, 0.3, 0.4]]) 57 | out = norm(x) 58 | expected = jnp.array([[0.7302919, 1.4605838, 2.1908758, 2.9211676]]) 59 | 60 | assert out.shape == (1, dim) 61 | assert jnp.allclose(out, expected, atol=1e-5) 62 | 63 | # Test with zero weights (as initialized in the module) 64 | norm.weight = nnx.Param(jnp.zeros((dim,), dtype=jnp.float32)) 65 | x_ones = jnp.ones((2, dim)) 66 | out_zeros_w = norm(x_ones) 67 | assert out_zeros_w.shape == (2, dim) 68 | # With input=ones and weight=zeros, output should be normalized input 69 | # (close to input / sqrt(mean(square(input)))) 70 | # For ones input, mean(square(input)) = 1, sqrt = 1, output = input / 1 = 1 71 | assert jnp.allclose(out_zeros_w, jnp.ones_like(x_ones), atol=1e-6) 72 | 73 | 74 | def test_rope_embedding(): 75 | dim = 4 76 | batch = 2 77 | seq = 1 78 | num_heads = 2 79 | head_dim = dim // num_heads 80 | rope = Gemma2RotaryEmbedding(dim=dim // num_heads) 81 | 82 | # Shape [B, L, N, H] - Use head_dim here! 83 | x = jnp.ones((batch, seq, num_heads, head_dim)) 84 | positions = jnp.array([[1], [0]]) # Example positions 85 | 86 | out = rope(x, positions) 87 | assert out.shape == x.shape 88 | 89 | # For pos=1, head=0: sin(1/100^(0/2))=sin(1)=0.841, cos(1)=0.540 90 | # RoPE applies rotations like [x0*c - x1*s, x0*s + x1*c] 91 | # Input is all 1s. 92 | # Expected head 0, pos 1: [1*c0-1*s0, 1*s0+1*c0] = [cos(1)-sin(1), cos(1)+sin(1)] 93 | expected_pos1 = jnp.array([-0.30116867, 1.38177329]) 94 | # For pos=0, sin(0)=0, cos(0)=1. Output should be unchanged [1, 1] 95 | expected_pos0 = jnp.ones((head_dim,)) 96 | 97 | # Check head 0 output 98 | np.testing.assert_allclose(out[0, 0, 0, :], expected_pos1, atol=1e-5) 99 | np.testing.assert_allclose(out[1, 0, 0, :], expected_pos0, atol=1e-5) 100 | 101 | 102 | def test_attention_shape_and_gqa(tiny_config, dummy_rngs): 103 | # Initialize Gemma2Attention layer 104 | attn = Gemma2Attention( 105 | layer_idx=0, 106 | config=tiny_config, 107 | attention_type="global", 108 | rngs=dummy_rngs, 109 | ) 110 | batch = 2 111 | seq = 8 112 | x = jnp.ones((batch, seq, tiny_config.hidden_size)) 113 | pos_ids = jnp.arange(seq)[None, :].repeat(batch, axis=0) 114 | mask = jnp.tril(jnp.ones((seq, seq), dtype=jnp.bool_))[None, None, :, :] 115 | mask = jnp.where(mask, 0.0, -jnp.inf) # Additive mask 116 | 117 | # Gemma2Attention.__call__ always returns (output, cache) 118 | out, _ = attn(x, pos_ids, attention_mask=mask, cache=None) 119 | assert out.shape == (batch, seq, tiny_config.hidden_size) 120 | 121 | # Test repeat_kv helper 122 | kv_heads = tiny_config.num_key_value_heads 123 | head_dim = tiny_config.head_dim 124 | num_attn_heads = tiny_config.num_attention_heads 125 | n_rep = tiny_config.num_attention_heads // kv_heads 126 | # Input shape should be (batch, kv_heads, seq, head_dim) 127 | hidden_kv = jnp.ones((batch, kv_heads, seq, head_dim)) 128 | repeated = attn._repeat_kv(hidden_kv, n_rep) 129 | # Expected output shape: (batch, num_attn_heads, seq, head_dim) 130 | assert repeated.shape == (batch, num_attn_heads, seq, head_dim) 131 | 132 | 133 | def test_attention_soft_cap(tiny_config, dummy_rngs): 134 | cap_value = 50.0 135 | config_with_cap = dataclasses.replace(tiny_config, attn_logits_soft_cap=cap_value) 136 | attn = Gemma2Attention( 137 | layer_idx=0, 138 | config=config_with_cap, 139 | attention_type="global", 140 | rngs=dummy_rngs, 141 | ) 142 | 143 | logits = jnp.array([-100.0, -10.0, 0.0, 10.0, 100.0]) 144 | capped_logits = attn.apply_soft_cap(logits, cap_value) 145 | 146 | expected = cap_value * jnp.tanh(logits / cap_value) 147 | np.testing.assert_allclose(capped_logits, expected, atol=1e-6) 148 | 149 | 150 | def test_mlp_geglu_shape(tiny_config, dummy_rngs): 151 | mlp = Gemma2MLP(config=tiny_config, rngs=dummy_rngs) 152 | batch = 2 153 | seq = 8 154 | x = jnp.ones((batch, seq, tiny_config.hidden_size)) 155 | out = mlp(x) 156 | assert out.shape == (batch, seq, tiny_config.hidden_size) 157 | 158 | # Test static geglu method 159 | intermediate_x = jnp.ones((batch, seq, tiny_config.intermediate_size * 2)) 160 | geglu_out = Gemma2MLP.geglu(intermediate_x) 161 | assert geglu_out.shape == (batch, seq, tiny_config.intermediate_size) 162 | 163 | 164 | # --- Test Decoder Layer --- 165 | 166 | 167 | def test_decoder_layer_structure(tiny_config, dummy_rngs): 168 | layer_idx = 0 169 | model = Gemma2ForCausalLM(config=tiny_config, rngs=dummy_rngs) 170 | decoder = model.layers[layer_idx] 171 | 172 | assert isinstance(decoder.pre_attn_norm, Gemma2RMSNorm) 173 | assert isinstance(decoder.attn, Gemma2Attention) 174 | assert isinstance(decoder.post_attn_norm, Gemma2RMSNorm) 175 | assert isinstance(decoder.pre_mlp_norm, Gemma2RMSNorm) 176 | assert isinstance(decoder.mlp, Gemma2MLP) 177 | assert isinstance(decoder.post_mlp_norm, Gemma2RMSNorm) 178 | 179 | # Check attention type alternation 180 | assert decoder.attn.attention_type == "global" 181 | layer_idx = 1 182 | decoder1 = model.layers[layer_idx] 183 | assert decoder1.attn.attention_type == "local" 184 | 185 | 186 | # --- Test Full Model --- 187 | 188 | 189 | def test_gemma2_init(tiny_config, dummy_rngs): 190 | model = Gemma2ForCausalLM(config=tiny_config, rngs=dummy_rngs) 191 | assert isinstance(model, Gemma2ForCausalLM) 192 | assert hasattr(model, "embed_tokens") 193 | assert hasattr(model, "layers") 194 | assert len(model.layers) == tiny_config.num_hidden_layers 195 | assert hasattr(model, "norm") 196 | 197 | 198 | def test_gemma2_forward_shape_dtype(tiny_config, dummy_rngs): 199 | model = Gemma2ForCausalLM(config=tiny_config, rngs=dummy_rngs) 200 | batch = 2 201 | seq = 8 202 | input_ids = jnp.ones((batch, seq), dtype=jnp.int32) 203 | pos_ids = jnp.arange(seq)[None, :].repeat(batch, axis=0) 204 | # Pass None for attention_mask; model should handle creation 205 | attn_mask = None 206 | 207 | out, cache = model(input_ids, pos_ids, attn_mask) 208 | 209 | assert out.shape == (batch, seq, tiny_config.vocab_size) 210 | assert out.dtype == tiny_config.dtype 211 | assert cache is not None # Cache returned in forward pass 212 | 213 | 214 | def test_final_logit_soft_cap(tiny_config, dummy_rngs): 215 | # Test with and without final soft cap 216 | config_no_cap = dataclasses.replace(tiny_config, final_logit_soft_cap=None) 217 | config_with_cap = dataclasses.replace(tiny_config, final_logit_soft_cap=30.0) 218 | 219 | model_no_cap = Gemma2ForCausalLM(config=config_no_cap, rngs=dummy_rngs) 220 | model_with_cap = Gemma2ForCausalLM(config=config_with_cap, rngs=dummy_rngs) 221 | 222 | nnx.update(model_with_cap, nnx.state(model_no_cap, nnx.Param)) 223 | 224 | batch = 1 225 | seq = 4 226 | input_ids = jnp.ones((batch, seq), dtype=jnp.int32) 227 | pos_ids = jnp.arange(seq)[None, :].repeat(batch, axis=0) 228 | attn_mask = jnp.ones((batch, seq), dtype=jnp.bool_) 229 | 230 | out, _ = model_with_cap(input_ids, pos_ids, attn_mask) 231 | final_logits = out[:, -1, :] 232 | 233 | assert out.shape == (batch, seq, tiny_config.vocab_size) 234 | assert jnp.max(jnp.abs(final_logits)) <= config_with_cap.final_logit_soft_cap + 1e-6 235 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JAXgarden 2 | 3 | [![Doc](https://github.com/ml-gde/jax-layers/actions/workflows/docs.yml/badge.svg)](https://github.com/ml-gde/jax-layers/actions/workflows/docs.yml) 4 | [![Tests](https://github.com/ml-gde/jax-layers/actions/workflows/tests.yml/badge.svg)](https://github.com/ml-gde/jax-layers/actions/workflows/tests.yml) 5 | 6 | ![Logo](./assets/logo.png) 7 | 8 | A reusable collection of high-performance neural network layers and models in JAX, aiming to match and exceed the capabilities available in the PyTorch ecosystem. 9 | 10 | ## Motivation 11 | 12 | JAXgarden was created to provide the JAX ecosystem with a comprehensive library of well-documented, thoroughly tested, and numerically accurate implementations of neural network layers and models. The project aims to: 13 | 14 | - Provide both functional APIs and Flax NNX wrappers for maximum flexibility 15 | - Ensure seamless integration with the broader JAX ecosystem, especially Flax 16 | - Facilitate easy upstreaming of implementations to core libraries 17 | - Maintain rigorous testing and documentation standards 18 | - Match or exceed the performance of equivalent PyTorch implementations 19 | 20 | Initially started within the ML GDE group, the project began with a high-performance MultiHeadAttention implementation supporting various attention backends, with plans to expand to more layers and models. 21 | 22 | ## Features 23 | 24 | - **MultiHeadAttention**: A Flax NNX-compatible implementation with support for different attention backends. 25 | - Supports JAX's native Flash Attention implementation through cuDNN 26 | - Seamlessly integrates with Flax NNX's module system 27 | - Provides a simple interface for switching between attention implementations 28 | 29 | ## Installation 30 | 31 | ```bash 32 | # Install from source 33 | git clone https://github.com/ml-gde/jax-layers.git 34 | cd jax-layers 35 | pip install -e . 36 | ``` 37 | 38 | ## Usage 39 | 40 | ### LLaMA inference 41 | 42 | ```python 43 | from jaxgarden import LlamaConfig, LlamaForCausalLM, Tokenizer 44 | from flax import nnx 45 | 46 | 47 | # HF repo id of the LLaMA variant that you want to use 48 | model_id = "meta-llama/Llama-3.2-1B" 49 | 50 | # initialize the LLaMA architecture 51 | config = LlamaConfig() 52 | model = LlamaForCausalLM(config, rngs=nnx.Rngs(0)) 53 | 54 | # This is a one-liner to download HF checkpoint from HuggingFace Hub, 55 | # convert it to jaxgarden format, 56 | # save it in an Orbax checkpoint, 57 | # and then remove the HF checkpoint. 58 | model.from_hf(model_id) 59 | 60 | # this works just like `transformers.AutoTokenizer`, 61 | # but without the dependency of the whole `transformers` library. 62 | # Instead, we simply extend `tokenizers` package and add some cnvenience code for JAX. 63 | tokenizer = Tokenizer.from_pretrained(model_id) 64 | 65 | text = "The meaning of life is" 66 | model_inputs = tokenizer.encode(text) 67 | output = model.generate(**model_inputs, max_length=20, do_sample=True) 68 | output_text = tokenizer.decode(output) 69 | print(output_text) 70 | ``` 71 | 72 | 73 | ### MultiHeadAttention Module (Flax NNX) 74 | 75 | ```python 76 | import jax 77 | import jax.numpy as jnp 78 | import flax.nnx as nnx 79 | from jaxgarden.attention import MultiHeadAttention 80 | 81 | # Create a MultiHeadAttention module with Flash Attention support 82 | attention = MultiHeadAttention( 83 | num_heads=8, 84 | in_features=512, 85 | implementation="cudnn", # Use cuDNN's Flash Attention if available 86 | rngs=nnx.Rngs(0), 87 | ) 88 | 89 | # Create input data 90 | key = jax.random.PRNGKey(0) 91 | x = jax.random.normal(key, (2, 128, 512)) # (batch, seq_length, hidden_dim) 92 | 93 | # Create a causal attention mask 94 | mask = jnp.tril(jnp.ones((2, 1, 128, 128))) # (batch, 1, q_len, kv_len) 95 | 96 | # Apply the model 97 | output = attention(x, mask=mask) 98 | ``` 99 | 100 | ### RoPEMultiHeadAttention Module (Flax NNX) 101 | 102 | ```python 103 | import jax 104 | import jax.numpy as jnp 105 | import flax.linen as nn 106 | from jaxgarden.attention.rope_multi_head_attention import RoPEMultiHeadAttention 107 | 108 | # 1. Setup 109 | key = jax.random.PRNGKey(0) 110 | batch_size, seq_len = 2, 16 111 | num_heads, head_dim = 4, 32 112 | embed_dim = num_heads * head_dim 113 | x = jnp.ones((batch_size, seq_len, embed_dim)) 114 | 115 | # 2. Instantiate Module 116 | attention = RoPEMultiHeadAttention(num_heads=num_heads, head_dim=head_dim) 117 | 118 | # 3. Initialize Parameters 119 | params = attention.init(key, x)['params'] 120 | 121 | # 4. Apply Module (Forward Pass) 122 | output = attention.apply({'params': params}, x) 123 | ``` 124 | 125 | ### Functional API 126 | 127 | #### Dot Product Attention with Implementation Selection 128 | 129 | ```python 130 | import jax 131 | import jax.numpy as jnp 132 | from jaxgarden.functional import dot_product_attention 133 | 134 | # Create random query, key, value tensors 135 | key = jax.random.PRNGKey(0) 136 | query = jax.random.normal(key, (2, 128, 8, 64)) # (batch, seq_len, heads, head_dim) 137 | key_tensor = jax.random.normal(key, (2, 128, 8, 64)) 138 | value = jax.random.normal(key, (2, 128, 8, 64)) 139 | 140 | # Create a causal attention mask 141 | mask = jnp.tril(jnp.ones((2, 1, 128, 128))) # (batch, 1, q_len, kv_len) 142 | 143 | # Apply dot product attention with Flash Attention implementation 144 | output = dot_product_attention( 145 | query=query, 146 | key=key_tensor, 147 | value=value, 148 | mask=mask, 149 | implementation="cudnn", # Use cuDNN's Flash Attention implementation 150 | ) 151 | ``` 152 | 153 | ## Development 154 | 155 | ### Setup 156 | 157 | 1. Please fork the repository to your account first. 158 | 2. Follow the instructions below. 159 | 160 | ```bash 161 | # Clone the repository 162 | git clone https://github.com/yourusername/jax-layers.git 163 | cd jax-layers 164 | 165 | # Install development dependencies 166 | pip install -e ".[dev]" 167 | ``` 168 | 169 | ### Pre-commit 170 | 171 | This project uses pre-commit hooks to ensure code quality and consistency. Pre-commit automatically runs linting and formatting tools (such as ruff) before each commit, helping to catch issues early. 172 | 173 | ```bash 174 | # Install Pre-commit Hooks 175 | pre-commit install 176 | 177 | # Run Pre-commit on All Files 178 | pre-commit run --all-files 179 | ``` 180 | 181 | Every time you attempt to commit, pre-commit automatically runs the configured hooks (e.g., ruff). If any issues are detected, the commit will be blocked until they are resolved. 182 | 183 | ### Testing 184 | 185 | The project maintains a comprehensive test suite to ensure correctness and numerical accuracy: 186 | 187 | ```bash 188 | # Run all tests 189 | pytest 190 | 191 | # Run tests with coverage 192 | pytest tests/ --cov=jaxgarden 193 | 194 | # Run specific test file 195 | pytest tests/test_multi_head_attention.py 196 | ``` 197 | 198 | ### Code Quality 199 | 200 | We maintain high code quality standards through automated checks: 201 | 202 | ```bash 203 | # Run linting 204 | ruff check . 205 | 206 | # Run type checking 207 | mypy jaxgarden 208 | 209 | # Run tests 210 | pytest 211 | ``` 212 | 213 | ### Documentation 214 | 215 | Documentation is automatically generated from docstrings: 216 | 217 | ```bash 218 | # Build documentation 219 | cd docs 220 | make html 221 | ``` 222 | 223 | ### Development Container (for Windows users) 224 | 225 | Since JAX doesn't support CUDA on Windows natively, we provide a development container configuration: 226 | 227 | 1. Install [Docker Desktop](https://www.docker.com/products/docker-desktop/) with WSL 2 backend 228 | 2. Install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) 229 | 3. Install [Visual Studio Code](https://code.visualstudio.com/) with the [Remote - Containers](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) extension 230 | 4. Open the project in VS Code 231 | 5. Click the green icon in the bottom-left corner and select "Reopen in Container" 232 | 233 | The container provides: 234 | 235 | - Python 3.10 236 | - CUDA 12.4 with cuDNN 9 237 | - JAX with CUDA support 238 | - All dependencies from your pyproject.toml 239 | 240 | See [.devcontainer/README.md](.devcontainer/README.md) for more details. 241 | 242 | ## Contributing 243 | 244 | Contributions are more than welcome! Whether it's: 245 | 246 | - Adding new layer implementations 247 | - Improving documentation 248 | - Adding tests 249 | - Reporting bugs 250 | - Suggesting improvements 251 | 252 | Please feel free to open issues and pull requests. 253 | 254 | ## License 255 | 256 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 257 | 258 | ## Acknowledgements 259 | 260 | **Google AI Developer Programs team supported this work by providing Google Cloud Credit.** 261 | 262 | - Thanks to the JAX and Flax teams for their excellent libraries. 263 | - Special thanks to the ML GDE group for initiating this project. 264 | -------------------------------------------------------------------------------- /jaxgarden/models/base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | from collections.abc import Iterator 5 | from dataclasses import dataclass, field 6 | from pathlib import Path 7 | from typing import Any 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | import orbax.checkpoint as ocp # type: ignore 12 | from flax import nnx 13 | from huggingface_hub import snapshot_download # type: ignore 14 | from safetensors import safe_open 15 | 16 | # Set up logging 17 | logger = logging.getLogger(__name__) 18 | 19 | DEFAULT_PARAMS_FILE = "jaxgarden_state" 20 | 21 | 22 | @dataclass 23 | class BaseConfig: 24 | """Base configuration for all the models implemented in the JAXgarden library. 25 | 26 | Each model implemented in JAXgarden should subclass this class for configuration management. 27 | """ 28 | 29 | seed: int = 42 30 | log_level: str = "info" 31 | extra: dict[str, Any] = field(default_factory=dict) 32 | 33 | def to_dict(self) -> dict[str, Any]: 34 | return self.__dict__ 35 | 36 | def update(self, **kwargs: dict) -> None: 37 | for k, v in kwargs.items(): 38 | if hasattr(self, k): 39 | setattr(self, k, v) 40 | else: 41 | self.extra[k] = v 42 | 43 | 44 | class BaseModel(nnx.Module): 45 | """Base class for all the models implemented in the JAXgarden library.""" 46 | 47 | def __init__( 48 | self, 49 | config: BaseConfig, 50 | *, 51 | dtype: jnp.dtype | None = None, 52 | param_dtype: jnp.dtype = jnp.float32, 53 | precision: jax.lax.Precision | str | None = None, 54 | rngs: nnx.Rngs, 55 | ): 56 | """Initialize the model. 57 | 58 | Args: 59 | config: config class for this model. 60 | dtype: Data type in which computation is performed. 61 | param_dtype: Data type in which params are stored. 62 | precision: Numerical precision. 63 | rngs: Random number generators for param initialization etc. 64 | """ 65 | self.config = config 66 | self.dtype = dtype 67 | self.param_dtype = param_dtype 68 | self.precision = precision 69 | self.rngs = rngs 70 | 71 | @property 72 | def state(self) -> nnx.State: 73 | """Splits state from the graph and returns it""" 74 | return nnx.split(self, nnx.Param, ...)[1] # type: ignore 75 | 76 | @property 77 | def state_dict(self) -> dict[str, jnp.ndarray]: 78 | """Splits state from the graph and returns it as a dictionary. 79 | 80 | It can be used for serialization with orbax.""" 81 | state = self.state 82 | pure_dict_state = nnx.to_pure_dict(state) 83 | return pure_dict_state 84 | 85 | def save(self, path: str) -> None: 86 | """Saves the model state to a directory. 87 | 88 | Args: 89 | path: The directory path to save the model state to. 90 | """ 91 | state = self.state_dict 92 | checkpointer = ocp.StandardCheckpointer() 93 | checkpointer.save(os.path.join(path, DEFAULT_PARAMS_FILE), state) 94 | checkpointer.wait_until_finished() 95 | 96 | def load(self, path: str) -> nnx.Module: 97 | """Loads the model state from a directory. 98 | 99 | Args: 100 | path: The directory path to load the model state from. 101 | """ 102 | checkpointer = ocp.StandardCheckpointer() 103 | restored_pure_dict = checkpointer.restore(os.path.join(path, DEFAULT_PARAMS_FILE)) 104 | abstract_model = nnx.eval_shape(lambda: self) 105 | graphdef, abstract_state = nnx.split(abstract_model) 106 | nnx.replace_by_pure_dict(abstract_state, restored_pure_dict) 107 | return nnx.merge(graphdef, abstract_state) 108 | 109 | @staticmethod 110 | def download_from_hf( 111 | repo_id: str, local_dir: str, token: str | None = None, force_download: bool = False 112 | ) -> None: 113 | """Downloads the model from the Hugging Face Hub. 114 | 115 | Args: 116 | repo_id: The repository ID of the model to download. 117 | local_dir: The local directory to save the model to. 118 | token: The hf auth token to download the model with. 119 | - If `True`, the token is read from the HuggingFace config 120 | folder. 121 | - If a string, it's used as the authentication token. 122 | force_download (`bool`, *optional*, defaults to `False`): 123 | Whether the file should be downloaded even if it already exists in the local cache. 124 | """ 125 | logger.info(f"Attempting to download {repo_id} from Hugging Face Hub to {local_dir}.") 126 | try: 127 | snapshot_download( 128 | repo_id, local_dir=local_dir, token=token, force_download=force_download 129 | ) 130 | logger.info(f"Successfully downloaded {repo_id} to {local_dir}.") 131 | except Exception as e: 132 | logger.error(f"Failed to download {repo_id}: {e}") 133 | raise 134 | 135 | @staticmethod 136 | def iter_safetensors(path_to_model_weights: str) -> Iterator[tuple[Any, Any]]: 137 | """Helper function to lazily load params from safetensors file. 138 | 139 | Use this static method to iterate over weights for conversion tasks. 140 | 141 | Args: 142 | path_to_model_weights: Path to directory containing .safetensors files.""" 143 | if not os.path.isdir(path_to_model_weights): 144 | raise ValueError(f"{path_to_model_weights} is not a valid directory.") 145 | 146 | safetensors_files = Path(path_to_model_weights).glob("*.safetensors") 147 | 148 | for file in safetensors_files: 149 | with safe_open(file, framework="jax", device="cpu") as f: 150 | for key in f.keys(): # noqa: SIM118 151 | yield key, f.get_tensor(key) 152 | 153 | def from_hf( 154 | self, 155 | model_repo_or_id: str, 156 | token: str | None = None, 157 | force_download: bool = False, 158 | save_in_orbax: bool = True, 159 | remove_hf_after_conversion: bool = True, 160 | ) -> None: 161 | """Downloads the model from the Hugging Face Hub and returns a new instance of the model. 162 | 163 | It can also save the converted weights in an Orbax checkpoint 164 | and removes the original HF checkpoint after conversion. 165 | 166 | Args: 167 | model_repo_or_id: The repository ID or name of the model to download. 168 | token: The token to use for authentication with the Hugging Face Hub. 169 | force_download: (`bool`, *optional*, defaults to `False`): 170 | Whether the file should be downloaded even if it already exists in the local cache. 171 | save_in_orbax: Whether to save the converted weights in an Orbax checkpoint. 172 | remove_hf_after_conversion: Whether to remove the downloaded HuggingFace checkpoint 173 | after conversion. 174 | """ 175 | logger.info(f"Starting from_hf process for model: {model_repo_or_id}") 176 | local_dir = os.path.join( 177 | os.path.expanduser("~"), ".jaxgarden", "hf_models", *model_repo_or_id.split("/") 178 | ) 179 | save_dir = local_dir.replace("hf_models", "models") 180 | if os.path.exists(save_dir): 181 | if force_download: 182 | logger.warning(f"Removing {save_dir} because force_download is set to True") 183 | shutil.rmtree(save_dir) 184 | else: 185 | raise RuntimeError( 186 | f"Path {save_dir} already exists." 187 | + " Set force_download to Tru to run conversion again." 188 | ) 189 | 190 | logger.debug(f"Local Hugging Face model directory set to: {local_dir}") 191 | 192 | BaseModel.download_from_hf( 193 | model_repo_or_id, local_dir, token=token, force_download=force_download 194 | ) 195 | logger.info(f"Initiating weight iteration from safetensors in {local_dir}") 196 | weights = BaseModel.iter_safetensors(local_dir) 197 | state = self.state 198 | logger.info("Running weight conversion...") 199 | self.convert_weights_from_hf(state, weights) 200 | logger.info("Weight conversion finished. Updating model state...") 201 | nnx.update(self, state) 202 | logger.warning("Model state successfully updated with converted weights.") 203 | 204 | if remove_hf_after_conversion: 205 | logger.warning(f"Removing HuggingFace checkpoint from {local_dir}...") 206 | shutil.rmtree(local_dir) 207 | 208 | if save_in_orbax: 209 | logger.warning(f")Saving Orbax checkpoint in {save_dir}.") 210 | self.save(save_dir) 211 | 212 | logger.warning(f"from_hf process completed for {model_repo_or_id}.") 213 | 214 | def convert_weights_from_hf(self, state: nnx.State, weights: Iterator[tuple[Any, Any]]) -> None: 215 | """Convert weights from Hugging Face Hub to the model's state. 216 | 217 | This method should be implemented in downstream classes 218 | to support conversion from HuggingFace format. 219 | """ 220 | raise NotImplementedError("This model does not support conversion from HuggingFace yet.") 221 | -------------------------------------------------------------------------------- /.cursor/rules/baseconfig.mdc: -------------------------------------------------------------------------------- 1 | --- 2 | description: JAXgarden tutorial chapter detailing BaseConfig, the base dataclass for managing model hyperparameters and configurations. 3 | globs: 4 | alwaysApply: false 5 | --- 6 | # Chapter 3: BaseConfig 7 | 8 | In the [previous chapter](basemodel.mdc), we explored `BaseModel`, the foundational class for models in `jaxgarden`. We saw that each `BaseModel` instance is initialized with a configuration object. This chapter dives into the base class for those configuration objects: `BaseConfig`. 9 | 10 | **Motivation:** Neural models are complex systems with numerous hyperparameters (like hidden size, number of layers, dropout rates) and settings (like data types, precision). Managing these configurations consistently across different models is crucial for reproducibility, maintainability, and ease of use. Using simple dictionaries or ad-hoc parameter passing can quickly become messy and error-prone. `BaseConfig` provides a standardized, structured way to define, manage, and serialize these configurations. 11 | 12 | **Central Use Case:** Defining the set of hyperparameters for a new model implementation (e.g., creating a `MyTransformerConfig` that inherits from `BaseConfig`) or inspecting and modifying the configuration of an existing `jaxgarden` model like [LlamaForCausalLM](llamaforcausallm.mdc), which uses its own `LlamaConfig` subclass derived from `BaseConfig`. 13 | 14 | ## Key Concepts 15 | 16 | `BaseConfig` leverages Python's `dataclasses` to provide a simple yet powerful configuration system: 17 | 18 | 1. **Dataclass Structure:** It's defined using the `@dataclass` decorator, offering type hints, default values, and automatic `__init__` generation. 19 | 2. **Inheritance:** Model-specific configurations (e.g., `LlamaConfig`, `ModernBERTConfig`) inherit from `BaseConfig`, adding their unique parameters while retaining the common base attributes. 20 | 3. **Common Attributes:** Provides essential base attributes applicable to most models, such as `seed` for reproducibility and `log_level` for controlling verbosity. 21 | 4. **Extensibility:** An `extra` dictionary allows storing arbitrary additional parameters not explicitly defined in the dataclass fields, offering flexibility. 22 | 5. **Serialization:** The `to_dict()` method converts the configuration object into a standard Python dictionary, useful for logging, saving, or inspection. 23 | 6. **Programmatic Updates:** The `update()` method allows modifying configuration attributes after instantiation using keyword arguments. 24 | 25 | ## Using `BaseConfig` 26 | 27 | You typically interact with subclasses of `BaseConfig` tailored to specific models. 28 | 29 | ### Defining a Custom Configuration 30 | 31 | To define a configuration for a new model, create a class inheriting from `BaseConfig` and use the `@dataclass` decorator. Add fields for your model's specific hyperparameters. 32 | 33 | ```python 34 | from dataclasses import dataclass, field 35 | from jaxgarden.models.base import BaseConfig 36 | from typing import Any 37 | 38 | @dataclass 39 | class MyModelConfig(BaseConfig): 40 | """Configuration for MyCustomModel.""" 41 | hidden_size: int = 256 42 | num_layers: int = 4 43 | dropout_rate: float = 0.1 44 | activation_fn: str = "relu" 45 | # Override a base attribute default if needed 46 | seed: int = 123 47 | 48 | # Instantiate the configuration 49 | my_config = MyModelConfig(hidden_size=512) # Override default hidden_size 50 | 51 | print(f"MyModelConfig instance: {my_config}") 52 | # Output: MyModelConfig instance: MyModelConfig(seed=123, log_level='info', extra={}, hidden_size=512, num_layers=4, dropout_rate=0.1, activation_fn='relu') 53 | 54 | print(f"Hidden size: {my_config.hidden_size}") 55 | # Output: Hidden size: 512 56 | print(f"Seed (overridden): {my_config.seed}") 57 | # Output: Seed (overridden): 123 58 | print(f"Log Level (from base): {my_config.log_level}") 59 | # Output: Log Level (from base): info 60 | ``` 61 | 62 | **Explanation:** 63 | - We define `MyModelConfig` inheriting from `BaseConfig`. 64 | - `@dataclass` automatically generates methods like `__init__`. 65 | - We add model-specific fields like `hidden_size` and `num_layers` with type hints and default values. 66 | - We can override defaults during instantiation (`hidden_size=512`). 67 | - Base attributes like `seed` and `log_level` are inherited. 68 | 69 | ### Serialization (`to_dict`) 70 | 71 | Convert the configuration object to a dictionary for easy inspection or storage. 72 | 73 | ```python 74 | config_dict = my_config.to_dict() 75 | print(f"Configuration as dictionary: {config_dict}") 76 | # Output: Configuration as dictionary: {'seed': 123, 'log_level': 'info', 'extra': {}, 'hidden_size': 512, 'num_layers': 4, 'dropout_rate': 0.1, 'activation_fn': 'relu'} 77 | 78 | import json 79 | print(f"Configuration as JSON: {json.dumps(config_dict, indent=2)}") 80 | # Output: 81 | # Configuration as JSON: { 82 | # "seed": 123, 83 | # "log_level": "info", 84 | # "extra": {}, 85 | # "hidden_size": 512, 86 | # "num_layers": 4, 87 | # "dropout_rate": 0.1, 88 | # "activation_fn": "relu" 89 | # } 90 | ``` 91 | 92 | **Explanation:** The `to_dict()` method simply returns the internal `__dict__` of the dataclass instance, making it compatible with standard serialization libraries like `json`. 93 | 94 | ### Programmatic Updates (`update`) 95 | 96 | Modify configuration values after the object has been created. 97 | 98 | ```python 99 | print(f"Original dropout rate: {my_config.dropout_rate}") 100 | # Output: Original dropout rate: 0.1 101 | print(f"Original extra dict: {my_config.extra}") 102 | # Output: Original extra dict: {} 103 | 104 | 105 | # Update existing attributes and add new ones to 'extra' 106 | update_params = { 107 | "dropout_rate": 0.15, 108 | "new_experimental_param": True, 109 | "learning_rate": 1e-4 110 | } 111 | my_config.update(**update_params) 112 | 113 | print(f"Updated dropout rate: {my_config.dropout_rate}") 114 | # Output: Updated dropout rate: 0.15 115 | print(f"Updated extra dict: {my_config.extra}") 116 | # Output: Updated extra dict: {'new_experimental_param': True, 'learning_rate': 0.0001} 117 | ``` 118 | 119 | **Explanation:** The `update()` method iterates through the provided keyword arguments. If an argument name matches an existing attribute of the config object, it updates that attribute. If the name doesn't match any defined attribute, it's added to the `extra` dictionary. 120 | 121 | ## Internal Implementation 122 | 123 | `BaseConfig` is fundamentally a Python `dataclass` with a couple of added helper methods. 124 | 125 | * **Location:** `jaxgarden/models/base.py` 126 | 127 | * **Definition:** 128 | ```python 129 | # From jaxgarden/models/base.py 130 | from dataclasses import dataclass, field 131 | from typing import Any 132 | 133 | @dataclass 134 | class BaseConfig: 135 | """Base configuration for models.""" 136 | seed: int = 42 137 | log_level: str = "info" 138 | extra: dict[str, Any] = field(default_factory=dict) 139 | # ... (Subclasses add more fields here) 140 | 141 | def to_dict(self) -> dict[str, Any]: 142 | # Simply returns the instance's dictionary 143 | return self.__dict__ 144 | 145 | def update(self, **kwargs: dict) -> None: 146 | # Iterates through provided keyword arguments 147 | for k, v in kwargs.items(): 148 | if hasattr(self, k): 149 | # If attribute exists, set its value 150 | setattr(self, k, v) 151 | else: 152 | # Otherwise, add it to the 'extra' dictionary 153 | self.extra[k] = v 154 | ``` 155 | * **Mechanism:** 156 | 1. `@dataclass`: Handles the `__init__`, `__repr__`, `__eq__`, etc., automatically based on the defined fields and type hints. 157 | 2. `field(default_factory=dict)`: Ensures that each instance gets its own empty `extra` dictionary by default, preventing accidental sharing between instances. 158 | 3. `to_dict()`: Leverages the standard Python object attribute `__dict__` for a direct conversion to a dictionary. 159 | 4. `update()`: Uses Python's reflection capabilities (`hasattr`, `setattr`) to dynamically update attributes or populate the `extra` dictionary. 160 | 161 | ## Relationship with `BaseModel` 162 | 163 | As discussed in [Chapter 2: BaseModel](basemodel.mdc), every `BaseModel` instance requires a configuration object (a subclass of `BaseConfig`) during initialization. This `config` object is stored as an attribute (`self.config`) within the model instance, making the model's hyperparameters readily accessible. 164 | 165 | ```python 166 | # Conceptual reminder from BaseModel chapter 167 | # class MyModel(BaseModel): 168 | # def __init__(self, config: MyModelConfig, *, rngs): 169 | # super().__init__(config, rngs=rngs) # Pass config to BaseModel 170 | # # Access config later: self.config.hidden_size 171 | # self.layer = nnx.Linear(config.hidden_size, ...) 172 | 173 | # my_config = MyModelConfig() 174 | # model = MyModel(config=my_config, ...) 175 | # print(model.config.num_layers) # Accessing config via the model 176 | ``` 177 | 178 | ## Conclusion 179 | 180 | `BaseConfig` provides a clean, structured, and extensible foundation for managing model configurations in `jaxgarden`. By using Python dataclasses and offering simple methods for serialization (`to_dict`) and updates (`update`), it promotes consistency and simplifies the process of defining and working with model hyperparameters. All specific model configurations, like the one we'll see next for Llama, build upon this base. 181 | 182 | **Next:** [LlamaForCausalLM](llamaforcausallm.mdc) 183 | 184 | 185 | --- 186 | 187 | Generated by [Rules for AI](https://github.com/altaidevorg/rules-for-ai) -------------------------------------------------------------------------------- /jaxgarden/attention/rope_multi_head_attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | JAX/Flax implementation of Multi-Head Attention with Rotary Positional Embedding (RoPE). 3 | 4 | This code implements the RoPE technique within a standard Multi-Head Attention 5 | framework. RoPE injects relative positional information by rotating pairs of 6 | features in the Query and Key vectors based on their absolute position before 7 | the attention calculation. 8 | 9 | The method was introduced in the paper: 10 | "RoFormer: Enhanced Transformer with Rotary Position Embedding" 11 | by Jianlin Su, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, Yunfeng Liu. 12 | arXiv:2104.09864v5 [cs.CL] (Submitted on 20 Apr 2021) 13 | """ 14 | 15 | import flax.linen as nn 16 | import jax 17 | import jax.numpy as jnp 18 | 19 | 20 | def rotate_half(x: jnp.ndarray) -> jnp.ndarray: 21 | """Rotates half the hidden dims of the input tensor.""" 22 | x1 = x[..., ::2] 23 | x2 = x[..., 1::2] 24 | # Builds the rotated tensor by concatenating the negated second half 25 | # and the first half along the last dimension. 26 | return jnp.concatenate((-x2, x1), axis=-1) 27 | 28 | 29 | def apply_rotary_pos_emb(x: jnp.ndarray, cos_emb: jnp.ndarray, sin_emb: jnp.ndarray) -> jnp.ndarray: 30 | """Applies Rotary Positional Embedding to the input tensor. 31 | 32 | Args: 33 | x: Input tensor, e.g., query or key (batch, seq_len, num_heads, head_dim) 34 | cos_emb: Cosine component of the positional embedding. 35 | Shape: (1, seq_len, 1, head_dim) or compatible via broadcasting. 36 | sin_emb: Sine component of the positional embedding. 37 | Shape: (1, seq_len, 1, head_dim) or compatible via broadcasting. 38 | Returns: 39 | Tensor with RoPE applied. 40 | """ 41 | # Applying the rotation formula: 42 | # x_rotated = x * cos(theta) + rotate_half(x) * sin(theta) 43 | # Ensure shapes are broadcastable: cos_emb and sin_emb should have dimensions 44 | # for sequence length and features, matching the corresponding dimensions in x. 45 | # Typically, precomputed embeddings have shape (seq_len, head_dim) 46 | # or (1, seq_len, 1, head_dim) for easy broadcasting. 47 | return (x * cos_emb) + (rotate_half(x) * sin_emb) 48 | 49 | 50 | def precompute_rotary_embeddings( 51 | seq_len: int, head_dim: int, base: float = 10000.0 52 | ) -> tuple[jnp.ndarray, jnp.ndarray]: 53 | """Precomputes the RoPE cosine and sine embeddings. 54 | 55 | Args: 56 | seq_len: The maximum sequence length. 57 | head_dim: The dimension of each attention head (must be even). 58 | base: The base value for the inverse frequency calculation. 59 | 60 | Returns: 61 | cos_emb: Cosine embeddings (1, seq_len, 1, head_dim) 62 | sin_emb: Sine embeddings (1, seq_len, 1, head_dim) 63 | """ 64 | if head_dim % 2 != 0: 65 | raise ValueError(f"head_dim must be even, got {head_dim}") 66 | 67 | # Calculate inverse frequencies (theta_i) 68 | # theta_i = 1 / (base^(2*i / head_dim)) for i in [0, 1, ..., head_dim/2 - 1] 69 | inv_freq = 1.0 / (base ** (jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim)) 70 | 71 | # Calculate position indices (m) 72 | pos = jnp.arange(seq_len, dtype=jnp.float32) 73 | 74 | # Calculate angles (m * theta_i) 75 | freqs = jnp.outer(pos, inv_freq) # Shape: (seq_len, head_dim / 2) 76 | 77 | # Duplicate frequencies for the full head dimension (for both elements in pairs) 78 | emb = jnp.concatenate((freqs, freqs), axis=-1) # Shape: (seq_len, head_dim) 79 | 80 | # Calculate cosine and sine embeddings 81 | cos_emb = jnp.cos(emb)[None, :, None, :] # Shape: (1, seq_len, 1, head_dim) 82 | sin_emb = jnp.sin(emb)[None, :, None, :] # Shape: (1, seq_len, 1, head_dim) 83 | 84 | return cos_emb, sin_emb 85 | 86 | 87 | class RoPEMultiHeadAttention(nn.Module): 88 | """Multi-Head Attention with Rotary Positional Embeddings.""" 89 | 90 | num_heads: int 91 | head_dim: int 92 | rope_base: float = 10000.0 93 | dtype: jnp.dtype = jnp.float32 94 | 95 | def setup(self) -> None: # Added -> None return type 96 | """Initializes the attention projections.""" 97 | # Check head_dim validity early during setup 98 | if self.head_dim % 2 != 0: 99 | raise ValueError(f"head_dim ({self.head_dim}) must be even for RoPE.") 100 | 101 | # Define layers here - they will be initialized when the module is first called 102 | total_head_dim = self.num_heads * self.head_dim 103 | self.query_proj = nn.Dense( 104 | features=total_head_dim, use_bias=False, dtype=self.dtype, name="query_proj" 105 | ) 106 | self.key_proj = nn.Dense( 107 | features=total_head_dim, use_bias=False, dtype=self.dtype, name="key_proj" 108 | ) 109 | self.value_proj = nn.Dense( 110 | features=total_head_dim, use_bias=False, dtype=self.dtype, name="value_proj" 111 | ) 112 | self.output_proj = nn.Dense( 113 | features=self.num_heads * self.head_dim, # Output should match embed_dim 114 | use_bias=False, 115 | dtype=self.dtype, 116 | name="output_proj", 117 | ) 118 | 119 | @nn.compact 120 | # Also using Optional for the mask type hint for clarity with None default 121 | def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarray: 122 | """Forward pass for RoPE MHA. 123 | 124 | Args: 125 | x: Input tensor (batch_size, seq_len, embed_dim). 126 | mask: Optional attention mask (batch_size, 1, seq_len, seq_len) 127 | or (batch_size, 1, 1, seq_len) for causal masking. 128 | Mask values should be 0 where attention is allowed, -inf otherwise. 129 | Flax convention often uses boolean masks (True=masked). We'll handle both. 130 | 131 | Returns: 132 | Output tensor (batch_size, seq_len, embed_dim). 133 | """ 134 | batch_size, seq_len, embed_dim = x.shape 135 | total_head_dim = self.num_heads * self.head_dim 136 | 137 | if embed_dim != total_head_dim: 138 | raise ValueError( 139 | f"embed_dim ({embed_dim}) must equal num_heads*head_dim ({total_head_dim})" 140 | ) 141 | # Note: head_dim even check moved to setup for earlier failure 142 | 143 | # 1. Linear projections for Q, K, V 144 | query = self.query_proj(x) 145 | key = self.key_proj(x) 146 | value = self.value_proj(x) 147 | 148 | # 2. Reshape for multi-head processing 149 | # (batch, seq_len, embed_dim) -> (batch, seq_len, num_heads, head_dim) 150 | query = query.reshape(batch_size, seq_len, self.num_heads, self.head_dim) 151 | key = key.reshape(batch_size, seq_len, self.num_heads, self.head_dim) 152 | value = value.reshape(batch_size, seq_len, self.num_heads, self.head_dim) 153 | 154 | # 3. Precompute RoPE embeddings (cosine and sine) 155 | # We compute them dynamically based on the input sequence length 156 | cos_emb, sin_emb = precompute_rotary_embeddings(seq_len, self.head_dim, base=self.rope_base) 157 | # Ensure RoPE embeddings have correct dtype 158 | cos_emb = cos_emb.astype(self.dtype) 159 | sin_emb = sin_emb.astype(self.dtype) 160 | 161 | # 4. Apply RoPE to Query and Key 162 | query = apply_rotary_pos_emb(query, cos_emb, sin_emb) 163 | key = apply_rotary_pos_emb(key, cos_emb, sin_emb) 164 | 165 | # 5. Transpose for attention calculation: (batch, num_heads, seq_len, head_dim) 166 | query = query.transpose((0, 2, 1, 3)) 167 | key = key.transpose((0, 2, 1, 3)) 168 | value = value.transpose((0, 2, 1, 3)) 169 | 170 | # 6. Scaled Dot-Product Attention 171 | # Attention scores: (batch, num_heads, seq_len, seq_len) 172 | attn_scores = jnp.matmul(query, key.transpose((0, 1, 3, 2))) / jnp.sqrt( 173 | self.head_dim 174 | ).astype(self.dtype) # Ensure sqrt is correct dtype 175 | 176 | # Apply mask (if provided) 177 | if mask is not None: 178 | # Standard Flax causal mask is boolean (True means mask) 179 | # nn.make_causal_mask returns (1, seq_len, seq_len) or (batch, 1, seq_len, seq_len) 180 | # Check if mask needs broadcasting or conversion 181 | if mask.ndim == 2: # Likely (seq_len, seq_len) 182 | mask = mask[None, None, :, :] # -> (1, 1, seq_len, seq_len) 183 | elif mask.ndim == 3 and mask.shape[1] != self.num_heads: 184 | # Likely (batch, seq_len, seq_len) or causal (1, sl, sl) 185 | mask = mask[:, None, :, :] 186 | # Assume (batch, seq_len, seq_len) -> (batch, 1, seq_len, seq_len) 187 | 188 | # Ensure mask is broadcastable to attn_scores shape 189 | mask_shape_expected = (batch_size, self.num_heads, seq_len, seq_len) 190 | if mask.shape != mask_shape_expected: 191 | # Attempt broadcasting common causal mask shapes 192 | if mask.shape == (1, 1, seq_len, seq_len) or mask.shape == ( 193 | batch_size, 194 | 1, 195 | seq_len, 196 | seq_len, 197 | ): # Causal mask for all batches/heads 198 | mask = jnp.broadcast_to(mask, mask_shape_expected) 199 | # Add other broadcasting cases if needed 200 | else: 201 | raise ValueError(f"Mask shape {mask.shape} != exp shape {mask_shape_expected}") 202 | 203 | # Apply mask: Use large negative number where mask is True 204 | # (or where mask value is 0 if using 0/-inf convention) 205 | # Assuming boolean mask convention (True = mask) common in Flax examples 206 | # If using 0/-inf mask, the logic would be: attn_scores = attn_scores + mask 207 | attn_scores = jnp.where(mask, jnp.finfo(self.dtype).min, attn_scores) 208 | 209 | # Softmax to get attention weights 210 | attn_weights = jax.nn.softmax(attn_scores, axis=-1).astype( 211 | self.dtype 212 | ) # Shape: (batch, num_heads, seq_len, seq_len) 213 | 214 | # Apply attention weights to Value 215 | # Output per head: (batch, num_heads, seq_len, head_dim) 216 | attn_output = jnp.matmul(attn_weights, value) 217 | 218 | # 7. Concatenate heads and final projection 219 | # Transpose back: (batch, seq_len, num_heads, head_dim) 220 | attn_output = attn_output.transpose((0, 2, 1, 3)) 221 | # Reshape to (batch, seq_len, embed_dim) 222 | attn_output = attn_output.reshape(batch_size, seq_len, total_head_dim) 223 | 224 | # Final linear projection 225 | output = self.output_proj(attn_output) # Use self.output_proj defined in setup 226 | 227 | return output 228 | -------------------------------------------------------------------------------- /.cursor/rules/rotary_position_embeddings__rope_.mdc: -------------------------------------------------------------------------------- 1 | --- 2 | description: JAXgarden tutorial chapter on Rotary Position Embeddings (RoPE) for injecting relative positional information in attention. 3 | globs: 4 | alwaysApply: false 5 | --- 6 | # Chapter 8: Rotary Position Embeddings (RoPE) 7 | 8 | In the [previous chapter](attention_mechanism__multiheadattention___dot_product_attention_.mdc), we explored the core `MultiHeadAttention` mechanism and its efficient backend implementations. However, standard self-attention is permutation-invariant, meaning it doesn't inherently understand the order of tokens in a sequence. To address this, models need positional information. This chapter introduces Rotary Position Embeddings (RoPE), a clever technique for incorporating relative position information directly into the attention mechanism. 9 | 10 | **Motivation:** Traditional methods for adding positional information include adding learned absolute position embeddings or fixed sinusoidal embeddings to the input token embeddings. While effective, these methods add extra parameters (learned embeddings) or modify the input representation before the transformer layers. RoPE offers an alternative by *rotating* the query and key vectors based on their absolute positions *before* the dot-product calculation in the attention mechanism. This rotation is designed such that the dot product between a rotated query and a rotated key naturally depends on their relative positions, effectively injecting relative positional awareness directly into the attention scores without adding separate embedding vectors. 11 | 12 | **Central Use Case:** Applying RoPE within the attention modules of models like [LlamaForCausalLM](llamaforcausallm.mdc) (`LlamaAttention`) or [ModernBERTForMaskedLM](modernbertformaskedlm.mdc) (`ModernBertAttention`). RoPE modifies the query and key vectors just before their dot product, enabling the attention mechanism to weigh interactions based on how far apart tokens are in the sequence. 13 | 14 | ## Key Concepts 15 | 16 | 1. **Core Idea:** RoPE operates by viewing pairs of features in the query and key vectors as complex numbers and rotating them in the complex plane. The angle of rotation depends on the token's absolute position (`m`) and the feature index (`i`). When computing the dot product between a rotated query (at position `m`) and a rotated key (at position `n`), the resulting score implicitly depends on their relative distance (`m - n`). 17 | 2. **Mathematical Basis:** The rotation is achieved using sinusoidal functions derived from the position index (`m`) and the feature dimension (`d`). Specifically, frequencies (`theta_i`) are calculated based on the feature index `i` and a base value (`base`, often 10000 or larger). The rotation involves multiplying elements by `cos(m * theta_i)` and `sin(m * theta_i)`. 18 | 3. **Implementation Variations:** `jaxgarden` provides two main implementations reflecting common practices: 19 | * **`LlamaRotaryEmbedding`:** Used in [LlamaForCausalLM](llamaforcausallm.mdc). It calculates the `cos` and `sin` values *on-the-fly* based on the input `position_ids`. This is flexible for varying sequence lengths during generation. 20 | * **`RoPEPositionalEmbedding`:** Used in [ModernBERTForMaskedLM](modernbertformaskedlm.mdc). It *pre-computes* and caches the `cos` and `sin` values for positions up to a specified `max_position_embeddings` during initialization. This can be more efficient if the maximum sequence length is known and fixed, as it replaces calculation with lookup. 21 | 4. **Application Point:** RoPE is applied to the query and key vectors *after* their initial linear projections (`Wq`, `Wk`) but *before* the dot-product attention score calculation (`QK^T`). This modification is typically encapsulated within a helper function (e.g., `apply_rotary_pos_emb`) called by the attention module. 22 | 23 | ## Using RoPE 24 | 25 | RoPE is generally an internal component of attention modules. You typically don't interact with `LlamaRotaryEmbedding` or `RoPEPositionalEmbedding` directly unless implementing a custom attention layer. Instead, you configure the model (via its `Config` object) which implicitly configures the RoPE within its attention layers. 26 | 27 | **Example 1: RoPE within `LlamaAttention`** 28 | 29 | The `LlamaAttention` module internally instantiates `LlamaRotaryEmbedding` and uses it within its `__call__` method. 30 | 31 | ```python 32 | # Inside LlamaAttention initialization (__init__) 33 | # config values like head_dim and rope_theta are passed 34 | self.rotary_emb = LlamaRotaryEmbedding( 35 | dim=config.head_dim, 36 | base=config.rope_theta, 37 | rngs=rngs 38 | ) 39 | 40 | # Inside LlamaAttention forward pass (__call__) 41 | def __call__(self, x, position_ids, attention_mask): 42 | # ... project x to query, key, value ... 43 | query = self.q_proj(x).reshape(...) # [batch, n_heads, seq_len, head_dim] 44 | key = self.k_proj(x).reshape(...) # [batch, n_kv_heads, seq_len, head_dim] 45 | value = self.v_proj(x).reshape(...) # [batch, n_kv_heads, seq_len, head_dim] 46 | 47 | # Calculate cos/sin on the fly (assuming batch_size=1 for simplicity here) 48 | cos, sin = self.rotary_emb(position_ids[0]) 49 | 50 | # Apply RoPE rotation *before* attention calculation 51 | query, key = self.apply_rotary_pos_emb(query, key, cos, sin) 52 | 53 | # ... repeat key/value if GQA, compute attention scores ... 54 | # attn_weights = jnp.matmul(query, key.transpose(...)) / scale 55 | # ... compute final output ... 56 | return output 57 | ``` 58 | **Explanation:** `LlamaAttention` creates a `LlamaRotaryEmbedding` instance. In the forward pass, it calls this instance with `position_ids` to get the `cos` and `sin` values dynamically. These are then passed to `apply_rotary_pos_emb` (a method within `LlamaAttention`) to rotate `query` and `key`. 59 | 60 | **Example 2: RoPE within `ModernBertAttention`** 61 | 62 | `ModernBertAttention` uses the pre-computed cache from `RoPEPositionalEmbedding`. 63 | 64 | ```python 65 | # Inside ModernBertAttention initialization (__init__) 66 | # config values like head_dim, max_position_embeddings, rope_theta are passed 67 | self.rotary_emb = RoPEPositionalEmbedding( 68 | rngs=rngs, 69 | dim=self.head_dim, 70 | max_position_embeddings=max_pos, # Can depend on local/global 71 | base=rope_theta # Can depend on local/global 72 | ) 73 | # The cache is stored in self.rotary_emb.cache 74 | 75 | # Inside ModernBertAttention forward pass (__call__) 76 | # It uses a standalone apply_rotary_pos_emb function defined in modernbert.py 77 | from jaxgarden.models.modernbert import apply_rotary_pos_emb 78 | 79 | def __call__(self, hidden_states, attention_mask, ..., position_ids): 80 | # ... project hidden_states to qkv ... 81 | # Split into query, key, value: [batch, num_heads, seq_len, head_dim] 82 | # Transpose for RoPE application: [batch, seq_len, num_heads, head_dim] 83 | query = query.transpose((0, 2, 1, 3)) 84 | key = key.transpose((0, 2, 1, 3)) 85 | 86 | # Apply RoPE using the pre-computed cache 87 | query = apply_rotary_pos_emb(query, self.rotary_emb.cache, position_ids) 88 | key = apply_rotary_pos_emb(key, self.rotary_emb.cache, position_ids) 89 | 90 | # Transpose back: [batch, num_heads, seq_len, head_dim] 91 | query = query.transpose((0, 2, 1, 3)) 92 | key = key.transpose((0, 2, 1, 3)) 93 | 94 | # ... compute attention scores ... 95 | # attn_scores = jnp.matmul(query, key.swapaxes(-2, -1)) / scale 96 | # ... compute final output ... 97 | return output_tuple 98 | ``` 99 | **Explanation:** `ModernBertAttention` creates a `RoPEPositionalEmbedding` instance, which pre-computes the cache. In the forward pass, the standalone `apply_rotary_pos_emb` function is called, passing the `query` and `key` vectors along with the *cached* sinusoidal values (`self.rotary_emb.cache`) and optional `position_ids` for lookup. 100 | 101 | ## Internal Implementation 102 | 103 | Let's examine the core logic. 104 | 105 | ### `LlamaRotaryEmbedding` (On-the-fly Calculation) 106 | 107 | * **File:** `jaxgarden/models/llama.py` 108 | 109 | **Walkthrough:** 110 | 1. The `__call__` method receives `position_ids` (typically `[1, seq_len]`). 111 | 2. It calculates the inverse frequencies (`inv_freq`) based on `self.dim` and `self.base`. This determines how quickly the angle changes for different feature dimensions. 112 | 3. It uses `jnp.einsum` to efficiently compute the angles (`freqs`) for all positions and half the dimensions: `angle = position_id * inv_freq`. 113 | 4. It concatenates the angles to cover the full dimension (`emb`). 114 | 5. It computes and returns the cosine and sine of these angles. 115 | 116 | ```python 117 | # Simplified from LlamaRotaryEmbedding.__call__ 118 | @nnx.jit() 119 | def __call__(self, position_ids: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: 120 | # inv_freq shape: [dim/2] 121 | inv_freq = 1.0 / (self.base ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim)) 122 | # Expand dims for einsum: inv_freq -> [1, 1, dim/2] 123 | inv_freq_expanded = jnp.expand_dims(inv_freq, axis=(0, 1)) 124 | # position_ids -> [1, seq_len, 1] 125 | position_ids_expanded = jnp.expand_dims(position_ids, axis=(0, 2)).astype(jnp.float32) 126 | 127 | # Calculate angles: freqs shape [1, seq_len, dim/2] 128 | freqs = jnp.einsum("bij,bjk->bik", position_ids_expanded, inv_freq_expanded) 129 | 130 | # Concatenate angles for full dimension: emb shape [1, seq_len, dim] 131 | emb = jnp.concatenate([freqs, freqs], axis=-1) 132 | 133 | # Compute cos and sin: Shapes [seq_len, dim] (after squeeze) 134 | cos = jnp.cos(emb).squeeze(0).astype(jnp.bfloat16) # Assuming batch=1 input 135 | sin = jnp.sin(emb).squeeze(0).astype(jnp.bfloat16) 136 | return cos, sin 137 | ``` 138 | 139 | ### `RoPEPositionalEmbedding` (Pre-computed Cache) 140 | 141 | * **File:** `jaxgarden/models/modernbert.py` 142 | 143 | **Walkthrough:** 144 | 1. **`create_sinusoidal_positions` (Helper Function):** Calculates angles `pos * theta` for all positions up to `max_length` and half the dimensions. Returns stacked `(cos, sin)` values. Shape: `[max_length, dim//2, 2]`. 145 | 2. **`RoPEPositionalEmbedding.__init__`:** Calls `create_sinusoidal_positions` and stores the result in `self.cache`. 146 | 3. **`apply_rotary_pos_emb` (Helper Function):** 147 | * Takes an input tensor `x` (query or key), the `cache`, and optional `positions`. 148 | * Determines the sequence length (`seq_len`) from `x`. 149 | * Looks up the `cos`/`sin` values from the `cache`. If `positions` are given, it uses advanced indexing; otherwise, it takes the first `seq_len` entries from the cache. 150 | * Reshapes `x` and the looked-up cache values for broadcasting and rotation. 151 | * Performs the rotation using the complex number multiplication analogy: `x_rotated = x * cos +/- x_flipped * sin`. 152 | * Reshapes the result back to the original shape. 153 | 154 | ```mermaid 155 | sequenceDiagram 156 | participant Attn as Attention Module 157 | participant ApplyRoPE as apply_rotary_pos_emb 158 | participant Cache as RoPEPositionalEmbedding.cache 159 | participant QK as Query/Key Vectors 160 | 161 | Attn->>QK: Original Query/Key 162 | Attn->>+ApplyRoPE: apply_rotary_pos_emb(QK, cache, positions) 163 | ApplyRoPE->>Cache: Lookup cos/sin values using positions/seq_len 164 | Cache-->>ApplyRoPE: cos_sin_values 165 | ApplyRoPE->>ApplyRoPE: Reshape QK and cos_sin_values 166 | ApplyRoPE->>ApplyRoPE: Perform rotation (complex multiplication) 167 | ApplyRoPE-->>-Attn: Rotated Query/Key 168 | Attn->>Attn: Compute Attention Scores (using rotated Q/K) 169 | ``` 170 | 171 | ```python 172 | # Simplified from modernbert.py create_sinusoidal_positions 173 | def create_sinusoidal_positions(max_length, dim, base=10000.0): 174 | positions = jnp.arange(max_length, dtype=jnp.float32) # [max_length] 175 | dim_indices = jnp.arange(0, dim, 2, dtype=jnp.float32) # [dim/2] 176 | theta = 1.0 / (base ** (dim_indices / dim)) # [dim/2] 177 | angles = jnp.einsum("i,j->ij", positions, theta) # [max_length, dim/2] 178 | # Stack cos and sin along a new last dimension 179 | return jnp.stack([jnp.cos(angles), jnp.sin(angles)], axis=-1) # [max_length, dim//2, 2] 180 | 181 | # Simplified from RoPEPositionalEmbedding.__init__ 182 | class RoPEPositionalEmbedding(nnx.Module): 183 | def __init__(self, ..., dim, max_position_embeddings, base): 184 | # ... super().__init__() ... 185 | # Pre-compute and store the cache 186 | self.cache = create_sinusoidal_positions( 187 | max_position_embeddings, dim, base 188 | ) 189 | # Cache shape: [max_position_embeddings, dim//2, 2] 190 | 191 | # Simplified from modernbert.py apply_rotary_pos_emb 192 | def apply_rotary_pos_emb(x, cache, positions=None): 193 | # x shape: [batch, seq, heads, dim] 194 | # cache shape: [max_len, dim//2, 2] 195 | seq_len = x.shape[1] 196 | 197 | # Lookup from cache based on positions or seq_len 198 | if positions is not None: 199 | rope_cache = cache[positions] # [batch, seq, dim//2, 2] 200 | else: 201 | rope_cache = cache[:seq_len] # [seq, dim//2, 2] 202 | rope_cache = jnp.expand_dims(rope_cache, 0) # [1, seq, dim//2, 2] 203 | 204 | # Reshape x for rotation: [batch, seq, heads, dim//2, 2] 205 | x_reshaped = x.reshape(*x.shape[:-1], -1, 2) 206 | # Expand cache dims for broadcasting: [batch/1, seq, 1, dim//2, 2] 207 | rope_cache = jnp.expand_dims(rope_cache, 2) 208 | 209 | # Perform rotation (like complex multiplication) 210 | cos_vals = rope_cache[..., 0] 211 | sin_vals = rope_cache[..., 1] 212 | x_real = x_reshaped[..., 0] 213 | x_imag = x_reshaped[..., 1] 214 | x_rotated_real = x_real * cos_vals - x_imag * sin_vals 215 | x_rotated_imag = x_imag * cos_vals + x_real * sin_vals 216 | 217 | x_out = jnp.stack([x_rotated_real, x_rotated_imag], axis=-1) 218 | # Reshape back to original: [batch, seq, heads, dim] 219 | return x_out.reshape(x.shape) 220 | ``` 221 | *Note:* The `apply_rotary_pos_emb` method within `LlamaAttention` performs a similar rotation logic but uses the dynamically calculated `cos` and `sin` values instead of a pre-computed cache. The core rotation math (`x * cos +/- rotate_half(x) * sin`) is equivalent. 222 | 223 | ## Conclusion 224 | 225 | Rotary Position Embeddings (RoPE) provide an elegant and effective way to incorporate relative positional information into transformer attention mechanisms. By rotating query and key vectors based on their absolute positions using sinusoidal functions, RoPE makes the attention scores sensitive to relative distances without requiring additional learned parameters or modifying input embeddings directly. `jaxgarden` offers both on-the-fly (`LlamaRotaryEmbedding`) and pre-computed/cached (`RoPEPositionalEmbedding`) implementations, catering to different model architectures and efficiency considerations, as seen in [LlamaForCausalLM](llamaforcausallm.mdc) and [ModernBERTForMaskedLM](modernbertformaskedlm.mdc) respectively. 226 | 227 | This concludes the core chapters outlined for the `jaxgarden` tutorial structure. Understanding these components provides a solid foundation for using and extending the library. 228 | 229 | 230 | --- 231 | 232 | Generated by [Rules for AI](https://github.com/altaidevorg/rules-for-ai) -------------------------------------------------------------------------------- /.cursor/rules/generationmixin.mdc: -------------------------------------------------------------------------------- 1 | --- 2 | description: jaxgarden tutorial on GenerationMixin, providing autoregressive text generation with sampling for JAX models. 3 | globs: jaxgarden/models/generation_utils.py 4 | alwaysApply: false 5 | --- 6 | # Chapter 5: GenerationMixin 7 | 8 | In the [previous chapter](llamaforcausallm.mdc), we explored `LlamaForCausalLM`, a causal language model built using `jaxgarden` components. We saw that it inherits text generation capabilities. This chapter focuses on the `GenerationMixin` class, the source of that functionality. 9 | 10 | **Motivation:** Implementing autoregressive text generation involves complex logic: managing the token-by-token loop efficiently, handling various sampling strategies (temperature, top-k, top-p, min-p), managing padding and end-of-sequence tokens, and dealing with JAX's PRNG key management and JIT compilation nuances. Encapsulating this logic in a reusable mixin avoids code duplication across different causal language models and provides a standardized generation interface. `GenerationMixin` solves this by providing a robust `generate` method that can be easily added to any compatible causal LM inheriting from [BaseModel](basemodel.mdc). 11 | 12 | **Central Use Case:** Adding the `generate` method to a custom causal language model (or using it via an existing model like `LlamaForCausalLM`). This allows the model instance to perform text generation based on a prompt, controlling the output's creativity and coherence through sampling parameters, and leveraging JAX optimizations like `lax.scan` and JIT compilation for performance on accelerators. 13 | 14 | ## Key Concepts 15 | 16 | 1. **Mixin Pattern:** `GenerationMixin` is not meant to be used standalone. It's designed to be inherited *alongside* a primary base class (like `BaseModel` or a specific model class like `LlamaForCausalLM`). It "mixes in" the `generate` method and its helpers into the inheriting class. 17 | 2. **Autoregressive Loop (`jax.lax.scan`):** Text generation is sequential. The model predicts the next token based on the previously generated tokens. `GenerationMixin` implements this loop using `jax.lax.scan`, which is highly efficient for iterative computations on JAX accelerators (GPU/TPU) as it unrolls the loop within the compiled computation graph. 18 | 3. **Sampling Strategies:** Controls how the next token is chosen from the model's output probability distribution (logits). 19 | * `temperature`: Scales logits before sampling. Higher values -> more randomness; lower values -> more determinism. 20 | * `top_k`: Restricts sampling to the `k` most likely tokens. 21 | * `top_p` (Nucleus Sampling): Restricts sampling to the smallest set of tokens whose cumulative probability exceeds `p`. 22 | * `min_p`: Restricts sampling to tokens with probability `p * max_probability` or higher. 23 | * `do_sample`: Boolean flag to enable/disable sampling (if `False`, uses greedy decoding - picks the most likely token). 24 | 4. **Helper Functions:** Sampling strategies are implemented via standalone helper functions (`temperature_scale`, `top_k_logits`, `top_p_logits`, `min_p_logits`, `sample_logits`) for clarity and testability. 25 | 5. **State Management:** The generation loop manages the sequence length, detects the End-of-Sequence (EOS) token to stop generation for specific sequences in a batch, handles padding (`pad_token_id`), and correctly splits and passes the JAX PRNG key (`rng`) at each step if sampling is enabled. 26 | 6. **JIT Compilation (`use_jit`):** The `generate` method offers a `use_jit` flag. If `True`, it calls a pre-compiled version of the core generation loop (`_generate_compiled`). This requires specifying `static_argnames` (like `max_length`, `temperature`, `top_k`, etc., and crucially `self`) to `jax.jit`, as these values influence the computation graph structure and cannot be dynamic JAX tracers during compilation. 27 | 28 | ## Using `GenerationMixin` 29 | 30 | You typically don't interact with `GenerationMixin` directly. Instead, you call the `generate` method on a model class that inherits from it, like `LlamaForCausalLM`. 31 | 32 | ```python 33 | import jax 34 | import jax.numpy as jnp 35 | from flax import nnx 36 | # Assume LlamaForCausalLM and Tokenizer are imported and initialized 37 | # from jaxgarden.models.llama import LlamaConfig, LlamaForCausalLM 38 | # from jaxgarden.tokenization import Tokenizer 39 | 40 | # --- Setup (Conceptual) --- 41 | # model_config = LlamaConfig(...) # Load appropriate config 42 | # tokenizer = Tokenizer.from_pretrained(...) # Load matching tokenizer 43 | # rngs = nnx.Rngs(0) 44 | # model = LlamaForCausalLM(model_config, rngs=rngs) 45 | # model.from_hf(...) # Optional: Load pretrained weights 46 | 47 | # --- Generation Call --- 48 | # prompt = "The definition of JAX is" 49 | # inputs = tokenizer.encode(prompt, return_tensors="jax") 50 | # input_ids = inputs['input_ids'] 51 | # attention_mask = inputs['attention_mask'] # Important for initial prompt 52 | 53 | # Set generation parameters 54 | # max_new_tokens = 50 55 | # target_max_length = input_ids.shape[1] + max_new_tokens 56 | # generation_rng = jax.random.PRNGKey(42) 57 | # pad_token_id = tokenizer.pad_token_id 58 | # eos_token_id = tokenizer.eos_token_id 59 | 60 | # print(f"Generating text with max_length={target_max_length}, temperature=0.8, top_k=50") 61 | 62 | # --- The core call to the generate method --- 63 | # output_ids = model.generate( 64 | # input_ids=input_ids, 65 | # attention_mask=attention_mask, # Use mask from prompt encoding 66 | # max_length=target_max_length, 67 | # temperature=0.8, 68 | # top_k=50, 69 | # top_p=0.9, 70 | # min_p=None, # Example: min_p not used 71 | # do_sample=True, # Enable sampling 72 | # pad_token_id=pad_token_id, 73 | # eos_token_id=eos_token_id, 74 | # rng=generation_rng, 75 | # use_jit=True # Use the compiled version for speed 76 | # ) 77 | 78 | # --- Decoding --- 79 | # generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) 80 | # print(f"\nPrompt: {prompt}") 81 | # print(f"Generated: {generated_text}") 82 | 83 | # Example Output (Conceptual - depends heavily on model and sampling): 84 | # Generating text with max_length=58, temperature=0.8, top_k=50 85 | # 86 | # Prompt: The definition of JAX is 87 | # Generated: The definition of JAX is a high-performance numerical computation library 88 | # focused on machine learning research. It combines automatic differentiation 89 | # (autograd) and XLA compilation for speed on accelerators like GPUs and TPUs... 90 | print("Skipping actual generation execution.") 91 | ``` 92 | 93 | **Explanation:** 94 | - `input_ids`: The starting token sequence (prompt) as a JAX array `[batch_size, seq_len]`. 95 | - `attention_mask`: A mask indicating which input tokens are real (1) vs padding (0), shape `[batch_size, seq_len]`. Important for the model to ignore padding in the prompt. 96 | - `max_length`: The *total* maximum length of the output sequence (prompt + generated tokens). 97 | - `temperature`, `top_k`, `top_p`, `min_p`: Control the sampling process. Setting `top_k=1` or `do_sample=False` results in greedy decoding. 98 | - `do_sample`: If `True`, performs sampling according to the other parameters; if `False`, performs greedy decoding. 99 | - `pad_token_id`: The ID used for padding shorter sequences in the batch *after* they finish generating (e.g., hit EOS). If not provided, it tries to infer from `model.config.pad_token_id` or defaults to 0. 100 | - `eos_token_id`: The ID that signals the end of a sequence. Generation stops for a sequence once this token is sampled. If not provided, generation continues until `max_length`. 101 | - `rng`: A `jax.random.PRNGKey` required if `do_sample=True`. 102 | - `use_jit`: If `True`, calls the JIT-compiled version of the generation loop. Highly recommended for performance. 103 | - **Output:** The method returns a JAX array `output_ids` of shape `[batch_size, max_length]` containing the prompt and the generated tokens, padded up to `max_length`. 104 | 105 | ## Internal Implementation 106 | 107 | The `generate` method orchestrates the process, while the core token-by-token logic resides in `_generate_scan_logic`, which is wrapped by `jax.lax.scan`. 108 | 109 | 1. **`generate` Method:** 110 | * Performs input validation (shapes, parameter ranges). 111 | * Handles default values for `pad_token_id`, `eos_token_id`, and `rng`. 112 | * Determines the initial sequence length (`initial_seq_len`) and batch size. 113 | * Initializes the `finished_sequences` boolean array based on whether the *last* token of the input is already EOS. 114 | * **Conditional Dispatch:** Based on the `use_jit` flag, it calls either: 115 | * `self._generate_compiled(...)`: A pre-compiled version of `_generate_scan_logic` created using `partial(jax.jit, static_argnames=...)`. 116 | * `self._generate_scan_logic(...)`: The raw Python function (slower, useful for debugging). 117 | 118 | 2. **`_generate_scan_logic` Method (Core Loop Logic):** 119 | * Initializes the output tensor `output_ids` of shape `[batch_size, max_length]` filled with `pad_token_id`, and copies the `initial_input_ids` into the beginning. 120 | * Sets up the initial `carry` dictionary for `lax.scan`, containing `output_ids`, `current_length`, `rng`, and `finished` status. 121 | * Defines the `scan_step` function, which performs one step of generation: 122 | * Takes the current `carry` state and a loop counter (`_`) as input. 123 | * Splits the PRNG key (`rng`) if `do_sample` is true. 124 | * Creates the `attention_mask` for the *current* sequence length within the `max_length` buffer. Note: Causal masking is typically handled *inside* the model's attention mechanism, but this mask handles padding. 125 | * **Model Call:** Calls `self(input_ids=current_output_ids, attention_mask=attention_mask, deterministic=True)`. This executes the forward pass of the model (e.g., `LlamaForCausalLM.__call__`). `deterministic=True` ensures dropout etc. are disabled. 126 | * Extracts the logits for the *next* token prediction (logits at the `current_length - 1` position). 127 | * **Sampling:** Calls `sample_logits` with the extracted logits and sampling parameters (`temperature`, `top_k`, etc.) to get the `next_token`. 128 | * **EOS/Padding:** Checks if the `next_token` is the `eos_token_id`. Updates the `finished` status for the sequence. Determines the token to actually write: `pad_token_id` if already finished, otherwise the sampled `next_token`. 129 | * Updates the `output_ids` tensor at the `current_length` position with the chosen token. 130 | * Updates the `carry` dictionary for the next step (`current_length` incremented, `rng` updated, `finished` status updated). 131 | * Returns the updated `carry` and `None` (as `lax.scan` expects `(carry, scan_output)`). 132 | * Calls `jax.lax.scan(scan_step, initial_carry, None, length=num_steps_to_generate)`. 133 | * Returns the final `output_ids` from the resulting `carry`. 134 | 135 | 3. **`sample_logits` Function:** 136 | * Takes raw logits, RNG key, and all sampling parameters. 137 | * If `do_sample=False`, returns `jnp.argmax(logits)`. 138 | * Applies `temperature_scale`. 139 | * Sequentially applies filtering functions (`min_p_logits`, `top_k_logits`, `top_p_logits`) if the corresponding parameters are set. These functions mask invalid logits by setting them to `-jnp.inf`. 140 | * **Edge Case Handling:** If *all* logits become `-jnp.inf` after filtering (can happen with very restrictive filtering), it falls back to sampling from the *pre-filtered* (but temperature-scaled) logits to prevent errors. 141 | * Uses `jax.random.categorical` to sample from the final filtered (or fallback) logit distribution. 142 | 143 | ```mermaid 144 | sequenceDiagram 145 | participant User 146 | participant Model as Model (e.g., LlamaForCausalLM) 147 | participant GenMixin as GenerationMixin Logic 148 | participant JitCompiled as _generate_compiled (JIT) 149 | participant ScanLogic as _generate_scan_logic 150 | participant LaxScan as jax.lax.scan 151 | participant ScanStep as scan_step (Inner Function) 152 | participant Sampler as sample_logits 153 | 154 | User->>+Model: generate(input_ids, ..., use_jit=True, rng=key) 155 | Model->>+GenMixin: generate(...) 156 | GenMixin->>+JitCompiled: _generate_compiled(input_ids, finished, rng, ..., static_args...) 157 | Note over JitCompiled: Pre-compiled version of ScanLogic 158 | JitCompiled->>+ScanLogic: _generate_scan_logic(...) # Compiled call 159 | ScanLogic->>ScanLogic: Initialize output_ids, initial_carry 160 | ScanLogic->>+LaxScan: scan(scan_step, initial_carry, length=N) 161 | loop N times (Number of tokens to generate) 162 | LaxScan->>+ScanStep: scan_step(current_carry, _) 163 | ScanStep->>Model: self(current_output_ids, mask, deterministic=True) 164 | Model-->>ScanStep: next_token_logits 165 | ScanStep->>+Sampler: sample_logits(logits, rng_step, temp, k, p, ...) 166 | Sampler-->>-ScanStep: next_token 167 | ScanStep->>ScanStep: Update finished status (EOS check) 168 | ScanStep->>ScanStep: Determine output_token (handle padding) 169 | ScanStep->>ScanStep: Update output_ids tensor 170 | ScanStep->>ScanStep: Prepare next_carry (update length, rng, finished) 171 | ScanStep-->>-LaxScan: next_carry, None 172 | end 173 | LaxScan-->>-ScanLogic: final_carry, _ 174 | ScanLogic->>ScanLogic: Extract final_output_ids from final_carry 175 | ScanLogic-->>-JitCompiled: final_output_ids 176 | JitCompiled-->>-GenMixin: final_output_ids 177 | GenMixin-->>-User: final_output_ids (Generated Sequence) 178 | ``` 179 | 180 | 4. **JIT Compilation (`_generate_compiled`)**: 181 | * Defined using `partial(jax.jit, static_argnames=...)`. 182 | * `static_argnames` must include parameters that affect the *structure* of the computation or are needed as Python values inside the loop logic (not JAX tracers). This includes: 183 | * `self`: Needed to call the model's forward pass (`self.__call__`) inside `scan_step`. `self` contains the model's graph definition. 184 | * `max_length`: Determines the size of the output tensor and loop length. 185 | * `temperature`, `top_k`, `top_p`, `min_p`, `do_sample`: Control conditional logic within `sample_logits` and `scan_step`. 186 | * `pad_token_id`, `eos_token_id`: Used in conditional logic (`jnp.where`). 187 | * `initial_seq_len`: Affects the number of scan steps. 188 | * Arguments *not* listed as static (`initial_input_ids`, `initial_finished_sequences`, `initial_rng`) are treated as dynamic JAX arrays (tracers) during compilation. 189 | 190 | * **Code Location:** `jaxgarden/models/generation_utils.py` 191 | 192 | ## Conclusion 193 | 194 | `GenerationMixin` provides a powerful and reusable mechanism for adding sophisticated autoregressive text generation capabilities to causal language models in `jaxgarden`. By leveraging `jax.lax.scan` for efficiency, offering flexible sampling strategies, and integrating seamlessly with JAX's JIT compilation, it allows developers to easily enable text generation in their models while benefiting from high performance on accelerators. Understanding its parameters and internal workings is key to effectively controlling the generation process. 195 | 196 | **Next:** [ModernBERTForMaskedLM](modernbertformaskedlm.mdc) 197 | 198 | 199 | --- 200 | 201 | Generated by [Rules for AI](https://github.com/altaidevorg/rules-for-ai) -------------------------------------------------------------------------------- /tests/models/test_generation_utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import pytest 4 | from jax import random 5 | 6 | from jaxgarden.models.generation_utils import GenerationMixin 7 | 8 | # --- Mock Model for GenerationMixin Tests --- 9 | 10 | 11 | class MockModel(GenerationMixin): 12 | """A simple mock model inheriting GenerationMixin for testing.""" 13 | 14 | def __init__(self, vocab_size, eos_token_id=None, pad_token_id=0): 15 | # Mock config 16 | self.config = { 17 | "vocab_size": vocab_size, 18 | "eos_token_id": eos_token_id, 19 | "pad_token_id": pad_token_id, 20 | } 21 | self.vocab_size = vocab_size 22 | self._eos_token_id = eos_token_id # Store separately if needed 23 | 24 | def __call__(self, input_ids, attention_mask=None, **kwargs): 25 | """Mock call that returns deterministic logits.""" 26 | batch_size, seq_len = input_ids.shape 27 | if attention_mask is None: 28 | # Fallback: If no mask, assume the sequence length is the full length 29 | current_length = jnp.array(seq_len) 30 | else: 31 | # Determine current length from attention mask 32 | # Handle 1D [max_length] or 2D [batch, max_length] masks 33 | if attention_mask.ndim == 1: 34 | current_length = jnp.sum(attention_mask) 35 | elif attention_mask.ndim == 2: 36 | # Ensure current_length is scalar if attention_mask is [1, seq_len] 37 | if attention_mask.shape[0] == 1: 38 | current_length = jnp.sum(attention_mask[0]) 39 | else: # Handle [batch, seq_len] 40 | current_length = jnp.sum(attention_mask, axis=-1) 41 | else: 42 | raise ValueError(f"Unexpected attention_mask ndim: {attention_mask.ndim}") 43 | 44 | # Ensure length is at least 1 to avoid index -1 45 | valid_length = jnp.maximum(1, current_length) 46 | last_token_index = valid_length - 1 # Index of the last valid token 47 | 48 | # Gather the last valid token for each item in the batch 49 | # Needs to handle scalar or per-batch indices 50 | if isinstance(last_token_index, int) or ( 51 | hasattr(last_token_index, "shape") and last_token_index.shape == () 52 | ): 53 | # If current_length (and thus index) is scalar across batch 54 | last_token = input_ids[:, last_token_index] # Shape [batch_size] 55 | elif hasattr(last_token_index, "shape") and last_token_index.ndim == 1: 56 | # If current_length varies per batch item (shape [batch_size]) 57 | # Use gather to select the token at last_token_index for each batch item 58 | last_token_index_expanded = last_token_index[:, None] # Shape [batch_size, 1] 59 | last_token = jnp.take_along_axis(input_ids, last_token_index_expanded, axis=1)[:, 0] 60 | else: 61 | raise TypeError(f"""Unexpected type or shape for last_token_index: 62 | {type(last_token_index)}""") 63 | 64 | # Create deterministic logits: next token is always (last_token + 1) % vocab_size 65 | next_token_logits = ( 66 | jax.nn.one_hot( 67 | (last_token + 1) % self.vocab_size, num_classes=self.vocab_size, dtype=jnp.float32 68 | ) 69 | * 10.0 70 | ) # Multiply to make it peaky 71 | # Return shape [batch_size, seq_len, vocab_size] 72 | # We only care about the last token logits for generation, 73 | # so we just broadcast it for simplicity in this mock. 74 | # A real model would compute logits based on the full sequence. 75 | return jnp.repeat(next_token_logits[:, None, :], seq_len, axis=1) 76 | 77 | 78 | # --- Tests for GenerationMixin --- 79 | 80 | 81 | @pytest.fixture 82 | def generation_setup(): 83 | key = random.PRNGKey(42) 84 | vocab_size = 10 85 | eos_token_id = 9 86 | pad_token_id = 0 87 | model = MockModel(vocab_size, eos_token_id, pad_token_id) 88 | input_ids = jnp.array([[1, 2], [5, 6]], dtype=jnp.int32) 89 | return key, model, input_ids, vocab_size, eos_token_id, pad_token_id 90 | 91 | 92 | def test_generate_greedy(generation_setup): 93 | """Tests greedy generation (do_sample=False).""" 94 | key, model, input_ids, _, eos_token_id, pad_token_id = generation_setup 95 | max_length = 7 96 | 97 | output_ids = model.generate( 98 | input_ids, 99 | max_length=max_length, 100 | do_sample=False, 101 | eos_token_id=eos_token_id, 102 | pad_token_id=pad_token_id, 103 | ) 104 | 105 | assert output_ids.shape == (input_ids.shape[0], max_length) 106 | 107 | # Expected sequence based on mock model: next = (last + 1) % vocab_size 108 | # Batch 1: [1, 2] -> 3 -> 4 -> 5 -> 6 -> 7 109 | # Batch 2: [5, 6] -> 7 -> 8 -> 9 (EOS) -> 0 (PAD) -> 0 (PAD) 110 | expected_output = jnp.array([[1, 2, 3, 4, 5, 6, 7], [5, 6, 7, 8, 9, 0, 0]], dtype=jnp.int32) 111 | assert jnp.array_equal(output_ids, expected_output) 112 | 113 | 114 | def test_generate_sampling(generation_setup): 115 | """Tests sampling generation (do_sample=True).""" 116 | key, model, input_ids, vocab_size, eos_token_id, pad_token_id = generation_setup 117 | max_length = 6 118 | 119 | # Use a very low temperature to make it nearly deterministic for testing 120 | key, subkey = random.split(key) 121 | output_ids_low_temp = model.generate( 122 | input_ids, 123 | max_length=max_length, 124 | do_sample=True, 125 | temperature=0.01, # Very low temp 126 | eos_token_id=eos_token_id, 127 | pad_token_id=pad_token_id, 128 | rng=subkey, 129 | ) 130 | 131 | # Expected sequence (likely, due to low temp): 132 | # Batch 1: [1, 2] -> 3 -> 4 -> 5 -> 6 133 | # Batch 2: [5, 6] -> 7 -> 8 -> 9 (EOS) -> 0 (PAD) 134 | expected_output_likely = jnp.array([[1, 2, 3, 4, 5, 6], [5, 6, 7, 8, 9, 0]], dtype=jnp.int32) 135 | assert output_ids_low_temp.shape == (input_ids.shape[0], max_length) 136 | # Check it's highly likely the same as greedy due to low temp 137 | assert jnp.array_equal(output_ids_low_temp, expected_output_likely) 138 | 139 | # Test reproducibility with the same key 140 | key, subkey = random.split(key) 141 | output_ids1 = model.generate( 142 | input_ids, 143 | max_length=max_length, 144 | do_sample=True, 145 | temperature=1.0, 146 | rng=subkey, 147 | eos_token_id=eos_token_id, 148 | pad_token_id=pad_token_id, 149 | ) 150 | output_ids2 = model.generate( 151 | input_ids, 152 | max_length=max_length, 153 | do_sample=True, 154 | temperature=1.0, 155 | rng=subkey, 156 | eos_token_id=eos_token_id, 157 | pad_token_id=pad_token_id, 158 | ) 159 | assert jnp.array_equal(output_ids1, output_ids2) 160 | 161 | 162 | def test_generate_max_length(generation_setup): 163 | """Tests that generation stops at max_length.""" 164 | key, model, input_ids, _, eos_token_id, pad_token_id = generation_setup 165 | max_length = 4 # Shorter than needed to reach EOS for batch 1 166 | 167 | output_ids = model.generate( 168 | input_ids, 169 | max_length=max_length, 170 | do_sample=False, 171 | eos_token_id=eos_token_id, 172 | pad_token_id=pad_token_id, 173 | ) 174 | 175 | assert output_ids.shape == (input_ids.shape[0], max_length) 176 | # Expected sequence (truncated): 177 | # Batch 1: [1, 2] -> 3 -> 4 178 | # Batch 2: [5, 6] -> 7 -> 8 179 | expected_output = jnp.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=jnp.int32) 180 | assert jnp.array_equal(output_ids, expected_output) 181 | 182 | 183 | def test_generate_eos_handling(generation_setup): 184 | """Tests EOS token handling and subsequent padding.""" 185 | key, model, input_ids, _, eos_token_id, pad_token_id = generation_setup 186 | max_length = 8 # Long enough for EOS 187 | 188 | output_ids = model.generate( 189 | input_ids, 190 | max_length=max_length, 191 | do_sample=False, 192 | eos_token_id=eos_token_id, 193 | pad_token_id=pad_token_id, 194 | ) 195 | 196 | assert output_ids.shape == (input_ids.shape[0], max_length) 197 | # Expected sequence: 198 | # Batch 1: [1, 2] -> 3 -> 4 -> 5 -> 6 -> 7 -> 8 199 | # Batch 2: [5, 6] -> 7 -> 8 -> 9 (EOS) -> 0 (PAD) -> 0 (PAD) -> 0 (PAD) 200 | expected_output = jnp.array( 201 | [[1, 2, 3, 4, 5, 6, 7, 8], [5, 6, 7, 8, 9, 0, 0, 0]], dtype=jnp.int32 202 | ) 203 | assert jnp.array_equal(output_ids, expected_output) 204 | 205 | # Test case without EOS token (should just fill to max_length) 206 | model_no_eos = MockModel(model.vocab_size, eos_token_id=None, pad_token_id=pad_token_id) 207 | output_ids_no_eos = model_no_eos.generate( 208 | input_ids, 209 | max_length=max_length, 210 | do_sample=False, 211 | eos_token_id=None, 212 | pad_token_id=pad_token_id, 213 | ) 214 | expected_output_no_eos = jnp.array( 215 | [[1, 2, 3, 4, 5, 6, 7, 8], [5, 6, 7, 8, 9, 0, 1, 2]], # Continues sequence cyclically 216 | dtype=jnp.int32, 217 | ) 218 | assert jnp.array_equal(output_ids_no_eos, expected_output_no_eos) 219 | 220 | 221 | def test_generate_padding_default(generation_setup): 222 | """Tests generation using default pad_token_id from config.""" 223 | key, model, input_ids, _, eos_token_id, pad_token_id = generation_setup 224 | max_length = 7 225 | 226 | # Don't pass pad_token_id, should use model.config['pad_token_id'] (which is 0) 227 | output_ids = model.generate( 228 | input_ids, 229 | max_length=max_length, 230 | do_sample=False, 231 | eos_token_id=eos_token_id, 232 | # pad_token_id not provided 233 | ) 234 | # Expected is same as test_generate_greedy 235 | expected_output = jnp.array([[1, 2, 3, 4, 5, 6, 7], [5, 6, 7, 8, 9, 0, 0]], dtype=jnp.int32) 236 | assert jnp.array_equal(output_ids, expected_output) 237 | 238 | 239 | def test_generate_input_length_equals_max_length(generation_setup): 240 | """Tests generation when input_ids length equals max_length.""" 241 | key, model, input_ids, _, eos_token_id, pad_token_id = generation_setup 242 | max_length = input_ids.shape[1] # 2 243 | 244 | output_ids = model.generate( 245 | input_ids, 246 | max_length=max_length, 247 | do_sample=False, 248 | eos_token_id=eos_token_id, 249 | pad_token_id=pad_token_id, 250 | ) 251 | # Should return the input_ids unchanged 252 | assert output_ids.shape == (input_ids.shape[0], max_length) 253 | assert jnp.array_equal(output_ids, input_ids) 254 | 255 | 256 | def test_generate_input_length_greater_than_max_length(generation_setup): 257 | """Tests generation when input_ids length exceeds max_length.""" 258 | key, model, input_ids, _, eos_token_id, pad_token_id = generation_setup 259 | max_length = 1 260 | 261 | output_ids = model.generate( 262 | input_ids, 263 | max_length=max_length, 264 | do_sample=False, 265 | eos_token_id=eos_token_id, 266 | pad_token_id=pad_token_id, 267 | ) 268 | # Should return input_ids truncated to max_length 269 | assert output_ids.shape == (input_ids.shape[0], max_length) 270 | assert jnp.array_equal(output_ids, input_ids[:, :max_length]) 271 | 272 | 273 | def test_generate_sampling_with_filtering(generation_setup): 274 | """Tests sampling with filtering parameters (top_k, top_p, min_p).""" 275 | key, model, input_ids, vocab_size, eos_token_id, pad_token_id = generation_setup 276 | max_length = 6 277 | 278 | # Test with top_k 279 | key, subkey = random.split(key) 280 | output_ids_top_k = model.generate( 281 | input_ids, 282 | max_length=max_length, 283 | do_sample=True, 284 | top_k=3, 285 | rng=subkey, 286 | eos_token_id=eos_token_id, 287 | pad_token_id=pad_token_id, 288 | ) 289 | assert output_ids_top_k.shape == (input_ids.shape[0], max_length) 290 | 291 | # Test with top_p 292 | key, subkey = random.split(key) 293 | output_ids_top_p = model.generate( 294 | input_ids, 295 | max_length=max_length, 296 | do_sample=True, 297 | top_p=0.8, 298 | rng=subkey, 299 | eos_token_id=eos_token_id, 300 | pad_token_id=pad_token_id, 301 | ) 302 | assert output_ids_top_p.shape == (input_ids.shape[0], max_length) 303 | 304 | # Test with min_p 305 | key, subkey = random.split(key) 306 | output_ids_min_p = model.generate( 307 | input_ids, 308 | max_length=max_length, 309 | do_sample=True, 310 | min_p=0.1, 311 | rng=subkey, 312 | eos_token_id=eos_token_id, 313 | pad_token_id=pad_token_id, 314 | ) 315 | assert output_ids_min_p.shape == (input_ids.shape[0], max_length) 316 | 317 | # Test with multiple filtering methods 318 | key, subkey = random.split(key) 319 | output_ids_combo = model.generate( 320 | input_ids, 321 | max_length=max_length, 322 | do_sample=True, 323 | top_k=3, 324 | top_p=0.8, 325 | min_p=0.1, 326 | rng=subkey, 327 | eos_token_id=eos_token_id, 328 | pad_token_id=pad_token_id, 329 | ) 330 | assert output_ids_combo.shape == (input_ids.shape[0], max_length) 331 | 332 | 333 | def test_generate_rng_none_warning(generation_setup, capsys): 334 | """Tests that a warning is printed when do_sample is True but rng is None.""" 335 | key, model, input_ids, _, eos_token_id, pad_token_id = generation_setup 336 | max_length = 5 337 | 338 | model.generate( 339 | input_ids, 340 | max_length=max_length, 341 | do_sample=True, 342 | rng=None, # No RNG key provided 343 | eos_token_id=eos_token_id, 344 | pad_token_id=pad_token_id, 345 | ) 346 | 347 | captured = capsys.readouterr() 348 | assert "Warning: No RNG key provided for sampling, using default key 0." in captured.out 349 | 350 | 351 | def test_generate_no_sample_rng_none(generation_setup): 352 | """Tests that do_sample=False (greedy) works without rng.""" 353 | key, model, input_ids, _, eos_token_id, pad_token_id = generation_setup 354 | max_length = 5 355 | output_ids = model.generate( 356 | input_ids, 357 | max_length=max_length, 358 | do_sample=False, 359 | rng=None, # No RNG key provided 360 | eos_token_id=eos_token_id, 361 | pad_token_id=pad_token_id, 362 | ) 363 | 364 | assert output_ids.shape == (input_ids.shape[0], max_length) 365 | 366 | 367 | def test_generate_jax_compile(): 368 | """Tests that jax.jit can be used with generate without error.""" 369 | vocab_size = 10 370 | eos_token_id = 9 371 | pad_token_id = 0 372 | model = MockModel(vocab_size, eos_token_id, pad_token_id) 373 | input_ids = jnp.array([[1, 2], [5, 6]], dtype=jnp.int32) 374 | max_length = 7 375 | 376 | output_ids = model.generate( 377 | input_ids, 378 | max_length=max_length, 379 | do_sample=False, 380 | eos_token_id=eos_token_id, 381 | pad_token_id=pad_token_id, 382 | use_jit=True, 383 | ) 384 | 385 | assert output_ids.shape == (input_ids.shape[0], max_length) 386 | 387 | # Expected sequence based on mock model: next = (last + 1) % vocab_size 388 | # Batch 1: [1, 2] -> 3 -> 4 -> 5 -> 6 -> 7 389 | # Batch 2: [5, 6] -> 7 -> 8 -> 9 (EOS) -> 0 (PAD) -> 0 (PAD) 390 | expected_output = jnp.array([[1, 2, 3, 4, 5, 6, 7], [5, 6, 7, 8, 9, 0, 0]], dtype=jnp.int32) 391 | assert jnp.array_equal(output_ids, expected_output) # Should match greedy output 392 | 393 | # Test that jax.jit does not affect the output 394 | output_ids2 = model.generate( 395 | input_ids, 396 | max_length=max_length, 397 | do_sample=False, 398 | eos_token_id=eos_token_id, 399 | pad_token_id=pad_token_id, 400 | use_jit=True, 401 | ) 402 | assert jnp.array_equal(output_ids, output_ids2) 403 | -------------------------------------------------------------------------------- /.cursor/rules/tokenizer.mdc: -------------------------------------------------------------------------------- 1 | --- 2 | description: Tutorial chapter for the jaxgarden Tokenizer, detailing text encoding/decoding and chat templating for JAX. 3 | globs: 4 | alwaysApply: false 5 | --- 6 | # Chapter 1: Tokenizer 7 | 8 | Welcome to the `jaxgarden` library tutorial! This first chapter introduces the `Tokenizer` class, a fundamental component for processing text data within JAX-based Natural Language Processing (NLP) models. 9 | 10 | **Motivation:** Deep learning models, especially those built with JAX, operate on numerical tensors. Raw text needs to be converted into a numerical format (token IDs) that models can understand, and conversely, model outputs (token IDs) need to be converted back into human-readable text. Furthermore, different models, particularly instruction-tuned ones, expect conversational inputs to be formatted in specific ways (chat templates). The Hugging Face `tokenizers` library is excellent for this, but its outputs are standard Python lists. `jaxgarden.Tokenizer` wraps this library to provide a seamless experience for JAX users, returning `jax.numpy.ndarray` (jnp arrays) directly and integrating features like chat templating. 11 | 12 | **Central Use Case:** Preparing text input for a JAX-based language model like [LlamaForCausalLM](llamaforcausallm.mdc) and decoding its generated token IDs. For conversational models, formatting user prompts and conversation history according to the model's specific chat template is crucial. 13 | 14 | ## Key Concepts 15 | 16 | The `jaxgarden.Tokenizer` provides several core functionalities: 17 | 18 | 1. **Loading:** Instantiating a tokenizer from pre-trained configurations stored on the Hugging Face Hub or locally. 19 | 2. **Encoding:** Converting text strings into sequences of token IDs, handling padding and truncation, and returning JAX arrays. 20 | 3. **Decoding:** Converting sequences of token IDs back into text strings. 21 | 4. **Special Token Management:** Automatically identifying or allowing specification of crucial tokens like Beginning-of-Sequence (BOS), End-of-Sequence (EOS), and Padding (PAD). 22 | 5. **Chat Templating:** Applying Jinja-based templates to format conversational data for instruction-tuned models. 23 | 24 | ## Using the Tokenizer 25 | 26 | Let's explore how to use the `Tokenizer`. 27 | 28 | ### Loading a Tokenizer 29 | 30 | The primary way to get a `Tokenizer` instance is using the `from_pretrained` class method. You provide a model identifier from the Hugging Face Hub (e.g., `"gpt2"`, `"meta-llama/Llama-2-7b-chat-hf"`) or a path to a local directory containing `tokenizer.json` and optionally `tokenizer_config.json`. 31 | 32 | ```python 33 | # Assuming jaxgarden is installed 34 | from jaxgarden.tokenization import Tokenizer 35 | 36 | # Load from Hugging Face Hub 37 | tokenizer = Tokenizer.from_pretrained("gpt2") 38 | 39 | # Example: Load from a local directory (if you have one) 40 | # tokenizer_local = Tokenizer.from_pretrained("./path/to/local_tokenizer_files") 41 | 42 | print(f"Loaded tokenizer for 'gpt2' with vocab size: {tokenizer.vocab_size}") 43 | print(f"Pad token: {tokenizer.pad_token}, ID: {tokenizer.pad_token_id}") 44 | print(f"BOS token: {tokenizer.bos_token}, ID: {tokenizer.bos_token_id}") 45 | print(f"EOS token: {tokenizer.eos_token}, ID: {tokenizer.eos_token_id}") 46 | ``` 47 | 48 | **Explanation:** `from_pretrained` downloads necessary files (`tokenizer.json`, `tokenizer_config.json`) from the Hub or reads them locally. It then instantiates the underlying Hugging Face `tokenizers.Tokenizer` and extracts configuration like special tokens and chat templates (if available in `tokenizer_config.json`). The `jaxgarden.Tokenizer` wrapper uses this information to set its own attributes like `pad_token_id`, `bos_token_id`, etc. 49 | 50 | ### Encoding Text 51 | 52 | The `encode` method converts text into token IDs. It offers options for handling batches, padding, and truncation, returning JAX arrays by default. 53 | 54 | ```python 55 | import jax.numpy as jnp 56 | 57 | text = "Hello, world!" 58 | batch_text = ["First sequence.", "This is a second sequence."] 59 | 60 | # Basic encoding 61 | encoded_single = tokenizer.encode(text) 62 | print("Encoded Single:", encoded_single) 63 | # Output: Encoded Single: {'input_ids': DeviceArray([[50256, 15496, 11, 1917, 25, 50256]], dtype=int32), 64 | # 'attention_mask': DeviceArray([[1, 1, 1, 1, 1, 1]], dtype=int32)} 65 | 66 | # Encoding a batch with padding to the longest sequence 67 | encoded_batch = tokenizer.encode(batch_text, padding=True, add_special_tokens=False) 68 | print("Encoded Batch (padded):", encoded_batch) 69 | # Output: Encoded Batch (padded): { 70 | # 'input_ids': DeviceArray([[ 8285, 16337, 13, 50256, 50256, 50256], 71 | # [ 1212, 318, 257, 1144, 16337, 13]], dtype=int32), 72 | # 'attention_mask': DeviceArray([[1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1]], dtype=int32) } 73 | 74 | 75 | # Encoding with truncation and padding to a max length 76 | encoded_truncated = tokenizer.encode( 77 | batch_text, padding="max_length", truncation=True, max_length=5, add_special_tokens=False 78 | ) 79 | print("Encoded Batch (truncated/padded):", encoded_truncated) 80 | # Output: Encoded Batch (truncated/padded): { 81 | # 'input_ids': DeviceArray([[ 8285, 16337, 13, 50256, 50256], 82 | # [ 1212, 318, 257, 1144, 16337]], dtype=int32), 83 | # 'attention_mask': DeviceArray([[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]], dtype=int32)} 84 | 85 | ``` 86 | 87 | **Explanation:** 88 | - `text`: The input string or list of strings. 89 | - `add_special_tokens`: Controls whether BOS/EOS tokens are added (based on tokenizer config). Default is `True`. 90 | - `padding`: Can be `False` (no padding), `True` or `'longest'` (pad to the longest sequence in the batch), or `'max_length'` (pad to `max_length`). Requires `max_length` if set to `'max_length'`. 91 | - `truncation`: Boolean. If `True`, truncates sequences longer than `max_length`. Requires `max_length`. 92 | - `max_length`: Integer specifying the target length for padding or truncation. 93 | - `return_tensors`: Set to `"jax"` to get `jnp.ndarray`. Set to `None` to get lists of integers. 94 | The output is a dictionary containing `input_ids` and `attention_mask` as JAX arrays. The attention mask indicates which tokens are real (1) and which are padding (0). 95 | 96 | ### Decoding Token IDs 97 | 98 | The `decode` method converts token IDs back into human-readable text. 99 | 100 | ```python 101 | # Use the 'input_ids' from the previous encoding example 102 | ids_to_decode = encoded_batch['input_ids'] # Example: DeviceArray([[ 8285, 16337, 13, 50256, 50256, 50256], ...]) 103 | 104 | # Decode the batch 105 | decoded_text = tokenizer.decode(ids_to_decode, skip_special_tokens=True) 106 | print("Decoded Text:", decoded_text) 107 | # Output: Decoded Text: ['First sequence.', 'This is a second sequence.'] 108 | 109 | # Decode a single sequence (e.g., the first one from the batch) 110 | single_decoded = tokenizer.decode(ids_to_decode[0], skip_special_tokens=True) 111 | print("Single Decoded:", single_decoded) 112 | # Output: Single Decoded: First sequence. 113 | ``` 114 | 115 | **Explanation:** 116 | - `token_ids`: A list of integers, a list of lists of integers, or a JAX array. 117 | - `skip_special_tokens`: If `True` (default), removes special tokens like BOS, EOS, PAD from the output string(s). 118 | 119 | ### Applying Chat Templates 120 | 121 | For models fine-tuned on conversational data, inputs must be formatted correctly. The `apply_chat_template` method uses a Jinja template (either defined in the tokenizer's config or provided explicitly) to structure conversations. 122 | 123 | ```python 124 | # Load a tokenizer known to have a chat template 125 | # Note: You might need to log in to Hugging Face CLI: `huggingface-cli login` 126 | try: 127 | chat_tokenizer = Tokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") # Example 128 | except Exception as e: 129 | print(f"Skipping chat template example: Could not load Llama-2 tokenizer. Error: {e}") 130 | chat_tokenizer = None # Set to None to skip the rest of the block 131 | 132 | if chat_tokenizer and chat_tokenizer.chat_template: 133 | conversation = [ 134 | {"role": "user", "content": "Hello, how are you?"}, 135 | {"role": "assistant", "content": "I'm doing well, thank you!"}, 136 | {"role": "user", "content": "What is JAX?"} 137 | ] 138 | 139 | # Apply the template to format the conversation 140 | formatted_prompt = chat_tokenizer.apply_chat_template( 141 | conversation, 142 | add_generation_prompt=True # Adds the prompt for the assistant's turn 143 | ) 144 | print("\nFormatted Chat Prompt:\n", formatted_prompt) 145 | 146 | # Example of Llama-2 Chat format (output structure may vary slightly): 147 | # Formatted Chat Prompt: 148 | # [INST] Hello, how are you? [/INST] I'm doing well, thank you! [INST] What is JAX? [/INST] 149 | else: 150 | print("\nSkipping chat template example: No chat_tokenizer or template found.") 151 | 152 | ``` 153 | 154 | **Explanation:** 155 | - `conversation`: A list of dictionaries, each with a `role` (e.g., 'user', 'assistant', 'system') and `content` (the message text). 156 | - `chat_template`: Optional override for the tokenizer's default template. 157 | - `add_generation_prompt`: A common pattern in chat templates is to append tokens indicating it's the assistant's turn to speak. Setting this to `True` often achieves this, though the exact implementation depends on the template itself. 158 | - `**kwargs`: Allows passing additional variables to the Jinja template. 159 | The method returns the formatted string, ready to be encoded and fed into the model. 160 | 161 | ## Internal Implementation 162 | 163 | Understanding the internals helps in debugging and extending functionality. 164 | 165 | 1. **Initialization (`__init__`)**: 166 | - Stores the Hugging Face `HfTokenizer` instance. 167 | - Tries to infer `bos_token`, `eos_token`, and `pad_token` from the `HfTokenizer`'s known special tokens or common defaults (e.g., ``, ``, `[PAD]`). It logs warnings if defaults or fallbacks are used. 168 | - If `pad_token` isn't explicitly provided or found, it defaults to `eos_token` (with a warning) or raises an error if no suitable token can be determined. 169 | - It fetches the corresponding IDs (`bos_token_id`, etc.) using `hf_tokenizer.token_to_id`. 170 | - Critically, it ensures the `HfTokenizer` has padding configured using the determined `pad_token_id` and `pad_token`, either by enabling it or verifying/correcting existing padding settings. 171 | 172 | ```python 173 | # Simplified __init__ logic for special tokens 174 | def __init__(self, hf_tokenizer, chat_template=None, bos_token=None, /*...*/ pad_token=None): 175 | self.hf_tokenizer = hf_tokenizer 176 | # ... (chat_template assignment) ... 177 | 178 | # Infer BOS/EOS (simplified example) 179 | self.bos_token = bos_token or self._infer_token(hf_tokenizer, ["[BOS]", ""]) 180 | self.eos_token = eos_token or self._infer_token(hf_tokenizer, ["[EOS]", ""]) 181 | 182 | # Infer/Validate PAD token 183 | if pad_token: 184 | self.pad_token = pad_token 185 | elif hf_tokenizer.padding and hf_tokenizer.padding.get("pad_token"): 186 | self.pad_token = hf_tokenizer.padding["pad_token"] 187 | # ... (fallback logic using EOS or searching for [PAD]/) ... 188 | else: 189 | raise ValueError("Cannot determine padding token.") 190 | 191 | # Get IDs 192 | self.bos_token_id = hf_tokenizer.token_to_id(self.bos_token) # if self.bos_token else None 193 | self.eos_token_id = hf_tokenizer.token_to_id(self.eos_token) # if self.eos_token else None 194 | self.pad_token_id = hf_tokenizer.token_to_id(self.pad_token) 195 | # ... (error if pad_token_id is None) ... 196 | 197 | # Configure HF tokenizer padding 198 | self.hf_tokenizer.enable_padding(pad_id=self.pad_token_id, pad_token=self.pad_token) 199 | # ... (logic to handle pre-existing padding config) ... 200 | 201 | # Helper function (conceptual) 202 | def _infer_token(self, hf_tok, candidates): 203 | for token_str in candidates: 204 | if hf_tok.token_to_id(token_str) is not None: 205 | return token_str 206 | return None 207 | ``` 208 | 209 | 2. **Loading (`from_pretrained`)**: 210 | - Checks if the `identifier` is a local path. 211 | - If local: Looks for `tokenizer.json` and `tokenizer_config.json` in that directory. 212 | - If not local (Hub): Uses `huggingface_hub.hf_hub_download` to fetch the files. 213 | - Loads the core tokenizer from `tokenizer.json` using `HfTokenizer.from_file`. 214 | - Loads the configuration (chat template, special tokens) from `tokenizer_config.json` if it exists. 215 | - Prioritizes explicitly passed arguments (`bos_token`, `eos_token`, etc.) over values found in `tokenizer_config.json`. 216 | - Calls the `__init__` method with the loaded `HfTokenizer` and extracted/provided configuration. 217 | 218 | 3. **Encoding (`encode`)**: 219 | - Temporarily configures the `hf_tokenizer` instance based on `padding`, `truncation`, and `max_length` arguments by calling `hf_tokenizer.enable_padding`, `hf_tokenizer.enable_truncation`, or `hf_tokenizer.no_padding`/`no_truncation`. 220 | - Configures the `hf_tokenizer.post_processor` using `TemplateProcessing` if `add_special_tokens=True` and BOS/EOS tokens are available, otherwise sets it to `None`. This handles adding BOS/EOS correctly. 221 | - Calls `hf_tokenizer.encode` (for single text) or `hf_tokenizer.encode_batch` (for list of texts). 222 | - Extracts the `ids` and `attention_mask` from the result(s). 223 | - If `return_tensors="jax"`, converts the lists of integers into `jnp.int32` arrays using `jnp.array`. It handles potential raggedness after encoding if `padding='longest'` was used by manually padding again to the true max length found in the batch *after* encoding. 224 | - Returns the dictionary `{'input_ids': ..., 'attention_mask': ...}`. 225 | 226 | ```mermaid 227 | sequenceDiagram 228 | participant User 229 | participant Tokenizer as JaxTokenizer 230 | participant HfTokenizer as HF Tokenizer 231 | participant JNP as jax.numpy 232 | 233 | User->>+JaxTokenizer: encode(text, padding=True, ...) 234 | JaxTokenizer->>+HfTokenizer: enable_padding(pad_id, pad_token, ...) 235 | JaxTokenizer->>+HfTokenizer: enable_truncation(...) / no_truncation() (based on args) 236 | JaxTokenizer->>+HfTokenizer: set post_processor (for special tokens) 237 | JaxTokenizer->>+HfTokenizer: encode_batch(text) / encode(text) 238 | HfTokenizer-->>-JaxTokenizer: list[Encoding] (contains ids, attention_mask) 239 | JaxTokenizer->>JaxTokenizer: Extract ids and masks into lists 240 | alt return_tensors == "jax" 241 | JaxTokenizer->>JNP: array(ids_list, dtype=int32) 242 | JNP-->>JaxTokenizer: ids_array (jnp.ndarray) 243 | JaxTokenizer->>JNP: array(mask_list, dtype=int32) 244 | JNP-->>JaxTokenizer: mask_array (jnp.ndarray) 245 | JaxTokenizer-->>-User: {'input_ids': ids_array, 'attention_mask': mask_array} 246 | else 247 | JaxTokenizer-->>-User: {'input_ids': ids_list, 'attention_mask': mask_list} 248 | end 249 | 250 | ``` 251 | 252 | 4. **Decoding (`decode`)**: 253 | - Converts input `jnp.ndarray` to lists if necessary. 254 | - Determines if the input is a single sequence or a batch. 255 | - Calls the underlying `hf_tokenizer.decode` or `hf_tokenizer.decode_batch` method, passing `skip_special_tokens`. 256 | - Returns the resulting string or list of strings. 257 | 258 | 5. **Chat Templating (`apply_chat_template`)**: 259 | - Selects the Jinja template (either `self.chat_template` or the one passed as an argument). Raises an error if none is available. 260 | - Creates a dictionary of variables to pass to the template, including `messages`, `bos_token`, `eos_token`, and any `**kwargs`. 261 | - Performs basic validation on the template structure (optional, checks for expected variables like `messages`). 262 | - Modifies the template string if `add_generation_prompt=True` (using a common but potentially model-specific pattern). 263 | - Creates a Jinja environment and renders the template with the variables. 264 | - Returns the rendered string. 265 | 266 | ## Conclusion 267 | 268 | The `jaxgarden.Tokenizer` provides a crucial bridge between raw text and JAX-based NLP models. It leverages the power of Hugging Face `tokenizers` while ensuring compatibility with JAX workflows by returning `jnp.ndarray` objects. Key functionalities include easy loading from the Hub/local files, robust encoding/decoding with padding and truncation control, automatic handling of special tokens, and essential chat templating for conversational AI. 269 | 270 | Understanding how to use the `Tokenizer` is the first step in building or using models within `jaxgarden`. The next chapter will introduce the foundational building blocks for models themselves. 271 | 272 | **Next:** [BaseModel](basemodel.mdc) 273 | 274 | 275 | --- 276 | 277 | Generated by [Rules for AI](https://github.com/altaidevorg/rules-for-ai) -------------------------------------------------------------------------------- /.cursor/rules/llamaforcausallm.mdc: -------------------------------------------------------------------------------- 1 | --- 2 | description: Details the jaxgarden LlamaForCausalLM model for causal language modeling, covering architecture, components, and HF weight conversion. 3 | globs: jaxgarden/models/llama.py 4 | alwaysApply: false 5 | --- 6 | # Chapter 4: LlamaForCausalLM 7 | 8 | In the [previous chapter](baseconfig.mdc), we examined `BaseConfig`, the foundation for configuring models in `jaxgarden`. We saw how model-specific configurations like `LlamaConfig` inherit from it to define hyperparameters. Now, we dive into the complete model implementation that uses this configuration: `LlamaForCausalLM`. 9 | 10 | **Motivation:** Large language models like Meta's Llama have demonstrated remarkable capabilities in natural language understanding and generation. Implementing such complex architectures efficiently within the JAX ecosystem requires careful integration of various components (attention mechanisms, normalization layers, embeddings) and adherence to best practices for performance and state management. `LlamaForCausalLM` provides a faithful and optimized implementation of the Llama architecture, ready for training and inference using JAX and Flax NNX. 11 | 12 | **Central Use Case:** Loading pretrained Llama weights (e.g., from Hugging Face Hub) into a `jaxgarden.LlamaForCausalLM` instance and using it for autoregressive text generation. This involves initializing the model with the correct `LlamaConfig`, leveraging the `from_hf` method inherited from [BaseModel](basemodel.mdc) for weight conversion, and then using the `generate` method from the [GenerationMixin](generationmixin.mdc) for text synthesis. 13 | 14 | ## Key Concepts 15 | 16 | `LlamaForCausalLM` integrates several advanced transformer components into a cohesive causal language model: 17 | 18 | 1. **Core Structure:** It inherits from [BaseModel](basemodel.mdc) for configuration, state management, and HF integration capabilities, and from [GenerationMixin](generationmixin.mdc) for text generation methods. 19 | 2. **Token Embeddings (`nnx.Embed`):** Maps input token IDs from a vocabulary to dense vector representations (hidden states). 20 | 3. **`LlamaTransformerBlock`:** The main building block, repeated multiple times (`n_layers` specified in `LlamaConfig`). Each block contains: 21 | * **`LlamaRMSNorm` (Pre-Normalization):** Applied before the attention and MLP layers for improved training stability. RMSNorm is a simpler and often faster alternative to LayerNorm. 22 | * **`LlamaAttention`:** Implements multi-head self-attention. It incorporates [Rotary Position Embeddings (RoPE)](rotary_position_embeddings__rope_.mdc) via `LlamaRotaryEmbedding` to inject positional information dynamically. It also supports Grouped Query Attention (GQA) where the number of key/value heads (`n_kv_heads`) can be smaller than the number of query heads (`n_heads`) for reduced computational cost and memory footprint, particularly during inference. 23 | * **`LlamaMLP`:** A feed-forward network using the SwiGLU activation function (`silu(gate(x)) * up(x)`), which has shown strong performance in recent models. 24 | 4. **Final Normalization & LM Head:** After the last `LlamaTransformerBlock`, a final `LlamaRMSNorm` is applied. The output is then passed through a linear layer (`lm_head`) that projects the final hidden state back to the vocabulary size, producing logits for the next token prediction. 25 | 5. **Weight Tying:** The weights of the `lm_head` are typically tied to the weights of the `token_embed` layer. This reduces the total number of parameters and can improve performance. This tying is handled during weight initialization and conversion. 26 | 6. **HF Weight Conversion (`convert_weights_from_hf`):** Implements the logic to map parameter names and shapes from Hugging Face Llama checkpoints (stored in Safetensors format) to the specific structure and naming convention of the `jaxgarden` `LlamaForCausalLM` state. 27 | 7. **Text Generation:** Inherits the `generate` method from [GenerationMixin](generationmixin.mdc), enabling autoregressive text generation with sampling strategies like temperature, top-k, and top-p. 28 | 29 | ## Using `LlamaForCausalLM` 30 | 31 | Let's see how to initialize, load weights, and use the model. 32 | 33 | ### Initialization 34 | 35 | First, define the configuration (`LlamaConfig`) and instantiate the model using `nnx.Rngs`. 36 | 37 | ```python 38 | import jax 39 | import jax.numpy as jnp 40 | from flax import nnx 41 | from jaxgarden.models.llama import LlamaConfig, LlamaForCausalLM 42 | 43 | # Example: Configuration for a small Llama model 44 | config = LlamaConfig( 45 | dim=512, 46 | n_layers=4, 47 | n_heads=8, 48 | n_kv_heads=4, # GQA enabled (n_kv_heads < n_heads) 49 | head_dim=64, 50 | intermediate_size=1024, 51 | vocab_size=10000, # Example vocabulary size 52 | norm_eps=1e-5, 53 | rope_theta=10000.0 54 | ) 55 | 56 | # Initialize PRNG keys 57 | rngs = nnx.Rngs(params=0) # Or use more sophisticated key management 58 | 59 | # Instantiate the model 60 | model = LlamaForCausalLM(config, rngs=rngs, param_dtype=jnp.bfloat16) 61 | 62 | print(f"Initialized LlamaForCausalLM with {config.n_layers} layers.") 63 | # Output: Initialized LlamaForCausalLM with 4 layers. 64 | ``` 65 | **Explanation:** We create a `LlamaConfig` dataclass instance, specifying the architecture details. We then pass this config and an `nnx.Rngs` object to the `LlamaForCausalLM` constructor. `param_dtype=jnp.bfloat16` is often used for large models to save memory. 66 | 67 | ### Loading Pretrained Weights 68 | 69 | To load weights from a Hugging Face checkpoint (assuming you have one compatible, e.g., `"meta-llama/Llama-2-7b-hf"`), use the `from_hf` method. 70 | 71 | ```python 72 | # hf_model_id = "meta-llama/Llama-2-7b-hf" # Example HF model ID 73 | # print(f"Attempting to load weights from {hf_model_id}...") 74 | # try: 75 | # # Ensure config matches the HF model being loaded 76 | # # config_hf = LlamaConfig(...) # Load/define config matching the HF model 77 | # # model_hf = LlamaForCausalLM(config_hf, rngs=nnx.Rngs(1), param_dtype=jnp.bfloat16) 78 | # # model_hf.from_hf(hf_model_id, save_in_orbax=False, remove_hf_after_conversion=True) 79 | # # print("Weights loaded successfully.") 80 | # except Exception as e: 81 | # # print(f"Skipping HF weight loading example due to error: {e}") 82 | # pass 83 | print("Skipping actual HF weight loading execution.") 84 | ``` 85 | **Explanation:** Calling `model.from_hf(hf_model_id)` triggers the download (via `BaseModel.download_from_hf`), weight iteration (`BaseModel.iter_safetensors`), and crucially, the conversion logic defined in `LlamaForCausalLM.convert_weights_from_hf`. The model's state is updated in place with the loaded weights. Ensure the `LlamaConfig` used to initialize the model matches the architecture of the Hugging Face checkpoint. 86 | 87 | ### Forward Pass (`__call__`) 88 | 89 | The `__call__` method performs a standard forward pass, taking token IDs and an optional attention mask, and returning logits. 90 | 91 | ```python 92 | # Prepare dummy input (Batch size 1, sequence length 10) 93 | batch_size = 1 94 | seq_len = 10 95 | dummy_input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) 96 | dummy_attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) # 1s indicate valid tokens 97 | 98 | # Perform forward pass 99 | # Note: Actual model execution requires JIT compilation or running on device 100 | @jax.jit 101 | def forward_pass(model_state, input_ids, attention_mask): 102 | # Need to split model for JIT 103 | graphdef, params = nnx.split(model, nnx.Param, ...) 104 | logits = graphdef(input_ids=input_ids, attention_mask=attention_mask) 105 | return logits 106 | 107 | # We won't execute the JITted function here, just show the structure 108 | # logits = forward_pass(model.state, dummy_input_ids, dummy_attention_mask) 109 | # print(f"Output logits shape: (batch_size, seq_len, vocab_size) = {logits.shape}") 110 | # Expected output shape if run: (1, 10, 10000) 111 | print(f"Expected output logits shape: ({batch_size}, {seq_len}, {config.vocab_size})") 112 | # Output: Expected output logits shape: (1, 10, 10000) 113 | ``` 114 | **Explanation:** The `__call__` method expects `input_ids` (shape `[batch_size, seq_len]`) and an optional `attention_mask` (same shape, 1 for real tokens, 0 for padding). It returns logits (shape `[batch_size, seq_len, vocab_size]`). For performance, the forward pass is typically JIT-compiled. Note that the current implementation asserts `batch_size == 1`. 115 | 116 | ### Text Generation (`generate`) 117 | 118 | Use the inherited `generate` method for autoregressive text generation. 119 | 120 | ```python 121 | from jaxgarden.tokenization import Tokenizer # Assuming Tokenizer is available 122 | 123 | # Assume 'model' is initialized and potentially loaded with weights 124 | # Assume 'tokenizer' is loaded, e.g., Tokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") 125 | 126 | prompt = "The capital of France is" 127 | # encoded_prompt = tokenizer.encode(prompt, return_tensors="jax") 128 | # input_ids = encoded_prompt['input_ids'] 129 | # attention_mask = encoded_prompt['attention_mask'] 130 | 131 | # # Set generation parameters 132 | # max_new_tokens = 10 133 | # target_max_length = input_ids.shape[1] + max_new_tokens 134 | 135 | # # Generate text (requires JIT or device execution) 136 | # generation_rng = jax.random.PRNGKey(42) 137 | # # output_ids = model.generate( 138 | # # input_ids, 139 | # # attention_mask=attention_mask, 140 | # # max_length=target_max_length, 141 | # # temperature=0.7, 142 | # # top_k=50, 143 | # # do_sample=True, 144 | # # pad_token_id=tokenizer.pad_token_id, 145 | # # eos_token_id=tokenizer.eos_token_id, 146 | # # rng=generation_rng, 147 | # # use_jit=True # Recommended for performance 148 | # # ) 149 | 150 | # # Decode the output 151 | # # decoded_output = tokenizer.decode(output_ids[0], skip_special_tokens=True) 152 | # # print(f"Prompt: {prompt}") 153 | # # print(f"Generated Text: {decoded_output}") 154 | # Expected Output (conceptual): 155 | # Prompt: The capital of France is 156 | # Generated Text: The capital of France is Paris. It is known for... 157 | print("Skipping actual text generation execution.") 158 | ``` 159 | **Explanation:** We encode a prompt using the [Tokenizer](tokenizer.mdc). The resulting `input_ids` and `attention_mask` are passed to `model.generate`, along with generation parameters (`max_length`, `temperature`, `top_k`, `do_sample`, etc.) and token IDs for padding and end-of-sequence. The `generate` method (detailed in [GenerationMixin](generationmixin.mdc)) handles the autoregressive sampling loop. The output token IDs are then decoded back into text. Using `use_jit=True` is highly recommended for efficient generation. 160 | 161 | ## Internal Implementation 162 | 163 | Understanding the internal structure helps in debugging and potential extensions. 164 | 165 | * **File:** `jaxgarden/models/llama.py` 166 | 167 | 1. **Initialization (`__init__`)**: 168 | * Calls `super().__init__` from [BaseModel](basemodel.mdc). 169 | * Initializes `self.token_embed` as an `nnx.Embed` layer. 170 | * Creates a list of `LlamaTransformerBlock` instances (`self.layers`), passing the configuration and layer index to each. 171 | * Initializes the final `self.norm` as a `LlamaRMSNorm` layer. 172 | * Initializes `self.lm_head` as an `nnx.Linear` layer. Weight tying with `self.token_embed` happens implicitly during weight loading/conversion (in `convert_weights_from_hf`) or potentially during custom initialization logic not shown here. 173 | 174 | 2. **Forward Pass (`__call__`)**: 175 | 176 | ```mermaid 177 | sequenceDiagram 178 | participant Input 179 | participant LlamaForCausalLM 180 | participant Embed as nnx.Embed 181 | participant Blocks as LlamaTransformerBlock (Loop) 182 | participant Norm as LlamaRMSNorm 183 | participant Head as nnx.Linear (lm_head) 184 | participant Output 185 | 186 | Input->>+LlamaForCausalLM: __call__(input_ids, attention_mask) 187 | LlamaForCausalLM->>LlamaForCausalLM: Calculate position_ids 188 | LlamaForCausalLM->>LlamaForCausalLM: Convert attention_mask to additive mask 189 | LlamaForCausalLM->>+Embed: token_embed(input_ids) 190 | Embed-->>-LlamaForCausalLM: hidden_states (x) 191 | loop N Layers 192 | LlamaForCausalLM->>+Blocks: layer(x, position_ids, additive_mask) 193 | Blocks-->>-LlamaForCausalLM: updated hidden_states (x) 194 | end 195 | LlamaForCausalLM->>+Norm: norm(x) 196 | Norm-->>-LlamaForCausalLM: normalized_states 197 | LlamaForCausalLM->>+Head: lm_head(normalized_states) 198 | Head-->>-LlamaForCausalLM: logits 199 | LlamaForCausalLM-->>-Output: logits 200 | ``` 201 | * Calculates `position_ids` based on the sequence length. 202 | * Converts the boolean `attention_mask` into an additive mask (0.0 for valid, -inf for masked) suitable for softmax normalization in attention. 203 | * Passes `input_ids` through `self.token_embed`. 204 | * Iteratively passes the hidden states `x` through each `LlamaTransformerBlock` in `self.layers`, along with `position_ids` and the `attention_mask`. 205 | * Applies the final `self.norm`. 206 | * Applies `self.lm_head` to get the final logits. 207 | 208 | 3. **Hugging Face Weight Conversion (`convert_weights_from_hf`)**: 209 | * This method receives the model's `state` (an `nnx.State` object or dict) and an iterator yielding `(hf_key, hf_tensor)` tuples from the Safetensors files. 210 | * It iterates through the HF weights and maps the HF key names to the corresponding `jaxgarden` state keys. 211 | * Key transformations include: 212 | * Renaming: e.g., `model.layers..self_attn...` -> `state["layers"][]["attention"]...` 213 | * Transposition: Weights from `nn.Linear` layers in HF PyTorch often need to be transposed (`.T`) for Flax `nnx.Linear` kernel shape `[in_features, out_features]`. 214 | * Direct Assignment: Normalization weights (`model.layers..input_layernorm.weight`) are directly assigned (`state["layers"][]["input_layernorm"]["norm_weights"].value = tensor`). 215 | * Weight Tying: The embedding weights (`model.embed_tokens.weight`) are assigned to both `state["token_embed"].embedding` and (after transposition) `state["lm_head"].kernel`. 216 | 217 | ```python 218 | # Snippet from LlamaForCausalLM.convert_weights_from_hf 219 | def convert_weights_from_hf( 220 | self, state: nnx.State | dict[str, jnp.ndarray], weights: Iterator[tuple[Any, Any]] 221 | ) -> None: 222 | for wholekey, tensor in weights: 223 | keys = wholekey.split(".") 224 | if keys[0] == "model": # Strip 'model.' prefix often found in HF checkpoints 225 | keys = keys[1:] 226 | 227 | if keys[0] == "layers": 228 | layer_idx = int(keys[1]) 229 | component = keys[2] 230 | param_name = keys[3] 231 | target_state = state["layers"][layer_idx] 232 | 233 | if component == "self_attn": component = "attention" # Rename 234 | 235 | if component in ["attention", "mlp"]: 236 | # Linear layer weights need transpose 237 | target_state[component][param_name]["kernel"].value = tensor.T 238 | elif component in ["input_layernorm", "post_attention_layernorm"]: 239 | # RMSNorm weights 240 | target_state[component]["norm_weights"].value = tensor 241 | elif keys[0] == "embed_tokens": 242 | state["token_embed"].embedding.value = tensor 243 | # Tie weights with lm_head (transpose required) 244 | state["lm_head"].kernel.value = tensor.T 245 | elif keys[0] == "norm": 246 | state["norm"].norm_weights.value = tensor 247 | elif keys[0] == "lm_head": 248 | # This case might occur if weights aren't tied in HF checkpoint 249 | # If already tied via embed_tokens, this might overwrite or be redundant 250 | if not jnp.array_equal(state["lm_head"].kernel.value, tensor.T): 251 | print(f"Warning: Overwriting lm_head from separate checkpoint tensor: {wholekey}") 252 | state["lm_head"].kernel.value = tensor.T 253 | # else: # Handle potential unexpected keys 254 | # print(f"Warning: Unhandled HF key: {wholekey}") 255 | 256 | ``` 257 | 258 | ## Conclusion 259 | 260 | `LlamaForCausalLM` provides a comprehensive and efficient implementation of the Llama architecture within the `jaxgarden` framework. By leveraging `BaseModel` for structure and `GenerationMixin` for functionality, and integrating key components like `LlamaRMSNorm`, `LlamaAttention` with RoPE and GQA, and `LlamaMLP` with SwiGLU, it offers a powerful tool for causal language modeling tasks. Its ability to load pretrained weights via `convert_weights_from_hf` makes it easy to utilize existing Llama models directly in JAX. 261 | 262 | The next chapter will delve deeper into the text generation capabilities provided by the mixin class used by this model. 263 | 264 | **Next:** [GenerationMixin](generationmixin.mdc) 265 | 266 | 267 | --- 268 | 269 | Generated by [Rules for AI](https://github.com/altaidevorg/rules-for-ai) -------------------------------------------------------------------------------- /.cursor/rules/modernbertformaskedlm.mdc: -------------------------------------------------------------------------------- 1 | --- 2 | description: Details the jaxgarden ModernBERTForMaskedLM model, a bidirectional encoder with RoPE, pre-LN, and mixed attention for MLM tasks. 3 | globs: jaxgarden/models/modernbert.py 4 | alwaysApply: false 5 | --- 6 | # Chapter 6: ModernBERTForMaskedLM 7 | 8 | In the [previous chapter](generationmixin.mdc), we explored `GenerationMixin`, focusing on autoregressive text generation for causal language models. Now, we shift our focus to a different type of transformer architecture: a bidirectional encoder designed specifically for Masked Language Modeling (MLM) tasks, incorporating modern optimizations. Welcome to `ModernBERTForMaskedLM`. 9 | 10 | **Motivation:** While causal models excel at text generation, many NLP tasks benefit from understanding context bidirectionally (both left and right). Traditional BERT achieved this but often suffered from computational inefficiency, especially with long sequences. The ModernBERT architecture, as proposed by Answer.AI, aims to create a "Smarter, Better, Faster, Longer" bidirectional encoder by incorporating techniques like Rotary Position Embeddings (RoPE), pre-Layer Normalization, and efficient attention mechanisms, making it suitable for fast, memory-efficient training and inference, particularly for long contexts. `ModernBERTForMaskedLM` implements this architecture within `jaxgarden` for the MLM pre-training objective. 11 | 12 | **Central Use Case:** Pre-training or fine-tuning a language model using the masked language modeling objective, where the model predicts randomly masked tokens in the input sequence. This model can also serve as a powerful feature extractor for downstream NLP tasks requiring bidirectional context understanding, potentially after loading pre-trained weights using the framework provided by [BaseModel](basemodel.mdc). 13 | 14 | ## Key Concepts 15 | 16 | `ModernBERTForMaskedLM` integrates several components and modern architectural choices: 17 | 18 | 1. **Inheritance:** It inherits from [BaseModel](basemodel.mdc), gaining standardized configuration management, state handling (Flax NNX), checkpointing (`save`/`load`), and the interface for Hugging Face weight conversion (`from_hf`). 19 | 2. **`ModernBertEmbeddings`:** Handles the initial input processing. It converts token IDs into dense vectors using an embedding layer and applies Layer Normalization. Crucially, unlike traditional BERT, it does *not* include explicit learned position embeddings; positional information is injected later via RoPE. 20 | 3. **`ModernBERTEncoder`:** The core of the model, consisting of a stack of `ModernBertLayer` instances (`num_hidden_layers` defined in `ModernBERTConfig`). 21 | 4. **`ModernBertLayer`:** Implements a single transformer block using the **Pre-LayerNorm** architecture (LayerNorm applied *before* the attention and MLP sub-layers, followed by residual connections). This structure often leads to more stable training compared to the original Post-LayerNorm BERT. Each layer contains: 22 | * `ModernBertAttention`: Performs multi-head self-attention. 23 | * `ModernBertMLP`: A feed-forward network applied after attention. 24 | 5. **`ModernBertAttention`:** 25 | * Calculates query, key, and value projections from the input. 26 | * Applies **[Rotary Position Embeddings (RoPE)](rotary_position_embeddings__rope_.mdc)** directly to the query and key vectors before computing attention scores. `RoPEPositionalEmbedding` generates the necessary sinusoidal encodings. 27 | * Supports **Mixed Global and Local Attention:** Can operate in standard global attention mode (all tokens attend to all others) or use a sliding window (local attention) where tokens only attend to nearby tokens (`local_attention` parameter). This can be configured to happen only on certain layers (`global_attn_every_n_layers`) to balance global context understanding with computational efficiency. 28 | 6. **`ModernBertMLP`:** Uses GELU activation function within the feed-forward network. 29 | 7. **`ModernBERTMLMHead`:** Added on top of the `ModernBERTEncoder`. It takes the final hidden states, applies Layer Normalization, a dense layer with GELU activation, and finally projects the result to the vocabulary size to produce logits for predicting the original masked tokens. The final projection layer (decoder) is often tied to the token embedding weights. 30 | 31 | ## Using `ModernBERTForMaskedLM` 32 | 33 | Let's see how to instantiate and use the model. 34 | 35 | ### Initialization 36 | 37 | Define a `ModernBERTConfig` and initialize the model. Key parameters in the config control the architecture, including attention behavior. 38 | 39 | ```python 40 | import jax 41 | import jax.numpy as jnp 42 | from flax import nnx 43 | from jaxgarden.models.modernbert import ModernBERTConfig, ModernBERTForMaskedLM 44 | 45 | # Configuration for a smaller ModernBERT model 46 | config = ModernBERTConfig( 47 | vocab_size=30522, # Example BERT vocab size 48 | hidden_size=256, 49 | num_hidden_layers=4, 50 | num_attention_heads=8, 51 | intermediate_size=512, 52 | max_position_embeddings=512, # Max sequence length for RoPE cache 53 | attention_dropout=0.1, 54 | hidden_dropout=0.1, 55 | # Use local attention (window size 128 left, 128 right) 56 | local_attention=(128, 128), 57 | # Apply global attention every 2 layers (layers 0, 2) 58 | global_attn_every_n_layers=2, 59 | pad_token_id=0 60 | ) 61 | 62 | # Initialize PRNG keys 63 | rngs = nnx.Rngs(params=0) # Or use more sophisticated key management 64 | 65 | # Instantiate the model 66 | model = ModernBERTForMaskedLM(config, rngs=rngs, param_dtype=jnp.float32) 67 | 68 | print(f"Initialized ModernBERTForMaskedLM with {config.num_hidden_layers} layers.") 69 | print(f"Local attention window: {config.local_attention}") 70 | print(f"Global attention every {config.global_attn_every_n_layers} layers.") 71 | # Output: Initialized ModernBERTForMaskedLM with 4 layers. 72 | # Output: Local attention window: (128, 128) 73 | # Output: Global attention every 2 layers. 74 | ``` 75 | **Explanation:** We create a `ModernBERTConfig` instance, customizing parameters like layer count, hidden size, and attention settings (`local_attention`, `global_attn_every_n_layers`). We then instantiate `ModernBERTForMaskedLM` with this config and an `nnx.Rngs` object. 76 | 77 | ### Forward Pass (`__call__`) 78 | 79 | The `__call__` method takes token IDs and an optional attention mask, returning a dictionary containing logits and optionally hidden states and attentions. 80 | 81 | ```python 82 | # Prepare dummy input (Batch size 2, sequence length 64) 83 | batch_size = 2 84 | seq_len = 64 85 | dummy_input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) 86 | # Mask indicating valid tokens (1) vs padding (0) 87 | dummy_attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) 88 | 89 | # Define forward pass function (typically JIT-compiled) 90 | @jax.jit 91 | def forward_pass(model_state, input_ids, attention_mask): 92 | # Split graphdef/params for JIT 93 | graphdef, params = nnx.split(model, nnx.Param, ...) # simplified split 94 | # Call the model's graph definition 95 | outputs = graphdef( 96 | input_ids=input_ids, 97 | attention_mask=attention_mask, 98 | deterministic=True # Disable dropout for inference 99 | ) 100 | return outputs['logits'] # Return only logits for this example 101 | 102 | # --- Execution --- 103 | # We won't execute the JITted function here, just show structure/shapes 104 | # logits = forward_pass(model.state, dummy_input_ids, dummy_attention_mask) 105 | # print(f"Output logits shape: {logits.shape}") 106 | # Expected output shape if run: (2, 64, 30522) 107 | print(f"Expected output logits shape: ({batch_size}, {seq_len}, {config.vocab_size})") 108 | # Output: Expected output logits shape: (2, 64, 30522) 109 | ``` 110 | **Explanation:** 111 | - `input_ids`: `[batch_size, seq_len]` tensor of token IDs. 112 | - `attention_mask`: `[batch_size, seq_len]` tensor (1 for real tokens, 0 for padding). This mask is used by the attention mechanism to prevent attending to padding tokens. 113 | - `deterministic`: If `True`, dropout layers are disabled. Typically `True` for evaluation/inference and `False` for training. 114 | - The output is a dictionary. The primary output for MLM is `logits` (`[batch_size, seq_len, vocab_size]`), representing the model's prediction scores for each token position. Optional keys `hidden_states` and `attentions` can be requested via `output_hidden_states=True` and `output_attentions=True`. 115 | 116 | ### Loading Pretrained Weights (Conceptual) 117 | 118 | Like other `BaseModel` subclasses, `ModernBERTForMaskedLM` can potentially load weights from Hugging Face using `from_hf`, provided a `convert_weights_from_hf` method is implemented for this specific architecture (mapping HF ModernBERT checkpoint names to `jaxgarden` state names). 119 | 120 | ```python 121 | # hf_model_id = "AnswerDotAI/modernbert-8k-S-500k" # Example HF model ID 122 | # print(f"Attempting to load weights from {hf_model_id}...") 123 | # try: 124 | # # Ensure config matches the HF model being loaded 125 | # # config_hf = ModernBERTConfig(...) # Load/define config matching HF model 126 | # # model_hf = ModernBERTForMaskedLM(config_hf, rngs=nnx.Rngs(1)) 127 | # 128 | # # This call requires a ModernBERT-specific implementation of convert_weights_from_hf 129 | # # model_hf.from_hf(hf_model_id) 130 | # 131 | # # print("Weights loaded successfully.") 132 | # except NotImplementedError: 133 | # print(f"NOTE: {type(model).__name__} does not currently implement HF weight conversion.") 134 | # except Exception as e: 135 | # # print(f"Skipping HF weight loading example due to error: {e}") 136 | # pass 137 | print("Skipping actual HF weight loading execution (requires implementation).") 138 | ``` 139 | 140 | ## Internal Implementation 141 | 142 | Understanding the flow of data within the model helps in debugging and customization. 143 | 144 | **High-Level Walkthrough (`__call__`)**: 145 | 146 | 1. **Embeddings:** Input `input_ids` are passed to `self.embeddings` (`ModernBertEmbeddings`), which performs token lookup, scales embeddings, applies Layer Normalization, and optional dropout. Result: `hidden_states`. 147 | 2. **Attention Mask Prep:** The input `attention_mask` (e.g., `[1, 1, 0]`) is converted internally by the attention mechanism into an additive mask (e.g., `[0, 0, -inf]`) suitable for adding to attention scores before softmax. Sliding window masks are generated or used if local attention is active for the current layer. 148 | 3. **Encoder:** `hidden_states` and the prepared `attention_mask` (and optionally `position_ids`) are passed to `self.encoder` (`ModernBERTEncoder`). 149 | * The encoder iterates through its list of `ModernBertLayer`s. 150 | * Each `ModernBertLayer`: 151 | * Applies LayerNorm to the input (`attn_norm`, potentially `Identity` for layer 0). 152 | * Passes the normalized input to `self.attn` (`ModernBertAttention`), which calculates attention output using RoPE and potentially local masking. 153 | * Adds the attention output back to the layer's input (first residual connection). 154 | * Applies LayerNorm to the result (`mlp_norm`). 155 | * Passes the normalized result to `self.mlp` (`ModernBertMLP`). 156 | * Adds the MLP output back (second residual connection). 157 | * The final output of the last layer is passed through a final LayerNorm (`self.encoder.final_norm`). Result: `sequence_output`. 158 | 4. **MLM Head:** `sequence_output` is passed to `self.mlm_head` (`ModernBERTMLMHead`). 159 | * Applies LayerNorm (`self.mlm_head.norm`). 160 | * Applies a dense layer (`self.mlm_head.dense`) followed by GELU activation. 161 | * Applies the final decoder/projection layer (`self.mlm_head.decoder`) to get scores over the vocabulary. Result: `logits`. 162 | 5. **Output:** Returns a dictionary containing `logits` and any requested optional outputs (`hidden_states`, `attentions`). 163 | 164 | ```mermaid 165 | sequenceDiagram 166 | participant Input 167 | participant Model as ModernBERTForMaskedLM 168 | participant Embed as ModernBertEmbeddings 169 | participant Encoder as ModernBERTEncoder 170 | participant Layer as ModernBertLayer (Loop) 171 | participant Attn as ModernBertAttention 172 | participant MLP as ModernBertMLP 173 | participant Head as ModernBERTMLMHead 174 | participant OutputDict 175 | 176 | Input->>+Model: __call__(input_ids, attention_mask) 177 | Model->>+Embed: embeddings(input_ids) 178 | Embed-->>-Model: hidden_states 179 | Model->>+Encoder: encoder(hidden_states, attention_mask) 180 | loop N Layers 181 | Encoder->>+Layer: layer(hidden_states, mask) 182 | Layer->>Layer: Apply attn_norm (pre-norm) 183 | Layer->>+Attn: attn(normed_states, mask) # Applies RoPE internally 184 | Attn-->>-Layer: attention_output 185 | Layer->>Layer: Residual 1 (hidden_states + attention_output) 186 | Layer->>Layer: Apply mlp_norm (pre-norm) 187 | Layer->>+MLP: mlp(normed_states) 188 | MLP-->>-Layer: mlp_output 189 | Layer->>Layer: Residual 2 (hidden_states + mlp_output) -> new hidden_states 190 | Layer-->>-Encoder: updated hidden_states 191 | end 192 | Encoder->>Encoder: Apply final_norm 193 | Encoder-->>-Model: sequence_output 194 | Model->>+Head: mlm_head(sequence_output) 195 | Head-->>-Model: logits 196 | Model->>+OutputDict: Create {'logits': logits, ...} 197 | OutputDict-->>-Model: output_dict 198 | Model-->>-Input: output_dict 199 | ``` 200 | 201 | **Code Details (from `jaxgarden/models/modernbert.py`)**: 202 | 203 | * **`ModernBERTForMaskedLM.__init__`**: Initializes the main components (`self.embeddings`, `self.encoder`, `self.mlm_head`) by passing the `config` and `rngs`. 204 | ```python 205 | # Simplified from ModernBERTForMaskedLM.__init__ 206 | class ModernBERTForMaskedLM(BaseModel): 207 | def __init__(self, config: ModernBERTConfig, *, rngs: nnx.Rngs, ...): 208 | super().__init__(config, rngs=rngs, ...) 209 | 210 | self.embeddings = ModernBertEmbeddings(rngs=rngs, ...) # Pass relevant config fields 211 | self.encoder = ModernBERTEncoder(rngs=rngs, ...) # Pass relevant config fields 212 | self.mlm_head = ModernBERTMLMHead(rngs=rngs, ...) # Pass relevant config fields 213 | ``` 214 | 215 | * **`ModernBERTForMaskedLM.__call__`**: Orchestrates the forward pass. 216 | ```python 217 | # Simplified from ModernBERTForMaskedLM.__call__ 218 | def __call__(self, input_ids, attention_mask=None, ..., deterministic=True, ...): 219 | # 1. Get embeddings 220 | hidden_states = self.embeddings( 221 | input_ids=input_ids, deterministic=deterministic, ... 222 | ) 223 | 224 | # 2. Apply encoder 225 | encoder_outputs = self.encoder( 226 | hidden_states, attention_mask=attention_mask, deterministic=deterministic, ... 227 | ) 228 | sequence_output = encoder_outputs[0] 229 | # (Handle optional hidden_states/attentions from encoder_outputs) 230 | 231 | # 3. Apply MLM head 232 | logits = self.mlm_head(sequence_output) 233 | 234 | # 4. Build output dictionary 235 | outputs = {"logits": logits} 236 | # (Add optional outputs if requested) 237 | return outputs 238 | ``` 239 | 240 | * **`ModernBertAttention.__call__`**: Key steps include QKV projection, RoPE application, attention score calculation, mask application (including sliding window if active), softmax, dropout, and output projection. 241 | ```python 242 | # Conceptual steps within ModernBertAttention.__call__ 243 | # qkv = self.Wqkv(hidden_states) 244 | # query, key, value = split_heads_and_transpose(qkv) 245 | # query = apply_rotary_pos_emb(query, self.rotary_emb.cache, position_ids) 246 | # key = apply_rotary_pos_emb(key, self.rotary_emb.cache, position_ids) 247 | # attention_scores = compute_scores(query, key) 248 | # if self.local_attention != (-1, -1): 249 | # sliding_window_mask = create_sliding_window_mask(...) or use provided 250 | # attention_scores += sliding_window_mask 251 | # if attention_mask is not None: 252 | # attention_scores += attention_mask # Additive mask 253 | # attention_probs = softmax(attention_scores) 254 | # # Apply dropout if needed 255 | # attention_output = combine_heads(matmul(attention_probs, value)) 256 | # attention_output = self.Wo(attention_output) 257 | ``` 258 | 259 | ## Conclusion 260 | 261 | `ModernBERTForMaskedLM` provides a `jaxgarden` implementation of a modern, efficient bidirectional transformer encoder tailored for masked language modeling. By incorporating features like RoPE, pre-Layer Normalization, and optional sliding window attention, it aims to improve upon traditional BERT architectures in terms of speed, memory usage, and handling of long sequences. It leverages the `BaseModel` foundation for consistency and integrates specialized components like `ModernBertEmbeddings`, `ModernBERTEncoder`, `ModernBertAttention`, `ModernBertMLP`, and `ModernBERTMLMHead`. 262 | 263 | Understanding the general principles of attention mechanisms is fundamental to grasping how models like ModernBERT work internally. The next chapter delves into the details of attention itself. 264 | 265 | **Next:** [Attention Mechanism (MultiHeadAttention / dot_product_attention)](attention_mechanism__multiheadattention___dot_product_attention_.mdc) 266 | 267 | 268 | --- 269 | 270 | Generated by [Rules for AI](https://github.com/altaidevorg/rules-for-ai) --------------------------------------------------------------------------------