├── .gitignore
├── .gitmodules
├── .readthedocs.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── constraints.txt
├── docs
├── Makefile
├── README.md
├── images
│ ├── kithara_logo.png
│ ├── kithara_logo_no_text.png
│ └── kithara_logo_with_green_bg.png
├── requirements.txt
└── source
│ ├── .getting_tpus.rst.swp
│ ├── .overview.rst.swp
│ ├── api
│ ├── kithara.dataset_api.rst
│ ├── kithara.model_api.rst
│ └── kithara.trainer_api.rst
│ ├── basic_examples
│ └── full_param_finetuning_example.rst
│ ├── checkpointing.rst
│ ├── conf.py
│ ├── datasets.rst
│ ├── ddp.rst
│ ├── disk_storage.rst
│ ├── dws_flex_start.rst
│ ├── finetuning_guide.rst
│ ├── getting_tpus.rst
│ ├── help_and_feedback.rst
│ ├── index.rst
│ ├── installation.rst
│ ├── installation
│ ├── .tpu_gke.rst.swp
│ ├── tpu_gke.rst
│ ├── tpu_qr.rst
│ └── tpu_vm.rst
│ ├── lora.rst
│ ├── models.rst
│ ├── observability.rst
│ ├── optimizations.rst
│ ├── optimizers.rst
│ ├── overview.rst
│ ├── overviewxyz.rst
│ ├── packing.rst
│ ├── pretraining.rst
│ ├── quickstart.rst
│ ├── scaling_with_ray.rst
│ ├── serve_with_vllm.rst
│ ├── sft.rst
│ ├── temp.rst
│ └── troubleshooting.rst
├── examples
├── colab
│ └── SFT_with_LoRA_Gemma2-2b.ipynb
├── example_datasets.py
├── multihost
│ └── ray
│ │ ├── GPU
│ │ └── sft_lora_example.py
│ │ └── TPU
│ │ ├── continued_pretraining_example.py
│ │ ├── full_finetuning_example.py
│ │ ├── quick_start.py
│ │ ├── sft_lora_example.py
│ │ └── sft_lora_example.yaml
└── singlehost
│ ├── continued_pretraining_example.py
│ ├── full_finetuning_example.py
│ ├── quick_start.py
│ └── sft_lora_example.py
├── kithara
├── __init__.py
├── callbacks
│ ├── __init__.py
│ ├── checkpointer.py
│ └── profiler.py
├── config
│ ├── __init__.py
│ ├── default.yaml
│ └── pyconfig.py
├── dataset
│ ├── __init__.py
│ ├── dataloader.py
│ ├── dataset.py
│ ├── packed_dataset.py
│ ├── sft.py
│ ├── text_completion.py
│ └── utils.py
├── distributed
│ ├── __init__.py
│ ├── data
│ │ ├── __init__.py
│ │ └── split.py
│ └── sharding
│ │ ├── __init__.py
│ │ ├── _data_sharding.py
│ │ ├── _layout.py
│ │ ├── _mesh.py
│ │ ├── maxtext.py
│ │ ├── models
│ │ ├── __init__.py
│ │ ├── gemma.py
│ │ └── llama.py
│ │ ├── strategy.py
│ │ └── utils.py
├── model
│ ├── README.md
│ ├── __init__.py
│ ├── hf_compatibility
│ │ ├── __init__.py
│ │ ├── model_configs.py
│ │ ├── shape_mapping.py
│ │ └── to_huggingface.py
│ ├── kerashub
│ │ ├── __init__.py
│ │ ├── ckpt_compatibility
│ │ │ ├── param_mapping.py
│ │ │ └── to_huggingface.py
│ │ └── keras_hub_model.py
│ ├── maxtext
│ │ ├── __init__.py
│ │ ├── ckpt_compatibility
│ │ │ ├── __init__.py
│ │ │ ├── from_huggingface.py
│ │ │ ├── param_mapping.py
│ │ │ ├── to_huggingface.py
│ │ │ └── utils.py
│ │ ├── conversion_utils.py
│ │ ├── inference_engine.py
│ │ └── maxtext_model.py
│ ├── model.py
│ └── supported_models.py
├── optimizers
│ ├── __init__.py
│ ├── optax_optimizer.py
│ └── protocol.py
├── trainer
│ ├── __init__.py
│ └── trainer.py
└── utils
│ ├── gcs_utils.py
│ ├── logging_utils.py
│ ├── safetensor_utils.py
│ ├── torch_utils.py
│ └── tree_utils.py
├── perf
├── kerashub.py
├── kithara_packing.py
├── maxtext_native.py
├── maxtext_via_kithara.py
├── sft_kithara.py
├── sft_unsloth.py
└── trl_packing.py
├── pyproject.toml
├── ray
├── GPU
│ ├── cluster.yaml
│ └── find_devices.py
├── README.md
├── TPU
│ ├── GCE
│ │ └── cluster.yaml
│ ├── GKE
│ │ ├── Dockerfile
│ │ ├── multi-host.yaml
│ │ └── single-host.yaml
│ └── QR
│ │ ├── cluster.yaml
│ │ └── qr_worker_startup_script.sh
└── submit_job.py
└── tests
├── README.md
├── __init__.py
├── callbacks
├── __init__.py
└── test_orbax_checkpointer.py
├── dataset
├── __init__.py
├── test_dataloader_creation.py
├── test_dataset_creation.py
├── test_dataset_packing.py
└── utils.py
├── model
├── __init__.py
├── kerashub
│ ├── ckpt_compatibility
│ │ ├── __init__.py
│ │ ├── test_loading_models.py
│ │ ├── test_lora_adapters_shape_match.py
│ │ ├── test_lora_adapters_value_match.py
│ │ └── test_saving_models.py
│ ├── test_creation.py
│ └── test_inference.py
├── maxtext
│ ├── __init__.py
│ ├── ckpt_compatibility
│ │ ├── __init__.py
│ │ ├── test_loading_models.py
│ │ └── test_saving_models.py
│ ├── test_creation.py
│ └── test_inference.py
├── test_prompt.py
└── utils.py
├── optimizers
├── __init__.py
├── test_optax_optimizer.py
└── test_protocol.py
├── test_utils.py
└── trainer
├── __init__.py
├── test_sft_e2e.py
└── test_trainer_creation.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.egg-info/
2 | *.pyc
3 | .DS_Store
4 | .ipynb_checkpoints
5 | events*
6 | **/__pycache__
7 | docs/build
8 | tmp/
9 | dist/
10 | *.pyc
11 | build/
12 | docs/bin/
13 | docs/lib/
14 | docs/lib64
15 | docs/pyvenv.cfg
16 |
17 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "kithara/model/maxtext/maxtext"]
2 | path = kithara/model/maxtext/maxtext
3 | url = https://github.com/google/maxtext
4 | [submodule "kithara/model/maxtext/JetStream"]
5 | path = kithara/model/maxtext/JetStream
6 | url = https://github.com/AI-Hypercomputer/JetStream.git
7 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # Read the Docs configuration file
2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
3 |
4 | # Required
5 | version: 2
6 |
7 | # Set the OS, Python version, and other tools you might need
8 | build:
9 | os: ubuntu-24.04
10 | tools:
11 | python: "3.11"
12 |
13 | # Build documentation in the "docs/" directory with Sphinx
14 | sphinx:
15 | configuration: docs/source/conf.py
16 |
17 | # Optionally, but recommended,
18 | # declare the Python requirements required to build your documentation
19 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
20 | python:
21 | install:
22 | - requirements: docs/requirements.txt
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Kithara - Easy Finetuning on TPUs
2 |
3 | [](https://pypi.org/project/kithara/)
4 | [](https://github.com/AI-Hypercomputer/kithara/pulls)
5 | [](https://github.com/AI-Hypercomputer/kithara/commits/main)
6 | [](https://kithara.readthedocs.io/en/latest/)
7 |
8 |
17 |
18 | ## 👋 Overview
19 |
20 | Kithara is a lightweight library offering building blocks and recipes for tuning popular open source LLMs including Gemma2 and Llama3 on Google TPUs.
21 |
22 | It provides:
23 |
24 | - **Frictionless scaling**: Distributed training abstractions intentionally built with simplicity in mind.
25 | - **Multihost training support**: Integration with Ray, GCE and GKE.
26 | - **Async, distributed checkpointing**: Multi-host & Multi-device checkpointing via Orbax.
27 | - **Distributed, streamed dataloading**: Per-process, streamed data loading via Ray.data.
28 | - **GPU/TPU fungibility**: Same code works for both GPU and TPU out of the box.
29 | - **Native integration with HuggingFace**: Tune and save models in HuggingFace format.
30 |
31 | **New to TPUs?**
32 |
33 | Using TPUs provides significant advantages in terms of performance, cost-effectiveness, and scalability, enabling faster training times and the ability to work with larger models and datasets. Check out our onboarding guide to [getting TPUs](https://kithara.readthedocs.io/en/latest/getting_tpus.html).
34 |
35 | ## 🔗 **Key links and resources**
36 | | | |
37 | | --------------------------------- | --------------------------------------------------------------------------------------------------------------------------- |
38 | | 📚 **Documentation** | [Read Our Docs](https://kithara.readthedocs.io/en/latest/) |
39 | | 💾 **Installation** | [Quick Pip Install](https://kithara.readthedocs.io/en/latest/installation.html) |
40 | | ✏️ **Get Started** | [Intro to Kithara](https://kithara.readthedocs.io/en/latest/quickstart.html) |
41 | | 🌟 **Supported Models** | [List of Models](https://kithara.readthedocs.io/en/latest/models.html) |
42 | | 🌐 **Supported Datasets** | [List of Data Formats](https://kithara.readthedocs.io/en/latest/datasets.html) |
43 | | ⌛️ **Performance Optimizations** | [Our Memory and Throughput Optimizations](https://kithara.readthedocs.io/en/latest/optimizations.html) |
44 | | 📈 **Scaling up** | [Guide for Tuning Large Models](https://kithara.readthedocs.io/en/latest/scaling_with_ray.html) |
45 |
46 |
47 | ## 🌵 **Examples**
48 |
49 | - **Quick Start Colab Notebook**: [SFT + LoRA with Gemma2-2b](https://colab.sandbox.google.com/github/AI-Hypercomputer/kithara/blob/main/examples/colab/SFT_with_LoRA_Gemma2-2b.ipynb)
50 |
51 | - **SFT + LoRA**: [Step by Step Example](https://kithara.readthedocs.io/en/latest/sft.html)
52 |
53 | - **Continued Pretraining**: [Step by Step Example](https://kithara.readthedocs.io/en/latest/pretraining.html)
54 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-autobuild
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # Kithara ReadTheDocs Site
2 |
3 | To build:
4 |
5 | ```
6 | # Install Sphinx deps.
7 | pip install -r docs/requirements.txt
8 |
9 | make -C docs html
10 |
11 | # Preview locally
12 |
13 | Go to http://127.0.0.1:8000
14 | ```
--------------------------------------------------------------------------------
/docs/images/kithara_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/docs/images/kithara_logo.png
--------------------------------------------------------------------------------
/docs/images/kithara_logo_no_text.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/docs/images/kithara_logo_no_text.png
--------------------------------------------------------------------------------
/docs/images/kithara_logo_with_green_bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/docs/images/kithara_logo_with_green_bg.png
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx==8.2.0
2 | sphinx-autobuild
3 | sphinxawesome-theme>=5.3.2
4 | sphinx_design
5 |
--------------------------------------------------------------------------------
/docs/source/.getting_tpus.rst.swp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/docs/source/.getting_tpus.rst.swp
--------------------------------------------------------------------------------
/docs/source/.overview.rst.swp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/docs/source/.overview.rst.swp
--------------------------------------------------------------------------------
/docs/source/api/kithara.dataset_api.rst:
--------------------------------------------------------------------------------
1 | .. _dataset_api:
2 |
3 | Datasets
4 | ========
5 |
6 | TextCompletionDataset
7 | --------------------
8 |
9 | .. py:class:: TextCompletionDataset
10 |
11 | A dataset class for standard text completion tasks.
12 |
13 | :param source: The source Ray dataset containing the text data
14 | :type source: ray.data.Dataset
15 | :param tokenizer: The tokenizer instance to use
16 | :type tokenizer: Optional[AutoTokenizer]
17 | :param tokenizer_handle: Handle/name of the HF tokenizer to load if not provided. E.g. "hf://google/gemma-2-2b"
18 | :type tokenizer_handle: Optional[str]
19 | :param column_mapping: Mapping of source column name to expected column name ("text")
20 | :type column_mapping: Optional[Dict[str, str]]
21 | :param model_type: Type of model implementation to use. Supported types: "KerasHub", "MaxText", "auto"
22 | :type model_type: ModelImplementationType | "auto"
23 | :param max_seq_len: Maximum sequence length for tokenization (default: 1024)
24 | :type max_seq_len: int
25 | :param custom_formatting_fn: A custom formatting function to apply to the raw sample before any other transformation steps
26 | :type custom_formatting_fn: Optional[callable]
27 | :param packing: Whether to enable sequence packing
28 | :type packing: bool
29 |
30 | .. py:method:: to_packed_dataset()
31 |
32 | Converts the current dataset to a PackedDataset for more efficient processing.
33 |
34 | :return: A new PackedDataset instance
35 | :rtype: PackedDataset
36 |
37 | SFTDataset
38 | ----------
39 |
40 | .. py:class:: SFTDataset(TextCompletionDataset)
41 |
42 | A dataset class for Supervised Fine-Tuning (SFT) tasks.
43 |
44 | :param source: The source Ray dataset containing the training data
45 | :type source: ray.data.Dataset
46 | :param tokenizer: HuggingFace tokenizer instance
47 | :type tokenizer: Optional[AutoTokenizer]
48 | :param tokenizer_handle: Handle/name of the HF tokenizer to load if not provided. E.g. "hf://google/gemma-2-2b"
49 | :type tokenizer_handle: Optional[str]
50 | :param column_mapping: Mapping of source column names to expected column names ("prompt" and "answer")
51 | :type column_mapping: Optional[Dict[str, str]]
52 | :param model_type: Type of model implementation to use. Supported types: "KerasHub", "MaxText", "auto"
53 | :type model_type: ModelImplementationType | "auto"
54 | :param max_seq_len: Maximum sequence length for tokenization (default: 1024)
55 | :type max_seq_len: int
56 | :param custom_formatting_fn: A custom formatting function to apply to the raw sample before any other transformation steps
57 | :type custom_formatting_fn: Optional[callable]
58 |
59 | .. py:method:: to_packed_dataset()
60 |
61 | Converts the current dataset to a PackedDataset for more efficient processing.
62 |
63 | :return: A new PackedDataset instance
64 | :rtype: PackedDataset
65 |
66 | PackedDataset
67 | ------------
68 |
69 | .. py:class:: PackedDataset
70 |
71 | A dataset class that packs multiple sequences together on the fly for more efficient processing.
72 |
73 | :param source_dataset: The source dataset containing unpacked sequences
74 | :type source_dataset: TextCompletionDataset
75 | :param pad_value: The value to use for padding (default: 0)
76 | :type pad_value: int
77 |
78 | .. note::
79 | - Packing must be used with Flash Attention enabled (which should be enabled by default)
80 | - Packing currently only works for MaxText models
81 | - Packing does not currently work for DDP training
82 |
83 |
84 | Example Usage
85 | ------------
86 |
87 | Here's a simple example of using the TextCompletionDataset::
88 |
89 | dataset = TextCompletionDataset(
90 | source=ray_dataset,
91 | tokenizer_handle="hf://google/gemma-2-2b",
92 | max_seq_len=512,
93 | )
94 |
95 | For supervised fine-tuning tasks, use the SFTDataset::
96 |
97 | sft_dataset = SFTDataset(
98 | source=ray_dataset,
99 | tokenizer_handle="hf://google/gemma-2-2b",
100 | column_mapping={"input": "prompt", "output": "answer"},
101 | max_seq_len=1024
102 | )
103 |
104 | To enable sequence packing for more efficient processing::
105 |
106 | packed_dataset = dataset.to_packed_dataset()
--------------------------------------------------------------------------------
/docs/source/api/kithara.trainer_api.rst:
--------------------------------------------------------------------------------
1 | .. _trainer_api:
2 |
3 | Trainer
4 | =======
5 |
6 | .. py:class:: Trainer
7 |
8 | A base Trainer class supporting continued pretraining and SFT, designed to be subclassed for implementing other training objectives, e.g. DPO.
9 |
10 | :param model: The model to be trained or evaluated
11 | :type model: kithara.Model
12 | :param optimizer: The optimizer used for training
13 | :type optimizer: keras.Optimizer | optax.GradientTransformation | optax.GradientTransformationExtraArgs
14 | :param train_dataloader: A dataloader that provides training batches
15 | :type train_dataloader: kithara.Dataloader
16 | :param eval_dataloader: A dataloader that provides evaluation batches
17 | :type eval_dataloader: Optional[kithara.Dataloader]
18 | :param steps: The total number of training steps to execute. Defaults to None and trains 1 epoch
19 | :type steps: Optional[int]
20 | :param epochs: The total number of training epochs to execute. Defaults to None
21 | :type epochs: Optional[int]
22 | :param log_steps_interval: The interval between logging steps. Each log includes the current loss value and performance metrics
23 | :type log_steps_interval: int
24 | :param eval_steps_interval: The interval between evaluation steps. Only one of eval_steps_interval or eval_epochs_interval can be set
25 | :type eval_steps_interval: Optional[int]
26 | :param eval_epochs_interval: The interval between evaluation epochs. Only one of eval_steps_interval or eval_epochs_interval can be set
27 | :type eval_epochs_interval: Optional[int]
28 | :param max_eval_samples: The maximum number of samples to use during evaluation. Uses the entire evaluation dataset if not provided
29 | :type max_eval_samples: int
30 | :param tensorboard_dir: The directory path for TensorBoard logs. Can be either a local directory or a Google Cloud Storage path
31 | :type tensorboard_dir: Optional[str]
32 | :param profiler: A profiler instance for monitoring performance metrics
33 | :type profiler: Optional[kithara.Profiler]
34 | :param checkpointer: A checkpointer instance for saving model checkpoints
35 | :type checkpointer: Optional[kithara.Checkpointer]
36 |
37 | .. py:method:: train()
38 |
39 | Execute the main training loop. Handles epoch iteration, batch processing, loss computation, model updates, progress logging, and periodic evaluation.
40 |
41 |
42 | Example usage:
43 |
44 | .. code-block:: python
45 |
46 | trainer = Trainer(
47 | model=my_model,
48 | optimizer=keras.optimizers.Adam(learning_rate=1e-4),
49 | train_dataloader=train_loader,
50 | eval_dataloader=eval_loader,
51 | steps=1000,
52 | log_steps_interval=10,
53 | eval_steps_interval=100,
54 | tensorboard_dir="local_dir_or_gs_bucket",
55 | checkpointer= kithara.Checkpointer(
56 | save_dir="local_dir_or_gs_bucket",
57 | save_interval=100,
58 | ),
59 | )
60 |
61 | trainer.train()
62 |
63 | Note
64 | ----
65 | - If both ``steps`` and ``epochs`` are None, defaults to training for 1 epoch
66 | - If ``eval_dataloader`` is provided but no evaluation interval is set, defaults to evaluating every epoch
67 | - The trainer automatically handles data sharding for distributed training
--------------------------------------------------------------------------------
/docs/source/checkpointing.rst:
--------------------------------------------------------------------------------
1 | .. _checkpointing:
2 |
3 | Checkpointing
4 | =============
5 |
6 | Checkpointing allows you to save and restore model states during training.
7 | Kithara supports saving checkpoints to local directories or Google Cloud Storage buckets.
8 |
9 | Kithara checkpointing is fast -- it uses `Orbax `_ to save checkpoints asynchronously and in a distributed manner.
10 |
11 | .. note::
12 | For efficiency, checkpoints are saved in a non-huggingface format. Use ``model.save_in_hf_format`` to save models in HuggingFace format.
13 |
14 | Basic Usage
15 | ----------
16 |
17 | During Training
18 | ^^^^^^^^^^^^^^^^^^^
19 | Provide the ``checkpointer`` arg to the ``Trainer`` class to save checkpoints during training.::
20 |
21 | from Kithara import Checkpointer, Trainer
22 | # Keeps the latest 5 checkpoints, saving one every 100 steps
23 | checkpointer = Checkpointer("gs://...", save_interval_steps=100, max_to_keep=5)
24 | trainer = Trainer(..., checkpointer=checkpointer)
25 | trainer.train()
26 |
27 | Restoring checkpoint after training:::
28 |
29 | # Initialize a random model
30 | model = MaxTextModel.from_random(
31 | model_name="gemma2-9b",
32 | seq_len=4096,
33 | per_device_batch_size=1,
34 | precision="mixed_bfloat16",
35 | )
36 |
37 | # Restore from specific checkpoint
38 | checkpointer = Checkpointer(
39 | "gs://your_bucket/your_model_name/ckpt/20", # Step 20
40 | model=model
41 | )
42 | checkpointer.load()
43 |
44 | model.generate(...)
45 |
46 | As Standalone Utility
47 | ^^^^^^^^^^^^^^^^^^^^^
48 |
49 | You can also use the ``Checkpointer`` as a standalone utility to save and load checkpoints outside of the ``Trainer`` class.
50 |
51 | .. code-block:: python
52 |
53 | model = kithara.MaxTextModel.from_preset("hf://google/gemma2-2b")
54 |
55 | # Attach the checkpointer to the model
56 | checkpointer = Checkpointer("gs://...", model)
57 |
58 | # Save checkpoint - checkpoints need to be numbered sequentially
59 | checkpointer.save(0, blocking=True)
60 |
61 | # Load latest checkpoint back to the model
62 | checkpointer.load()
63 |
64 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | # -- Project information -----------------------------------------------------
7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8 |
9 | project = 'Kithara'
10 | copyright = '2024, The Kithara Authors'
11 | author = 'kithara_authors'
12 |
13 | # -- General configuration ---------------------------------------------------
14 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
15 |
16 | extensions = [
17 | 'sphinx.ext.duration',
18 | 'sphinx.ext.doctest',
19 | 'sphinx.ext.autodoc',
20 | 'sphinx.ext.autosummary',
21 | 'sphinx.ext.intersphinx',
22 | 'sphinx.ext.napoleon',
23 | 'sphinx_design',
24 | ]
25 |
26 | templates_path = ['_templates']
27 | exclude_patterns = []
28 |
29 |
30 |
31 | # -- Options for HTML output -------------------------------------------------
32 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
33 |
34 | html_theme = 'sphinxawesome_theme'
35 | html_static_path = ['_static']
36 | html_css_files = ['custom.css']
37 | html_permalinks_icon = '#'
38 |
39 | html_theme_options = {
40 | "main_nav_links": {
41 | "GitHub": "https://github.com/AI-Hypercomputer/kithara",
42 | },
43 | "logo_light": "../images/kithara_logo_no_text.png",
44 | "logo_dark": "../images/kithara_logo_no_text.png",
45 | "show_prev_next": True,
46 | }
47 |
--------------------------------------------------------------------------------
/docs/source/datasets.rst:
--------------------------------------------------------------------------------
1 | .. _datasets:
2 |
3 | Supported Data Formats
4 | ======================
5 |
6 | The expected flow is always Ray Dataset -> Kithara Dataset -> Kithara Dataloader.
7 |
8 | For example,::
9 |
10 | hf_dataset = load_dataset("allenai/c4", "en", split="train", streaming=True)
11 | ray_dataset = ray.data.from_huggingface(hf_dataset)
12 |
13 | kithara_dataset = Kithara.TextCompletionDataset(
14 | ray_dataset,
15 | tokenizer_handle="hf://google/gemma-2-2b",
16 | max_seq_len=8192
17 | )
18 | dataloader = Kithara.Dataloader(
19 | kithara_dataset,
20 | per_device_batch_size=1
21 | )
22 |
23 | .. note::
24 | You can load files and other data from your local directory, S3, GCS, ABS, and any other data storage platforms supported by `Ray Data `_.
25 |
26 | Kithara supports all file formats supported by Ray Dataset. Listed below are some examples using popular dataset types.
27 |
28 | HuggingFace Dataset
29 | ~~~~~~~~~~~~~~~~~~
30 | Load datasets directly from the Hugging Face Hub. Streaming datasets are supported::
31 |
32 | hf_dataset = load_dataset("allenai/c4", "en", split="train", streaming=True)
33 | ray_dataset = ray.data.from_huggingface(hf_dataset)
34 |
35 | JSON and JSONL
36 | ~~~~~~~~~~~~~
37 | Read JSON and JSONL files.
38 |
39 | .. code-block:: python
40 |
41 | ray_dataset = ray.data.read_json("s3://anonymous@ray-example-data/log.json")
42 |
43 | CSV
44 | ~~~
45 | Read CSV files.
46 |
47 | .. code-block:: python
48 |
49 | ray_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
50 |
51 | Python Dictionary
52 | ~~~~~~~~~~~~~~~~~
53 |
54 | .. code-block:: python
55 |
56 | dataset_items = [{
57 | "prompt": "What is your name?",
58 | "answer": "My name is Kithara",
59 | } for _ in range(1000)]
60 |
61 | ray_dataset = ray.data.from_items(dataset_items)
62 |
63 | Pandas
64 | ~~~~~~
65 |
66 | .. code-block:: python
67 |
68 | df = pd.DataFrame({
69 | "food": ["spam", "ham", "eggs"],
70 | "price": [9.34, 5.37, 0.94]
71 | })
72 | ray_dataset = ray.data.from_pandas(df)
73 |
74 | Lines of Text
75 | ~~~~~~~~~~~~~
76 | Read text files.
77 |
78 | .. code-block:: python
79 |
80 | ray_dataset = ray.data.read_text("s3://anonymous@ray-example-data/this.txt")
81 |
82 | TFRecords
83 | ~~~~~~~~~
84 | Read TFRecords Files
85 |
86 | .. code-block:: python
87 |
88 |
89 | ray_dataset = ray.data.read_tfrecords("s3://anonymous@ray-example-data/iris.tfrecords")
90 |
91 | Parquet
92 | ~~~~~~~
93 | Read Parquet files.
94 |
95 | .. code-block:: python
96 |
97 | ray_dataset = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet")
98 |
99 | Additional Formats
100 | ~~~~~~~~~~~~~~~
101 | Kithara supports all Ray Data formats. For more information:
102 |
103 | - Complete list of supported formats: `Ray Data Input/Output `_
104 | - Ray Dataset transformation guide: `Ray Data API `_
105 |
106 | Note: You should handle data cleaning, transformation, shuffling, and splitting using Ray Dataset utilities before passing the data to Kithara Dataset.
107 |
--------------------------------------------------------------------------------
/docs/source/ddp.rst:
--------------------------------------------------------------------------------
1 | .. _ddp:
2 |
3 | Managing Large Datasets
4 | =======================
5 |
6 | Distributed Data Parallel is a data loading technique useful for multi-host training on large datasets.
7 |
8 | When to use?
9 | ^^^^^^^^^^^^^^^
10 | - When you are running multihost training, and your dataset is too large to fit in memory on a single host.
11 | - When large batch processing is the bottleneck in your training pipeline.
12 |
13 |
14 | How it works?
15 | ^^^^^^^^^^^^^^^
16 |
17 | 1. Split your dataset into N shards, where N is the number of hosts. Each shard having the same number of samples.
18 | .. tip::
19 | Use ``ray.data.split()`` on your Ray dataset, or ``ray.data.streaming_split()`` for streamed datasets
20 | 2. Each host loads and processes data only from its assigned shard
21 |
22 | For example, with a 32,000-sample dataset split across 4 hosts, each host manages 8,000 samples instead of the full dataset
23 |
24 | Example DDP Workflow
25 | ^^^^^^^^^^^^^^^^^^
26 | Here's an example of how to implement DDP in Kithara::
27 |
28 | @ray.remote(resources={"TPU": num_chips_per_host})
29 | def main(train_ds):
30 | ...
31 | dataloader = Kithara.Dataloader(
32 | train_ds,
33 | per_device_batch_size=1, # or any other batch size
34 | dataset_is_sharded_per_host=True, # Enable DDP
35 | )
36 | ...
37 | trainer.train()
38 |
39 | train_ds: List[Any] = split_dataset(ray_dataset, num_hosts=num_tpu_hosts)
40 | ray.get(
41 | [
42 | main.remote(train_ds[i]) for i in range(num_tpu_hosts)
43 | ]
44 | )
--------------------------------------------------------------------------------
/docs/source/disk_storage.rst:
--------------------------------------------------------------------------------
1 | .. _disk_storage:
2 |
3 | Dealing With Disks
4 | ==================
5 | During the process of training and serving a LLM, you could easily run out of
6 | disk space on your VM or container. Potential large storage objects may
7 | include:
8 |
9 | * Base model (e.g. from Hugging Face)
10 | * Finetuning Dataset
11 | * Output checkpoints
12 |
13 | This page presents a few potential options for provisioning additional
14 | disk space for your deployments.
15 |
16 |
17 | General Recommendations
18 | -----------------------
19 | We generally recommend the following best practices:
20 |
21 | * Download your model or dataset to a shared storage volume or bucket
22 | first, and mount this volume or bucket when you need it. This would
23 | save both disk space and time spent on downloading large files.
24 | * Prefer writing checkpoints from a single worker. This avoids having
25 | to manage multiple writes to the same shared drive.
26 |
27 |
28 | Persistent Disks
29 | ----------------
30 | Persistent Disks (PDs) are durable storage volumes that can be
31 | dynamically provisioned at deployment time. PDs provide block storage,
32 | which is similar to having a hard drive attached to your VM. These
33 | are a versatile solution and are ideal for multi-read, low latency
34 | workloads.
35 |
36 | * Provisioning on TPU VMs: https://cloud.google.com/tpu/docs/setup-persistent-disk#:~:text=Attach%20a%20Persistent%20Disk%20when%20you%20create%20a%20TPU%20VM&text=To%20determine%20which%20VM%20image,you%20create%20a%20TPU%20VM.
37 |
38 | * Provisioning on GKE: https://cloud.google.com/kubernetes-engine/docs/concepts/persistent-volumes
39 |
40 |
41 | Google Cloud Storage Buckets
42 | ----------------------------
43 | GCS buckets are an alternative solution to PDs. Unlike PDs, GCS buckets
44 | provide object storage. These are ideal for storing large, unstructured
45 | data such as datasets and model checkpoints.
46 |
47 | Kithara is designed to work seamlessly with GCS buckets. To use a GCS
48 | bucket, make sure that the service account which is used to run your
49 | workload has write access to the GCS bucket.
50 |
51 | .. code-block:: bash
52 |
53 | gcloud storage buckets add-iam-policy-binding gs://YOUR_BUCKET_NAME \
54 | --member="serviceAccount:YOUR_SERVICE_ACCOUNT_EMAIL" \
55 | --role="YOUR_IAM_ROLE"
56 |
57 |
58 | The IAM role should be either 'roles/storage.objectCreator' or 'roles/storage.objectAdmin'.
59 |
60 |
61 | GCS Fuse
62 | --------
63 | An alternative way to use GCS Buckets is to mount them as local drives,
64 | with GCS Fuse. After mounting the bucket, you can access the files in the
65 | bucket as if they were in a locally mounted volume. GCSFuse buckets can
66 | be read-only or read-write. You could even mount the same bucket on multiple
67 | workers.
68 |
69 | * GCE VM: https://cloud.google.com/storage/docs/cloud-storage-fuse/overview
70 |
71 | * GKE: https://cloud.google.com/kubernetes-engine/docs/how-to/cloud-storage-fuse-csi-driver-setup#enable
72 |
73 | * Sample Kubernetes yaml: https://github.com/AI-Hypercomputer/kithara/blob/main/ray/TPU/GKE/single-host.yaml
74 |
75 | * Note - running on GKE requires enabling workload identity first: https://cloud.google.com/kubernetes-engine/docs/how-to/workload-identity
76 |
77 |
78 |
79 |
--------------------------------------------------------------------------------
/docs/source/finetuning_guide.rst:
--------------------------------------------------------------------------------
1 | .. _finetuning_guide:
2 |
3 | Finetuning Guide
4 | ===================
5 |
6 | Ready to fine-tune with your own data? Let's break it down into simple steps.
7 |
8 | 1. Pick Your Model
9 | ------------------
10 |
11 | Kithara uses two model gardens - MaxText and Keras Hub.
12 |
13 | MaxText
14 | ~~~~~~~
15 |
16 |
17 | MaxText is recommended for full-parameter fine-tuning.
18 |
19 | .. parsed-literal::
20 |
21 | from kithara import MaxTextModel
22 |
23 | model = **MaxTextModel**.from_preset(
24 | "hf://google/gemma-2-2b",
25 | seq_len=8192, # Seq len and batch size need to be specified up front
26 | per_device_batch_size=1
27 | )
28 |
29 | KerasHub
30 | ~~~~~~~~
31 |
32 | KerasHub is recommended for parameter efficient fine-tuning (LoRA).
33 |
34 | .. parsed-literal::
35 |
36 | from kithara import KerasHubModel
37 |
38 | model = **KerasHubModel**.from_preset(
39 | "hf://google/gemma-2-2b",
40 | lora_rank=16 # Applied to q_proj and v_proj
41 | )
42 |
43 | .. tip::
44 | - Always start model handles with ``hf://`` when loading from HuggingFace - so we know you are not loading from local directory 😀
45 | - The default precision ``mixed_bfloat16`` is your friend - it's memory-friendly! It loads model weights in full precision and casts activations to bfloat16.
46 | - Check out our :doc:`model garden ` for supported architectures
47 | - Want to save your model? Simply do ``model.save_in_hf_format(destination)`` to either save it locally, to GCS, or to HuggingFace.
48 | - Check out :doc:`Model API ` documentation
49 |
50 | 2. Prepare Your Data
51 | --------------------
52 |
53 | Kithara supports HuggingFace Datasets as well as Ray Datasets.
54 |
55 | HuggingFace Dataset::
56 |
57 | from datasets import load_dataset
58 | hf_dataset = load_dataset("allenai/c4","en",split="train",streaming=True),
59 | dataset = Kithara.TextCompletionDataset(
60 | hf_dataset,
61 | tokenizer_handle="hf://google/gemma-2-2b",
62 | max_seq_len=8192
63 | )
64 |
65 | Ray Dataset::
66 |
67 | import ray
68 | ray_dataset = ray.data.read_json("s3://anonymous@ray-example-data/log.json")
69 | dataset = Kithara.TextCompletionDataset(
70 | ray_dataset,
71 | tokenizer_handle="hf://google/gemma-2-2b",
72 | max_seq_len=8192
73 | )
74 |
75 | Now create a Kithara Dataloader to batchify your dataset.::
76 |
77 | # Create dataloader
78 | dataloader = Kithara.Dataloader(
79 | dataset,
80 | per_device_batch_size=1
81 | )
82 |
83 | The expected flow is always Ray/HF Dataset -> Kithara Dataset -> Kithara Dataloader.
84 |
85 | Quick Tips:
86 |
87 | - Your global batch size is always `per_device_batch_size` * `number of devices (chips)`.
88 | - Check out :doc:`supported data formats ` (CSV, JSON, etc.)
89 | - Check out :doc:`Dataset API ` documentation.
90 |
91 |
92 | 3. Choose Algorithm
93 | -------------------
94 |
95 | - **Continued pretraining**: Train your base model with large datasets to expand its knowledge
96 | - **Supervised finetuning**: Quickly and efficiently tune your model using labeled examples
97 |
98 | Check out these examples to get started.
99 |
100 | - :doc:`🌵 Continued Pretraining Example `
101 | - :doc:`🌵 SFT+LoRA Example `
102 |
103 | 4. Select Hardware
104 | ------------------
105 |
106 | If you haven't done so, check out :doc:`Getting TPUs ` to get your TPUs ready.
107 |
108 | If your TPU topology has multiple hosts, and you are not familiar with distributed training,
109 | we recommend you follow the :doc:`Scaling up with Ray ` guide to set up a
110 | Ray Cluster so that you can run multihost jobs.
111 |
112 |
--------------------------------------------------------------------------------
/docs/source/help_and_feedback.rst:
--------------------------------------------------------------------------------
1 | .. _help_and_feedback:
2 |
3 | Bugs, Features, Feedback
4 | =====
5 |
6 | Send the Google team a direct email to report bugs, request features and share feedback: `buganizer-system+1746240@google.com `_. You can use the templates below to add the required background information. Only Google can view what you share.
7 |
8 |
9 |
10 | Email template for bugs and feature requests
11 | --------
12 |
13 | .. parsed-literal::
14 |
15 | Issue Type: (Choose one and delete the others)
16 | * Bug Report
17 | * Feature Request
18 |
19 | Description:
20 |
21 | * What happened? (Required)
22 | * What did you expect to happen? (Required)
23 | * Error messages: (Required: Paste any error messages here and include the full log if you can. If none, write "N/A".)
24 |
25 | ---
26 | Context:
27 |
28 | * Project ID:
29 | * Company:
30 | * Kithara Version:
31 | * Python Version:
32 | * JAX Version:
33 | * Keras Version:
34 | * KerasHub Version:
35 | * TPU Resource type (QR, GKE, GCE):
36 | * TPU Generation (Trillium, v5p, etc.):
37 | * Number of Chips:
38 | * TPU VM image:
39 | * Ray Version (if running on multihost):
40 | * GKE Version (if applicable):
41 |
42 | ---
43 | Additional Information (Optional):
44 |
45 | * Model:
46 | * Dataset:
47 | * Reproduction Script: Please attach or link a minimal, self-contained, reproducible example.
48 | * Severity - Choose one: Critical (Blocking work), Medium (Workaround exists), Low (Annoying):
49 | Issue Type: (Choose one and delete the others)
50 | * Bug Report
51 | * Feature Request
52 |
53 | Description:
54 |
55 | * What happened? (Required)
56 | * What did you expect to happen? (Required)
57 | * Error messages: (Required: Paste any error messages here and include the full log if you can. If none, write "N/A".)
58 |
59 | ---
60 | Context:
61 |
62 | * Project ID:
63 | * Company:
64 | * Kithara Version:
65 | * Python Version:
66 | * JAX Version:
67 | * Keras Version:
68 | * KerasHub Version:
69 | * TPU Resource type (QR, GKE, GCE):
70 | * TPU Generation (Trillium, v5p, etc.):
71 | * Number of Chips:
72 | * TPU VM image:
73 | * Ray Version (if running on multihost):
74 | * GKE Version (if applicable):
75 |
76 | ---
77 | Additional Information (Optional):
78 |
79 | * Model:
80 | * Dataset:
81 | * Reproduction Script: Please attach or link a minimal, self-contained, reproducible example.
82 | * Severity - Choose one: Critical (Blocking work), Medium (Workaround exists), Low (Annoying):
83 |
84 | Email template to share feedback
85 | ------
86 |
87 | .. parsed-literal::
88 |
89 | * What is your overall impression of Kithara? (What do you like? What could be improved?)
90 | * What specific features or aspects are most valuable to you?
91 | * Are there any features or workflows that you find confusing or difficult?
92 | * Do you have any suggestions for new features or improvements? What will these help you do?
93 | * Is there anything else you would like to share?
94 |
95 | General How-Tos and Support
96 | ------
97 |
98 | You can also use use `Github Issues `_ to get support
99 |
100 |
101 |
102 |
103 |
104 |
--------------------------------------------------------------------------------
/docs/source/installation.rst:
--------------------------------------------------------------------------------
1 | .. _installation:
2 |
3 | Installation
4 | ============
5 |
6 | .. note::
7 |
8 | Kithara requires ``Python>=3.11``.
9 |
10 | We recommend using a virtual environment.
11 |
12 | 1. With `conda`:
13 |
14 | .. dropdown:: How to install Conda on GCP VM
15 | :open:
16 |
17 | .. code-block:: bash
18 |
19 | mkdir -p ~/miniconda3
20 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
21 | bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
22 | rm ~/miniconda3/miniconda.sh
23 | source ~/miniconda3/bin/activate
24 | conda init --all
25 |
26 | .. code-block:: bash
27 |
28 | conda create -n kithara_env python=3.11
29 | conda activate kithara_env
30 |
31 |
32 | 2. With `venv`:
33 |
34 | .. code-block:: bash
35 |
36 | sudo apt-get install -y python3.11
37 | sudo apt install python3.11-venv
38 |
39 | python3.11 -m venv kithara_env
40 | source kithara_env/bin/activate
41 |
42 | Installation on TPU
43 | ~~~~~~~~~~~~~~~~~~~
44 |
45 | .. code-block:: bash
46 |
47 | pip install kithara[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://download.pytorch.org/whl/cpu
48 |
49 | Installation on GPU
50 | ~~~~~~~~~~~~~~~~~~~
51 |
52 | .. warning::
53 |
54 | Our GPU support is still in beta. Please report any issues you encounter.
55 |
56 | .. code-block:: bash
57 |
58 | pip install kithara[gpu]
59 |
60 |
61 | Clone the repo
62 | ~~~~~~~~~~~~~~
63 |
64 | .. code-block:: bash
65 | git clone --recursive https://github.com/AI-Hypercomputer/kithara.git
66 |
--------------------------------------------------------------------------------
/docs/source/installation/.tpu_gke.rst.swp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/docs/source/installation/.tpu_gke.rst.swp
--------------------------------------------------------------------------------
/docs/source/installation/tpu_gke.rst:
--------------------------------------------------------------------------------
1 | .. _tpu_vm:
2 |
3 | Setting up Ray Cluster with TPU GKE
4 | =====================================
5 | Prerequisites
6 | -------------
7 | * Clone the Kithara repo.
8 |
9 | .. code-block:: bash
10 |
11 | git clone https://github.com/AI-Hypercomputer/kithara.git
12 | * Create a GKE Cluster with Ray add-on: https://docs.ray.io/en/latest/cluster/kubernetes/user-guides/gcp-gke-tpu-cluster.html
13 |
14 | Preparing Your GKE Cluster
15 | --------------------------
16 | * Enable GCSFuse.
17 |
18 | This step allows GCS buckets to be mounted on GKE containers as drives. This makes it easier for Kithara to save checkpoints to GCS.
19 |
20 | You can follow the instructions here: https://docs.ray.io/en/latest/cluster/kubernetes/user-guides/gke-gcs-bucket.html.
21 | * Authenticate to Your GKE Cluster:
22 |
23 | .. code-block:: bash
24 |
25 | gcloud container clusters get-credentials $CLUSTER --zone $ZONE --project $YOUR_PROJECT
26 | * Create a Hugging Face token on https://huggingface.co/docs/hub/en/security-tokens
27 | * Save the Hugging Face token to the Kubernetes cluster:
28 |
29 | .. code-block:: bash
30 |
31 | kubectl create secret generic hf-secret \
32 | --from-literal=hf_api_token=HUGGING_FACE_TOKEN
33 |
34 |
35 |
36 |
37 | Setting Up a Ray Cluster
38 | ------------------------
39 | 1. Edit one of the following manifest files:
40 |
41 | - Single-host: https://github.com/AI-Hypercomputer/kithara/blob/main/ray/TPU/GKE/single-host.yaml
42 |
43 | - Multi-host: https://github.com/AI-Hypercomputer/kithara/blob/main/ray/TPU/GKE/multi-host.yaml
44 |
45 | Make sure to replace ``YOUR_GCS_BUCKET`` with the name of the GCS bucket created in previous steps.
46 | 2. Deploy the Ray cluster:
47 |
48 | .. code-block:: bash
49 |
50 | kubectl apply -f $MANIFEST_FILE
51 | 3. Check that the cluster is running with:
52 |
53 | .. code-block:: bash
54 |
55 | kubectl get pods
56 |
57 | If everything works as expected, you should see pods running:
58 |
59 | .. code-block:: bash
60 |
61 | NAME READY STATUS RESTARTS AGE
62 | example-cluster-kuberay-head-kgxkp 2/2 Running 0 1m
63 | example-cluster-kuberay-worker-workergroup-bzrz2 2/2 Running 0 1m
64 | example-cluster-kuberay-worker-workergroup-g7k4t 2/2 Running 0 1m
65 | example-cluster-kuberay-worker-workergroup-h6zsx 2/2 Running 0 1m
66 | example-cluster-kuberay-worker-workergroup-pdf8x 2/2 Running 0 1m
67 |
68 |
69 | Running a Ray Workload
70 | ----------------------
71 | 1. Set the following environment variable:
72 |
73 | .. code-block:: bash
74 |
75 | export RAY_ADDRESS=http://localhost:8265
76 | 2. Port-forward to the Ray cluster:
77 |
78 | .. code-block:: bash
79 |
80 | kubectl port-forward svc/example-cluster-kuberay-head-svc 8265:8265 &
81 | 3. Submit a Ray job, for example:
82 |
83 | .. code-block:: bash
84 |
85 | ray job submit --working-dir . \
86 | --runtime-env-json='{"excludes": [".git", "kithara/model/maxtext/maxtext/MaxText/test_assets"]}' \
87 | -- python examples/multihost/ray/TPU/full_finetuning_example.py
88 | 4. You can visit ``http://localhost:8265`` in your browser to see the Ray dashboard and monitor job status.
89 |
90 |
91 | Clean Up
92 | --------
93 | 1. When your job is done, you can delete it by running:
94 |
95 | .. code-block:: bash
96 |
97 | kubectl delete -f $MANIFEST_FILE
98 |
99 | 2. The GKE cluster can be deleted with:
100 |
101 | .. code-block:: bash
102 |
103 | gcloud clusters delete $CLUSTER --zone $ZONE
104 |
105 |
--------------------------------------------------------------------------------
/docs/source/installation/tpu_qr.rst:
--------------------------------------------------------------------------------
1 | .. _tpu_vm:
2 |
3 |
4 | Setting up Ray Cluster with TPU Queued Resources
5 | ================================================
6 |
7 | Prerequisites
8 | -------------
9 | * Clone the Kithara repo.
10 |
11 | .. code-block:: bash
12 |
13 | git clone https://github.com/AI-Hypercomputer/kithara.git
14 |
15 | * Modify ``ray/TPU/cluster.yaml`` with your GCP project, zone, and TPU resource types.
16 |
17 | .. tip::
18 | Search for "MODIFY" in the YAML file to find required changes
19 |
20 | Setting up the Ray Cluster
21 | -------------------------
22 |
23 | 1. Create Ray Head Node
24 | ----------------------
25 |
26 | Create the Ray head node and launch the Ray cluster::
27 |
28 | ray up -y ray/TPU/QR/cluster.yaml
29 |
30 | 2. Monitor Cluster Creation
31 | --------------------------
32 |
33 | The ``ray up`` command will take a few minutes. If you lose the terminal, monitor the cluster creation process with::
34 |
35 | ray monitor ray/TPU/QR/cluster.yaml
36 |
37 | 3. Launch Ray Dashboard
38 | ----------------------
39 |
40 | Launch Ray Cluster dashboard once the ``ray_head_default`` node is ``Active``. You should see the dashboard on your ``localhost:8265``::
41 |
42 | ray dashboard ray/TPU/QR/cluster.yaml
43 |
44 | 4. Note Head Node IP
45 | -------------------
46 |
47 | In the Ray dashboard, note down the IP of your ray head node from the ``Cluster`` panel.
48 |
49 | .. note::
50 | This should be the internal IP, not the external IP of the head node. It should start with 10.x.x.x
51 |
52 | 5. Set Resource Variables
53 | ------------------------
54 |
55 | Specify the following variables with your own TPU resources spec::
56 |
57 | export ZONE="us-central2-b"
58 | export QR_NAME="my_tpu_qr"
59 | export NODE_ID="v4-a"
60 | export PROJECT="gcp_project_name"
61 | export TPU_TYPE="v4-32"
62 | export RAY_CLUSTER_IP="your_cluster_ip"
63 |
64 | 6. Create TPU VMs
65 | ----------------
66 |
67 | Create TPU VMs via QR::
68 |
69 | gcloud alpha compute tpus queued-resources create $QR_NAME \
70 | --node-id $NODE_ID \
71 | --zone $ZONE \
72 | --project $PROJECT \
73 | --accelerator-type $TPU_TYPE \
74 | --runtime-version tpu-ubuntu2204-base \
75 | --metadata-from-file='startup-script=ray/TPU/QR/qr_worker_startup_script.sh'
76 |
77 | 7. Monitor QR Status
78 | -------------------
79 |
80 | Monitor the status of the QR creation::
81 |
82 | gcloud compute tpus queued-resources describe $QR_NAME --project $PROJECT --zone $ZONE
83 |
84 | Once the status becomes ``ACTIVE``, monitor the logs to verify package installation::
85 |
86 | gcloud alpha compute tpus queued-resources ssh $QR_NAME \
87 | --project $PROJECT \
88 | --zone $ZONE \
89 | --command="sudo cat /var/log/syslog | grep startup-script" \
90 | --worker=0 \
91 | --node=all
92 |
93 | 8. Attach TPU VMs to Ray Cluster
94 | -------------------------------
95 |
96 | Once QRs are ready, attach the TPU VMs as worker nodes::
97 |
98 | gcloud alpha compute tpus queued-resources ssh $QR_NAME \
99 | --project $PROJECT \
100 | --zone $ZONE \
101 | --command="ray start --address=$RAY_CLUSTER_IP:6379 --resources='{\"tpu_host\": 1}'" \
102 | --worker=all \
103 | --node=all
104 |
105 | Troubleshooting
106 | ~~~~~~~~~~~~~~
107 |
108 | If you encounter Python or Ray version inconsistencies, check the worker node logs::
109 |
110 | gcloud alpha compute tpus queued-resources ssh $QR_NAME \
111 | --project $PROJECT \
112 | --zone $ZONE \
113 | --command="sudo cat /var/log/syslog | grep startup-script" \
114 | --worker=all \
115 | --node=all
116 |
117 | 9. Run Examples
118 | --------------
119 |
120 | Your Ray Cluster is now ready. Try examples in the ``examples/multihost/TPU`` folder::
121 |
122 | python ray/submit_job.py "python3.11 examples/multihost/ray/TPU/sft_lora_example.py" --hf-token your_token
123 |
124 | To early-stop your job::
125 |
126 | ray job stop ray_job_id
127 |
128 | 10. Remove QRs
129 | -------------
130 |
131 | To remove QRs from your Ray Cluster::
132 |
133 | gcloud compute tpus queued-resources delete $QR_NAME --project $PROJECT --zone $ZONE
134 |
135 | 11. Tear Down Cluster
136 | --------------------
137 |
138 | When finished with your ray cluster, tear it down::
139 |
140 | ray down ray/TPU/QR/cluster.yaml
--------------------------------------------------------------------------------
/docs/source/installation/tpu_vm.rst:
--------------------------------------------------------------------------------
1 | .. _tpu_vm:
2 |
3 | Setting up Ray Cluster with TPU VMs
4 | ================================================
5 |
6 | Prerequisites
7 | -------------
8 | * Clone the Kithara repo.
9 |
10 | .. code-block:: bash
11 |
12 | git clone https://github.com/AI-Hypercomputer/kithara.git
13 |
14 | * Modify ``ray/TPU/cluster.yaml`` with your GCP project, zone, and TPU resource types.
15 |
16 | .. tip::
17 | Search for "MODIFY" in the YAML file to find required changes
18 |
19 | Setting up the Ray Cluster
20 | -------------------------
21 |
22 | 1. Launch the cluster::
23 |
24 | ray up -y ray/TPU/cluster.yaml
25 |
26 | 3. Monitor setup process::
27 |
28 | ray monitor ray/TPU/cluster.yaml
29 |
30 | 4. Launch Ray dashboard::
31 |
32 | ray dashboard ray/TPU/cluster.yaml
33 |
34 | The dashboard will be available at ``localhost:8265``
35 |
36 | Troubleshooting
37 | ~~~~~~~~~~~~~~
38 | * ``update-failed`` errors typically don't affect proper node setup
39 | * Check node status by executing::
40 |
41 | ray attach cluster.yaml
42 | ray status
43 |
44 | Running Multihost Jobs
45 | ---------------------
46 |
47 | 1. Submit job::
48 |
49 | python ray/submit_job.py "python3.11 examples/multihost/ray/TPU/sft_lora_example.py" --hf-token your_token
50 |
51 | 2. To stop a job early::
52 |
53 | export RAY_ADDRESS="http://127.0.0.1:8265"
54 | ray job stop ray_job_id
55 |
56 | Cleanup
57 | -------
58 |
59 | When finished, tear down the cluster::
60 |
61 | ray down cluster.yaml
--------------------------------------------------------------------------------
/docs/source/lora.rst:
--------------------------------------------------------------------------------
1 | .. _lora:
2 |
3 | LoRA: Low-Rank Adaptation Guide
4 | ===============================
5 |
6 | What is LoRA?
7 | ------------
8 | LoRA (Low-Rank Adaptation) is a technique that makes fine-tuning large language models much more efficient. Instead of updating all model parameters during fine-tuning, LoRA:
9 |
10 | - Keeps the original model weights frozen
11 | - Adds small trainable matrices to specific model layers
12 | - Drastically reduces the number of parameters that need to be updated
13 |
14 | For example, you can fine-tune a 9B parameter model by training only about 100M parameters (roughly 1% of the original size) or less.
15 |
16 | Simple Usage
17 | ------------
18 |
19 | To train your model with LoRA, you don't have to worry about changing anything in your training script other specifying the ``lora_rank`` arg.
20 | LoRA will be applied to the q_proj and v_proj layers.::
21 |
22 | from Kithara import KerasHubModel
23 |
24 | model = KerasHubModel.from_preset(
25 | "hf://google/gemma-2-2b",
26 | lora_rank=16 # <-- One line toggle
27 | )
28 |
29 | Saving LoRA Models
30 | ----------
31 | You have three options for saving models trained with LoRA:
32 |
33 | 1. Save Only LoRA Adapters
34 | ~~~~~~~~~~
35 | Since the base model is left unchanged, you can save just the LoRA Adapters::
36 |
37 | model.save_in_hf_format(
38 | destination,
39 | only_save_adapters=True
40 | )
41 |
42 | 2. Save Base Model and Adapters Separately
43 | ~~~~~~~~~~
44 | In case you want to save the base model as well. ::
45 |
46 | model.save_in_hf_format(
47 | destination,
48 | only_save_adapters=False,
49 | save_adapters_separately=True
50 | )
51 |
52 | 3. Save Merged Model
53 | ~~~~~~~~~~
54 | Creates a single model combining base weights and adaptations::
55 |
56 | model.save_in_hf_format(
57 | destination,
58 | save_adapters_separately=False
59 | )
60 |
61 | Load LoRA Models back into HuggingFace
62 | ----------
63 |
64 | To load a model trained with LoRA back into HuggingFace, you can use the following code, where ``lora_dir`` and ``model_dir`` stores the weights saved by Kithara.::
65 |
66 | # Load adapters separately
67 | hf_model = AutoModelForCausalLM.from_pretrained(model_id)
68 | hf_model.load_adapter(lora_dir)
69 |
70 | # Load merged model
71 | hf_model = AutoModelForCausalLM.from_pretrained(model_dir)
72 |
73 |
74 | Next Steps
75 | ----------
76 | For a complete example of using LoRA with supervised fine-tuning, see the :doc:`SFT + LoRA guide `.
--------------------------------------------------------------------------------
/docs/source/models.rst:
--------------------------------------------------------------------------------
1 | .. _models:
2 |
3 | Supported Models
4 | ================
5 |
6 | We support all safetensor formatted models on `HuggingFace Hub `_ for the following architectures.
7 |
8 | .. list-table::
9 | :header-rows: 1
10 |
11 | * - Model
12 | - Size
13 | - HuggingFace Handle Example
14 | * - Gemma 2
15 | - 2B, 9B, 27B
16 | - google/gemma-2-2b
17 | * - Llama 3.1
18 | - 8B, 70B, 405B
19 | - meta-llama/Llama-3.1-8B
20 |
--------------------------------------------------------------------------------
/docs/source/observability.rst:
--------------------------------------------------------------------------------
1 | .. _observability:
2 |
3 | Observability
4 | =============
5 |
6 | Kithara supports Tensorboard and (soon) Weights and Biases for observability.
7 |
8 | Tensorboard
9 | -----------
10 |
11 | To use Tensorboard, simply specify the ``tensorboard_dir`` arg in the ``Trainer`` class to a local directory or a Google Cloud Storage bucket.
12 |
13 | To track training and evaluation performance, launch the tensorboard server with::
14 |
15 | tensorboard --logdir=your_tensorboard_dir
16 |
--------------------------------------------------------------------------------
/docs/source/optimizers.rst:
--------------------------------------------------------------------------------
1 | .. _optimizers:
2 |
3 | Optimizer Support Guide
4 | =======================
5 |
6 | Choosing an Optimizer
7 | ---------------------
8 | Effective model fine-tuning relies heavily on the chosen optimization algorithm. Kithara offers flexibility by supporting optimizers from two widely-used libraries:
9 |
10 | * **Keras Optimizers**: Standard, easy-to-use optimizers integrated directly into the Keras framework.
11 | * **Optax Optimizers**: A powerful and composable optimization library designed for JAX (which Keras can leverage), offering fine-grained control over the gradient update process.
12 |
13 | This guide explains how to utilize both types within the Kithara ``Trainer``.
14 |
15 | Using Keras Optimizers
16 | ----------------------
17 | You can directly use any optimizer available in the ``keras.optimizers`` module. Simply instantiate the desired optimizer with its configuration parameters and pass the instance to the ``Trainer``.
18 |
19 | Here's how you can use the ``AdamW`` optimizer from Keras::
20 |
21 | import keras
22 | from Kithara import Trainer # Assuming Trainer is in Kithara namespace
23 |
24 | # Assume 'model' and 'train_dataloader' are already defined
25 | # Define hyperparameters
26 | LEARNING_RATE = 1e-4
27 | WEIGHT_DECAY = 0.01
28 |
29 | # Instantiate a Keras optimizer
30 | optimizer = keras.optimizers.AdamW(
31 | learning_rate=LEARNING_RATE,
32 | weight_decay=WEIGHT_DECAY,
33 | # Add other Keras optimizer arguments as needed
34 | )
35 |
36 | # Initialize the Trainer
37 | trainer = Trainer(
38 | model=model,
39 | optimizer=optimizer, # Pass the Keras optimizer instance
40 | train_dataloader=train_dataloader,
41 | epochs=1,
42 | # ... other Trainer arguments (log_steps_interval, etc.)
43 | )
44 |
45 | # Start training
46 | trainer.train()
47 |
48 | For a comprehensive list of available Keras optimizers and their specific arguments, consult the official `Keras Optimizers documentation `_.
49 |
50 | Using Optax Optimizers
51 | --------------------
52 | Optax provides a highly flexible approach where optimizers are often built by chaining together multiple gradient transformations (like learning rate scheduling, gradient clipping, or weight decay) with a core optimizer function. This allows for complex optimization strategies.
53 |
54 | To use Optax, import the ``optax`` library, construct your optimizer (potentially using ``optax.chain``), and pass the resulting object to the ``Trainer``.
55 |
56 | This example demonstrates using ``optax.adamw`` combined with gradient clipping::
57 |
58 | import optax
59 | from Kithara import Trainer # Assuming Trainer is in Kithara namespace
60 |
61 | # Assume 'model' and 'train_dataloader' are already defined
62 | # Define hyperparameters
63 | LEARNING_RATE = 1e-4
64 | WEIGHT_DECAY = 0.01
65 | CLIP_NORM = 1.0
66 |
67 | # Create an Optax optimizer chain
68 | # Common practice: apply transformations like clipping before the main optimizer step
69 | optimizer = optax.chain(
70 | optax.clip(CLIP_NORM),
71 | optax.adamw(learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY),
72 | # Add other Optax transformations as needed
73 | )
74 |
75 | # Initialize the Trainer
76 | trainer = Trainer(
77 | model=model,
78 | optimizer=optimizer, # Pass the Optax optimizer instance
79 | train_dataloader=train_dataloader,
80 | epochs=1,
81 | # ... other Trainer arguments (log_steps_interval, etc.)
82 | )
83 |
84 | # Start training
85 | trainer.train()
86 |
87 | Explore the extensive capabilities of Optax, including its various optimizers and gradient transformations, in the official `Optax API documentation `_.
88 |
89 | Key Takeaway
90 | ------------
91 | Whether you choose a Keras or an Optax optimizer, the integration with Kithara is straightforward:
92 |
93 | 1. Instantiate your chosen optimizer object with the desired configuration.
94 | 2. Pass this optimizer object to the ``optimizer`` argument when initializing the ``Trainer``.
95 |
96 | The ``Trainer`` will then handle the application of the optimizer during the training loop.
97 |
--------------------------------------------------------------------------------
/docs/source/overviewxyz.rst:
--------------------------------------------------------------------------------
1 | .. _overviewxyz:
2 |
3 | Overview
4 | =====
5 |
6 |
7 |
--------------------------------------------------------------------------------
/docs/source/packing.rst:
--------------------------------------------------------------------------------
1 | .. _packing:
2 |
3 | Dataset Packing
4 | ===============
5 |
6 | What is Dataset Packing?
7 | -----------------------
8 |
9 | Dataset packing combines multiple short sequences into longer ones during training, reducing overall training time when your data consists of sequences shorter than the model's maximum sequence length.
10 |
11 | When to Use?
12 | ------------
13 | Use Packing will not hurt performance, so you should use it whenever possible. However, Packing is only supported for MaxText models.
14 |
15 | How to Use?
16 | ---------
17 |
18 | Convert your dataset to a packed format using ``.to_packed_dataset()``:
19 |
20 | .. code-block:: python
21 |
22 | dataset = Kithara.TextCompletionDataset(
23 | ray_dataset,
24 | tokenizer="hf://google/gemma-2-2b",
25 | max_seq_len=4096
26 | ).to_packed_dataset()
27 |
28 | Important Notes
29 | -------------
30 |
31 | * No other changes to your training script are needed
32 | * Loss calculation remains accurate as Kithara uses per-token loss
33 | * Adjust your learning rate and training steps since packing reduces the effective number of samples in a batch
--------------------------------------------------------------------------------
/docs/source/quickstart.rst:
--------------------------------------------------------------------------------
1 | .. _quickstart:
2 |
3 | Quick Start
4 | ===========
5 |
6 | This guide demonstrates how to fine-tune a Gemma2 2B model using LoRA with a toy dataset.
7 |
8 | The script can also be found on `Github `_.
9 |
10 | Prefer running in Colab? Check out the `this Colab Notebook instead `_.
11 |
12 | Setup
13 | -----
14 | Log into HuggingFace and import required packages::
15 |
16 | from huggingface_hub import login
17 | login(token="your_hf_token", add_to_git_credential=False)
18 |
19 | import os
20 | os.environ["KERAS_BACKEND"] = "jax"
21 | import keras
22 | import ray
23 | from kithara import (
24 | KerasHubModel,
25 | Dataloader,
26 | Trainer,
27 | SFTDataset,
28 | )
29 |
30 | .. tip::
31 | New to HuggingFace? First create an access token, `apply access `_ to the Gemma2 HuggingFace model which will be used in this example.
32 |
33 | Quick Usage
34 | ----------
35 |
36 | 1. Create the Model::
37 |
38 | model = KerasHubModel.from_preset(
39 | "hf://google/gemma-2-2b",
40 | precision="mixed_bfloat16",
41 | lora_rank=4,
42 | )
43 |
44 | 2. Prepare Dataset::
45 |
46 | dataset_items = [
47 | {
48 | "prompt": "What is your name?",
49 | "answer": "My name is Kithara",
50 | }
51 | for _ in range(1000)
52 | ]
53 | dataset = ray.data.from_items(dataset_items)
54 | train_ds, eval_ds = dataset.train_test_split(test_size=500)
55 |
56 | 3. Create Dataset and Optimizer::
57 |
58 | train_dataset = SFTDataset(
59 | train_ds,
60 | tokenizer_handle="hf://google/gemma-2-2b",
61 | max_seq_len=4096,
62 | )
63 |
64 | eval_dataset = SFTDataset(
65 | eval_ds,
66 | tokenizer_handle="hf://google/gemma-2-2b",
67 | max_seq_len=4096,
68 | )
69 |
70 | optimizer = keras.optimizers.AdamW(
71 | learning_rate=2e-4,
72 | weight_decay=0.01
73 | )
74 |
75 | 4. Create Dataloaders::
76 |
77 | train_dataloader = Dataloader(
78 | train_dataset,
79 | per_device_batch_size=1
80 | )
81 |
82 | eval_dataloader = Dataloader(
83 | eval_dataset,
84 | per_device_batch_size=1
85 | )
86 |
87 | 5. Initialize and Run Trainer::
88 |
89 | trainer = Trainer(
90 | model=model,
91 | optimizer=optimizer,
92 | train_dataloader=train_dataloader,
93 | eval_dataloader=eval_dataloader,
94 | steps=200, # You can also use epochs instead of steps
95 | eval_steps_interval=10,
96 | max_eval_samples=50,
97 | log_steps_interval=10,
98 | )
99 |
100 | trainer.train()
101 |
102 | 6. Test the Model::
103 |
104 | pred = model.generate(
105 | "What is your name?",
106 | max_length=30,
107 | tokenizer_handle="hf://google/gemma-2-2b",
108 | return_decoded=True
109 | )
110 | print("Tuned model generates:", pred)
111 |
112 | Running This Example on Single Host
113 | ------------------------------------------------
114 |
115 | Simple copy paste this script from the Github repo, and run it on your TPU VM::
116 |
117 | python examples/singlehost/quick_start.py
118 |
119 |
120 | Running This Example on Multi-host
121 | ---------------------------------
122 |
123 | Kithara works with any accelerator orchestrator. However, if you are new to distributed training, we provide guide for :doc:`multihost training with Ray `.
124 |
125 | Once you set up a Ray cluster, clone the Github Repo, and run this example with your Ray Cluster::
126 |
127 | python ray/submit_job.py "python3.11 examples/multihost/ray/TPU/quick_start.py" --hf-token your_token
128 |
129 |
130 | Next Steps
131 | -----------
132 |
133 | Check out the :doc:`Finetuning Guide ` to craft out your own finetuning job.
134 |
--------------------------------------------------------------------------------
/docs/source/scaling_with_ray.rst:
--------------------------------------------------------------------------------
1 | .. _scaling_with_ray:
2 |
3 | Running on Multihost?
4 | =====================
5 |
6 | Kithara works with any accelerator orchestrator. However, if you are new to distributed training, we provide guide for multihost training with `Ray `_.
7 |
8 | .. admonition:: What is Ray?
9 |
10 | Ray is a popular and powerful tool for running distributed TPU and GPU workloads, offering:
11 |
12 | - Dashboard for job queueing and log visualization
13 | - Streamlined environment setup and file syncing
14 | - Simple command interface for multihost workloads
15 |
16 | Follow the instructions below to set up your Ray cluster.
17 |
18 | * :doc:`Set up Ray Cluster with TPU VMs `
19 | * :doc:`Set up Ray Cluster with TPU QRs `
20 | * :doc:`Set up Ray Cluster with TPU GKE `
21 |
22 | .. toctree::
23 | :maxdepth: 2
24 | :caption: Ray Cluster Setup
25 | :hidden:
26 |
27 | Ray Cluster with TPU VMs
28 | Ray Cluster with TPU QRs
29 | Ray Cluster with TPU GKE
30 |
31 |
32 | After setting up your Ray Cluster, you can scale up your workload to run on multiple hosts using the following recipe.
33 |
34 | .. code-block:: python
35 | :caption: my_multihost_ray_job.py
36 |
37 | import ray
38 | import jax
39 |
40 | ray.init()
41 |
42 | num_chips_per_host = 4 # <--IMPORTANT: Use 4 for v4 and v5, 8 for v4e and v5e TPU
43 | num_tpu_hosts = int(ray.cluster_resources()["TPU"] / num_chips_per_host)
44 | print(f"{num_tpu_hosts=}")
45 |
46 | @ray.remote(resources={"TPU": num_chips_per_host})
47 | def main():
48 |
49 | # HuggingFace login
50 | from huggingface_hub import login
51 | import os
52 |
53 | hf_token = os.getenv("HF_TOKEN")
54 | if hf_token:
55 | login(token=hf_token, add_to_git_credential=False)
56 |
57 | # Let JAX know that we are running a distributed job
58 | jax.distributed.initialize()
59 |
60 | # No need to change your single host job script, simply use it as it is.
61 | from examples.singlehost.quick_start import run_workload
62 |
63 | # Run this workload on all hosts. Don't worry, we are handling
64 | # model sharding and batch sharding for you.
65 | run_workload()
66 |
67 |
68 | ray.get([main.remote() for i in range(num_tpu_hosts)])
69 |
70 | ray.shutdown()
71 |
72 | To launch this job, run the following command::
73 |
74 | ray job submit \
75 | --address="http://localhost:8265" \
76 | --runtime-env-json='{"env_vars": {"HF_TOKEN": "your_token_here", "HF_HUB_ENABLE_HF_TRANSFER": "1"}}' \
77 | -- "python3.11 my_multihost_ray_job.py"
78 |
79 | Equivalently, use the Kithara `helper script `_::
80 |
81 | python ray/submit_job.py "python3.11 my_multihost_ray_job.py" --hf-token your_token
82 |
83 | Check out some multihost examples:
84 |
85 | - `Multihost Quickstart Example `_
86 | - `Multihost Continued Pretraining Example `_
87 | - `Multihost SFT+LoRA Example `_
88 |
--------------------------------------------------------------------------------
/docs/source/serve_with_vllm.rst:
--------------------------------------------------------------------------------
1 | .. serve_with_vllm:
2 |
3 | Serve with vLLM
4 | ==============
5 |
6 | The fine-tuned model can be served on `vLLM:GPU `_ or `vLLM:TPU `_.
7 |
8 | Exporting your model
9 | --------------------
10 | In order to serve your finetuned model, first make sure that your output checkpoints are exported to a GCS bucket or persistent volume. You can do this byoverriding the `model_output_dir` parameter in the config yaml:
11 |
12 | .. code-block:: python
13 |
14 | model_output_dir: "gs://bucket_name/ckpt/"
15 |
16 |
17 | Starting vLLM with a Finetuned Model
18 | ------------------------------------
19 | Next, start your vLLM deployment by attaching the previous output volume as a mounted drive. Please visit our GCS guide (link TBD) for details.
20 |
21 | You can start vLLM with a path to the directory containing the model files. For example if you mounted the volume at `/model` and the model checkpoints are located under the `checkpoint` directory:
22 |
23 | .. code-block:: bash
24 |
25 | vllm serve /model/checkpoint --tensor_parallel_size 8 --max-model-length 4096
26 |
27 | where `tensor_parallel_size` is equal to the number of TPU chips available. You can adjust `max-model-length` depending on the availably HBM.
28 |
29 | If loading the model is successful, you should see the following output in the console:
30 |
31 | .. code-block:: bash
32 |
33 | INFO 03-05 23:24:13 api_server.py:756] Using supplied chat template:
34 | INFO 03-05 23:24:13 api_server.py:756] None
35 | INFO 03-05 23:24:13 launcher.py:21] Available routes are:
36 | INFO 03-05 23:24:13 launcher.py:29] Route: /openapi.json, Methods: GET, HEAD
37 | INFO 03-05 23:24:13 launcher.py:29] Route: /docs, Methods: GET, HEAD
38 | INFO 03-05 23:24:13 launcher.py:29] Route: /docs/oauth2-redirect, Methods: GET, HEAD
39 | INFO 03-05 23:24:13 launcher.py:29] Route: /redoc, Methods: GET, HEAD
40 | INFO 03-05 23:24:13 launcher.py:29] Route: /health, Methods: GET
41 | INFO 03-05 23:24:13 launcher.py:29] Route: /ping, Methods: POST, GET
42 | INFO 03-05 23:24:13 launcher.py:29] Route: /tokenize, Methods: POST
43 | INFO 03-05 23:24:13 launcher.py:29] Route: /detokenize, Methods: POST
44 | INFO 03-05 23:24:13 launcher.py:29] Route: /v1/models, Methods: GET
45 | INFO 03-05 23:24:13 launcher.py:29] Route: /version, Methods: GET
46 | INFO 03-05 23:24:13 launcher.py:29] Route: /v1/chat/completions, Methods: POST
47 | INFO 03-05 23:24:13 launcher.py:29] Route: /v1/completions, Methods: POST
48 | INFO 03-05 23:24:13 launcher.py:29] Route: /v1/embeddings, Methods: POST
49 | INFO 03-05 23:24:13 launcher.py:29] Route: /pooling, Methods: POST
50 | INFO 03-05 23:24:13 launcher.py:29] Route: /score, Methods: POST
51 | INFO 03-05 23:24:13 launcher.py:29] Route: /v1/score, Methods: POST
52 | INFO 03-05 23:24:13 launcher.py:29] Route: /rerank, Methods: POST
53 | INFO 03-05 23:24:13 launcher.py:29] Route: /v1/rerank, Methods: POST
54 | INFO 03-05 23:24:13 launcher.py:29] Route: /v2/rerank, Methods: POST
55 | INFO 03-05 23:24:13 launcher.py:29] Route: /invocations, Methods: POST
56 | INFO: Started server process [67252]
57 | INFO: Waiting for application startup.
58 | INFO: Application startup complete.
59 | INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
60 |
61 |
62 | Send a request to the vLLM server to verify that it's working:
63 |
64 | .. code-block:: bash
65 |
66 | curl http://localhost:8000/v1/completions \
67 | -H "Content-Type: application/json" \
68 | -d '{
69 | "model": "/model/checkpoint/",
70 | "prompt": "your prompt",
71 | "max_tokens": 256,
72 | "temperature": 0.7
73 | }'
74 |
75 |
--------------------------------------------------------------------------------
/docs/source/temp.rst:
--------------------------------------------------------------------------------
1 | #.. _overviewxyz:
2 |
3 | 👋 Welcome to Kithara!
4 | ===================================
5 | .. note::
6 |
7 | This project is under active development.
8 |
9 | Introduction
10 | ------------
11 |
12 | What is Kithara?
13 | ~~~~~~~~~~~~~~~~
14 |
15 | **Kithara** is a lightweight library offering building blocks and recipes for tuning popular open source LLMs like Llama 3 and Gemma 2 on Google TPUs.
16 |
17 | Kithara makes post-training easier to do on TPUs. Post-training algorithms include continued pretraining, fine-tuning, PEFT, RLHF, and quantization. These algorithms adapt a pre-trained language model to a specific task or domain by training it on a custom dataset. Using TPUs provides significant advantages in terms of performance, cost-effectiveness, and scalability, enabling faster training times and the ability to work with larger models and datasets.
18 |
19 |
--------------------------------------------------------------------------------
/docs/source/troubleshooting.rst:
--------------------------------------------------------------------------------
1 | .. _troubleshooting:
2 |
3 | Troubleshooting
4 | ===============
5 |
6 | 1. Disk OOM when loading HF model checkpoint
7 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
8 |
9 | First try emptying your cache by running the following code on your Ray Cluster.
10 |
11 | .. code-block:: python
12 |
13 | import shutil
14 | shutil.rmtree("/home/ubuntu/.cache/huggingface/hub/", ignore_errors=True)
15 | shutil.rmtree("/home/ubuntu/.keras/models", ignore_errors=True)
16 |
17 | If you are using a single VM, the path may be different.
18 |
19 | .. code-block:: python
20 |
21 | import shutil
22 | shutil.rmtree("~/.cache/huggingface/hub/", ignore_errors=True)
23 | shutil.rmtree("~/.keras/models", ignore_errors=True)
24 |
25 | If emptying the cache still doesn't help, try attaching a disk to your VM and change HF cache directory using the environment variable ``export HF_HOME=``.
26 |
27 | You may have to copy your HF token to this new cache directory with ``cp .cache/huggingface/token /token``.
28 |
29 | 2. Permission denied error when uploading checkpoint to GCS
30 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
31 |
32 | First make sure that your bucket is in the same project as your TPU VM.
33 |
34 | Otherwise, verify your current authentication:
35 |
36 | .. code-block:: bash
37 |
38 | gcloud auth list
39 | gsutil ls gs://your_bucket
40 |
41 | For your Python code, you likely need to ensure you're using the same credentials.
42 |
43 | .. code-block:: bash
44 |
45 | gcloud auth application-default login
46 |
47 | 3. jaxlib.xla_extension.XlaRuntimeError errors
48 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
49 |
50 | Try uninstall and reinstalling ``jax`` and ``jaxlib``
51 |
52 | .. code-block:: bash
53 |
54 | pip uninstall jax jaxlib
55 | pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
56 | pip install libtpu-nightly==0.1.dev20250128+nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
57 |
58 | 4. Unable to initialize backend 'tpu': INTERNAL: Failed to get global TPU topology.
59 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
60 |
61 | Try adding ``JAX_PLATFORMS=''`` to your environment variables.
--------------------------------------------------------------------------------
/examples/example_datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from typing import Tuple
18 | from datasets import load_dataset
19 | import ray
20 | from ray.data import Dataset
21 |
22 |
23 | def example_datasets(option: str) -> Tuple[ray.data.Dataset, ray.data.Dataset]:
24 | """
25 | Load examples datasets based on the specified option.
26 |
27 | Returns:
28 | Tuple[ray.data.Dataset, ray.data.Dataset]: A tuple of (training_dataset, evaluation_dataset)
29 |
30 | Example usage:
31 | train_ds, eval_ds = example_datasets("hf")
32 | """
33 | valid_options = {"hf", "finetune_toy", "sft_toy", "files"}
34 | if option not in valid_options:
35 | raise ValueError(
36 | f"Invalid option: {option}. Must be one of {', '.join(valid_options)}"
37 | )
38 |
39 | if option == "hf":
40 | return _load_huggingface_dataset()
41 | elif option == "finetune_toy":
42 | return _create_finetune_toy_dataset()
43 | elif option == "sft_toy":
44 | return _create_sft_toy_dataset()
45 | elif option == "files":
46 | raise NotImplementedError(
47 | "Loading from files is not yet implemented. Please use another option."
48 | )
49 |
50 |
51 | def _load_huggingface_dataset() -> Tuple[Dataset, Dataset]:
52 | """Load the C4 dataset from HuggingFace."""
53 | hf_train_dataset = load_dataset("allenai/c4", "en", split="train", streaming=True)
54 | hf_val_dataset = load_dataset(
55 | "allenai/c4", "en", split="validation", streaming=True
56 | )
57 |
58 | return (
59 | ray.data.from_huggingface(hf_train_dataset),
60 | ray.data.from_huggingface(hf_val_dataset),
61 | )
62 |
63 |
64 | def _create_finetune_toy_dataset() -> Tuple[Dataset, Dataset]:
65 | """Create a toy dataset for finetuning with 1000 examples."""
66 | dataset_items = [
67 | {"text": f"{i} What is your name? My name is Mary."} for i in range(1000)
68 | ]
69 | dataset = ray.data.from_items(dataset_items)
70 | return dataset.train_test_split(test_size=500)
71 |
72 |
73 | def _create_sft_toy_dataset() -> Tuple[Dataset, Dataset]:
74 | """Create a toy dataset for supervised fine-tuning with 1000 examples."""
75 | dataset_items = [{
76 | "prompt": "What is your name?",
77 | "answer": "My name is Mary",
78 | } for _ in range(1000)]
79 | dataset = ray.data.from_items(dataset_items)
80 | return dataset.train_test_split(test_size=500)
81 |
--------------------------------------------------------------------------------
/examples/multihost/ray/GPU/sft_lora_example.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import ray
18 | from typing import List, Any
19 | from kithara.distributed.data import split_dataset
20 |
21 | ray.init()
22 |
23 | num_chips_per_host = 4
24 | num_gpu_devices = int(ray.cluster_resources()["GPU"])
25 | print(f"{num_gpu_devices=}")
26 |
27 |
28 | @ray.remote(num_gpus=num_chips_per_host)
29 | def main(train_ds, eval_ds, split_data_across_host):
30 | from huggingface_hub import login
31 | import os
32 |
33 | hf_token = os.getenv("HF_TOKEN")
34 | if hf_token:
35 | login(token=hf_token, add_to_git_credential=False)
36 |
37 | from examples.singlehost.full_finetuning_example import run_workload
38 |
39 | run_workload(
40 | train_source=train_ds,
41 | eval_source=eval_ds,
42 | dataset_is_sharded_per_host=split_data_across_host,
43 | )
44 |
45 |
46 | # Create mulit-host datasets
47 | dataset_items = [
48 | {
49 | "prompt": "What is your name?",
50 | "answer": "My name is Mary",
51 | }
52 | for _ in range(1000)
53 | ]
54 | dataset = ray.data.from_items(dataset_items)
55 | train_ds, eval_ds = dataset.train_test_split(test_size=500)
56 |
57 | split_data_across_host = False
58 | if split_data_across_host:
59 | train_ds: List[Any] = split_dataset(train_ds, num_hosts=num_gpu_devices)
60 | eval_ds: List[Any] = split_dataset(eval_ds, num_hosts=num_gpu_devices)
61 | ray.get(
62 | [
63 | main.remote(train_ds[i], eval_ds[i], split_data_across_host)
64 | for i in range(num_gpu_devices)
65 | ]
66 | )
67 | else:
68 | ray.get(
69 | [
70 | main.remote(train_ds, eval_ds, split_data_across_host)
71 | for i in range(num_gpu_devices)
72 | ]
73 | )
74 |
75 | ray.shutdown()
76 |
--------------------------------------------------------------------------------
/examples/multihost/ray/TPU/continued_pretraining_example.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | """Full parameter continued pretraining a Gemma2 9B model.
18 |
19 | This script demonstrates how to
20 |
21 | 1. Load HuggingFace Gemma2 checkpoint
22 | 2. Loading large HuggingFace dataset
23 | 3. Configure data loading and preprocessing
24 | 4. Run training across TPU/GPU devices
25 | 5. Save checkpoint to GCS periodically
26 | 6. Generate text using the trained model
27 | 7. Save model in HuggingFace format to GCS
28 |
29 | This script can be run on both single-host and multi-host.
30 | For mulit-host set up, please follow https://kithara.readthedocs.io/en/latest/scaling_with_ray.html.
31 |
32 | Singlehost: python examples/singlehost/continued_pretraining_example.py
33 | Multihost: python ray/submit_job.py "python3.11 examples/multihost/ray/TPU/continued_pretraining_example.py" --hf-token your_token
34 | """
35 |
36 | import ray
37 | import jax
38 |
39 | ray.init()
40 |
41 | # Verify TPU resources
42 | num_chips_per_host = 4 # <-- IMPORTANT: Use 4 for v4 and v5, 8 for v4e and v5e
43 | num_tpu_hosts = int(ray.cluster_resources()["TPU"] / num_chips_per_host)
44 | print(f"{num_tpu_hosts=}")
45 |
46 |
47 | @ray.remote(resources={"TPU": num_chips_per_host})
48 | def main():
49 |
50 | import subprocess
51 |
52 | subprocess.run(["rm", "-rf", "/tmp/libtpu_lockfile", "/tmp/tpu_logs"])
53 | subprocess.run(["rm", "-rf", "/home/ubuntu/.cache/huggingface/hub"])
54 |
55 | # HuggingFace login
56 | from huggingface_hub import login
57 | import os
58 |
59 | hf_token = os.getenv("HF_TOKEN")
60 | if hf_token:
61 | login(token=hf_token, add_to_git_credential=False)
62 |
63 | jax.distributed.initialize()
64 |
65 | from singlehost.continued_pretraining_example import run_workload
66 |
67 | # Save model outputs to GCS
68 | run_workload(
69 | gemma2_model_size="9b", model_output_dir="gs://wenxindong-vm/model_output"
70 | )
71 |
72 |
73 | ray.get([main.remote() for i in range(num_tpu_hosts)])
74 |
75 | ray.shutdown()
76 |
--------------------------------------------------------------------------------
/examples/multihost/ray/TPU/full_finetuning_example.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | """Full parameter finetune a Gemma2 9B model.
18 |
19 | This script demonstrates how to:
20 | 1. Set up a Gemma2 model for full parameter finetuning
21 | 2. Load HuggingFace Gemma2 checkpoint
22 | 3. Configure data loading and preprocessing
23 | 4. Run training across TPU/GPU devices
24 | 5. Save checkpoint to GCS periodically
25 | 6. Generate text using the trained model
26 | 7. Save model in HuggingFace format to GCS
27 |
28 | This script should be run on multihost, since gemma2-9b will not fit on a single host. However,
29 | you can change the model to `gemma2-2b` to run on single host.
30 |
31 | Singlehost: python examples/singlehost/full_finetuning_example.py
32 | Multihost: python ray/submit_job.py "python3.11 examples/multihost/ray/TPU/full_finetuning_example.py" --hf-token your_token
33 |
34 | """
35 |
36 | import ray
37 | from absl import app
38 | from typing import List, Any
39 | from kithara.config.pyconfig import load_config
40 | from kithara.distributed.data import split_dataset
41 | import jax
42 |
43 | ray.init()
44 |
45 | # Verify TPU resources
46 | num_chips_per_host = 4 # 4 for v4 and v5, 8 for v4e and v5e
47 | num_tpu_devices = int(ray.cluster_resources()["TPU"])
48 | num_tpu_hosts = num_tpu_devices // num_chips_per_host
49 | print(f"{num_tpu_devices=}")
50 | print(f"{num_tpu_hosts=}")
51 |
52 |
53 | @ray.remote(resources={"TPU": num_chips_per_host})
54 | def run(config, train_ds, eval_ds, dataset_is_sharded_per_host):
55 |
56 | import subprocess
57 |
58 | subprocess.run(["rm", "-rf", "/tmp/libtpu_lockfile", "/tmp/tpu_logs"])
59 |
60 | # HuggingFace login
61 | from huggingface_hub import login
62 | import os
63 |
64 | hf_token = os.getenv("HF_TOKEN")
65 | if hf_token:
66 | login(token=hf_token, add_to_git_credential=False)
67 |
68 | jax.distributed.initialize()
69 |
70 | # Run workload in SPMD mode
71 | from examples.singlehost.full_finetuning_example import run_workload
72 |
73 | run_workload(
74 | config=config,
75 | train_source=train_ds,
76 | eval_source=eval_ds,
77 | dataset_is_sharded_per_host=dataset_is_sharded_per_host,
78 | )
79 |
80 |
81 | def main(argv):
82 | config = load_config(argv)
83 |
84 | # Create multi-host datasets
85 | dataset_items = [
86 | {"text": f"{i} What is your name? My name is Mary."} for i in range(1000)
87 | ]
88 | dataset = ray.data.from_items(dataset_items)
89 | train_ds, eval_ds = dataset.train_test_split(test_size=500)
90 |
91 | split_data_across_host = False
92 | if split_data_across_host:
93 | train_ds: List[Any] = split_dataset(train_ds, num_hosts=num_tpu_hosts)
94 | eval_ds: List[Any] = split_dataset(eval_ds, num_hosts=num_tpu_hosts)
95 | ray.get(
96 | [
97 | run.remote(config, train_ds[i], eval_ds[i], split_data_across_host)
98 | for i in range(num_tpu_hosts)
99 | ]
100 | )
101 | else:
102 | ray.get(
103 | [
104 | run.remote(config, train_ds, eval_ds, split_data_across_host)
105 | for _ in range(num_tpu_hosts)
106 | ]
107 | )
108 |
109 | ray.shutdown()
110 |
111 |
112 | if __name__ == "__main__":
113 | app.run(main)
114 |
--------------------------------------------------------------------------------
/examples/multihost/ray/TPU/quick_start.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | """Quick Start Example
18 |
19 | This script demonstrates how to run LoRA SFT on a toy dataset
20 |
21 | 1. Load HuggingFace Gemma2 checkpoint
22 | 2. Configure data loading and preprocessing
23 | 3. Run training across TPU/GPU devices
24 |
25 | This script can be run on both single-host and multi-host.
26 | For mulit-host set up, please follow https://kithara.readthedocs.io/en/latest/scaling_with_ray.html.
27 |
28 |
29 | Singlehost: python examples/singlehost/quick_start.py
30 | Multihost: python ray/submit_job.py "python3.11 examples/multihost/ray/TPU/quick_start.py" --hf-token your_token
31 | """
32 |
33 | import ray
34 | import jax
35 |
36 | ray.init()
37 |
38 | num_chips_per_host = 4 # <--IMPORTANT: Use 4 for v4 and v5, 8 for v4e and v5e
39 | num_tpu_hosts = int(ray.cluster_resources()["TPU"] / num_chips_per_host)
40 | print(f"{num_tpu_hosts=}")
41 |
42 | @ray.remote(resources={"TPU": num_chips_per_host})
43 | def main():
44 |
45 | import subprocess
46 |
47 | # This is not strictly necessary, but helps to remove TPU deadlocks if you ever run into them.
48 | subprocess.run(["rm", "-rf", "/tmp/libtpu_lockfile", "/tmp/tpu_logs"])
49 |
50 | # HuggingFace login
51 | from huggingface_hub import login
52 | import os
53 |
54 | hf_token = os.getenv("HF_TOKEN")
55 | if hf_token:
56 | login(token=hf_token, add_to_git_credential=False)
57 |
58 | # Let JAX know that we are running a distributed job
59 | jax.distributed.initialize()
60 |
61 | # No need to change your single host job script, simply use it as it is.
62 | from examples.singlehost.quick_start import run_workload
63 |
64 | # Run this workload on all hosts. Don't worry, we are handling
65 | # all the model sharding and batch sharding for you.
66 | run_workload()
67 |
68 | ray.get([main.remote() for i in range(num_tpu_hosts)])
69 |
70 | ray.shutdown()
71 |
--------------------------------------------------------------------------------
/examples/multihost/ray/TPU/sft_lora_example.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | """SFT a Gemma2 2B model using LoRA on TPU or GPU on an Alpaca Dataset
18 |
19 | This script demonstrates how to:
20 | 1. Set up a Gemma2 model for LoRA SFT
21 | 2. Load HuggingFace Gemma2 checkpoint
22 | 3. Load HuggingFace Dataset
23 | 4. Configure data loading and preprocessing
24 | 5. Run training across TPU/GPU devices
25 | 6. Save the LoRA adapters
26 |
27 | This script can be run on both single-host and multi-host.
28 | For mulit-host set up, please follow https://kithara.readthedocs.io/en/latest/scaling_with_ray.html.
29 |
30 | Singlehost: python examples/singlehost/sft_lora_example.py
31 | Multihost: python ray/submit_job.py "python3.11 examples/multihost/ray/TPU/sft_lora_example.py" --hf-token your_token
32 | """
33 |
34 | import ray
35 | import jax
36 |
37 | ray.init()
38 |
39 | # Verify TPU resources
40 | num_chips_per_host = 4 # <-- IMPORTANT: Use 4 for v4 and v5, 8 for v4e and v5e
41 | num_tpu_hosts = int(ray.cluster_resources()["TPU"] / num_chips_per_host)
42 | print(f"{num_tpu_hosts=}")
43 |
44 |
45 | @ray.remote(resources={"TPU": num_chips_per_host})
46 | def main():
47 |
48 | import subprocess
49 |
50 | subprocess.run(["rm", "-rf", "/tmp/libtpu_lockfile", "/tmp/tpu_logs"])
51 |
52 |
53 | # HuggingFace login
54 | from huggingface_hub import login
55 | import os
56 |
57 | hf_token = os.getenv("HF_TOKEN")
58 | if hf_token:
59 | login(token=hf_token, add_to_git_credential=False)
60 |
61 | jax.distributed.initialize()
62 |
63 | # Run workload in SPMD mode
64 | from singlehost.sft_lora_example import run_workload
65 |
66 | # Save your model in cloud storage. Use None to skip model saving.
67 | run_workload(model_output_dir=None)
68 |
69 |
70 | ray.get([main.remote() for i in range(num_tpu_hosts)])
71 | ray.shutdown()
72 |
--------------------------------------------------------------------------------
/examples/multihost/ray/TPU/sft_lora_example.yaml:
--------------------------------------------------------------------------------
1 | # Launch this example using the following command:
2 |
3 | # Important: Replace v5e with your tpu generation
4 | # python ray/submit_job.py "python3.11 kithara/config/launcher.py --override_config=examples/multihost/ray/TPU/sft_lora_example.yaml --override tpu_generation=v5e" --hf-token your_token
5 |
6 | # -------------------------
7 | # Orchestration Configuration
8 | # -------------------------
9 | single_host: False
10 |
11 | # -------------------------
12 | # Model Configuration
13 | # -------------------------
14 | seq_len: 8192
15 | model_handle: "hf://google/gemma-2-9b"
16 |
17 | # -------------------------
18 | # Training Settings
19 | # -------------------------
20 | training_steps: 100
21 | eval_steps_interval: 20
22 | log_steps_interval: 10
23 |
24 | # -------------------------
25 | # Checkpointing
26 | # -------------------------
27 | save_checkpoint_interval: 20
28 | max_checkpoints_to_keep: 3
29 | checkpoint_dir: "gs://wenxindong-tpu-prod-env-multipod-bucket/checkpoints/"
30 |
31 | # -------------------------
32 | # Logging
33 | # -------------------------
34 | tensorboard_dir: "gs://wenxindong-tpu-prod-env-multipod-bucket/tensorboard_dir/"
35 |
36 | # -------------------------
37 | # Train Dataset
38 | # -------------------------
39 | train_dataset_column_mapping: {"prompt":"input", "answer":"output"}
40 | train_eval_split: 50 # Use 50 test samples
41 |
42 | # -------------------------
43 | # Output Configuration
44 | # -------------------------
45 | model_output_dir: "gs://wenxindong-tpu-prod-env-multipod-bucket/model_output/"
--------------------------------------------------------------------------------
/examples/singlehost/quick_start.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | """Quick Start Example
18 |
19 | This script demonstrates how to run LoRA SFT on a toy dataset
20 |
21 | 1. Load HuggingFace Gemma2 checkpoint
22 | 2. Configure data loading and preprocessing
23 | 3. Run training across TPU/GPU devices
24 |
25 | This script can be run on both single-host and multi-host.
26 | For mulit-host set up, please follow https://kithara.readthedocs.io/en/latest/scaling_with_ray.html.
27 |
28 | Singlehost: python examples/singlehost/quick_start.py
29 | Multihost: python ray/submit_job.py "python3.11 examples/multihost/ray/TPU/quick_start.py" --hf-token your_token
30 | """
31 |
32 | import os
33 |
34 | os.environ["KERAS_BACKEND"] = "jax"
35 | import keras
36 | import ray
37 | from kithara import (
38 | KerasHubModel,
39 | Dataloader,
40 | Trainer,
41 | SFTDataset,
42 | )
43 |
44 | config = {
45 | "model_handle": "hf://google/gemma-2-2b",
46 | "tokenizer_handle": "hf://google/gemma-2-2b",
47 | "seq_len": 4096,
48 | "lora_rank": 16,
49 | "precision": "mixed_bfloat16",
50 | "training_steps": 60,
51 | "eval_steps_interval": 10,
52 | "log_steps_interval": 1,
53 | "per_device_batch_size": 1,
54 | "max_eval_samples": 50,
55 | "learning_rate": 2e-4,
56 | }
57 |
58 |
59 | def run_workload():
60 |
61 | # Create a toy dataset
62 | dataset_items = [
63 | {
64 | "prompt": "What is your name?",
65 | "answer": "My name is Kithara",
66 | }
67 | for _ in range(1000)
68 | ]
69 | dataset = ray.data.from_items(dataset_items)
70 | train_source, eval_source = dataset.train_test_split(test_size=500)
71 |
72 | # Create model
73 | model = KerasHubModel.from_preset(
74 | config["model_handle"],
75 | precision=config["precision"],
76 | lora_rank=config["lora_rank"],
77 | )
78 |
79 | # Creates datasets
80 | train_dataset = SFTDataset(
81 | train_source,
82 | tokenizer_handle=config["tokenizer_handle"],
83 | max_seq_len=config["seq_len"],
84 | )
85 | eval_dataset = SFTDataset(
86 | eval_source,
87 | tokenizer_handle=config["tokenizer_handle"],
88 | max_seq_len=config["seq_len"],
89 | )
90 |
91 | # Create optimizer
92 | optimizer = keras.optimizers.AdamW(
93 | learning_rate=config["learning_rate"], weight_decay=0.01
94 | )
95 |
96 | # Create data loaders
97 | train_dataloader = Dataloader(
98 | train_dataset,
99 | per_device_batch_size=config["per_device_batch_size"],
100 | )
101 | eval_dataloader = Dataloader(
102 | eval_dataset,
103 | per_device_batch_size=config["per_device_batch_size"],
104 | )
105 |
106 | # Initialize trainer
107 | trainer = Trainer(
108 | model=model,
109 | optimizer=optimizer,
110 | train_dataloader=train_dataloader,
111 | eval_dataloader=eval_dataloader,
112 | steps=config["training_steps"],
113 | eval_steps_interval=config["eval_steps_interval"],
114 | max_eval_samples=config["max_eval_samples"],
115 | log_steps_interval=config["log_steps_interval"],
116 | )
117 |
118 | # Start training
119 | trainer.train()
120 |
121 | print("Finished training. Prompting model...")
122 |
123 | # Test after tuning
124 | pred = model.generate(
125 | "What is your name?",
126 | max_length=30,
127 | tokenizer_handle=config["tokenizer_handle"],
128 | return_decoded=True,
129 | )
130 | print("Tuned model generates:", pred)
131 |
132 |
133 | if __name__ == "__main__":
134 |
135 | run_workload()
136 |
--------------------------------------------------------------------------------
/kithara/callbacks/__init__.py:
--------------------------------------------------------------------------------
1 | from kithara.callbacks.profiler import *
2 | from kithara.callbacks.checkpointer import *
--------------------------------------------------------------------------------
/kithara/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/kithara/config/__init__.py
--------------------------------------------------------------------------------
/kithara/config/default.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Model parameters
16 | model_handle: "hf://google/gemma-2-2b"
17 | tokenizer_handle: "hf://google/gemma-2-2b"
18 | seq_len: 4096
19 | precision: "mixed_bfloat16"
20 |
21 | # Training parameters
22 | training_steps: 200
23 | eval_steps_interval: 100
24 | log_steps_interval: 1
25 | per_device_batch_size: 1
26 | max_eval_samples: 50
27 | learning_rate: 5e-05
28 |
29 | # Output parameters
30 | model_output_dir: "gs://bucket_name/ckpt/"
31 |
--------------------------------------------------------------------------------
/kithara/config/pyconfig.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | import sys
16 | from typing import Any
17 | import yaml
18 |
19 |
20 | DEFAULT_CONFIG = "kithara/config/default.yaml"
21 |
22 |
23 | def load_config(argv) -> dict[str, Any]:
24 | """Loads the YAML config from a file with a given name."""
25 | config_name = DEFAULT_CONFIG
26 | if len(argv) > 1:
27 | config_name = argv[1]
28 |
29 | with open(config_name, "r", encoding="utf-8") as yaml_file:
30 | config = yaml.safe_load(yaml_file)
31 |
32 | # Some float parameters like learning_rate could be parsed
33 | # as strings instead of floats, so we need to explicitly
34 | # convert them.
35 | if 'learning_rate' in config:
36 | config['learning_rate'] = float(config['learning_rate'])
37 |
38 |
39 | print(f"Parsed config: {config}")
40 | return config
41 |
--------------------------------------------------------------------------------
/kithara/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from kithara.dataset.dataloader import Dataloader
2 | from kithara.dataset.sft import SFTDataset
3 | from kithara.dataset.text_completion import TextCompletionDataset
4 |
--------------------------------------------------------------------------------
/kithara/dataset/dataloader.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from typing import Iterator, Any, Iterable, Dict
18 | import jax
19 | from kithara.dataset.dataset import Dataset
20 | from jax.tree_util import tree_map
21 | import numpy as np
22 |
23 | class Dataloader(Iterable):
24 | """Kithara Dataloader class. This dataloader class supports distributed
25 | data loading for multi-host training.
26 |
27 | Attributes:
28 | dataset: A Kithara Dataset instance, or any Iterable instance
29 | that that returns individual, tokenized samples formatted
30 | as expected by the model.
31 | per_device_batch_size: Number of samples per batch per device.
32 | If you experience HBM OOM errors, try reducing this value.
33 | dataset_is_sharded_per_host: True if dataset is already sharded
34 | and each host is provided with a local dataset shard. False if
35 | every host is loading from the same dataset.
36 | """
37 |
38 | def __init__(
39 | self,
40 | dataset: Dataset,
41 | per_device_batch_size: int,
42 | dataset_is_sharded_per_host: bool = False,
43 | ):
44 | self.dataset = dataset
45 | self.per_device_batch_size = per_device_batch_size
46 | self.dataset_is_sharded_per_host = dataset_is_sharded_per_host
47 | self._iterator = None
48 | self.num_hosts = jax.process_count()
49 | self.host_id = jax.process_index()
50 |
51 | def __iter__(self) -> Iterator[Any]:
52 | """Return iterator over batches in the dataset.
53 |
54 | Yields:
55 | Per-host batch input
56 | """
57 | self._iterator = iter(self.dataset)
58 | return self
59 |
60 | def __next__(self) -> Dict[str, str]:
61 | """Get next batch of data from the dataset.
62 |
63 | Returns:
64 | Per-host batch input
65 | """
66 | if self.dataset_is_sharded_per_host:
67 | # For sharded datasets, each host processes per_host_batch_size samples
68 | samples = [next(self._iterator) for _ in range(self.per_host_batch_size)]
69 | else:
70 | # For non-sharded datasets:
71 | # 1. Each host loads global_batch_size samples
72 | # 2. Only keeps samples corresponding to its host_id
73 | samples = []
74 | for i in range(self.global_batch_size):
75 | sample = next(self._iterator)
76 | if i % self.num_hosts == self.host_id:
77 | samples.append(sample)
78 |
79 | samples = tree_map(
80 | lambda *arrays: np.concatenate(arrays), *samples
81 | )
82 | return samples
83 |
84 | def __len__(self) -> int:
85 | """Get the total number of batches in the dataset.
86 |
87 | Returns:
88 | Total number of batches that will be yielded by this dataloader.
89 | For non-sharded datasets, this is the same across all hosts.
90 | For sharded datasets, each host has its own length.
91 | """
92 | if not hasattr(self.dataset, '__len__'):
93 | raise TypeError("Dataset must implement __len__ to support length calculation")
94 |
95 | total_samples = len(self.dataset)
96 | if self.dataset_is_sharded_per_host:
97 | batches = total_samples // self.per_host_batch_size
98 | else:
99 | batches = total_samples // (self.num_hosts * self.per_host_batch_size)
100 | return batches
101 |
102 | @property
103 | def global_batch_size(self):
104 | num_devices = jax.device_count()
105 | return self.per_device_batch_size * num_devices
106 | @property
107 | def per_host_batch_size(self):
108 | num_devices_per_host = jax.local_device_count()
109 | return self.per_device_batch_size * num_devices_per_host
110 |
111 |
--------------------------------------------------------------------------------
/kithara/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import ray
18 | from typing import Any, Iterator, Iterable
19 | from abc import ABC
20 | from datasets import Dataset as HF_Dataset, IterableDataset as HF_IterableDataset
21 |
22 |
23 | class Dataset(Iterable, ABC):
24 | """A base dataset class that serves as a base for other dataset
25 | implementations, including SFTDataset and TextCompletionDataset.
26 | It supports both regular and streaming Ray datasets.
27 |
28 | Args:
29 | source (ray.data.Dataset): The underlying Ray dataset to wrap.
30 |
31 | """
32 |
33 | def __init__(self, source: ray.data.Dataset):
34 | self.source = self._maybe_convert_to_ray_dataset(source)
35 | self._length = None
36 | self._iterator = None
37 |
38 | def _maybe_convert_to_ray_dataset(self, source):
39 | # HuggingFace datasets
40 | if isinstance(source, (HF_Dataset, HF_IterableDataset)):
41 | return ray.data.from_huggingface(source)
42 | # TODO: Add adapters for other dataset formats
43 | return source
44 |
45 | def process_sample(self, sample):
46 | return sample
47 |
48 | def __len__(self) -> int:
49 | if self._length is None:
50 | try:
51 | print(
52 | (
53 | "Warning: If your dataset is a streaming dataset, "
54 | "this operation (__len__) might trigger its lazy executation."
55 | )
56 | )
57 | self._length = self.source.count()
58 | except:
59 | # For streamed datasets where count() might not work
60 | print(
61 | (
62 | "Unable to get __len__ of dataset. Dataset may be a streaming "
63 | "dataset. Setting length to be 0."
64 | )
65 | )
66 | self._length = 0
67 | return self._length
68 |
69 | def __next__(self) -> Any:
70 | """Returns the next element in the dataset."""
71 | if self._iterator is None:
72 | self._iterator = iter(self.source.iter_rows())
73 | sample = next(self._iterator)
74 | return self.process_sample(sample)
75 |
76 | def __iter__(self) -> Iterator[Any]:
77 | """Return an iterator over the dataset."""
78 | self._iterator = iter(self.source.iter_rows())
79 | return self
80 |
--------------------------------------------------------------------------------
/kithara/dataset/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from typing import List, Dict
18 | from transformers import AutoTokenizer as HFTokenizer
19 | import numpy as np
20 | from functools import lru_cache
21 |
22 | @lru_cache(maxsize=5)
23 | def initialize_tokenizer(tokenizer_handle, tokenizer=None):
24 | """
25 | Creates an HuggingFace AutoTokenizer with the tokenizer_handle.
26 | Try to specify a pad_token for the tokenizer if not specified.
27 | """
28 | if tokenizer is None:
29 | if tokenizer_handle.startswith("hf://"):
30 | tokenizer_handle = tokenizer_handle.removeprefix("hf://")
31 | try:
32 | tokenizer = HFTokenizer.from_pretrained(tokenizer_handle)
33 | except ValueError as e:
34 | print(f"Tokenizer handle {tokenizer_handle} is not a valid HuggingFace tokenizer handle.")
35 |
36 | # Llama tokenizers don't have a default pad_token, we must add it here
37 | if tokenizer.pad_token is None:
38 | # Gemma 2
39 | if tokenizer.get_vocab().get("") is not None:
40 | tokenizer.add_special_tokens({'pad_token': ''})
41 | # Llama 3 uses "<|finetune_right_pad_id|>"
42 | elif tokenizer.get_vocab().get("<|finetune_right_pad_id|>") is not None:
43 | tokenizer.add_special_tokens({'pad_token': '<|finetune_right_pad_id|>'})
44 | # Llama 2 doesn't have a padding token, fall back to ""
45 | elif tokenizer.get_vocab().get("") is not None:
46 | tokenizer.add_special_tokens({'pad_token': ''})
47 | else:
48 | print("WARNING: Tokenizer doesn't have the attribute pad_token")
49 | return tokenizer
50 |
51 |
52 | def HFtokenize(
53 | text: List[str],
54 | tokenizer: HFTokenizer,
55 | seq_len: int,
56 | padding: str = "max_length",
57 | ) -> Dict[str, np.ndarray]:
58 | """Tokenizes text using a HuggingFace tokenizer with specific parameters.
59 |
60 | This function handles the tokenization of input text, applying padding and
61 | truncation as needed. It converts text into a format suitable for model input.
62 |
63 | Args:
64 | text (Union[str, List[str]]): Input text or list of texts to tokenize
65 | tokenizer (HFTokenizer): HuggingFace tokenizer instance
66 | seq_len (int): Maximum sequence length for tokenization
67 | padding (str, optional): Padding strategy. Defaults to "max_length".
68 | Options include:
69 | - "max_length": Pad to seq_len
70 | - "longest": Pad to longest sequence in batch
71 | - "do_not_pad": No padding
72 |
73 | Returns:
74 | Dict[str, np.ndarray]: Dictionary containing:
75 | - 'input_ids': Token IDs (shape: [batch_size, seq_len])
76 | - 'attention_mask': Mask indicating real vs padded tokens
77 | (shape: [batch_size, seq_len])
78 |
79 | Example:
80 | >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
81 | >>> result = HFtokenize(["hello world"], tokenizer, seq_len=10)
82 | >>> print(result['input_ids'].shape)
83 | (1, 10)
84 | """
85 |
86 | return tokenizer(
87 | text,
88 | max_length=seq_len,
89 | padding=padding,
90 | padding_side="right",
91 | truncation=True,
92 | add_special_tokens=False,
93 | return_tensors="np",
94 | )
95 |
--------------------------------------------------------------------------------
/kithara/distributed/__init__.py:
--------------------------------------------------------------------------------
1 | from kithara.distributed.sharding import *
2 | from kithara.distributed.data.split import split_dataset
--------------------------------------------------------------------------------
/kithara/distributed/data/__init__.py:
--------------------------------------------------------------------------------
1 | from kithara.distributed.data.split import split_dataset
--------------------------------------------------------------------------------
/kithara/distributed/data/split.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import ray
18 |
19 | def split_dataset(ds: ray.data.Dataset, num_hosts: int):
20 | if ds is None:
21 | return [None for _ in range(num_hosts)]
22 | return ds.streaming_split(num_hosts, equal=True)
23 |
--------------------------------------------------------------------------------
/kithara/distributed/sharding/__init__.py:
--------------------------------------------------------------------------------
1 | from kithara.distributed.sharding.strategy import ShardingStrategy, PredefinedShardingStrategy
2 |
--------------------------------------------------------------------------------
/kithara/distributed/sharding/_data_sharding.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from kithara.distributed.sharding._mesh import Axis
18 | from jax.sharding import NamedSharding, Mesh
19 | import jax
20 | from jax.sharding import PartitionSpec as P
21 | from dataclasses import dataclass
22 | from typing import ClassVar
23 |
24 |
25 | @dataclass
26 | class DataSharding:
27 |
28 | # Class-level dictionary to store mesh types
29 | _mesh_types: ClassVar[dict] = {
30 | "tp": lambda: DataSharding.tp(),
31 | "fsdp": lambda: DataSharding.fsdp(),
32 | "fully_replicated": lambda: DataSharding.fully_replicated()
33 | }
34 |
35 | def __class_getitem__(cls, key: str):
36 | if key not in cls._mesh_types:
37 | raise KeyError(f"Unknown mesh type: {key}")
38 | return cls._mesh_types[key]()
39 |
40 | @classmethod
41 | def fsdp(cls):
42 | return NamedSharding(
43 | Mesh(jax.devices(), (Axis.FSDP,)),
44 | P(Axis.FSDP, None),
45 | )
46 |
47 | @classmethod
48 | def tp(cls):
49 | return NamedSharding(
50 | Mesh(jax.devices(), (Axis.TP,)),
51 | P(None, Axis.TP),
52 | )
53 |
54 | @classmethod
55 | def fully_replicated(cls):
56 | return NamedSharding(
57 | Mesh(jax.devices(), ("devices",)),
58 | P(None),
59 | )
60 |
--------------------------------------------------------------------------------
/kithara/distributed/sharding/_layout.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from kithara.distributed.sharding.models import GEMMA_LAYOUT, LLAMA_LAYOUT
18 | from dataclasses import dataclass
19 | from typing import ClassVar
20 |
21 |
22 | # Layout configurations for different model architectures
23 | @dataclass
24 | class Layout:
25 |
26 | _mesh_types: ClassVar[dict] = {}
27 |
28 | @classmethod
29 | def initialize_mesh_types(cls):
30 | # Avoid circular import
31 | from kithara.model import supported_models
32 | cls._mesh_types = {
33 | supported_models.GEMMA2_2B: lambda: Layout.gemma(),
34 | supported_models.GEMMA2_9B: lambda: Layout.gemma(),
35 | supported_models.GEMMA2_27B: lambda: Layout.gemma(),
36 | supported_models.LLAMA31_8B: lambda: Layout.llama(),
37 | supported_models.LLAMA31_70B: lambda: Layout.llama(),
38 | supported_models.LLAMA31_405B: lambda: Layout.llama(),
39 | supported_models.LLAMA32_1B: lambda: Layout.llama(),
40 | supported_models.LLAMA32_3B: lambda: Layout.llama(),
41 | }
42 |
43 | def __class_getitem__(cls, key: str):
44 | if not cls._mesh_types:
45 | cls.initialize_mesh_types()
46 | if key not in cls._mesh_types:
47 | raise KeyError(f"Unknown mesh type: {key}")
48 | return cls._mesh_types[key]()
49 |
50 | @classmethod
51 | def gemma(cls):
52 | return GEMMA_LAYOUT
53 |
54 | @classmethod
55 | def llama(cls):
56 | return LLAMA_LAYOUT
57 |
--------------------------------------------------------------------------------
/kithara/distributed/sharding/_mesh.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from keras.distribution import DeviceMesh
18 | import jax
19 | from enum import Enum
20 | from dataclasses import dataclass
21 | from typing import ClassVar
22 |
23 |
24 | class Axis(str, Enum):
25 | """Enumeration of supported sharding axes in predefined shardings.
26 |
27 | Attributes:
28 | FSDP: Fully Sharded Data Parallel axis
29 | TP: Tensor Parallel axis
30 | """
31 |
32 | FSDP = "fsdp"
33 | TP = "tp"
34 |
35 |
36 | @dataclass
37 | class Mesh:
38 | # Class-level dictionary to store mesh types
39 | _mesh_types: ClassVar[dict] = {"tp": lambda: Mesh.tp(), "fsdp": lambda: Mesh.fsdp()}
40 |
41 | def __class_getitem__(cls, key: str):
42 | if key not in cls._mesh_types:
43 | raise KeyError(f"Unknown mesh type: {key}")
44 | return cls._mesh_types[key]()
45 |
46 | @classmethod
47 | def fsdp(cls):
48 | return DeviceMesh(
49 | shape=(len(jax.devices()),),
50 | axis_names=(Axis.FSDP,),
51 | devices=jax.devices(),
52 | )
53 |
54 | @classmethod
55 | def tp(cls):
56 | return DeviceMesh(
57 | shape=(len(jax.devices()),),
58 | axis_names=(Axis.TP,),
59 | devices=jax.devices(),
60 | )
61 |
--------------------------------------------------------------------------------
/kithara/distributed/sharding/maxtext.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from kithara.distributed.sharding.strategy import ShardingStrategy
18 | from dataclasses import dataclass
19 | from keras.distribution import DeviceMesh, LayoutMap
20 | from keras.src.distribution.distribution_lib import Distribution
21 | from jax.sharding import Sharding, PartitionSpec as P
22 | from typing import Any
23 | from jax.tree_util import tree_flatten_with_path
24 | from keras.src.distribution.distribution_lib import TensorLayout
25 | import keras
26 | from kithara.distributed.sharding.utils import convert_jax_mesh_to_keras_mesh
27 | from jax.sharding import Mesh
28 | import jax
29 |
30 | # Hacky monkey patching for now. Keras validate_axes currently does not accept nested tuples as the ParitionSpec value.
31 | # TODO: Patch Keras and remove this logic
32 | TensorLayout._validate_axes = lambda x: x
33 |
34 |
35 | @dataclass
36 | class MaxTextSharding(ShardingStrategy):
37 | """Sharding strategy from MaxText config.
38 | # TODO: Add docstring
39 | """
40 |
41 | jax_mesh: Mesh
42 | state_shardings: Any
43 | maxtext_config: 'pyconfig.Hyperparameter'
44 |
45 | def __post_init__(self) -> None:
46 | # """Initialize the sharding strategy configuration."""
47 |
48 | self._jax_mesh = self.jax_mesh
49 | self.jax_devices = self.jax_mesh.devices
50 | self.maxtext_config = self.maxtext_config
51 |
52 | self._mesh = self._configure_mesh(self.jax_mesh)
53 | self._layout_map = self._configure_layout_map(self.state_shardings)
54 | self._data_sharding = self._configure_data_sharding()
55 | self._distribution = self._configure_distribution()
56 |
57 | @property
58 | def mesh(self) -> DeviceMesh:
59 | return self._mesh
60 |
61 | @property
62 | def layout_map(self) -> LayoutMap:
63 | return self._layout_map
64 |
65 | @property
66 | def data_sharding(self) -> Sharding:
67 | return self._data_sharding
68 |
69 | @property
70 | def distribution(self) -> Sharding:
71 | return self._distribution
72 |
73 | def _configure_mesh(self, jax_mesh) -> DeviceMesh:
74 | return convert_jax_mesh_to_keras_mesh(jax_mesh)
75 |
76 | def _configure_layout_map(self, state_shardings) -> LayoutMap:
77 | layout_map = LayoutMap(self._mesh)
78 |
79 | # Maps regex string to tuple
80 | # E.g. .*params.decoder.decoder.norm.scale.* -> ('tensor',)
81 | var_path_and_sharding, * \
82 | _ = tree_flatten_with_path(state_shardings.params)
83 | mappings = {
84 | ".*"
85 | + ".".join(str(k.key) for k in var_path)
86 | + ".*": tuple(var_sharding.spec)
87 | for var_path, var_sharding in var_path_and_sharding
88 | }
89 |
90 | for pattern, layout in mappings.items():
91 | layout_map[pattern] = layout
92 |
93 | return layout_map
94 |
95 | def _configure_data_sharding(self) -> Sharding:
96 | data_pspec = P(*self.maxtext_config.data_sharding)
97 | data_sharding = jax.tree_util.tree_map(
98 | lambda p: jax.sharding.NamedSharding(self._jax_mesh, p), data_pspec)
99 | return data_sharding
100 |
101 | def _configure_distribution(self) -> Distribution:
102 | return keras.distribution.ModelParallel(layout_map=self._layout_map)
103 |
--------------------------------------------------------------------------------
/kithara/distributed/sharding/models/__init__.py:
--------------------------------------------------------------------------------
1 | from kithara.distributed.sharding.models.gemma import GEMMA_LAYOUT
2 | from kithara.distributed.sharding.models.llama import LLAMA_LAYOUT
3 |
--------------------------------------------------------------------------------
/kithara/distributed/sharding/models/gemma.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from kithara.distributed.sharding._mesh import Axis
18 |
19 | GEMMA_FSDP = {
20 | ".*token_embedding.embeddings.*": (None, Axis.FSDP),
21 | ".*decoder_block.*attention.*(query|key|value).kernel.*": (
22 | None,
23 | Axis.FSDP,
24 | ),
25 | ".*decoder_block.*attention_output.kernel.*": (None, None, Axis.FSDP),
26 | ".*decoder_block.*ffw_gating.kernel.*": (None, Axis.FSDP),
27 | ".*decoder_block.*ffw_gating_2.kernel.*": (None, Axis.FSDP),
28 | ".*decoder_block.*ffw_linear.kernel.*": (Axis.FSDP, None),
29 | # Lora layers
30 | ".*decoder_block.*attention.*(query|key|value).lora_kernel.*": (
31 | None,
32 | Axis.FSDP,
33 | ),
34 | }
35 |
36 | GEMMA_LAYOUT = {
37 | "fsdp": GEMMA_FSDP,
38 | }
39 |
--------------------------------------------------------------------------------
/kithara/distributed/sharding/models/llama.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from kithara.distributed.sharding._mesh import Axis
18 |
19 | LLAMA_FSDP = {
20 | ".*token_embedding.embeddings.*": (None, Axis.FSDP),
21 | ".*token_embedding.reverse_embeddings.*": (Axis.FSDP, None),
22 | ".*transformer_layer.*self_attention.*(query|key|value).kernel*": (
23 | None,
24 | Axis.FSDP,
25 | ),
26 | ".*transformer_layer.*self_attention.attention_output.kernel*": (
27 | None,
28 | None,
29 | Axis.FSDP,
30 | ),
31 | ".*transformer_layer.*feedforward_gate_dense.kernel*": (None, Axis.FSDP),
32 | ".*transformer_layer.*feedforward_intermediate_dense.kernel*": (None, Axis.FSDP),
33 | ".*transformer_layer.*feedforward_output_dense.kernel*": (None, Axis.FSDP),
34 | # Lora layers
35 | ".*transformer_layer.*self_attention.*(query|key|value).lora_kernel.*": (
36 | None,
37 | Axis.FSDP,
38 | ),
39 | }
40 |
41 | LLAMA_LAYOUT = {
42 | "fsdp": LLAMA_FSDP,
43 | }
44 |
--------------------------------------------------------------------------------
/kithara/model/README.md:
--------------------------------------------------------------------------------
1 | # Kithara Model Support Guide
2 |
3 | ## Overview
4 |
5 | Kithara is designed as a model-agnostic framework that focuses on model tuning rather than implementation. It provides integration with existing model implementations through supported providers:
6 |
7 | - MaxText
8 | - KerasHub
9 |
10 | Additional model implementation providers may be added in the future, community contributions are always welcomed.
11 |
12 | ## Supported Models
13 |
14 | The complete list of supported models is maintained in `supported_models.py`. For the most up-to-date information, please refer to this file.
15 |
16 | ## Adding Support for New Models
17 |
18 | This guide focuses on adding support for new KerasHub model architectures. The process involves several key steps:
19 |
20 | ### 1. Configure Model Sharding
21 |
22 | 1. Navigate to `distributed/sharding/models/`
23 | 2. Create or modify the FSDP (Fully Sharded Data Parallel) model sharding strategy for your model
24 | 3. Update the `PredefinedShardingStrategy` class to include support for your model
25 |
26 | ### 2. Model Implementation
27 |
28 | First, create a KerasHub model instance to examine its structure:
29 |
30 | ```python
31 | model = KerasHubModel.from_preset(
32 | your_model_handle
33 | )
34 |
35 | # Examine model weights and shapes
36 | for weight in model.weights:
37 | print(f"Path: {weight.path}, Shape: {weight.value.shape}")
38 | ```
39 |
40 | ### 3. HuggingFace Compatibility
41 |
42 | #### 3.1 Analyze HuggingFace Model Structure
43 |
44 | Create and examine the equivalent HuggingFace model:
45 |
46 | ```python
47 | from transformers import AutoModelForCausalLM
48 |
49 | model = AutoModelForCausalLM.from_pretrained(preset_handle)
50 |
51 | # Examine model state dictionary
52 | for key, value in model.state_dict().items():
53 | print(f"Key: {key}, Shape: {value.shape}")
54 | ```
55 |
56 | #### 3.2 Update Configuration Files
57 |
58 | 1. Add HuggingFace model configuration:
59 | - File: `hf_compatibility/model_configs.py`
60 | - Add the specific configuration for your model
61 |
62 | 2. Define weight shapes:
63 | - File: `hf_compatibility/shape_mapping.py`
64 | - Document the shapes of all weights in the HuggingFace model
65 |
66 | 3. Create weight mappings:
67 | - File: `model/kerashub/ckpt_compatibility/param_mapping.py`
68 | - Define the mapping between HuggingFace and Kithara weights
69 | - Implement weight translation hook functions
70 |
71 | ### 4. Testing and Validation
72 |
73 | 1. Create unit tests in `tests/model/models/kerashub/ckpt_compatibility/`
74 | 2. Follow the pattern established in `gemma2-2b.py`
75 | 3. Verify numerical correctness of checkpoint translations
76 |
77 | ### Debugging Tips
78 |
79 | When working on weight translations:
80 |
81 | 1. Fork the transformers library for debugging
82 | 2. Add logging statements to compare intermediate tensor outputs
83 | 3. Verify tensor values at each later
84 | 4. To verify MaxText model parameters and shapes
85 |
86 | ```python
87 | from kithara import MaxTextModel
88 | model = MaxTextModel.from_random(
89 | "llama3.1-8b"
90 | )
91 |
92 | # Examine model weights and shapes
93 | for weight in model.weights:
94 | print(f"Path: {weight.path}, Shape: {weight.value.shape}")
95 | ```
96 |
97 |
98 | ## Best Practices
99 |
100 | 1. Support all sizes of the supported model family
101 | 2. Document all weight mappings thoroughly and include examples of expected tensor shapes
102 |
--------------------------------------------------------------------------------
/kithara/model/__init__.py:
--------------------------------------------------------------------------------
1 | from kithara.model.model import *
2 | from kithara.model.maxtext.maxtext_model import *
3 | from kithara.model.kerashub.keras_hub_model import *
--------------------------------------------------------------------------------
/kithara/model/hf_compatibility/__init__.py:
--------------------------------------------------------------------------------
1 | from kithara.model.hf_compatibility.to_huggingface import *
2 | from kithara.model.hf_compatibility.model_configs import *
3 | from kithara.model.hf_compatibility.shape_mapping import *
4 |
5 |
--------------------------------------------------------------------------------
/kithara/model/kerashub/__init__.py:
--------------------------------------------------------------------------------
1 | from kithara.model.kerashub.keras_hub_model import *
--------------------------------------------------------------------------------
/kithara/model/maxtext/__init__.py:
--------------------------------------------------------------------------------
1 | from kithara.model.maxtext.maxtext_model import *
--------------------------------------------------------------------------------
/kithara/model/maxtext/ckpt_compatibility/__init__.py:
--------------------------------------------------------------------------------
1 | from kithara.model.maxtext.ckpt_compatibility.from_huggingface import *
2 | from kithara.model.maxtext.ckpt_compatibility.to_huggingface import *
3 | from kithara.model.maxtext.ckpt_compatibility.utils import *
4 | from kithara.model.maxtext.ckpt_compatibility.param_mapping import *
--------------------------------------------------------------------------------
/kithara/model/maxtext/ckpt_compatibility/to_huggingface.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | """
18 | Module to convert MaxText model weights to HuggingFace format.
19 | """
20 |
21 | import torch
22 | import contextlib
23 | import time
24 | from kithara.model.hf_compatibility import (
25 | MODEL_CONFIGS,
26 | )
27 | from kithara.model.hf_compatibility import process_weight, save_model_files
28 | from kithara.model.maxtext.ckpt_compatibility.param_mapping import (
29 | HOOK_FNS,
30 | PARAM_MAPPING,
31 | )
32 | from kithara.model.hf_compatibility import SHAPE_MAPPING
33 | import jax
34 |
35 | def _get_model_mappings(model_name: str, scan_layers: bool, config: dict):
36 | """Retrieves parameter, shape, and hook function mappings for the model."""
37 | return {
38 | "param_mapping": PARAM_MAPPING[model_name](config.to_dict(), scan_layers),
39 | "shape_mapping": SHAPE_MAPPING[model_name](config.to_dict()),
40 | "hook_fn_mapping": HOOK_FNS[model_name](
41 | config.to_dict(), scan_layers, saving_to_hf=True
42 | ),
43 | }
44 |
45 |
46 | def _save_checkpoint(
47 | maxtext_model: "kithara.MaxTextModel", output_dir: str, parallel_threads=8
48 | ):
49 | """Main function to save a MaxText model checkpoint in HuggingFace format."""
50 | if maxtext_model.model_name not in MODEL_CONFIGS:
51 | raise ValueError(
52 | f"Model {maxtext_model.model_name} is not supported. "
53 | f"Supported models are {list(MODEL_CONFIGS.keys())}"
54 | )
55 |
56 | config = MODEL_CONFIGS[maxtext_model.model_name]
57 | mappings = _get_model_mappings(
58 | maxtext_model.model_name, maxtext_model.scan_layers, config
59 | )
60 |
61 | # Process weights
62 | start_time = time.time()
63 | weight_arrays = {}
64 | for variable in maxtext_model.weights:
65 | weight_dict = process_weight(variable, mappings)
66 | weight_arrays.update(weight_dict)
67 | print(
68 | f"\n✅ Weights converted into HuggingFace format in {time.time() - start_time}s"
69 | )
70 |
71 | # Save all model files
72 | save_model_files(weight_arrays, config, output_dir, parallel_threads)
73 |
74 |
75 | def save_maxtext_model_in_hf_format(
76 | model: "MaxTextModel", output_dir: str, dtype: str = "auto", parallel_threads=8
77 | ):
78 | """Convert and save a MaxText model in HuggingFace format.
79 |
80 | Args:
81 | model: MaxTextModel instance to save
82 | output_dir: Directory to save the HuggingFace checkpoint
83 | dtype: dtype for the converted model ("auto" uses source model's dtype)
84 | """
85 |
86 | @contextlib.contextmanager
87 | def _set_default_tensor_type(dtype: torch.dtype):
88 | """Sets the default torch dtype to the given dtype."""
89 | torch.set_default_dtype(dtype)
90 | yield
91 | torch.set_default_dtype(torch.float)
92 |
93 | if dtype == "auto":
94 | dtype = model.weight_dtype
95 |
96 | print(f"-> Saving model with {dtype=}...")
97 | with _set_default_tensor_type(getattr(torch, dtype)):
98 | _save_checkpoint(model, output_dir, parallel_threads)
99 |
100 | jax.experimental.multihost_utils.sync_global_devices("saving_completed")
--------------------------------------------------------------------------------
/kithara/model/maxtext/ckpt_compatibility/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
18 | from transformers import AutoModelForCausalLM
19 |
20 | def get_hf_safetensor_weight_keys(preset_handle):
21 | with SafetensorLoader(preset_handle) as loader:
22 | weight_map_keys = loader.safetensor_config["weight_map"].keys()
23 | return list(weight_map_keys)
24 |
25 | def get_hf_model_weight_shapes(preset_handle):
26 | from transformers import AutoModelForCausalLM
27 | weight_keys = get_hf_safetensor_weight_keys(preset_handle)
28 | model = AutoModelForCausalLM.from_pretrained(preset_handle)
29 |
30 | weight_to_shape = {}
31 | for path in weight_keys:
32 | path = path.removesuffix(".weight")
33 | hf_module = model.get_submodule(path)
34 | target_shape = hf_module.state_dict()["weight"].shape
35 | weight_to_shape[path] = target_shape
36 | return weight_to_shape
37 |
--------------------------------------------------------------------------------
/kithara/model/maxtext/inference_engine.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from kithara.model.maxtext.maxtext.MaxText.maxengine import MaxEngine
16 | import jax
17 | import max_utils
18 |
19 | class MaxtextInferenceEngine(MaxEngine):
20 | """This is a patched version of MaxEngine
21 |
22 | Changes:
23 | - Added the `load_existing_params` function to allow
24 | the engine to run with an exisitng model instead of initiating
25 | a new one.
26 | """
27 | def load_existing_params(self, params, rng=None) -> "Params":
28 |
29 | if rng is None:
30 | rng = jax.random.PRNGKey(0)
31 |
32 | self.abstract_params = jax.tree_util.tree_map(
33 | lambda x: (
34 | jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)
35 | if isinstance(x, jax.Array)
36 | else None
37 | ),
38 | params,
39 | )
40 |
41 | self.prefill_kv_cache_annotations = max_utils.get_prefill_kv_cache_annotations(
42 | self.model, self.config, rng, self._mesh
43 | )
44 | self.prefill_kv_cache_shardings = jax.tree_util.tree_map(
45 | lambda x: jax.sharding.NamedSharding(self._mesh, x),
46 | self.prefill_kv_cache_annotations,
47 | )
48 |
49 | if self.config.stack_prefill_result_cache:
50 | # Add extra axis for the axis generated by the stack.
51 | self.prefill_kv_cache_shardings = jax.tree_util.tree_map(
52 | lambda x: jax.sharding.NamedSharding(
53 | self._mesh, jax.sharding.PartitionSpec(None, *x.spec)
54 | ),
55 | self.prefill_kv_cache_shardings,
56 | )
57 | self.prefill_kv_cache_shardings = self.prefill_kv_cache_shardings[
58 | "decoder"
59 | ]["layers_0"]
60 |
61 | self.kv_cache_annotations = max_utils.get_kv_cache_annotations(
62 | self.model, self.config, rng, self._mesh
63 | )
64 | self.kv_cache_shardings = jax.tree_util.tree_map(
65 | lambda x: jax.sharding.NamedSharding(self._mesh, x),
66 | self.kv_cache_annotations,
67 | )
68 | return params
69 |
--------------------------------------------------------------------------------
/kithara/model/supported_models.py:
--------------------------------------------------------------------------------
1 | GEMMA2_2B = "gemma2-2b"
2 | GEMMA2_9B = "gemma2-9b"
3 | GEMMA2_27B = "gemma2-27b"
4 | LLAMA31_8B = "llama3.1-8b"
5 | LLAMA31_70B = "llama3.1-70b"
6 | LLAMA31_405B = "llama3.1-405b"
7 | LLAMA32_1B = "llama3.2-1b"
8 | LLAMA32_3B = "llama3.2-3b"
9 |
--------------------------------------------------------------------------------
/kithara/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | from jaxtyping import PyTree
2 | from kithara.optimizers.protocol import KitharaOptimizer
3 | from kithara.optimizers.optax_optimizer import OptaxOptimizerInterface
4 | import optax
5 |
6 | def convert_to_kithara_optimizer(
7 | optimizer: optax.GradientTransformation, trainable_variables: PyTree) -> KitharaOptimizer:
8 | if isinstance(optimizer, optax.GradientTransformation):
9 | return OptaxOptimizerInterface(optimizer, trainable_variables)
10 | raise NotImplementedError(f"Unsupported optimizer type: {type(optimizer)}")
11 |
--------------------------------------------------------------------------------
/kithara/optimizers/optax_optimizer.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import jax
4 | from jaxtyping import PyTree
5 | from keras.src import backend
6 | import optax
7 |
8 | from kithara.optimizers.protocol import KitharaOptimizer, TrainableVariable, OptimizerVariable
9 |
10 |
11 | class OptaxOptimizerInterface(KitharaOptimizer):
12 | def __init__(self, optimizer: optax.GradientTransformation, params: PyTree[backend.Variable]):
13 | self.optimizer, state = optimizer, optimizer.init(params)
14 | self.state_in_keras_variable = self._convert_state_to_keras_variables(state, params)
15 |
16 | with backend.name_scope("OptaxOptimizerInterface"):
17 | iterations = backend.Variable(0, name="iterations")
18 |
19 | self._iterations = iterations
20 |
21 | @property
22 | def learning_rate(self) -> None:
23 | return
24 |
25 | @property
26 | def variables(self) -> PyTree[backend.Variable]:
27 | return self.state_in_keras_variable
28 |
29 | @property
30 | def iterations(self) -> backend.Variable:
31 | return self._iterations
32 |
33 | def stateless_apply(self, optimizer_variable: OptimizerVariable, gradients: PyTree,
34 | trainable_variables: TrainableVariable
35 | ) -> Tuple[TrainableVariable, OptimizerVariable]:
36 |
37 | updates, opt_state = self.optimizer.update(
38 | gradients, optimizer_variable, trainable_variables)
39 | new_trainable_variables = optax.apply_updates(trainable_variables, updates)
40 |
41 | return new_trainable_variables, opt_state
42 |
43 | def get_optimizer_memory_usage(self):
44 | return sum([leaf.nbytes for leaf in jax.tree_util.tree_leaves(self.variables)])
45 |
46 | def _convert_state_to_keras_variables(
47 | self, opt_state: PyTree, trainable_vars: TrainableVariable) -> PyTree[backend.Variable]:
48 | """Convert optax State to a `PyTree` of `keras.Variable`
49 |
50 | Creates a structure (PyTree) of `keras.Variable` objects that mirrors
51 | the structure of the optimizer state (`opt_state`).
52 |
53 | For each subtree (e.g., momentum, variance) within the `opt_state` PyTree:
54 | - Corresponding `keras.Variable` objects are created, one for each
55 | associated trainable variable in the `trainable_vars` PyTree.
56 | - Each created `keras.Variable` is named based on the path of its associated node in `trainable_vars`:
57 | 1. Take the trainable variable's path.
58 | 2. Replace all forward slashes (`/`) with underscores (`_`).
59 | 3. Append the suffix `_optstate`.
60 |
61 | For Example:
62 | - If a trainable variable's path in `trainable_vars` is
63 | `"decoder-block-0/layer-1/kernel"`.
64 | - The corresponding `keras.Variable` holding its optimizer state component
65 | will be named `"decoder-block-0_layer-1_kernel_optstate"`.
66 |
67 | Args:
68 | opt_state (PyTree): the optax optimizer state
69 | trainable_vars (PyTree[backend.Variable]): trainable variables
70 |
71 | Returns:
72 | PyTree[backend.Variable]: the transformed optax state.
73 | """
74 |
75 | _, trainable_tree_def = jax.tree.flatten(trainable_vars)
76 |
77 | def is_same_shape_with_trainable_subtree(opt_state_subtree: PyTree):
78 | _, opt_state_tree_def = jax.tree.flatten(opt_state_subtree)
79 | return opt_state_tree_def == trainable_tree_def
80 |
81 | def create_keras_variable_pytree(
82 | subtree, trainable_vars) -> PyTree[backend.Variable]:
83 |
84 | if not is_same_shape_with_trainable_subtree(subtree):
85 | return jax.tree.map(
86 | lambda leaf: backend.Variable(leaf,), subtree
87 | )
88 |
89 | return jax.tree.map(
90 | lambda leaf, keras_var: backend.Variable(
91 | leaf, name=f"{keras_var.path.replace('/', '_')}_optstate",),
92 | subtree,
93 | trainable_vars)
94 |
95 | return jax.tree.map(
96 | lambda subtree: create_keras_variable_pytree(subtree, trainable_vars),
97 | opt_state,
98 | is_leaf=is_same_shape_with_trainable_subtree
99 | )
100 |
--------------------------------------------------------------------------------
/kithara/optimizers/protocol.py:
--------------------------------------------------------------------------------
1 | from typing import (
2 | Protocol, Tuple, TypeAlias, List,
3 | runtime_checkable)
4 |
5 | from jaxtyping import PyTree
6 | from keras.src.backend import Variable
7 |
8 | TrainableVariable: TypeAlias = List[Variable]
9 | OptimizerVariable: TypeAlias = PyTree[Variable]
10 |
11 | @runtime_checkable
12 | class KitharaOptimizer(Protocol):
13 |
14 | @property
15 | def learning_rate(self) -> float | None:
16 | ...
17 | @property
18 | def variables(self) -> OptimizerVariable:
19 | ...
20 | @variables.setter
21 | def variables(self, state: PyTree):
22 | ...
23 | @property
24 | def iterations(self) -> Variable:
25 | ...
26 | def stateless_apply(self, optimizer_variable: OptimizerVariable, gradients: PyTree,
27 | trainable_variables: TrainableVariable) -> Tuple[TrainableVariable, OptimizerVariable]:
28 | """Apply the optimizer to the trainable variables and gradients.
29 |
30 | Apply the optimizer to the trainable variables and gradients. This method should be stateless, i.e. it should not
31 | modify the internal state of the optimizer. If the optimizer has state, it should be passed as `optimizer_variable`.
32 | If `optimizer_variable` is not provided, the optimizer should use the internal state tracked by `KitharaOptimizer`,
33 | but not optimizer itself.
34 |
35 | Args:
36 | optimizer_variable (OptimizerVariable): The explicit optimizer state.
37 | gradients (PyTree): The gradients to apply to the trainable variables.
38 | The shape of the gradients should match the shape of the `trainable variables`.
39 | trainable_variables (TrainableVariable): The trainable variables to apply the optimizer to
40 |
41 | Returns:
42 | Tuple[TrainableVariable, OptimizerVariable]: The tuple of the updated trainable variables and the updated optimizer state.
43 | """
44 | def get_optimizer_memory_usage(self) -> int:
45 | """Get the memory usage of the state/variables of the optimizer in bytes.
46 |
47 | Returns:
48 | int: bytes used by the optimizer state/variables.
49 | """
50 |
--------------------------------------------------------------------------------
/kithara/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from kithara.trainer.trainer import *
--------------------------------------------------------------------------------
/kithara/utils/gcs_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from google.cloud.storage import Client, transfer_manager
18 | import os
19 | import time
20 |
21 |
22 | def upload_file_to_gcs(
23 | local_file: str, gs_bucket_path: str, remove_local_file_after_upload=False
24 | ):
25 | """Uploads a single file to Google Cloud Storage.
26 |
27 | Args:
28 | local_file: Path to local file
29 | gs_bucket_path: GCS destination (e.g. "gs://my-bucket/path/file.txt" or "my-bucket/path/file.txt")
30 | """
31 | # Standardize bucket path format
32 | gs_bucket_path = gs_bucket_path.removeprefix("gs://")
33 | bucket_name = gs_bucket_path.split("/")[0]
34 | blob_name = gs_bucket_path[len(bucket_name) :].lstrip("/")
35 |
36 | print(f"-> Uploading {local_file} to {gs_bucket_path}...")
37 | # Upload file
38 | storage_client = Client()
39 | bucket = storage_client.bucket(bucket_name)
40 | blob = bucket.blob(blob_name)
41 | blob.upload_from_filename(local_file)
42 |
43 | print(f"✅ Uploaded {local_file} to {bucket.name}/{blob_name}")
44 |
45 | if remove_local_file_after_upload:
46 | os.remove(local_file)
47 | print(f"✅ Deleted {local_file}")
48 |
49 |
50 | def upload_folder_to_gcs(local_folder: str, gs_bucket_path: str, num_workers: int = 8):
51 | """Uploads all files from a local folder to Google Cloud Storage.
52 |
53 | Args:
54 | local_folder: Path to local folder (e.g. "data/images")
55 | gs_bucket_path: GCS destination (e.g. "gs://my-bucket/images" or "my-bucket/images")
56 | num_workers: Number of parallel upload workers
57 | """
58 | start_time = time.time()
59 |
60 | # Standardize bucket path format
61 | gs_bucket_path = gs_bucket_path.removeprefix("gs://")
62 | bucket_name = gs_bucket_path.split("/")[0]
63 | # Ensure destination ends with "/"
64 | destination_dir = gs_bucket_path[len(bucket_name) :]
65 | if destination_dir.startswith("/"):
66 | destination_dir = destination_dir[1:]
67 | if destination_dir != "" and not destination_dir.endswith("/"):
68 | destination_dir += "/"
69 |
70 | # Get files to upload
71 | files_in_local_folder = os.listdir(local_folder)
72 | # Set up GCS client
73 | storage_client = Client()
74 | bucket = storage_client.bucket(bucket_name)
75 |
76 | # Upload files in parallel
77 | results = transfer_manager.upload_many_from_filenames(
78 | bucket,
79 | files_in_local_folder,
80 | source_directory=local_folder,
81 | max_workers=num_workers,
82 | blob_name_prefix=destination_dir,
83 | )
84 |
85 | # Report results
86 | for name, result in zip(files_in_local_folder, results):
87 | if isinstance(result, Exception):
88 | print(f"Failed to upload {name}: {result}")
89 | else:
90 | print(f"✅ Uploaded {name} to {bucket.name}/{destination_dir}{name}")
91 |
92 | print(f"Upload completed in {time.time() - start_time}s")
93 |
94 |
95 | def find_cache_root_dir():
96 | if "KERAS_HOME" in os.environ:
97 | cachdir = os.environ.get("KERAS_HOME")
98 | else:
99 | cachdir = os.path.expanduser(os.path.join("~", ".keras"))
100 | if not os.access(cachdir, os.W_OK):
101 | cachdir = os.path.join("/tmp", ".keras")
102 | return cachdir
103 |
--------------------------------------------------------------------------------
/kithara/utils/logging_utils.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import platform
3 |
4 |
5 | def get_device_stats():
6 | device_count = jax.device_count()
7 | device_kind = jax.devices()[0].device_kind
8 | memory_info = jax.devices()[0].memory_stats()
9 | memory_per_device = memory_info["bytes_limit"] / (1024**3) # Convert to GB
10 | total_memory = memory_per_device * device_count
11 | return {
12 | "device_count": device_count,
13 | "device_kind": device_kind,
14 | "memory_per_device_gb": round(memory_per_device, 2),
15 | "total_memory_gb": round(total_memory, 2),
16 | }
17 |
18 |
19 | def print_kithara_logo_and_platform_info():
20 | platform_system = platform.system()
21 | tpu_stats = get_device_stats()
22 | statistics = (
23 | f" '==='\n"
24 | f" |||\n"
25 | f" '- ||| -'\n"
26 | f" / ||||| \\ Kithara. Platform: {platform_system}. JAX: {jax.__version__}\n"
27 | f" | (|||) | Hardware: {tpu_stats['device_kind']}. Device count: {tpu_stats['device_count']}.\n"
28 | f" | |◕‿◕| | HBM Per Device: {tpu_stats['memory_per_device_gb']} GB. Total HBM Memory: {tpu_stats['total_memory_gb']} GB\n"
29 | f" \\ ||||| / Free Apache license: http://github.com/ai-hypercomputer/kithara\n"
30 | f" --|===|--"
31 | )
32 | print(statistics)
33 |
--------------------------------------------------------------------------------
/kithara/utils/torch_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import torch
18 | import numpy as np
19 | from jax.experimental import multihost_utils
20 | from typing import Optional
21 |
22 |
23 | def convert_jax_weight_to_torch(
24 | weight: "jax.Array", dtype: Optional[str] = None
25 | ) -> torch.Tensor:
26 | expected_dtype = str(weight.dtype) if dtype is None else dtype
27 | expected_shape = weight.shape
28 | weight = multihost_utils.process_allgather(weight)
29 | weight = np.array(weight, dtype="float32")
30 | torch_dtype = getattr(torch, expected_dtype)
31 | torch_array = torch.from_numpy(weight).to(torch_dtype).reshape(expected_shape)
32 | return torch_array
33 |
--------------------------------------------------------------------------------
/kithara/utils/tree_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import jax
18 |
19 | def tree_path_to_string(path, sep=None):
20 | keys = []
21 | for key in path:
22 | if isinstance(key, jax.tree_util.SequenceKey):
23 | keys.append(str(key.idx))
24 | elif isinstance(key, jax.tree_util.DictKey):
25 | keys.append(str(key.key))
26 | elif isinstance(key, jax.tree_util.GetAttrKey):
27 | keys.append(str(key.name))
28 | elif isinstance(key, jax.tree_util.FlattenedIndexKey):
29 | keys.append(str(key.key))
30 | else:
31 | keys.append(str(key))
32 | if sep is None:
33 | return tuple(keys)
34 | return sep.join(keys)
35 |
36 |
37 | def named_tree_map(f, tree, *rest, is_leaf=None, sep=None):
38 | """An extended version of jax.tree_util.tree_map, where the mapped function
39 | f takes both the name (path) and the tree leaf as input.
40 | """
41 | return jax.tree_util.tree_map_with_path(
42 | lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r),
43 | tree,
44 | *rest,
45 | is_leaf=is_leaf,
46 | )
47 |
--------------------------------------------------------------------------------
/perf/kerashub.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | """
19 | Benchmark for training MaxText models via Kithara.
20 |
21 | This benchmark script runs multi-host training with the specified MaxText model
22 |
23 | Metrics: step time, Tokens/s/device
24 | Artifact: Tensorboard, Xplane (Uploaded to BASE_OUTPUT_DIR)
25 |
26 | Purpose: Compare native MaxText performance against performance of MaxText via Kithara.
27 |
28 | Launch Script: python ray/submit_job.py "python perf/kerashub.py"
29 |
30 | TODO: Launch benchmarks via YAML config.
31 | """
32 |
33 | import ray
34 |
35 | def run_benchmark():
36 | import os
37 | os.environ["KERAS_BACKEND"] = "jax"
38 | import keras
39 | from examples.example_datasets import example_datasets
40 | from kithara import KerasHubModel
41 | from kithara.dataset import Dataloader, TextCompletionDataset
42 | from kithara.trainer import Trainer
43 | from kithara.distributed import PredefinedShardingStrategy
44 | from kithara.callbacks import Profiler
45 |
46 | # Run parameters
47 | BASE_OUTPUT_DIR = "gs://wenxindong-vm/jan24/perf" # MODIFY with your GS bucket
48 | MODEL_HANDLE = "hf://google/gemma-2-9b"
49 | SEQ_LEN = 2048
50 | PER_DEVICE_BATCH_SIZE = 1
51 |
52 | train_data, eval_data = example_datasets(option="finetune_toy")
53 |
54 | model = KerasHubModel.from_preset(
55 | MODEL_HANDLE,
56 | precision="mixed_bfloat16",
57 | sharding_strategy=PredefinedShardingStrategy(parallelism="fsdp", model="gemma2-9b"),
58 | remat_mode=None,
59 | flash_attention=True,
60 | )
61 |
62 | # Create Keras optimizer
63 | optimizer = keras.optimizers.AdamW(
64 | learning_rate=5e-5,
65 | weight_decay=0.01,
66 | )
67 |
68 | # Create Dataset
69 | train_ds = TextCompletionDataset(
70 | source = train_data,
71 | tokenizer_handle=MODEL_HANDLE,
72 | max_seq_len=SEQ_LEN,
73 | )
74 |
75 | # Create Dataloader
76 | train_dataloader = Dataloader(train_ds, per_device_batch_size=PER_DEVICE_BATCH_SIZE)
77 |
78 | # Create Xprof Profiler
79 | profiler = Profiler(
80 | mode="xplane",
81 | output_path=BASE_OUTPUT_DIR,
82 | max_profile_steps=5,
83 | skip_first_n_steps=5,
84 | optional_postfix="kerashub",
85 | )
86 |
87 | # Initialize trainer
88 | trainer = Trainer(
89 | model=model,
90 | optimizer=optimizer,
91 | train_dataloader=train_dataloader,
92 | steps=10,
93 | log_steps_interval=1,
94 | tensorboard_dir=BASE_OUTPUT_DIR,
95 | profiler=profiler,
96 | )
97 |
98 | # Start training
99 | trainer.train()
100 |
101 |
102 | if __name__ == "__main__":
103 | ray.init()
104 |
105 | num_chips_per_host = 4 # 4 for v4 and v5, 8 for v4e and v5e
106 | num_tpu_devices = int(ray.cluster_resources()["TPU"])
107 | num_tpu_hosts = num_tpu_devices // num_chips_per_host
108 |
109 | @ray.remote(resources={"TPU": num_chips_per_host})
110 | def main():
111 | import sys
112 |
113 | # Add the MaxText directory to the Python path
114 | maxtext_dir = "maxtext/MaxText"
115 | sys.path.append(maxtext_dir)
116 |
117 | import subprocess
118 | subprocess.run(["pip install -e .[tpu] --no-deps"], shell=True)
119 | subprocess.run(["pip uninstall keras keras-hub keras-nlp"], shell=True)
120 | subprocess.run(["pip install keras-hub"], shell=True)
121 | subprocess.run(["pip install --upgrade keras-nightly"], shell=True)
122 | subprocess.run(["pip install git+https://github.com/divyashreepathihalli/keras-nlp.git@gemma_fs_bool"], shell=True)
123 | run_benchmark()
124 |
125 | ray.get([main.remote() for _ in range(num_tpu_hosts)])
126 |
127 | ray.shutdown()
128 |
--------------------------------------------------------------------------------
/perf/maxtext_native.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | """
18 | Benchmark for training MaxText models via MaxText.
19 |
20 | This benchmark script runs multi-host training with the specified MaxText model.
21 |
22 | Metrics: step time, TFLOP/s/device, Tokens/s/device
23 | Artifact: Tensorboard, Xplane (Uploaded to BASE_OUTPUT_DIR)
24 |
25 | Purpose: Compare native MaxText performance against performance of MaxText via Kithara.
26 |
27 | Launch Script: python ray/submit_job.py "python perf/maxtext_native.py"
28 |
29 | TODO: Launch benchmarks via YAML config.
30 | """
31 |
32 | import ray
33 |
34 | if __name__ == "__main__":
35 | ray.init()
36 |
37 | num_chips_per_host = 4 # 4 for v4 and v5, 8 for v4e and v5e
38 | num_tpu_devices = int(ray.cluster_resources()["TPU"])
39 | num_tpu_hosts = num_tpu_devices // num_chips_per_host
40 |
41 | print(f"{num_tpu_devices=}")
42 | print(f"{num_tpu_hosts=}")
43 |
44 | @ray.remote(resources={"TPU": num_chips_per_host})
45 | def main():
46 | import sys
47 | import subprocess
48 |
49 | # Add the MaxText directory to the Python path
50 | maxtext_dir = "maxtext/MaxText"
51 | sys.path.append(maxtext_dir)
52 |
53 | # Run parameters
54 | BASE_OUTPUT_DIR = "GS_BUCKET" # MODIFY with your GS bucket
55 | MODEL_NAME = "gemma2-9b"
56 | SEQ_LEN = 2048
57 | PER_DEVICE_BATCH_SIZE = 1
58 |
59 | subprocess.call(
60 | [
61 | "python",
62 | "maxtext/MaxText/train.py",
63 | "kithara/model/maxtext/maxtext/MaxText/configs/base.yml",
64 | f"model_name={MODEL_NAME}",
65 | "run_name=maxtext_native",
66 | f"max_target_length={SEQ_LEN}",
67 | f"per_device_batch_size={PER_DEVICE_BATCH_SIZE}",
68 | "steps=10",
69 | "enable_checkpointing=false",
70 | "dataset_type=synthetic",
71 | "profiler=xplane",
72 | "scan_layers=false",
73 | "skip_first_n_steps_for_profiler=5",
74 | "gcs_metrics=true",
75 | "profiler_steps=5",
76 | "remat_policy=minimal",
77 | "attention=flash",
78 | f"base_output_directory={BASE_OUTPUT_DIR}",
79 | "dataset_path=gs://max-datasets-rogue/",
80 | ]
81 | )
82 |
83 | ray.get([main.remote() for _ in range(num_tpu_hosts)])
84 |
85 | ray.shutdown()
86 |
--------------------------------------------------------------------------------
/perf/maxtext_via_kithara.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | """
19 | Benchmark for training MaxText models via Kithara.
20 |
21 | This benchmark script runs multi-host training with the specified MaxText model
22 |
23 | Metrics: step time, Tokens/s/device
24 | Artifact: Tensorboard, Xplane (Uploaded to BASE_OUTPUT_DIR)
25 |
26 | Purpose: Compare native MaxText performance against performance of MaxText via Kithara.
27 |
28 | Launch Script: python ray/submit_job.py "python perf/maxtext_via_kithara.py"
29 |
30 | TODO: Launch benchmarks via YAML config.
31 | """
32 |
33 | import ray
34 |
35 | def run_benchmark():
36 |
37 | import os
38 |
39 | os.environ["KERAS_BACKEND"] = "jax"
40 | import keras
41 | from examples.example_datasets import example_datasets
42 | from kithara import MaxTextModel, Trainer, Dataloader, TextCompletionDataset
43 | from kithara.callbacks import Profiler
44 |
45 | # Run parameters
46 | BASE_OUTPUT_DIR = "GS_BUCKET" # MODIFY with your GS bucket
47 | MODEL_NAME = "gemma2-9b"
48 | SEQ_LEN = 2048
49 | PER_DEVICE_BATCH_SIZE = 1
50 |
51 | train_data, eval_data = example_datasets(option="finetune_toy")
52 |
53 | # Create a randomly initialized MaxText Model
54 | model = MaxTextModel.from_random(
55 | model_name=MODEL_NAME,
56 | seq_len=SEQ_LEN,
57 | per_device_batch_size=PER_DEVICE_BATCH_SIZE,
58 | scan_layers=False,
59 | maxtext_config_args={"remat_policy": "minimal"},
60 | )
61 |
62 | # Create Keras optimizer
63 | optimizer = keras.optimizers.AdamW(
64 | learning_rate=5e-5,
65 | weight_decay=0.01,
66 | )
67 |
68 | # Create Dataset
69 | train_ds = TextCompletionDataset(
70 | source=train_data,
71 | tokenizer_handle="hf://google/gemma-2-2b",
72 | max_seq_len=SEQ_LEN,
73 | )
74 |
75 | # Create Dataloader
76 | train_dataloader = Dataloader(train_ds, per_device_batch_size=PER_DEVICE_BATCH_SIZE)
77 |
78 | # Create Xprof Profiler
79 | profiler = Profiler(
80 | mode="xplane",
81 | output_path=BASE_OUTPUT_DIR,
82 | max_profile_steps=5,
83 | skip_first_n_steps=5,
84 | optional_postfix="maxtext_via_kithara",
85 | )
86 |
87 | # Initialize trainer
88 | trainer = Trainer(
89 | model=model,
90 | optimizer=optimizer,
91 | train_dataloader=train_dataloader,
92 | steps=10,
93 | log_steps_interval=1,
94 | tensorboard_dir=BASE_OUTPUT_DIR,
95 | profiler=profiler,
96 | )
97 |
98 | # Start training
99 | trainer.train()
100 |
101 |
102 | if __name__ == "__main__":
103 | ray.init()
104 |
105 | num_chips_per_host = 4 # 4 for v4 and v5, 8 for v4e and v5e
106 | num_tpu_devices = int(ray.cluster_resources()["TPU"])
107 | num_tpu_hosts = num_tpu_devices // num_chips_per_host
108 |
109 | @ray.remote(resources={"TPU": num_chips_per_host})
110 | def main():
111 | import sys
112 |
113 | # Add the MaxText directory to the Python path
114 | maxtext_dir = "maxtext/MaxText"
115 | sys.path.append(maxtext_dir)
116 |
117 | run_benchmark()
118 |
119 | ray.get([main.remote() for _ in range(num_tpu_hosts)])
120 |
121 | ray.shutdown()
122 |
--------------------------------------------------------------------------------
/perf/sft_unsloth.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | """
18 | Benchmark for SFT via unsloth.
19 |
20 | This benchmark script runs supervised fine tuning using LoRA with the
21 | specified model.
22 |
23 | Metrics: train_samples_per_second, eval_samples_per_second
24 |
25 | Purpose: Compare performance of SFT using unsloth.
26 |
27 | set up:
28 | - Create 2 X A100 40Gb GPU
29 | - pip install unsloth tensorboardx
30 |
31 | Launch Script: python kithara/perf/sft_unsloth.py
32 | """
33 |
34 | from unsloth import FastLanguageModel
35 | import torch
36 |
37 | max_seq_length = 4096
38 |
39 | model, tokenizer = FastLanguageModel.from_pretrained(
40 | model_name="unsloth/gemma-2-2b",
41 | max_seq_length=max_seq_length,
42 | dtype=torch.bfloat16,
43 | load_in_4bit=False,
44 | )
45 |
46 | model = FastLanguageModel.get_peft_model(
47 | model,
48 | r = 16,
49 | target_modules = ["q_proj", "v_proj"],
50 | lora_alpha = 16,
51 | lora_dropout = 0,
52 | bias = "none",
53 | # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
54 | use_gradient_checkpointing = False, # True or "unsloth" for very long context
55 | random_state = 3407
56 | )
57 |
58 | alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
59 |
60 | ### Instruction:
61 | {}
62 |
63 | ### Input:
64 | {}
65 |
66 | ### Response:
67 | {}"""
68 |
69 | def formatting_prompts_func(examples):
70 | EOS_TOKEN = tokenizer.eos_token
71 | instructions = examples["instruction"]
72 | inputs = examples["input"]
73 | outputs = examples["output"]
74 | texts = []
75 | for instruction, input, output in zip(instructions, inputs, outputs):
76 | # Must add EOS_TOKEN, otherwise your generation will go on forever!
77 | text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
78 | texts.append(text)
79 | return { "text" : texts, }
80 |
81 | from datasets import load_dataset
82 | dataset = load_dataset("yahma/alpaca-cleaned", split = "train")
83 | dataset = dataset.map(formatting_prompts_func, batched = True,)
84 | datasets = dataset.train_test_split(test_size=64, shuffle=False)
85 | train_dataset, test_dataset = datasets["train"], datasets["test"]
86 |
87 | from trl import SFTTrainer
88 | from transformers import TrainingArguments
89 | from unsloth import is_bfloat16_supported
90 |
91 |
92 | trainer = SFTTrainer(
93 | model=model,
94 | tokenizer=tokenizer,
95 | train_dataset=train_dataset,
96 | eval_dataset=test_dataset,
97 | dataset_text_field="text",
98 | max_seq_length=max_seq_length,
99 | dataset_num_proc=2,
100 | packing=True, # Can make training 5x faster for short sequences.
101 | args=TrainingArguments(
102 | per_device_train_batch_size=2,
103 | do_predict=True,
104 | per_device_eval_batch_size=2,
105 | max_steps=500,
106 | eval_strategy="steps",
107 | eval_steps=100,
108 | learning_rate=2e-4,
109 | fp16=not is_bfloat16_supported(),
110 | bf16=is_bfloat16_supported(), #This is True
111 | logging_steps=10,
112 | optim="adamw_torch",
113 | weight_decay=0.01,
114 | lr_scheduler_type="constant",
115 | seed=3407,
116 | output_dir="unsloth_outputs",
117 | report_to="tensorboard",
118 | ),
119 | )
120 |
121 | trainer_stats = trainer.train()
122 |
123 | FastLanguageModel.for_inference(model)
124 | inputs = tokenizer(
125 | [
126 | alpaca_prompt.format(
127 | "Continue the fibonnaci sequence.", # instruction
128 | "1, 1, 2, 3, 5, 8", # input
129 | "", # output - leave this blank for generation!
130 | )
131 | ], return_tensors = "pt").to("cuda")
132 |
133 | outputs = model.generate(**inputs, max_new_tokens = 64, use_cache = True)
134 | print("generated", tokenizer.batch_decode(outputs))
135 |
136 |
--------------------------------------------------------------------------------
/perf/trl_packing.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | """
18 | Benchmark for full parameter finetuning via trl.
19 |
20 | Metrics: train_samples_per_second
21 |
22 | Purpose: Compare perforance of full parameter finetuning using trl.
23 |
24 | Launch Script: python kithara/perf/trl_packing.py
25 | """
26 |
27 |
28 | """
29 | Set up:
30 | pip install trl peft transformers datasets hf_transfer huggingface_hub[cli]
31 |
32 | Run this script on single host: HF_HUB_ENABLE_HF_TRANSFER=1 python kithara_trl_comparison/trl_train.py
33 | """
34 |
35 | import torch._dynamo
36 | torch._dynamo.config.suppress_errors = True
37 |
38 | from datasets import load_dataset
39 | import torch
40 | from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
41 | from peft import LoraConfig
42 | from trl import SFTTrainer, SFTConfig
43 |
44 | model_id = "google/gemma-2-2b"
45 | tokenizer = AutoTokenizer.from_pretrained(model_id)
46 | alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
47 |
48 | ### Instruction:
49 | {}
50 |
51 | ### Input:
52 | {}
53 |
54 | ### Response:
55 | {}"""
56 |
57 | def formatting_prompts_func(examples):
58 | EOS_TOKEN = tokenizer.eos_token
59 | instructions = examples["instruction"]
60 | inputs = examples["input"]
61 | outputs = examples["output"]
62 | texts = []
63 | for instruction, input, output in zip(instructions, inputs, outputs):
64 | # Must add EOS_TOKEN, otherwise your generation will go on forever!
65 | text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
66 | texts.append(text)
67 | return { "text" : texts, }
68 |
69 | from datasets import load_dataset
70 | dataset = load_dataset("yahma/alpaca-cleaned", split = "train")
71 | dataset = dataset.map(formatting_prompts_func, batched = True,)
72 | datasets = dataset.train_test_split(test_size=200, shuffle=False)
73 | train_dataset, test_dataset = datasets["train"], datasets["test"]
74 |
75 | model = AutoModelForCausalLM.from_pretrained(
76 | model_id, torch_dtype=torch.bfloat16
77 | )
78 |
79 | trainer = SFTTrainer(
80 | model = model,
81 | tokenizer = tokenizer,
82 | train_dataset=dataset,
83 | eval_dataset = test_dataset,
84 | args = SFTConfig(
85 | dataset_text_field="text",
86 | per_device_train_batch_size=2,
87 | max_steps=1000,
88 | logging_steps=10,
89 | # eval_strategy="steps",
90 | # eval_steps=100,
91 | learning_rate = 2e-4,
92 | optim = "adamw_torch",
93 | lr_scheduler_type="constant",
94 | bf16=True,
95 | seed = 3407,
96 | report_to = "tensorboard",
97 | max_seq_length = 8192,
98 | dataset_num_proc = 1,
99 | packing = False,
100 | output_dir = "trl_packing/no_packing"
101 | ),
102 | )
103 |
104 | trainer.train()
105 |
106 | inputs = tokenizer(
107 | [
108 | alpaca_prompt.format(
109 | "Continue the fibonnaci sequence.", # instruction
110 | "1, 1, 2, 3, 5, 8", # input
111 | "", # output - leave this blank for generation!
112 | )
113 | ], return_tensors = "pt").to("cuda")
114 |
115 | outputs = model.generate(**inputs, max_new_tokens = 64, use_cache = True)
116 | print("generated", tokenizer.batch_decode(outputs))
117 |
118 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["flit_core >=3.2,<4"]
3 | build-backend = "flit_core.buildapi"
4 |
5 | [project]
6 | name = "kithara"
7 | description = "LLM post-training library"
8 | version = "0.0.10"
9 | readme = "README.md"
10 | license = { file = 'LICENSE' }
11 | authors = [{ name = 'Kithara Authors' }]
12 | requires-python = ">= 3.11"
13 | dependencies = [
14 | "flax>=0.7.0",
15 | "datasets",
16 | "huggingface-hub",
17 | "keras>=3.8.0",
18 | "transformers>=4.45.1",
19 | "keras-hub>=0.18.1",
20 | "google-api-python-client",
21 | "google-auth-httplib2",
22 | "google-auth-oauthlib",
23 | "ray[default]==2.40.0",
24 | "jax[cpu]",
25 | "peft",
26 | "hf_transfer",
27 | "tabulate",
28 | "cryptography",
29 | "huggingface_hub",
30 | # MaxText dependencies
31 | "aqtp",
32 | "grain-nightly",
33 | "orbax-checkpoint>=0.10.3",
34 | "google-cloud-logging",
35 | "tensorboardx",
36 | "ml-collections",
37 | "tensorflow_datasets",
38 | "sentencepiece",
39 | "tiktoken",
40 | "cloud-accelerator-diagnostics",
41 | "cloud-tpu-diagnostics",
42 | "ml-goodput-measurement",
43 | "google-cloud-monitoring",
44 | "omegaconf",
45 | # Jetstream dependencies
46 | "setuptools==61.0",
47 | "jaxtyping",
48 | "clu",
49 | "editdistance",
50 | "pyglove",
51 | "tensorflow_datasets",
52 | "tfds-nightly" #"tfds-nightly==4.9.8.dev202503240044"
53 | ]
54 |
55 | [project.optional-dependencies]
56 | tpu = ["jax[tpu]", "torch==2.4.0+cpu"]
57 | gpu = ["jax[cuda]", "torch==2.4.0"]
58 | cpu = ["jax[cpu]", "torch==2.4.0"]
59 | dev = [
60 | "twine",
61 | "flit",
62 | "sphinx==8.2.0",
63 | "sphinx-autobuild",
64 | "sphinxawesome-theme>=5.3.2",
65 | "sphinx_design",
66 | ]
67 |
68 | [tool.flit.module]
69 | name = "kithara"
70 |
71 | [project.urls]
72 | Homepage = 'https://github.com/AI-Hypercomputer/kithara'
73 | Repository = 'https://github.com/AI-Hypercomputer/kithara'
74 | Documentation = "https://kithara.readthedocs.io/en/latest/index.html"
75 |
76 | [tool.flit.sdist]
77 | exclude = [
78 | "kithara/model/maxtext/maxtext/MaxText/test_assets/",
79 | "**/*.Dockerfile",
80 | ]
81 |
--------------------------------------------------------------------------------
/ray/GPU/cluster.yaml:
--------------------------------------------------------------------------------
1 | cluster_name: my-gpu-cluster # MODIFY choose your cluster name
2 | max_workers: 100
3 |
4 | provider:
5 | type: gcp
6 | region: us-central1 #MODIFY this with your region
7 | availability_zone: us-central1-a #MODIFY this with your availability zone
8 | project_id: #MODIFY this with your project id
9 |
10 | head_node_type: ray_head_node
11 |
12 | available_node_types:
13 | ray_head_node:
14 | node_config:
15 | machineType: n1-standard-8
16 | disks:
17 | - boot: true
18 | autoDelete: true
19 | type: PERSISTENT
20 | initializeParams:
21 | diskSizeGb: 50
22 | sourceImage: projects/ubuntu-os-cloud/global/images/family/ubuntu-2204-lts
23 | metadata:
24 | install-nvidia-driver: 'false'
25 | serviceAccounts:
26 | - email: # MODIFY this with your service account email
27 | scopes:
28 | - https://www.googleapis.com/auth/cloud-platform
29 | resources: {"CPU": 8}
30 | min_workers: 0
31 | max_workers: 0
32 |
33 | ray_worker_node:
34 | node_config:
35 | machineType: g2-standard-48 # Required for L4 GPUs
36 | guestAccelerators:
37 | - acceleratorType: nvidia-l4 # Use L4 GPUs
38 | acceleratorCount: 4 # Other GPU generations may have different accelerator counts
39 | disks:
40 | - boot: true
41 | autoDelete: true
42 | type: PERSISTENT
43 | initializeParams:
44 | diskSizeGb: 50
45 | sourceImage: projects/ubuntu-os-cloud/global/images/family/ubuntu-2204-lts
46 | metadata:
47 | install-nvidia-driver: 'true'
48 | serviceAccounts:
49 | - email: # MODIFY this with your service account email
50 | scopes:
51 | - https://www.googleapis.com/auth/cloud-platform
52 | scheduling:
53 | onHostMaintenance: "TERMINATE" # Disable live migration
54 | resources: {"CPU": 48, "GPU": 4} # Should match number of CPUs and GPUs of worker node
55 | min_workers: 1
56 | max_workers: 4
57 |
58 | initialization_commands:
59 | # Don't stall on ubuntu graphic...
60 | - sudo sed -i 's/#$nrconf{restart} = '"'"'i'"'"';/$nrconf{restart} = '"'"'a'"'"';/g' /etc/needrestart/needrestart.conf
61 | - sudo add-apt-repository -y ppa:deadsnakes/ppa
62 | - sudo apt-get update
63 | - sudo apt-get install -y python3.11
64 | - sudo apt-get install -y python3-pip python-is-python3
65 | - sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1
66 | - sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1
67 | - echo 'export PATH="$HOME/.local/bin:$PATH"' >> ~/.bashrc
68 | - source ~/.bashrc
69 | - python -m pip install --upgrade pip
70 |
71 | # Install dependencies needed by both the head and worker nodes
72 | # If you missed some dependencies during the set up process, you can
73 | # install them later during runtime.
74 | setup_commands:
75 | - pip install "ray[default]==2.40.0"
76 |
77 | # Install dependecies needed only by the head node
78 | head_setup_commands:
79 | - pip install google-api-python-client
80 | - pip install -U kithara[cpu]
81 |
82 | # Ensure Python, CUDA, and PyTorch are installed on the worker nodes
83 | worker_setup_commands:
84 | # Install CUDA
85 | - wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-ubuntu2204.pin
86 | - sudo mv cuda-ubuntu2204.pin /etc/apt/preferences.d/cuda-repository-pin-600
87 | - wget https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb
88 | - sudo dpkg -i cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb
89 | - sudo cp /var/cuda-repo-ubuntu2204-12-4-local/cuda-*-keyring.gpg /usr/share/keyrings/
90 | - sudo apt-get update
91 | - sudo apt-get -y install cuda-toolkit-12-4
92 | - sudo apt-get install -y cuda-drivers-550
93 | - sudo nvidia-smi # allows for sudoless nvidia-smi
94 | - pip install -U kithara[gpu]
--------------------------------------------------------------------------------
/ray/GPU/find_devices.py:
--------------------------------------------------------------------------------
1 | # Test your GPU Ray cluster set up with
2 | # ray job submit --working-dir . -- python3 ray/GPU/find_devices.py
3 |
4 | import ray
5 |
6 | ray.init()
7 |
8 | num_chips_per_host = 4
9 | num_gpu_devices = int(ray.cluster_resources()["GPU"])
10 | print(f"{num_gpu_devices=}")
11 |
12 | @ray.remote(num_gpus=num_chips_per_host)
13 | def main():
14 | import jax
15 | print(jax.devices())
16 |
17 | ray.get([main.remote() for _ in range(num_gpu_devices)])
18 |
19 | ray.shutdown()
20 |
--------------------------------------------------------------------------------
/ray/TPU/GCE/cluster.yaml:
--------------------------------------------------------------------------------
1 | # A unique identifier for the head node and workers of this cluster.
2 | cluster_name: ray-cluster # MODIFY choose your cluster name
3 | max_workers: 1
4 | available_node_types:
5 | # The ray head node is a CPU node
6 | ray_head_default:
7 | min_workers: 0
8 | max_workers: 0
9 | resources: { "CPU": 0 } # Don't change this
10 | # Provider-specific config for this node type, e.g. instance type. By default
11 | # Ray will auto-configure unspecified fields such as subnets and ssh-keys.
12 | # For more documentation on available fields, see:
13 | # https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert
14 | node_config:
15 | machineType: n1-standard-16
16 | disks:
17 | - boot: true
18 | autoDelete: true
19 | type: PERSISTENT
20 | initializeParams:
21 | diskSizeGb: 50 # Increase this if you need more disk space
22 | # See https://cloud.google.com/compute/docs/images for more images
23 | sourceImage: projects/ubuntu-os-cloud/global/images/family/ubuntu-2204-lts
24 | # Uncomment and modify the following if you want to use a custom VPC network
25 |
26 | # networkInterfaces:
27 | # - kind: compute#networkInterface
28 | # subnetwork: https://www.googleapis.com/compute/v1/projects/{{project_name}}/regions/{{zone}}/subnetworks/{{subnet_name}}
29 | # accessConfigs:
30 | # - kind: compute#accessConfig
31 | # networkTier: STANDARD
32 | # name: "External NAT"
33 | # type: "ONE_TO_ONE_NAT"
34 |
35 | # Additional supported GCP configs: https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert
36 |
37 | # They ray worker nodes are TPU nodes
38 | ray_tpu:
39 | min_workers: 1
40 | max_workers: 1
41 | resources: { "TPU": 4 } # MODIFY this to be the number of chips per host for your TPU type. Should be 4 for v4 and v5 and 8 for v4e and v5e
42 | node_config:
43 | acceleratorType: v4-16 # MODIFY this to be your TPU type
44 | runtimeVersion: tpu-ubuntu2204-base
45 | provider: # MODIFY this section with your provider spec
46 | type: gcp
47 | region: us-central2 #MODIFY this with your region
48 | availability_zone: us-central2-b #MODIFY this with your availability_zone
49 | project_id: project_id #MODIFY this with your project id
50 |
51 | initialization_commands:
52 | # Don't stall on ubuntu graphic...
53 | - sudo sed -i 's/#$nrconf{restart} = '"'"'i'"'"';/$nrconf{restart} = '"'"'a'"'"';/g' /etc/needrestart/needrestart.conf
54 | - sudo add-apt-repository -y ppa:deadsnakes/ppa
55 | - sudo apt-get update
56 | - sudo apt-get install -y python3.11
57 | - sudo apt-get install -y python3-pip python-is-python3
58 | - sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1
59 | - sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1
60 | - echo 'export PATH="$HOME/.local/bin:$PATH"' >> ~/.bashrc
61 | - source ~/.bashrc
62 | - python -m pip install --upgrade pip
63 |
64 |
65 | # Install dependencies needed by both the head and worker nodes
66 | # If you missed some dependencies during the set up process, you can
67 | # install them later during runtime.
68 | setup_commands:
69 | - pip install "ray[default]==2.40.0"
70 |
71 | # Install dependecies needed only by the head node
72 | head_setup_commands:
73 | - pip install google-api-python-client
74 | - pip install -U kithara[cpu]
75 |
76 | # Install dependencies needed only by the worker node
77 | worker_setup_commands:
78 | - pip install -U kithara[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://download.pytorch.org/whl/cpu
79 |
80 | # Specify the node type of the head node (as configured above).
81 | head_node_type: ray_head_default
--------------------------------------------------------------------------------
/ray/TPU/GKE/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM rayproject/ray:2.42.0-py311
2 |
3 | RUN python -m pip install --upgrade pip
4 |
5 | RUN pip install -U kithara[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://download.pytorch.org/whl/cpu
6 |
7 | USER root
8 |
--------------------------------------------------------------------------------
/ray/TPU/QR/cluster.yaml:
--------------------------------------------------------------------------------
1 | # A unique identifier for the head node and workers of this cluster.
2 | cluster_name: wenxindong-qr-cluster-dws-4 # MODIFY choose your cluster name
3 | max_workers: 1
4 | available_node_types:
5 | # The ray head node is a CPU node
6 | ray_head_default:
7 | min_workers: 0
8 | max_workers: 0
9 | resources: { "CPU": 0 } # Don't change this
10 | # Provider-specific config for this node type, e.g. instance type. By default
11 | # Ray will auto-configure unspecified fields such as subnets and ssh-keys.
12 | # For more documentation on available fields, see:
13 | # https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert
14 | node_config:
15 | machineType: n1-standard-16
16 | disks:
17 | - boot: true
18 | autoDelete: true
19 | type: PERSISTENT
20 | initializeParams:
21 | diskSizeGb: 50 # Increase this if you need more disk space
22 | # See https://cloud.google.com/compute/docs/images for more images
23 | sourceImage: projects/ubuntu-os-cloud/global/images/family/ubuntu-2204-lts
24 |
25 | # Uncomment and modify the following if you want to use a custom VPC network
26 |
27 | # networkInterfaces:
28 | # - kind: compute#networkInterface
29 | # subnetwork: https://www.googleapis.com/compute/v1/projects/{{project_name}}/regions/{{zone}}/subnetworks/{{subnet_name}}
30 | # accessConfigs:
31 | # - kind: compute#accessConfig
32 | # networkTier: STANDARD
33 | # name: "External NAT"
34 | # type: "ONE_TO_ONE_NAT"
35 |
36 | # Additional supported GCP configs: https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert
37 |
38 | provider: # MODIFY this section with your provider spec
39 | type: gcp
40 | region: us-west4 #MODIFY this with your region
41 | availability_zone: us-west4-a #MODIFY this with your availability_zone
42 | project_id: tpu-prod-env-multipod #MODIFY this with your project id
43 |
44 | initialization_commands:
45 | # Don't stall on ubuntu graphic...
46 | - sudo sed -i 's/#$nrconf{restart} = '"'"'i'"'"';/$nrconf{restart} = '"'"'a'"'"';/g' /etc/needrestart/needrestart.conf
47 | - sudo add-apt-repository -y ppa:deadsnakes/ppa
48 | - sudo apt-get update
49 | - sudo apt-get install -y python3.11
50 | - sudo apt-get install -y python3-pip python-is-python3
51 | - sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1
52 | - sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1
53 | - echo 'export PATH="$HOME/.local/bin:$PATH"' >> ~/.bashrc
54 | - source ~/.bashrc
55 | - python -m pip install --upgrade pip
56 |
57 | # Install dependencies needed by both the head and worker nodes
58 | # If you missed some dependencies during the set up process, you can
59 | # install them later during runtime.
60 | setup_commands:
61 | - pip install "ray[default]==2.40.0"
62 |
63 | # Install dependecies needed only by the head node
64 | head_setup_commands:
65 | - pip install google-api-python-client
66 | - pip install -U kithara[cpu]==0.0.8
67 |
68 | # Specify the node type of the head node (as configured above).
69 | head_node_type: ray_head_default
--------------------------------------------------------------------------------
/ray/TPU/QR/qr_worker_startup_script.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 |
4 | # Function to check and release dpkg lock
5 | check_dpkg_lock() {
6 | if sudo fuser -s /var/lib/dpkg/lock-frontend ; then
7 | sudo systemctl stop unattended-upgrades
8 | fi
9 | }
10 |
11 | check_dpkg_lock
12 | sudo sed -i 's/#$nrconf{restart} = '"'"'i'"'"';/$nrconf{restart} = '"'"'a'"'"';/g' /etc/needrestart/needrestart.conf
13 |
14 | check_dpkg_lock
15 | sudo add-apt-repository -y ppa:deadsnakes/ppa
16 |
17 | check_dpkg_lock
18 | sudo apt-get update
19 |
20 | check_dpkg_lock
21 | sudo apt-get install -y python3.11
22 |
23 | check_dpkg_lock
24 | sudo apt-get install -y python3-pip python-is-python3
25 |
26 | check_dpkg_lock
27 | sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1
28 | sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1
29 | echo 'export PATH="$HOME/.local/bin:$PATH"' >> ~/.bashrc
30 | source ~/.bashrc
31 |
32 | check_dpkg_lock
33 | sudo python -m pip install --upgrade pip
34 |
35 | check_dpkg_lock
36 | sudo pip install -U kithara[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://download.pytorch.org/whl/cpu
37 |
--------------------------------------------------------------------------------
/ray/submit_job.py:
--------------------------------------------------------------------------------
1 | from ray.job_submission import JobSubmissionClient, JobStatus
2 | import time
3 | import argparse
4 |
5 |
6 | def wait_until_status(client, job_id, status_to_wait_for, timeout_seconds=60):
7 | start = time.time()
8 | while time.time() - start <= timeout_seconds:
9 | status = client.get_job_status(job_id)
10 | print(f"status: {status}")
11 | if status in status_to_wait_for:
12 | break
13 | time.sleep(1)
14 |
15 |
16 | def parse_args():
17 | parser = argparse.ArgumentParser(
18 | description="Submit a Ray job with specified entrypoint"
19 | )
20 | parser.add_argument(
21 | "job_file",
22 | type=str,
23 | help='The entrypoint script to run (e.g., "python examples/multihost/ray/TPU/sft_lora_example.py")',
24 | )
25 | parser.add_argument(
26 | "--ray-address",
27 | type=str,
28 | default="http://localhost:8265",
29 | help="Ray cluster address (default: http://localhost:8265)",
30 | )
31 | parser.add_argument(
32 | "--hf-token",
33 | type=str,
34 | default="",
35 | help="Optional. Your Huggingface token if HuggingFace login is need. ",
36 | )
37 |
38 | return parser.parse_args()
39 |
40 |
41 | def main():
42 | args = parse_args()
43 | client = JobSubmissionClient(args.ray_address)
44 | print("Initialized job client")
45 |
46 | job_id = client.submit_job(
47 | entrypoint=args.job_file,
48 | runtime_env={
49 | "working_dir": "./",
50 | # These files are too large and are safe to exclude
51 | "excludes": ["**/test_assets/", ".git/", "*.Dockerfile", "*.pkl"],
52 | "pip": [], # Any missing dependencies to run job goes here
53 | "env_vars": {"HF_TOKEN": args.hf_token, "HF_HUB_ENABLE_HF_TRANSFER": "1"},
54 | },
55 | )
56 |
57 | print(f"Submitted job with {job_id=}")
58 |
59 | wait_until_status(
60 | client, job_id, {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED}
61 | )
62 |
63 | # See full logs in the Ray Dashboard
64 | logs = client.get_job_logs(job_id)
65 | print(logs)
66 |
67 | if __name__ == "__main__":
68 |
69 | main()
70 |
71 |
--------------------------------------------------------------------------------
/tests/README.md:
--------------------------------------------------------------------------------
1 | # Testing Guide
2 |
3 | ## Overview
4 | This guide explains how to add and run unit tests. All tests are located in the `/tests` directory, which mirrors the structure of the `/kithara` directory. Test files are excluded from the distributed wheel package.
5 |
6 | ## Test Discovery
7 | Tests are automatically discovered by `unittest` when they meet these criteria:
8 | - Files are named with the pattern `test_*.py`
9 | - Test functions start with `test_`
10 |
11 | ## Test Categories
12 | We use three categories of tests, controlled by environment variables:
13 |
14 | ### 1. Light Tests
15 | - Default test category
16 | - Quick to execute (under 3 minutes)
17 | - Sanity checks for basic correctness
18 | - No special decorator needed
19 |
20 | ### 2. Heavy Tests
21 | - Takes longer than 3 minutes to execute
22 | - Should run before merging to main branch
23 | - Requires following decorator:
24 | ```python
25 | @unittest.skipIf(int(os.getenv('RUN_LIGHT_TESTS_ONLY', 0)) == 1, "Heavy Test")
26 | ```
27 |
28 | ### 3. Manual Tests
29 | - Tests model/task correctness
30 | - Results unlikely to change
31 | - Run on-demand only
32 | - Requires decorator:
33 | ```python
34 | @unittest.skipIf(int(os.getenv('RUN_SKIPPED_TESTS', 0)) != 1, "Manual Test")
35 | ```
36 |
37 | ## Choosing Test Categories
38 |
39 | Use this decision flow to categorize your tests:
40 |
41 | 1. **Manual Test** if:
42 | - Tests numerical correctness, such as checkpoint conversion correctness.
43 | - Results should be stable across runs
44 | - Requires significant computation time
45 |
46 | 2. **Heavy Test** if:
47 | - Must run before main branch merges
48 | - Takes >3 minutes to execute
49 | - Tests critical functionality
50 |
51 | 3. **Light Test** if:
52 | - Quick to execute
53 | - Tests syntax, basic functionality
54 | - Should run frequently during development
55 |
56 | ## Environment Variables
57 |
58 | | Variable | Purpose | Values |
59 | |----------|----------|---------|
60 | | `RUN_LIGHT_TESTS_ONLY` | Run only light tests | `1` = light tests only, `0` = include heavy tests |
61 | | `RUN_SKIPPED_TESTS` | Include manual tests | `1` = run all tests, `0` = skip manual tests |
62 |
63 | ## Running Tests
64 |
65 | ### Quick Development Check (Light Tests Only)
66 | ```bash
67 | RUN_LIGHT_TESTS_ONLY=1 python -m unittest
68 | ```
69 |
70 | ### Pre-merge Check (Light + Heavy Tests)
71 | ```bash
72 | python -m unittest
73 | ```
74 |
75 | ### Full Test Suite (All Tests)
76 | ```bash
77 | RUN_SKIPPED_TESTS=1 python -m unittest
78 | ```
79 |
80 | ## Best Practices
81 |
82 | 1. **Test Organization**: Group related tests in the same file, following the source code structure
83 |
84 | 2. **Documentation**: Include docstrings at the top of every test file to describe the purpose the tests and how to run the tests in a standlone manner. E.g.
85 | ```
86 | These tests validate that the MaxText implementations produces logits that are
87 | numerically comparable to the reference HuggingFace implementation.
88 |
89 | Usage:
90 | Run script on single host VM: RUN_SKIPPED_TESTS=1 python -m unittest tests/model/maxtext/ckpt_compatibility/test_loading_models.py
91 | ```
92 |
93 | ## Contributing Tests
94 |
95 | 1. Create test file in appropriate `/tests` subdirectory
96 | 2. Choose appropriate test category
97 | 3. Add necessary decorators
98 | 4. Include clear docstrings and comments
99 | 5. Verify tests pass
100 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/tests/__init__.py
--------------------------------------------------------------------------------
/tests/callbacks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/tests/callbacks/__init__.py
--------------------------------------------------------------------------------
/tests/callbacks/test_orbax_checkpointer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | """End to end test for model checkpointing correctness
18 |
19 | Run script on single host VM: python -m unittest tests/callbacks/test_orbax_checkpointer.py
20 | """
21 |
22 | import os
23 |
24 | os.environ["KERAS_BACKEND"] = "jax"
25 |
26 | import shutil
27 | import unittest
28 | import numpy as np
29 | from kithara import KerasHubModel, MaxTextModel
30 | from kithara.callbacks.checkpointer import Checkpointer
31 | from kithara.utils.gcs_utils import find_cache_root_dir
32 | from tests.model.utils import check_arrays_match
33 | import jax
34 |
35 |
36 | class TestOrbaxCheckpointing(unittest.TestCase):
37 | @classmethod
38 | def setUpClass(cls):
39 | cls.TMP_DIR = os.path.join(find_cache_root_dir(), "test/ckpt")
40 |
41 | def setUp(self):
42 | shutil.rmtree(self.TMP_DIR, ignore_errors=True)
43 |
44 | def tearDown(self):
45 | shutil.rmtree(self.TMP_DIR, ignore_errors=True)
46 |
47 | def _test_checkpoint_save_load(self, model, model_input):
48 | """Test that model state is preserved after save and load operations."""
49 | # Get initial model output
50 | logits_before, _ = model.stateless_call(
51 | model.trainable_variables,
52 | model.non_trainable_variables,
53 | model_input,
54 | )
55 |
56 | # Save and load model
57 | checkpointer = Checkpointer(self.TMP_DIR, model=model)
58 | checkpointer.save(0, blocking=True)
59 | checkpointer.load()
60 |
61 | # Get model output after loading
62 | logits_after, _ = model.stateless_call(
63 | model.trainable_variables,
64 | model.non_trainable_variables,
65 | model_input,
66 | )
67 |
68 | # Verify outputs match within tolerance
69 | try:
70 | check_arrays_match(logits_before, logits_after, atol=0.0001)
71 | except AssertionError as e:
72 | self.fail(f"Model outputs don't match after save/load: {str(e)}")
73 |
74 | def test_kerashub_models(self):
75 | # Initialize model
76 | model = KerasHubModel.from_preset(
77 | "hf://google/gemma-2-2b",
78 | lora_rank=6
79 | )
80 |
81 | # Prepare test input
82 | model_input = {
83 | "token_ids": np.array([[1, 2, 3]]),
84 | "padding_mask": np.array([[1, 1, 1]]),
85 | }
86 | self._test_checkpoint_save_load(model, model_input)
87 |
88 | def test_maxtext_model(self):
89 | model = MaxTextModel.from_random(
90 | model_name="default", seq_len=100, per_device_batch_size=1, scan_layers=True
91 | )
92 |
93 | model_input = {
94 | "tokens": np.array([[i for i in range(100)]] * jax.device_count()),
95 | "positions": np.array([[i for i in range(100)]] * jax.device_count()),
96 | "segment_ids": np.array([[1 for _ in range(100)]] * jax.device_count()),
97 | }
98 | self._test_checkpoint_save_load(model, model_input)
99 |
100 |
101 | if __name__ == "__main__":
102 | unittest.main(verbosity=2)
103 |
--------------------------------------------------------------------------------
/tests/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/tests/dataset/__init__.py
--------------------------------------------------------------------------------
/tests/dataset/test_dataloader_creation.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | """Unit tests for creating Kithara Dataloaders
18 |
19 | Run test on a TPU VM: python -m unittest tests/dataset/test_dataloader_creation.py
20 | """
21 |
22 | import unittest
23 | from kithara import TextCompletionDataset, Dataloader
24 | import time
25 | import unittest.result
26 | import tests.dataset.utils as dataset_utils
27 | import numpy as np
28 | import os
29 |
30 | class TestDataloaderCreation(unittest.TestCase):
31 |
32 | def setUp(self):
33 | print(f"\nStarting test: {self._testMethodName}")
34 | self.start_time = time.time()
35 |
36 | def tearDown(self):
37 | duration = time.time() - self.start_time
38 | print(f"Completed test: {self._testMethodName} in {duration:.2f} seconds\n")
39 |
40 | def _run_test(self, data_source):
41 | dataset = TextCompletionDataset(
42 | data_source,
43 | tokenizer_handle="hf://google/gemma-2-2b",
44 | model_type="KerasHub",
45 | max_seq_len=100,
46 | )
47 | dataloader = Dataloader(dataset, per_device_batch_size=2)
48 | for batch in dataloader:
49 | self.assertTrue(isinstance(batch, dict))
50 | self.assertTrue("x" in batch)
51 | self.assertTrue("y" in batch)
52 | self.assertTrue("token_ids" in batch["x"])
53 | self.assertTrue("padding_mask" in batch["x"])
54 | self.assertTrue(isinstance(batch["x"]["token_ids"], np.ndarray))
55 | self.assertTrue(isinstance(batch["x"]["padding_mask"], np.ndarray))
56 | self.assertTrue(isinstance(batch["y"], np.ndarray))
57 | self.assertEqual(len(batch["x"]["token_ids"].shape), 2)
58 | break
59 |
60 | def test_dataloader_drops_last_batch(self):
61 | n = 500
62 | dict_dataset = dataset_utils.create_dict_ray_dataset(n)
63 | dataset = TextCompletionDataset(
64 | dict_dataset,
65 | tokenizer_handle="hf://google/gemma-2-2b",
66 | )
67 | dataloader = Dataloader(dataset, per_device_batch_size=n+1)
68 | self.assertEqual(len(dataloader), 0)
69 |
70 | dataloader = Dataloader(dataset, per_device_batch_size=3)
71 | batch_size = dataloader.per_host_batch_size
72 |
73 | self.assertEqual(len(dataloader), n//batch_size)
74 | total = 0
75 | for batch in dataloader:
76 | total += 1
77 | self.assertEqual(total, n//batch_size)
78 |
79 |
80 | def test_creating_dataloader_for_fixed_data(self):
81 | dict_dataset = dataset_utils.create_dict_ray_dataset()
82 | self._run_test(dict_dataset)
83 |
84 | @unittest.skipIf(int(os.getenv('RUN_LIGHT_TESTS_ONLY', 0)) == 1, "Heavy Test")
85 | def test_creating_dataloader_for_streaming_data(self):
86 | streaming_dataset = dataset_utils.create_hf_streaming_ray_dataset()
87 | self._run_test(streaming_dataset)
88 |
89 |
90 | if __name__ == "__main__":
91 | unittest.main(verbosity=2)
92 |
--------------------------------------------------------------------------------
/tests/dataset/test_dataset_creation.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | """Unit tests for creating Kithara Datasets
18 |
19 | Run test on a TPU VM: python -m unittest tests/dataset/test_dataset_creation.py
20 | """
21 |
22 | import unittest
23 | from kithara import SFTDataset, TextCompletionDataset
24 | import time
25 | import unittest.result
26 | import tests.dataset.utils as dataset_utils
27 | import numpy as np
28 | import os
29 |
30 |
31 | class TestDatasetCreation(unittest.TestCase):
32 |
33 | def setUp(self):
34 | print(f"\nStarting test: {self._testMethodName}")
35 | self.start_time = time.time()
36 |
37 | def tearDown(self):
38 | duration = time.time() - self.start_time
39 | print(f"Completed test: {self._testMethodName} in {duration:.2f} seconds\n")
40 |
41 | def _check_dataset_item_format(self, element):
42 | self.assertTrue(isinstance(element, dict))
43 | self.assertTrue("x" in element)
44 | self.assertTrue("y" in element)
45 | self.assertTrue("token_ids" in element["x"])
46 | self.assertTrue("padding_mask" in element["x"])
47 | self.assertTrue(isinstance(element["x"]["token_ids"], np.ndarray))
48 | self.assertTrue(isinstance(element["x"]["padding_mask"], np.ndarray))
49 | self.assertTrue(isinstance(element["y"], np.ndarray))
50 |
51 | def test_creating_sft_dataset_from_dict(self):
52 | dict_dataset = dataset_utils.create_dict_ray_dataset()
53 | dataset = SFTDataset(
54 | dict_dataset,
55 | tokenizer_handle="hf://google/gemma-2-2b",
56 | column_mapping={"prompt": "text", "answer": "text"},
57 | model_type="KerasHub",
58 | )
59 | self.assertIsNotNone(len(dataset))
60 | for element in dataset:
61 | self._check_dataset_item_format(element)
62 | break
63 |
64 | def test_creating_sft_dataset_fail_when_tokenizer_is_not_provided(self):
65 | dict_dataset = dataset_utils.create_dict_ray_dataset()
66 | with self.assertRaises(AssertionError):
67 | SFTDataset(dict_dataset, model_type="KerasHub")
68 |
69 | def test_creating_sft_dataset_from_csv(self):
70 | csv_dataset = dataset_utils.create_csv_ray_dataset()
71 | dataset = SFTDataset(
72 | csv_dataset,
73 | tokenizer_handle="hf://google/gemma-2-2b",
74 | column_mapping={
75 | "prompt": "sepal length (cm)",
76 | "answer": "sepal length (cm)",
77 | },
78 | model_type="KerasHub",
79 | )
80 | self.assertIsNotNone(len(dataset))
81 | for element in dataset:
82 | self._check_dataset_item_format(element)
83 | break
84 |
85 | @unittest.skipIf(int(os.getenv("RUN_LIGHT_TESTS_ONLY", 0)) == 1, "Heavy Test")
86 | def test_creating_text_completion_dataset_from_streaming_hf(self):
87 | streaming_dataset = dataset_utils.create_hf_streaming_ray_dataset()
88 | dataset = TextCompletionDataset(
89 | streaming_dataset,
90 | tokenizer_handle="hf://google/gemma-2-2b",
91 | model_type="KerasHub",
92 | )
93 | self.assertTrue(len(dataset) == 364608)
94 | for element in dataset:
95 | self._check_dataset_item_format(element)
96 | break
97 |
98 | @unittest.skipIf(int(os.getenv("RUN_LIGHT_TESTS_ONLY", 0)) == 1, "Heavy Test")
99 | def test_creating_dataset_from_raw_hf_dataset(self):
100 | from datasets import load_dataset
101 |
102 | hf_dataset = load_dataset(
103 | "allenai/c4", "en", split="validation", streaming=True
104 | )
105 | dataset = TextCompletionDataset(
106 | hf_dataset,
107 | tokenizer_handle="hf://google/gemma-2-2b",
108 | model_type="KerasHub",
109 | )
110 | self.assertTrue(len(dataset) == 364608)
111 | for element in dataset:
112 | self._check_dataset_item_format(element)
113 | break
114 |
115 |
116 | if __name__ == "__main__":
117 | unittest.main(verbosity=2)
118 |
--------------------------------------------------------------------------------
/tests/dataset/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import ray
18 | from datasets import load_dataset
19 |
20 | def create_dict_ray_dataset(n=1000):
21 | dataset_items = [
22 | {"text": f"{i} What is your name? My name is Kithara."} for i in range(n)
23 | ]
24 | return ray.data.from_items(dataset_items)
25 |
26 | def create_hf_streaming_ray_dataset():
27 | hf_val_dataset = load_dataset(
28 | "allenai/c4", "en", split="validation", streaming=True
29 | )
30 | return ray.data.from_huggingface(hf_val_dataset)
31 |
32 | def create_json_ray_dataset():
33 | ds = ray.data.read_json("s3:/import ray/anonymous@ray-example-data/log.json")
34 | return ds
35 |
36 | def create_csv_ray_dataset():
37 | ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
38 | return ds
--------------------------------------------------------------------------------
/tests/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/tests/model/__init__.py
--------------------------------------------------------------------------------
/tests/model/kerashub/ckpt_compatibility/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/tests/model/kerashub/ckpt_compatibility/__init__.py
--------------------------------------------------------------------------------
/tests/model/kerashub/test_creation.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | """Unit tests for creating a KerasHub model
18 |
19 | Run test on a TPU VM: python -m unittest tests/model/kerashub/test_creation.py
20 | """
21 | import unittest
22 | from kithara import KerasHubModel
23 | import time
24 | import unittest.result
25 |
26 | class TestModelCreation(unittest.TestCase):
27 |
28 | def setUp(self):
29 | print(f"\nStarting test: {self._testMethodName}")
30 | self.start_time = time.time()
31 |
32 | def tearDown(self):
33 | duration = time.time() - self.start_time
34 | print(f"Completed test: {self._testMethodName} in {duration:.2f} seconds\n")
35 |
36 | def test_creating_gemma2_2b_model(self):
37 | KerasHubModel.from_preset(
38 | "hf://google/gemma-2-2b"
39 | )
40 | def test_creating_gemma2_2b_model_with_lora(self):
41 | KerasHubModel.from_preset(
42 | "hf://google/gemma-2-2b",
43 | lora_rank=6
44 | )
45 |
46 | def test_creating_model_with_precision(self):
47 | model = KerasHubModel.from_preset(
48 | "hf://google/gemma-2-2b",
49 | precision="float32"
50 | )
51 | self.assertTrue(str(model.trainable_variables[0].value.dtype) == "float32")
52 |
53 | def test_creating_gemma2_9b_model(self):
54 | KerasHubModel.from_preset(
55 | "hf://google/gemma-2-9b"
56 | )
57 | def test_creating_gemma2_2b_model_with_lora(self):
58 | KerasHubModel.from_preset(
59 | "hf://google/gemma-2-9b",
60 | lora_rank=16
61 | )
62 |
63 | if __name__ == '__main__':
64 | unittest.main(verbosity=2)
65 |
--------------------------------------------------------------------------------
/tests/model/maxtext/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/tests/model/maxtext/__init__.py
--------------------------------------------------------------------------------
/tests/model/maxtext/ckpt_compatibility/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/tests/model/maxtext/ckpt_compatibility/__init__.py
--------------------------------------------------------------------------------
/tests/model/maxtext/test_creation.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | """Unit tests for creating a MaxText model
18 |
19 | Run test on a TPU VM: python -m unittest tests/model/maxtext/test_creation.py
20 | """
21 | import unittest
22 | from kithara import MaxTextModel
23 | import time
24 | import unittest.result
25 | import os
26 |
27 | class TestModelCreation(unittest.TestCase):
28 |
29 | def setUp(self):
30 | print(f"\nStarting test: {self._testMethodName}")
31 | self.start_time = time.time()
32 |
33 | def tearDown(self):
34 | duration = time.time() - self.start_time
35 | print(f"Completed test: {self._testMethodName} in {duration:.2f} seconds\n")
36 |
37 | def test_creating_random_gemma2_2b_model(self):
38 | MaxTextModel.from_random(
39 | "gemma2-2b",
40 | seq_len=100,
41 | per_device_batch_size=1,
42 | )
43 |
44 | def test_creating_preset_gemma2_2b_model(self):
45 | MaxTextModel.from_preset(
46 | "hf://google/gemma-2-2b",
47 | seq_len=100,
48 | per_device_batch_size=1,
49 | )
50 |
51 | def test_creating_preset_gemma2_2b_model_with_scan(self):
52 | MaxTextModel.from_preset(
53 | "hf://google/gemma-2-2b",
54 | seq_len=100,
55 | per_device_batch_size=1,
56 | scan_layers=True
57 | )
58 |
59 | def test_creating_random_model_with_precision(self):
60 | model = MaxTextModel.from_random(
61 | "gemma2-2b",
62 | seq_len=100,
63 | per_device_batch_size=1,
64 | precision="bfloat16",
65 | )
66 | self.assertTrue(str(model.trainable_variables[0].value.dtype) == "bfloat16")
67 |
68 | def test_creating_preset_model_with_precision(self):
69 | model = MaxTextModel.from_preset(
70 | "hf://google/gemma-2-2b",
71 | seq_len=100,
72 | per_device_batch_size=1,
73 | precision="bfloat16",
74 | )
75 | self.assertTrue(str(model.trainable_variables[0].value.dtype) == "bfloat16")
76 |
77 | def test_creating_random_gemma2_2b_model(self):
78 | MaxTextModel.from_random(
79 | "gemma2-9b",
80 | seq_len=100,
81 | per_device_batch_size=1,
82 | )
83 |
84 | @unittest.skipIf(int(os.getenv('RUN_LIGHT_TESTS_ONLY', 0)) == 1, "Heavy Test")
85 | def test_creating_preset_gemma2_9b_model(self):
86 | MaxTextModel.from_preset(
87 | "hf://google/gemma-2-9b",
88 | seq_len=100,
89 | per_device_batch_size=1,
90 | )
91 |
92 | @unittest.skipIf(int(os.getenv('RUN_LIGHT_TESTS_ONLY', 0)) == 1, "Heavy Test")
93 | def test_creating_preset_gemma2_9b_model_with_scan(self):
94 | MaxTextModel.from_preset(
95 | "hf://google/gemma-2-9b",
96 | seq_len=100,
97 | per_device_batch_size=1,
98 | scan_layers=True
99 | )
100 |
101 | if __name__ == '__main__':
102 | unittest.main(verbosity=2)
103 |
--------------------------------------------------------------------------------
/tests/model/test_prompt.py:
--------------------------------------------------------------------------------
1 | TEST_PROMPT="""# Activation dtypes.
2 | dtype: "bfloat16"
3 | # Used to configure quantization in the transformer layers, defaults to null implying bf16.
4 | # Possible alternative settings are as follows:
5 | # 'int8' for dynamic range quantization using 8-bits
6 | # 'int8w' for weights only quantization using 8-bits
7 | # 'int4w' for weights only quantization using 4-bits
8 | # 'intmp' for mixed precision weight only quantization based on config file
9 | # 'fp8' for 8-bit floating-point GeMMs on NVIDIA GPUs.
10 | quantization: ""
11 | # Choose one of default, high, and highest.
12 | # https://kolonist26-jax-kr.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
13 | matmul_precision: "default"
14 | activations_in_float32: False # Sets activations to float32 before nonlinearity it true, else dtype
15 | # Path to file with quantization config - only used for mixed precision.
16 | # Example configs in ../Maxtext/configs/quantization
17 | # Allows us to configure different bits, tiling and scale for quantizing selected weights.
18 | # Bits represents number of bits to quantize to,
19 | # tile-size represents the tiling sized used in AQT tiled_dot_general,
20 | # Value of scale is used to scale the abs_max value used for AQT quantization
21 | # Defaults values are 8 bits, tile-size=-1 (no tiling) and scale=1.
22 | quant_cfg_path: ""
23 | quantize_kvcache: False # Set to True to quantize KV Cache values, defaults to False
24 | # Valid kv_quant_axis values:
25 | # - "" is valid only when quantize_kvcache is False
26 | # - "dkv" indicates quantize kv cache over the cache_kv, i.e. kv dimension axis
27 | # - "heads_and_dkv" indicates quantize kv cache over cache_heads and cache_kv axes
28 | # Default to "heads_and_dkv" for faster compution, kv_quant_axis is not used when quantize_kvcache is False
29 | # - "dkv" is expected with better accuracy but degraded computation
30 | kv_quant_axis: "heads_and_dkv"
31 | kv_quant_dtype: "int8"
32 | checkpoint_is_quantized: False # Set to True if reading from a saved aqt quantized checkpoint
33 | # Saves params quantized on fly at following path
34 | save_quantized_params_path: ""
35 | # Shard the range finding operation for quantization. By default this is set to number of slices.
36 | quantization_local_shard_count: -1
37 | decoder_block: "llama2" # which style of DecoderBlock to use.
38 | # Global parameter scale needs to be a power of 2. If you want finer grained control of the model sizes
39 | # then you should explicitly set base_embed_dim, base_num_query_heads, base_num_kv_heads,
40 | # base_mlp_dim, base_num_decoder_layers and/or head_dim.
41 | weight_dtype: float32
42 | global_parameter_scale: 1
43 | base_emb_dim: 2048
44 | base_num_query_heads: 16
45 | base_num_kv_heads: 16
46 | base_mlp_dim: 7168
47 | base_num_decoder_layers: 16
48 | head_dim: 128
49 | mlp_activations: ["silu", "linear"]
50 | dropout_rate: 0
51 | logits_via_embedding: False
52 | normalize_embedding_logits: True # whether to normlize pre-softmax logits if logits_via_embedding is true
53 | logits_dot_in_fp32: True # whether to use fp32 in logits_dense or shared_embedding dot product for stability"""
--------------------------------------------------------------------------------
/tests/optimizers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/tests/optimizers/__init__.py
--------------------------------------------------------------------------------
/tests/optimizers/test_optax_optimizer.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import unittest.mock
3 |
4 | import jax
5 | import jax.numpy as jnp
6 | import optax
7 | from keras.src import backend
8 |
9 | from kithara.optimizers import OptaxOptimizerInterface
10 |
11 | class TestOptaxOptimizerInterface(unittest.TestCase):
12 |
13 | def setUp(self):
14 | self.dummy_keras_variables = [
15 | backend.Variable(jnp.ones((1024, 1024)), name="dummy_keras_variable")
16 | for _ in range(3)]
17 |
18 | return super().setUp()
19 |
20 | def test_learning_rate_is_none(self):
21 | optax_optimizer_interface = OptaxOptimizerInterface(
22 | optax.adam(1e-3), self.dummy_keras_variables)
23 | self.assertIsNone(optax_optimizer_interface.learning_rate)
24 |
25 | def test_optimizer_initialization(self):
26 | optax_optimizer_interface = OptaxOptimizerInterface(
27 | optax.adam(1e-3), self.dummy_keras_variables)
28 | self.assertTrue(
29 | jax.tree.reduce(
30 | lambda agg, leaf: agg and isinstance(leaf, backend.Variable),
31 | optax_optimizer_interface.variables, True))
32 |
33 | def test_stateless_apply(self):
34 | optax_optimizer_interface = OptaxOptimizerInterface(
35 | optax.adam(1e-3), self.dummy_keras_variables)
36 | new_trainables, new_opt_variables = optax_optimizer_interface.stateless_apply(
37 | optax_optimizer_interface.variables, self.dummy_keras_variables, self.dummy_keras_variables,)
38 | print(new_trainables, new_opt_variables)
--------------------------------------------------------------------------------
/tests/optimizers/test_protocol.py:
--------------------------------------------------------------------------------
1 | """Unit tests for testing the protocols integrity and the conversion of optimizers to Kithara optimizers.
2 |
3 | Run test on a TPU VM: python -m unittest tests/optimizers/test_protocol.py
4 | """
5 | import unittest
6 | import unittest.mock
7 |
8 | import keras
9 | import optax
10 |
11 | from kithara.optimizers import KitharaOptimizer, convert_to_kithara_optimizer
12 |
13 | class TestKitharaOptimizer(unittest.TestCase):
14 |
15 | def test_kithara_optimizer_interface_integrity(self):
16 |
17 | with self.assertRaises(TypeError):
18 | KitharaOptimizer()
19 | self.assertTrue(callable(KitharaOptimizer.stateless_apply))
20 | self.assertTrue(callable(KitharaOptimizer.get_optimizer_memory_usage))
21 | self.assertTrue(hasattr(KitharaOptimizer, "learning_rate"))
22 | self.assertTrue(hasattr(KitharaOptimizer, "iterations"))
23 | self.assertTrue(hasattr(KitharaOptimizer, "variables"))
24 |
25 | @unittest.mock.patch("kithara.optimizers.OptaxOptimizerInterface")
26 | def test_convert_to_kithara_optimizer(
27 | self, mocked_optax_optimizer_interface: unittest.mock.Mock):
28 |
29 | convert_to_kithara_optimizer(optax.adam(1e-3), 1)
30 | mocked_optax_optimizer_interface.assert_called_once()
31 |
32 | with self.assertRaises(NotImplementedError):
33 | convert_to_kithara_optimizer(keras.Optimizer(.01), 1.0)
34 |
35 | if __name__ == "__main__":
36 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2025 Google LLC
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | https://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import signal
18 | from functools import wraps
19 |
20 | def timeout(seconds):
21 | def decorator(func):
22 | @wraps(func)
23 | def wrapper(*args, **kwargs):
24 | def handler(signum, frame):
25 | raise TimeoutError(f"Test timed out after {seconds} seconds")
26 |
27 | # Set the timeout handler
28 | original_handler = signal.signal(signal.SIGALRM, handler)
29 | signal.alarm(seconds)
30 |
31 | try:
32 | result = func(*args, **kwargs)
33 | finally:
34 | # Restore the original handler and disable the alarm
35 | signal.alarm(0)
36 | signal.signal(signal.SIGALRM, original_handler)
37 | return result
38 | return wrapper
39 | return decorator
40 |
--------------------------------------------------------------------------------
/tests/trainer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/tests/trainer/__init__.py
--------------------------------------------------------------------------------