├── .gitignore ├── .gitmodules ├── .readthedocs.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── constraints.txt ├── docs ├── Makefile ├── README.md ├── images │ ├── kithara_logo.png │ ├── kithara_logo_no_text.png │ └── kithara_logo_with_green_bg.png ├── requirements.txt └── source │ ├── .getting_tpus.rst.swp │ ├── .overview.rst.swp │ ├── api │ ├── kithara.dataset_api.rst │ ├── kithara.model_api.rst │ └── kithara.trainer_api.rst │ ├── basic_examples │ └── full_param_finetuning_example.rst │ ├── checkpointing.rst │ ├── conf.py │ ├── datasets.rst │ ├── ddp.rst │ ├── disk_storage.rst │ ├── dws_flex_start.rst │ ├── finetuning_guide.rst │ ├── getting_tpus.rst │ ├── help_and_feedback.rst │ ├── index.rst │ ├── installation.rst │ ├── installation │ ├── .tpu_gke.rst.swp │ ├── tpu_gke.rst │ ├── tpu_qr.rst │ └── tpu_vm.rst │ ├── lora.rst │ ├── models.rst │ ├── observability.rst │ ├── optimizations.rst │ ├── optimizers.rst │ ├── overview.rst │ ├── overviewxyz.rst │ ├── packing.rst │ ├── pretraining.rst │ ├── quickstart.rst │ ├── scaling_with_ray.rst │ ├── serve_with_vllm.rst │ ├── sft.rst │ ├── temp.rst │ └── troubleshooting.rst ├── examples ├── colab │ └── SFT_with_LoRA_Gemma2-2b.ipynb ├── example_datasets.py ├── multihost │ └── ray │ │ ├── GPU │ │ └── sft_lora_example.py │ │ └── TPU │ │ ├── continued_pretraining_example.py │ │ ├── full_finetuning_example.py │ │ ├── quick_start.py │ │ ├── sft_lora_example.py │ │ └── sft_lora_example.yaml └── singlehost │ ├── continued_pretraining_example.py │ ├── full_finetuning_example.py │ ├── quick_start.py │ └── sft_lora_example.py ├── kithara ├── __init__.py ├── callbacks │ ├── __init__.py │ ├── checkpointer.py │ └── profiler.py ├── config │ ├── __init__.py │ ├── default.yaml │ └── pyconfig.py ├── dataset │ ├── __init__.py │ ├── dataloader.py │ ├── dataset.py │ ├── packed_dataset.py │ ├── sft.py │ ├── text_completion.py │ └── utils.py ├── distributed │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ └── split.py │ └── sharding │ │ ├── __init__.py │ │ ├── _data_sharding.py │ │ ├── _layout.py │ │ ├── _mesh.py │ │ ├── maxtext.py │ │ ├── models │ │ ├── __init__.py │ │ ├── gemma.py │ │ └── llama.py │ │ ├── strategy.py │ │ └── utils.py ├── model │ ├── README.md │ ├── __init__.py │ ├── hf_compatibility │ │ ├── __init__.py │ │ ├── model_configs.py │ │ ├── shape_mapping.py │ │ └── to_huggingface.py │ ├── kerashub │ │ ├── __init__.py │ │ ├── ckpt_compatibility │ │ │ ├── param_mapping.py │ │ │ └── to_huggingface.py │ │ └── keras_hub_model.py │ ├── maxtext │ │ ├── __init__.py │ │ ├── ckpt_compatibility │ │ │ ├── __init__.py │ │ │ ├── from_huggingface.py │ │ │ ├── param_mapping.py │ │ │ ├── to_huggingface.py │ │ │ └── utils.py │ │ ├── conversion_utils.py │ │ ├── inference_engine.py │ │ └── maxtext_model.py │ ├── model.py │ └── supported_models.py ├── optimizers │ ├── __init__.py │ ├── optax_optimizer.py │ └── protocol.py ├── trainer │ ├── __init__.py │ └── trainer.py └── utils │ ├── gcs_utils.py │ ├── logging_utils.py │ ├── safetensor_utils.py │ ├── torch_utils.py │ └── tree_utils.py ├── perf ├── kerashub.py ├── kithara_packing.py ├── maxtext_native.py ├── maxtext_via_kithara.py ├── sft_kithara.py ├── sft_unsloth.py └── trl_packing.py ├── pyproject.toml ├── ray ├── GPU │ ├── cluster.yaml │ └── find_devices.py ├── README.md ├── TPU │ ├── GCE │ │ └── cluster.yaml │ ├── GKE │ │ ├── Dockerfile │ │ ├── multi-host.yaml │ │ └── single-host.yaml │ └── QR │ │ ├── cluster.yaml │ │ └── qr_worker_startup_script.sh └── submit_job.py └── tests ├── README.md ├── __init__.py ├── callbacks ├── __init__.py └── test_orbax_checkpointer.py ├── dataset ├── __init__.py ├── test_dataloader_creation.py ├── test_dataset_creation.py ├── test_dataset_packing.py └── utils.py ├── model ├── __init__.py ├── kerashub │ ├── ckpt_compatibility │ │ ├── __init__.py │ │ ├── test_loading_models.py │ │ ├── test_lora_adapters_shape_match.py │ │ ├── test_lora_adapters_value_match.py │ │ └── test_saving_models.py │ ├── test_creation.py │ └── test_inference.py ├── maxtext │ ├── __init__.py │ ├── ckpt_compatibility │ │ ├── __init__.py │ │ ├── test_loading_models.py │ │ └── test_saving_models.py │ ├── test_creation.py │ └── test_inference.py ├── test_prompt.py └── utils.py ├── optimizers ├── __init__.py ├── test_optax_optimizer.py └── test_protocol.py ├── test_utils.py └── trainer ├── __init__.py ├── test_sft_e2e.py └── test_trainer_creation.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info/ 2 | *.pyc 3 | .DS_Store 4 | .ipynb_checkpoints 5 | events* 6 | **/__pycache__ 7 | docs/build 8 | tmp/ 9 | dist/ 10 | *.pyc 11 | build/ 12 | docs/bin/ 13 | docs/lib/ 14 | docs/lib64 15 | docs/pyvenv.cfg 16 | 17 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "kithara/model/maxtext/maxtext"] 2 | path = kithara/model/maxtext/maxtext 3 | url = https://github.com/google/maxtext 4 | [submodule "kithara/model/maxtext/JetStream"] 5 | path = kithara/model/maxtext/JetStream 6 | url = https://github.com/AI-Hypercomputer/JetStream.git 7 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Set the OS, Python version, and other tools you might need 8 | build: 9 | os: ubuntu-24.04 10 | tools: 11 | python: "3.11" 12 | 13 | # Build documentation in the "docs/" directory with Sphinx 14 | sphinx: 15 | configuration: docs/source/conf.py 16 | 17 | # Optionally, but recommended, 18 | # declare the Python requirements required to build your documentation 19 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 20 | python: 21 | install: 22 | - requirements: docs/requirements.txt 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Kithara - Easy Finetuning on TPUs 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/kithara)](https://pypi.org/project/kithara/) 4 | [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/AI-Hypercomputer/kithara/pulls) 5 | [![GitHub last commit](https://img.shields.io/github/last-commit/AI-Hypercomputer/kithara)](https://github.com/AI-Hypercomputer/kithara/commits/main) 6 | [![Documentation](https://img.shields.io/badge/docs-latest-brightgreen)](https://kithara.readthedocs.io/en/latest/) 7 | 8 |
9 | 10 | 11 | 12 | 13 | kithara logo 14 | 15 | 16 |
17 | 18 | ## 👋 Overview 19 | 20 | Kithara is a lightweight library offering building blocks and recipes for tuning popular open source LLMs including Gemma2 and Llama3 on Google TPUs. 21 | 22 | It provides: 23 | 24 | - **Frictionless scaling**: Distributed training abstractions intentionally built with simplicity in mind. 25 | - **Multihost training support**: Integration with Ray, GCE and GKE. 26 | - **Async, distributed checkpointing**: Multi-host & Multi-device checkpointing via Orbax. 27 | - **Distributed, streamed dataloading**: Per-process, streamed data loading via Ray.data. 28 | - **GPU/TPU fungibility**: Same code works for both GPU and TPU out of the box. 29 | - **Native integration with HuggingFace**: Tune and save models in HuggingFace format. 30 | 31 | **New to TPUs?** 32 | 33 | Using TPUs provides significant advantages in terms of performance, cost-effectiveness, and scalability, enabling faster training times and the ability to work with larger models and datasets. Check out our onboarding guide to [getting TPUs](https://kithara.readthedocs.io/en/latest/getting_tpus.html). 34 | 35 | ## 🔗 **Key links and resources** 36 | | | | 37 | | --------------------------------- | --------------------------------------------------------------------------------------------------------------------------- | 38 | | 📚 **Documentation** | [Read Our Docs](https://kithara.readthedocs.io/en/latest/) | 39 | | 💾 **Installation** | [Quick Pip Install](https://kithara.readthedocs.io/en/latest/installation.html) | 40 | | ✏️ **Get Started** | [Intro to Kithara](https://kithara.readthedocs.io/en/latest/quickstart.html) | 41 | | 🌟 **Supported Models** | [List of Models](https://kithara.readthedocs.io/en/latest/models.html) | 42 | | 🌐 **Supported Datasets** | [List of Data Formats](https://kithara.readthedocs.io/en/latest/datasets.html) | 43 | | ⌛️ **Performance Optimizations** | [Our Memory and Throughput Optimizations](https://kithara.readthedocs.io/en/latest/optimizations.html) | 44 | | 📈 **Scaling up** | [Guide for Tuning Large Models](https://kithara.readthedocs.io/en/latest/scaling_with_ray.html) | 45 | 46 | 47 | ## 🌵 **Examples** 48 | 49 | - **Quick Start Colab Notebook**: [SFT + LoRA with Gemma2-2b](https://colab.sandbox.google.com/github/AI-Hypercomputer/kithara/blob/main/examples/colab/SFT_with_LoRA_Gemma2-2b.ipynb) 50 | 51 | - **SFT + LoRA**: [Step by Step Example](https://kithara.readthedocs.io/en/latest/sft.html) 52 | 53 | - **Continued Pretraining**: [Step by Step Example](https://kithara.readthedocs.io/en/latest/pretraining.html) 54 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-autobuild 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Kithara ReadTheDocs Site 2 | 3 | To build: 4 | 5 | ``` 6 | # Install Sphinx deps. 7 | pip install -r docs/requirements.txt 8 | 9 | make -C docs html 10 | 11 | # Preview locally 12 | 13 | Go to http://127.0.0.1:8000 14 | ``` -------------------------------------------------------------------------------- /docs/images/kithara_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/docs/images/kithara_logo.png -------------------------------------------------------------------------------- /docs/images/kithara_logo_no_text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/docs/images/kithara_logo_no_text.png -------------------------------------------------------------------------------- /docs/images/kithara_logo_with_green_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/docs/images/kithara_logo_with_green_bg.png -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==8.2.0 2 | sphinx-autobuild 3 | sphinxawesome-theme>=5.3.2 4 | sphinx_design 5 | -------------------------------------------------------------------------------- /docs/source/.getting_tpus.rst.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/docs/source/.getting_tpus.rst.swp -------------------------------------------------------------------------------- /docs/source/.overview.rst.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/docs/source/.overview.rst.swp -------------------------------------------------------------------------------- /docs/source/api/kithara.dataset_api.rst: -------------------------------------------------------------------------------- 1 | .. _dataset_api: 2 | 3 | Datasets 4 | ======== 5 | 6 | TextCompletionDataset 7 | -------------------- 8 | 9 | .. py:class:: TextCompletionDataset 10 | 11 | A dataset class for standard text completion tasks. 12 | 13 | :param source: The source Ray dataset containing the text data 14 | :type source: ray.data.Dataset 15 | :param tokenizer: The tokenizer instance to use 16 | :type tokenizer: Optional[AutoTokenizer] 17 | :param tokenizer_handle: Handle/name of the HF tokenizer to load if not provided. E.g. "hf://google/gemma-2-2b" 18 | :type tokenizer_handle: Optional[str] 19 | :param column_mapping: Mapping of source column name to expected column name ("text") 20 | :type column_mapping: Optional[Dict[str, str]] 21 | :param model_type: Type of model implementation to use. Supported types: "KerasHub", "MaxText", "auto" 22 | :type model_type: ModelImplementationType | "auto" 23 | :param max_seq_len: Maximum sequence length for tokenization (default: 1024) 24 | :type max_seq_len: int 25 | :param custom_formatting_fn: A custom formatting function to apply to the raw sample before any other transformation steps 26 | :type custom_formatting_fn: Optional[callable] 27 | :param packing: Whether to enable sequence packing 28 | :type packing: bool 29 | 30 | .. py:method:: to_packed_dataset() 31 | 32 | Converts the current dataset to a PackedDataset for more efficient processing. 33 | 34 | :return: A new PackedDataset instance 35 | :rtype: PackedDataset 36 | 37 | SFTDataset 38 | ---------- 39 | 40 | .. py:class:: SFTDataset(TextCompletionDataset) 41 | 42 | A dataset class for Supervised Fine-Tuning (SFT) tasks. 43 | 44 | :param source: The source Ray dataset containing the training data 45 | :type source: ray.data.Dataset 46 | :param tokenizer: HuggingFace tokenizer instance 47 | :type tokenizer: Optional[AutoTokenizer] 48 | :param tokenizer_handle: Handle/name of the HF tokenizer to load if not provided. E.g. "hf://google/gemma-2-2b" 49 | :type tokenizer_handle: Optional[str] 50 | :param column_mapping: Mapping of source column names to expected column names ("prompt" and "answer") 51 | :type column_mapping: Optional[Dict[str, str]] 52 | :param model_type: Type of model implementation to use. Supported types: "KerasHub", "MaxText", "auto" 53 | :type model_type: ModelImplementationType | "auto" 54 | :param max_seq_len: Maximum sequence length for tokenization (default: 1024) 55 | :type max_seq_len: int 56 | :param custom_formatting_fn: A custom formatting function to apply to the raw sample before any other transformation steps 57 | :type custom_formatting_fn: Optional[callable] 58 | 59 | .. py:method:: to_packed_dataset() 60 | 61 | Converts the current dataset to a PackedDataset for more efficient processing. 62 | 63 | :return: A new PackedDataset instance 64 | :rtype: PackedDataset 65 | 66 | PackedDataset 67 | ------------ 68 | 69 | .. py:class:: PackedDataset 70 | 71 | A dataset class that packs multiple sequences together on the fly for more efficient processing. 72 | 73 | :param source_dataset: The source dataset containing unpacked sequences 74 | :type source_dataset: TextCompletionDataset 75 | :param pad_value: The value to use for padding (default: 0) 76 | :type pad_value: int 77 | 78 | .. note:: 79 | - Packing must be used with Flash Attention enabled (which should be enabled by default) 80 | - Packing currently only works for MaxText models 81 | - Packing does not currently work for DDP training 82 | 83 | 84 | Example Usage 85 | ------------ 86 | 87 | Here's a simple example of using the TextCompletionDataset:: 88 | 89 | dataset = TextCompletionDataset( 90 | source=ray_dataset, 91 | tokenizer_handle="hf://google/gemma-2-2b", 92 | max_seq_len=512, 93 | ) 94 | 95 | For supervised fine-tuning tasks, use the SFTDataset:: 96 | 97 | sft_dataset = SFTDataset( 98 | source=ray_dataset, 99 | tokenizer_handle="hf://google/gemma-2-2b", 100 | column_mapping={"input": "prompt", "output": "answer"}, 101 | max_seq_len=1024 102 | ) 103 | 104 | To enable sequence packing for more efficient processing:: 105 | 106 | packed_dataset = dataset.to_packed_dataset() -------------------------------------------------------------------------------- /docs/source/api/kithara.trainer_api.rst: -------------------------------------------------------------------------------- 1 | .. _trainer_api: 2 | 3 | Trainer 4 | ======= 5 | 6 | .. py:class:: Trainer 7 | 8 | A base Trainer class supporting continued pretraining and SFT, designed to be subclassed for implementing other training objectives, e.g. DPO. 9 | 10 | :param model: The model to be trained or evaluated 11 | :type model: kithara.Model 12 | :param optimizer: The optimizer used for training 13 | :type optimizer: keras.Optimizer | optax.GradientTransformation | optax.GradientTransformationExtraArgs 14 | :param train_dataloader: A dataloader that provides training batches 15 | :type train_dataloader: kithara.Dataloader 16 | :param eval_dataloader: A dataloader that provides evaluation batches 17 | :type eval_dataloader: Optional[kithara.Dataloader] 18 | :param steps: The total number of training steps to execute. Defaults to None and trains 1 epoch 19 | :type steps: Optional[int] 20 | :param epochs: The total number of training epochs to execute. Defaults to None 21 | :type epochs: Optional[int] 22 | :param log_steps_interval: The interval between logging steps. Each log includes the current loss value and performance metrics 23 | :type log_steps_interval: int 24 | :param eval_steps_interval: The interval between evaluation steps. Only one of eval_steps_interval or eval_epochs_interval can be set 25 | :type eval_steps_interval: Optional[int] 26 | :param eval_epochs_interval: The interval between evaluation epochs. Only one of eval_steps_interval or eval_epochs_interval can be set 27 | :type eval_epochs_interval: Optional[int] 28 | :param max_eval_samples: The maximum number of samples to use during evaluation. Uses the entire evaluation dataset if not provided 29 | :type max_eval_samples: int 30 | :param tensorboard_dir: The directory path for TensorBoard logs. Can be either a local directory or a Google Cloud Storage path 31 | :type tensorboard_dir: Optional[str] 32 | :param profiler: A profiler instance for monitoring performance metrics 33 | :type profiler: Optional[kithara.Profiler] 34 | :param checkpointer: A checkpointer instance for saving model checkpoints 35 | :type checkpointer: Optional[kithara.Checkpointer] 36 | 37 | .. py:method:: train() 38 | 39 | Execute the main training loop. Handles epoch iteration, batch processing, loss computation, model updates, progress logging, and periodic evaluation. 40 | 41 | 42 | Example usage: 43 | 44 | .. code-block:: python 45 | 46 | trainer = Trainer( 47 | model=my_model, 48 | optimizer=keras.optimizers.Adam(learning_rate=1e-4), 49 | train_dataloader=train_loader, 50 | eval_dataloader=eval_loader, 51 | steps=1000, 52 | log_steps_interval=10, 53 | eval_steps_interval=100, 54 | tensorboard_dir="local_dir_or_gs_bucket", 55 | checkpointer= kithara.Checkpointer( 56 | save_dir="local_dir_or_gs_bucket", 57 | save_interval=100, 58 | ), 59 | ) 60 | 61 | trainer.train() 62 | 63 | Note 64 | ---- 65 | - If both ``steps`` and ``epochs`` are None, defaults to training for 1 epoch 66 | - If ``eval_dataloader`` is provided but no evaluation interval is set, defaults to evaluating every epoch 67 | - The trainer automatically handles data sharding for distributed training -------------------------------------------------------------------------------- /docs/source/checkpointing.rst: -------------------------------------------------------------------------------- 1 | .. _checkpointing: 2 | 3 | Checkpointing 4 | ============= 5 | 6 | Checkpointing allows you to save and restore model states during training. 7 | Kithara supports saving checkpoints to local directories or Google Cloud Storage buckets. 8 | 9 | Kithara checkpointing is fast -- it uses `Orbax `_ to save checkpoints asynchronously and in a distributed manner. 10 | 11 | .. note:: 12 | For efficiency, checkpoints are saved in a non-huggingface format. Use ``model.save_in_hf_format`` to save models in HuggingFace format. 13 | 14 | Basic Usage 15 | ---------- 16 | 17 | During Training 18 | ^^^^^^^^^^^^^^^^^^^ 19 | Provide the ``checkpointer`` arg to the ``Trainer`` class to save checkpoints during training.:: 20 | 21 | from Kithara import Checkpointer, Trainer 22 | # Keeps the latest 5 checkpoints, saving one every 100 steps 23 | checkpointer = Checkpointer("gs://...", save_interval_steps=100, max_to_keep=5) 24 | trainer = Trainer(..., checkpointer=checkpointer) 25 | trainer.train() 26 | 27 | Restoring checkpoint after training::: 28 | 29 | # Initialize a random model 30 | model = MaxTextModel.from_random( 31 | model_name="gemma2-9b", 32 | seq_len=4096, 33 | per_device_batch_size=1, 34 | precision="mixed_bfloat16", 35 | ) 36 | 37 | # Restore from specific checkpoint 38 | checkpointer = Checkpointer( 39 | "gs://your_bucket/your_model_name/ckpt/20", # Step 20 40 | model=model 41 | ) 42 | checkpointer.load() 43 | 44 | model.generate(...) 45 | 46 | As Standalone Utility 47 | ^^^^^^^^^^^^^^^^^^^^^ 48 | 49 | You can also use the ``Checkpointer`` as a standalone utility to save and load checkpoints outside of the ``Trainer`` class. 50 | 51 | .. code-block:: python 52 | 53 | model = kithara.MaxTextModel.from_preset("hf://google/gemma2-2b") 54 | 55 | # Attach the checkpointer to the model 56 | checkpointer = Checkpointer("gs://...", model) 57 | 58 | # Save checkpoint - checkpoints need to be numbered sequentially 59 | checkpointer.save(0, blocking=True) 60 | 61 | # Load latest checkpoint back to the model 62 | checkpointer.load() 63 | 64 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | project = 'Kithara' 10 | copyright = '2024, The Kithara Authors' 11 | author = 'kithara_authors' 12 | 13 | # -- General configuration --------------------------------------------------- 14 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 15 | 16 | extensions = [ 17 | 'sphinx.ext.duration', 18 | 'sphinx.ext.doctest', 19 | 'sphinx.ext.autodoc', 20 | 'sphinx.ext.autosummary', 21 | 'sphinx.ext.intersphinx', 22 | 'sphinx.ext.napoleon', 23 | 'sphinx_design', 24 | ] 25 | 26 | templates_path = ['_templates'] 27 | exclude_patterns = [] 28 | 29 | 30 | 31 | # -- Options for HTML output ------------------------------------------------- 32 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 33 | 34 | html_theme = 'sphinxawesome_theme' 35 | html_static_path = ['_static'] 36 | html_css_files = ['custom.css'] 37 | html_permalinks_icon = '#' 38 | 39 | html_theme_options = { 40 | "main_nav_links": { 41 | "GitHub": "https://github.com/AI-Hypercomputer/kithara", 42 | }, 43 | "logo_light": "../images/kithara_logo_no_text.png", 44 | "logo_dark": "../images/kithara_logo_no_text.png", 45 | "show_prev_next": True, 46 | } 47 | -------------------------------------------------------------------------------- /docs/source/datasets.rst: -------------------------------------------------------------------------------- 1 | .. _datasets: 2 | 3 | Supported Data Formats 4 | ====================== 5 | 6 | The expected flow is always Ray Dataset -> Kithara Dataset -> Kithara Dataloader. 7 | 8 | For example,:: 9 | 10 | hf_dataset = load_dataset("allenai/c4", "en", split="train", streaming=True) 11 | ray_dataset = ray.data.from_huggingface(hf_dataset) 12 | 13 | kithara_dataset = Kithara.TextCompletionDataset( 14 | ray_dataset, 15 | tokenizer_handle="hf://google/gemma-2-2b", 16 | max_seq_len=8192 17 | ) 18 | dataloader = Kithara.Dataloader( 19 | kithara_dataset, 20 | per_device_batch_size=1 21 | ) 22 | 23 | .. note:: 24 | You can load files and other data from your local directory, S3, GCS, ABS, and any other data storage platforms supported by `Ray Data `_. 25 | 26 | Kithara supports all file formats supported by Ray Dataset. Listed below are some examples using popular dataset types. 27 | 28 | HuggingFace Dataset 29 | ~~~~~~~~~~~~~~~~~~ 30 | Load datasets directly from the Hugging Face Hub. Streaming datasets are supported:: 31 | 32 | hf_dataset = load_dataset("allenai/c4", "en", split="train", streaming=True) 33 | ray_dataset = ray.data.from_huggingface(hf_dataset) 34 | 35 | JSON and JSONL 36 | ~~~~~~~~~~~~~ 37 | Read JSON and JSONL files. 38 | 39 | .. code-block:: python 40 | 41 | ray_dataset = ray.data.read_json("s3://anonymous@ray-example-data/log.json") 42 | 43 | CSV 44 | ~~~ 45 | Read CSV files. 46 | 47 | .. code-block:: python 48 | 49 | ray_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv") 50 | 51 | Python Dictionary 52 | ~~~~~~~~~~~~~~~~~ 53 | 54 | .. code-block:: python 55 | 56 | dataset_items = [{ 57 | "prompt": "What is your name?", 58 | "answer": "My name is Kithara", 59 | } for _ in range(1000)] 60 | 61 | ray_dataset = ray.data.from_items(dataset_items) 62 | 63 | Pandas 64 | ~~~~~~ 65 | 66 | .. code-block:: python 67 | 68 | df = pd.DataFrame({ 69 | "food": ["spam", "ham", "eggs"], 70 | "price": [9.34, 5.37, 0.94] 71 | }) 72 | ray_dataset = ray.data.from_pandas(df) 73 | 74 | Lines of Text 75 | ~~~~~~~~~~~~~ 76 | Read text files. 77 | 78 | .. code-block:: python 79 | 80 | ray_dataset = ray.data.read_text("s3://anonymous@ray-example-data/this.txt") 81 | 82 | TFRecords 83 | ~~~~~~~~~ 84 | Read TFRecords Files 85 | 86 | .. code-block:: python 87 | 88 | 89 | ray_dataset = ray.data.read_tfrecords("s3://anonymous@ray-example-data/iris.tfrecords") 90 | 91 | Parquet 92 | ~~~~~~~ 93 | Read Parquet files. 94 | 95 | .. code-block:: python 96 | 97 | ray_dataset = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet") 98 | 99 | Additional Formats 100 | ~~~~~~~~~~~~~~~ 101 | Kithara supports all Ray Data formats. For more information: 102 | 103 | - Complete list of supported formats: `Ray Data Input/Output `_ 104 | - Ray Dataset transformation guide: `Ray Data API `_ 105 | 106 | Note: You should handle data cleaning, transformation, shuffling, and splitting using Ray Dataset utilities before passing the data to Kithara Dataset. 107 | -------------------------------------------------------------------------------- /docs/source/ddp.rst: -------------------------------------------------------------------------------- 1 | .. _ddp: 2 | 3 | Managing Large Datasets 4 | ======================= 5 | 6 | Distributed Data Parallel is a data loading technique useful for multi-host training on large datasets. 7 | 8 | When to use? 9 | ^^^^^^^^^^^^^^^ 10 | - When you are running multihost training, and your dataset is too large to fit in memory on a single host. 11 | - When large batch processing is the bottleneck in your training pipeline. 12 | 13 | 14 | How it works? 15 | ^^^^^^^^^^^^^^^ 16 | 17 | 1. Split your dataset into N shards, where N is the number of hosts. Each shard having the same number of samples. 18 | .. tip:: 19 | Use ``ray.data.split()`` on your Ray dataset, or ``ray.data.streaming_split()`` for streamed datasets 20 | 2. Each host loads and processes data only from its assigned shard 21 | 22 | For example, with a 32,000-sample dataset split across 4 hosts, each host manages 8,000 samples instead of the full dataset 23 | 24 | Example DDP Workflow 25 | ^^^^^^^^^^^^^^^^^^ 26 | Here's an example of how to implement DDP in Kithara:: 27 | 28 | @ray.remote(resources={"TPU": num_chips_per_host}) 29 | def main(train_ds): 30 | ... 31 | dataloader = Kithara.Dataloader( 32 | train_ds, 33 | per_device_batch_size=1, # or any other batch size 34 | dataset_is_sharded_per_host=True, # Enable DDP 35 | ) 36 | ... 37 | trainer.train() 38 | 39 | train_ds: List[Any] = split_dataset(ray_dataset, num_hosts=num_tpu_hosts) 40 | ray.get( 41 | [ 42 | main.remote(train_ds[i]) for i in range(num_tpu_hosts) 43 | ] 44 | ) -------------------------------------------------------------------------------- /docs/source/disk_storage.rst: -------------------------------------------------------------------------------- 1 | .. _disk_storage: 2 | 3 | Dealing With Disks 4 | ================== 5 | During the process of training and serving a LLM, you could easily run out of 6 | disk space on your VM or container. Potential large storage objects may 7 | include: 8 | 9 | * Base model (e.g. from Hugging Face) 10 | * Finetuning Dataset 11 | * Output checkpoints 12 | 13 | This page presents a few potential options for provisioning additional 14 | disk space for your deployments. 15 | 16 | 17 | General Recommendations 18 | ----------------------- 19 | We generally recommend the following best practices: 20 | 21 | * Download your model or dataset to a shared storage volume or bucket 22 | first, and mount this volume or bucket when you need it. This would 23 | save both disk space and time spent on downloading large files. 24 | * Prefer writing checkpoints from a single worker. This avoids having 25 | to manage multiple writes to the same shared drive. 26 | 27 | 28 | Persistent Disks 29 | ---------------- 30 | Persistent Disks (PDs) are durable storage volumes that can be 31 | dynamically provisioned at deployment time. PDs provide block storage, 32 | which is similar to having a hard drive attached to your VM. These 33 | are a versatile solution and are ideal for multi-read, low latency 34 | workloads. 35 | 36 | * Provisioning on TPU VMs: https://cloud.google.com/tpu/docs/setup-persistent-disk#:~:text=Attach%20a%20Persistent%20Disk%20when%20you%20create%20a%20TPU%20VM&text=To%20determine%20which%20VM%20image,you%20create%20a%20TPU%20VM. 37 | 38 | * Provisioning on GKE: https://cloud.google.com/kubernetes-engine/docs/concepts/persistent-volumes 39 | 40 | 41 | Google Cloud Storage Buckets 42 | ---------------------------- 43 | GCS buckets are an alternative solution to PDs. Unlike PDs, GCS buckets 44 | provide object storage. These are ideal for storing large, unstructured 45 | data such as datasets and model checkpoints. 46 | 47 | Kithara is designed to work seamlessly with GCS buckets. To use a GCS 48 | bucket, make sure that the service account which is used to run your 49 | workload has write access to the GCS bucket. 50 | 51 | .. code-block:: bash 52 | 53 | gcloud storage buckets add-iam-policy-binding gs://YOUR_BUCKET_NAME \ 54 | --member="serviceAccount:YOUR_SERVICE_ACCOUNT_EMAIL" \ 55 | --role="YOUR_IAM_ROLE" 56 | 57 | 58 | The IAM role should be either 'roles/storage.objectCreator' or 'roles/storage.objectAdmin'. 59 | 60 | 61 | GCS Fuse 62 | -------- 63 | An alternative way to use GCS Buckets is to mount them as local drives, 64 | with GCS Fuse. After mounting the bucket, you can access the files in the 65 | bucket as if they were in a locally mounted volume. GCSFuse buckets can 66 | be read-only or read-write. You could even mount the same bucket on multiple 67 | workers. 68 | 69 | * GCE VM: https://cloud.google.com/storage/docs/cloud-storage-fuse/overview 70 | 71 | * GKE: https://cloud.google.com/kubernetes-engine/docs/how-to/cloud-storage-fuse-csi-driver-setup#enable 72 | 73 | * Sample Kubernetes yaml: https://github.com/AI-Hypercomputer/kithara/blob/main/ray/TPU/GKE/single-host.yaml 74 | 75 | * Note - running on GKE requires enabling workload identity first: https://cloud.google.com/kubernetes-engine/docs/how-to/workload-identity 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /docs/source/finetuning_guide.rst: -------------------------------------------------------------------------------- 1 | .. _finetuning_guide: 2 | 3 | Finetuning Guide 4 | =================== 5 | 6 | Ready to fine-tune with your own data? Let's break it down into simple steps. 7 | 8 | 1. Pick Your Model 9 | ------------------ 10 | 11 | Kithara uses two model gardens - MaxText and Keras Hub. 12 | 13 | MaxText 14 | ~~~~~~~ 15 | 16 | 17 | MaxText is recommended for full-parameter fine-tuning. 18 | 19 | .. parsed-literal:: 20 | 21 | from kithara import MaxTextModel 22 | 23 | model = **MaxTextModel**.from_preset( 24 | "hf://google/gemma-2-2b", 25 | seq_len=8192, # Seq len and batch size need to be specified up front 26 | per_device_batch_size=1 27 | ) 28 | 29 | KerasHub 30 | ~~~~~~~~ 31 | 32 | KerasHub is recommended for parameter efficient fine-tuning (LoRA). 33 | 34 | .. parsed-literal:: 35 | 36 | from kithara import KerasHubModel 37 | 38 | model = **KerasHubModel**.from_preset( 39 | "hf://google/gemma-2-2b", 40 | lora_rank=16 # Applied to q_proj and v_proj 41 | ) 42 | 43 | .. tip:: 44 | - Always start model handles with ``hf://`` when loading from HuggingFace - so we know you are not loading from local directory 😀 45 | - The default precision ``mixed_bfloat16`` is your friend - it's memory-friendly! It loads model weights in full precision and casts activations to bfloat16. 46 | - Check out our :doc:`model garden ` for supported architectures 47 | - Want to save your model? Simply do ``model.save_in_hf_format(destination)`` to either save it locally, to GCS, or to HuggingFace. 48 | - Check out :doc:`Model API ` documentation 49 | 50 | 2. Prepare Your Data 51 | -------------------- 52 | 53 | Kithara supports HuggingFace Datasets as well as Ray Datasets. 54 | 55 | HuggingFace Dataset:: 56 | 57 | from datasets import load_dataset 58 | hf_dataset = load_dataset("allenai/c4","en",split="train",streaming=True), 59 | dataset = Kithara.TextCompletionDataset( 60 | hf_dataset, 61 | tokenizer_handle="hf://google/gemma-2-2b", 62 | max_seq_len=8192 63 | ) 64 | 65 | Ray Dataset:: 66 | 67 | import ray 68 | ray_dataset = ray.data.read_json("s3://anonymous@ray-example-data/log.json") 69 | dataset = Kithara.TextCompletionDataset( 70 | ray_dataset, 71 | tokenizer_handle="hf://google/gemma-2-2b", 72 | max_seq_len=8192 73 | ) 74 | 75 | Now create a Kithara Dataloader to batchify your dataset.:: 76 | 77 | # Create dataloader 78 | dataloader = Kithara.Dataloader( 79 | dataset, 80 | per_device_batch_size=1 81 | ) 82 | 83 | The expected flow is always Ray/HF Dataset -> Kithara Dataset -> Kithara Dataloader. 84 | 85 | Quick Tips: 86 | 87 | - Your global batch size is always `per_device_batch_size` * `number of devices (chips)`. 88 | - Check out :doc:`supported data formats ` (CSV, JSON, etc.) 89 | - Check out :doc:`Dataset API ` documentation. 90 | 91 | 92 | 3. Choose Algorithm 93 | ------------------- 94 | 95 | - **Continued pretraining**: Train your base model with large datasets to expand its knowledge 96 | - **Supervised finetuning**: Quickly and efficiently tune your model using labeled examples 97 | 98 | Check out these examples to get started. 99 | 100 | - :doc:`🌵 Continued Pretraining Example ` 101 | - :doc:`🌵 SFT+LoRA Example ` 102 | 103 | 4. Select Hardware 104 | ------------------ 105 | 106 | If you haven't done so, check out :doc:`Getting TPUs ` to get your TPUs ready. 107 | 108 | If your TPU topology has multiple hosts, and you are not familiar with distributed training, 109 | we recommend you follow the :doc:`Scaling up with Ray ` guide to set up a 110 | Ray Cluster so that you can run multihost jobs. 111 | 112 | -------------------------------------------------------------------------------- /docs/source/help_and_feedback.rst: -------------------------------------------------------------------------------- 1 | .. _help_and_feedback: 2 | 3 | Bugs, Features, Feedback 4 | ===== 5 | 6 | Send the Google team a direct email to report bugs, request features and share feedback: `buganizer-system+1746240@google.com `_. You can use the templates below to add the required background information. Only Google can view what you share. 7 | 8 | 9 | 10 | Email template for bugs and feature requests 11 | -------- 12 | 13 | .. parsed-literal:: 14 | 15 | Issue Type: (Choose one and delete the others) 16 | * Bug Report 17 | * Feature Request 18 | 19 | Description: 20 | 21 | * What happened? (Required) 22 | * What did you expect to happen? (Required) 23 | * Error messages: (Required: Paste any error messages here and include the full log if you can. If none, write "N/A".) 24 | 25 | --- 26 | Context: 27 | 28 | * Project ID: 29 | * Company: 30 | * Kithara Version: 31 | * Python Version: 32 | * JAX Version: 33 | * Keras Version: 34 | * KerasHub Version: 35 | * TPU Resource type (QR, GKE, GCE): 36 | * TPU Generation (Trillium, v5p, etc.): 37 | * Number of Chips: 38 | * TPU VM image: 39 | * Ray Version (if running on multihost): 40 | * GKE Version (if applicable): 41 | 42 | --- 43 | Additional Information (Optional): 44 | 45 | * Model: 46 | * Dataset: 47 | * Reproduction Script: Please attach or link a minimal, self-contained, reproducible example. 48 | * Severity - Choose one: Critical (Blocking work), Medium (Workaround exists), Low (Annoying): 49 | Issue Type: (Choose one and delete the others) 50 | * Bug Report 51 | * Feature Request 52 | 53 | Description: 54 | 55 | * What happened? (Required) 56 | * What did you expect to happen? (Required) 57 | * Error messages: (Required: Paste any error messages here and include the full log if you can. If none, write "N/A".) 58 | 59 | --- 60 | Context: 61 | 62 | * Project ID: 63 | * Company: 64 | * Kithara Version: 65 | * Python Version: 66 | * JAX Version: 67 | * Keras Version: 68 | * KerasHub Version: 69 | * TPU Resource type (QR, GKE, GCE): 70 | * TPU Generation (Trillium, v5p, etc.): 71 | * Number of Chips: 72 | * TPU VM image: 73 | * Ray Version (if running on multihost): 74 | * GKE Version (if applicable): 75 | 76 | --- 77 | Additional Information (Optional): 78 | 79 | * Model: 80 | * Dataset: 81 | * Reproduction Script: Please attach or link a minimal, self-contained, reproducible example. 82 | * Severity - Choose one: Critical (Blocking work), Medium (Workaround exists), Low (Annoying): 83 | 84 | Email template to share feedback 85 | ------ 86 | 87 | .. parsed-literal:: 88 | 89 | * What is your overall impression of Kithara? (What do you like? What could be improved?) 90 | * What specific features or aspects are most valuable to you? 91 | * Are there any features or workflows that you find confusing or difficult? 92 | * Do you have any suggestions for new features or improvements? What will these help you do? 93 | * Is there anything else you would like to share? 94 | 95 | General How-Tos and Support 96 | ------ 97 | 98 | You can also use use `Github Issues `_ to get support 99 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | .. _installation: 2 | 3 | Installation 4 | ============ 5 | 6 | .. note:: 7 | 8 | Kithara requires ``Python>=3.11``. 9 | 10 | We recommend using a virtual environment. 11 | 12 | 1. With `conda`: 13 | 14 | .. dropdown:: How to install Conda on GCP VM 15 | :open: 16 | 17 | .. code-block:: bash 18 | 19 | mkdir -p ~/miniconda3 20 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh 21 | bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 22 | rm ~/miniconda3/miniconda.sh 23 | source ~/miniconda3/bin/activate 24 | conda init --all 25 | 26 | .. code-block:: bash 27 | 28 | conda create -n kithara_env python=3.11 29 | conda activate kithara_env 30 | 31 | 32 | 2. With `venv`: 33 | 34 | .. code-block:: bash 35 | 36 | sudo apt-get install -y python3.11 37 | sudo apt install python3.11-venv 38 | 39 | python3.11 -m venv kithara_env 40 | source kithara_env/bin/activate 41 | 42 | Installation on TPU 43 | ~~~~~~~~~~~~~~~~~~~ 44 | 45 | .. code-block:: bash 46 | 47 | pip install kithara[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://download.pytorch.org/whl/cpu 48 | 49 | Installation on GPU 50 | ~~~~~~~~~~~~~~~~~~~ 51 | 52 | .. warning:: 53 | 54 | Our GPU support is still in beta. Please report any issues you encounter. 55 | 56 | .. code-block:: bash 57 | 58 | pip install kithara[gpu] 59 | 60 | 61 | Clone the repo 62 | ~~~~~~~~~~~~~~ 63 | 64 | .. code-block:: bash 65 | git clone --recursive https://github.com/AI-Hypercomputer/kithara.git 66 | -------------------------------------------------------------------------------- /docs/source/installation/.tpu_gke.rst.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/docs/source/installation/.tpu_gke.rst.swp -------------------------------------------------------------------------------- /docs/source/installation/tpu_gke.rst: -------------------------------------------------------------------------------- 1 | .. _tpu_vm: 2 | 3 | Setting up Ray Cluster with TPU GKE 4 | ===================================== 5 | Prerequisites 6 | ------------- 7 | * Clone the Kithara repo. 8 | 9 | .. code-block:: bash 10 | 11 | git clone https://github.com/AI-Hypercomputer/kithara.git 12 | * Create a GKE Cluster with Ray add-on: https://docs.ray.io/en/latest/cluster/kubernetes/user-guides/gcp-gke-tpu-cluster.html 13 | 14 | Preparing Your GKE Cluster 15 | -------------------------- 16 | * Enable GCSFuse. 17 | 18 | This step allows GCS buckets to be mounted on GKE containers as drives. This makes it easier for Kithara to save checkpoints to GCS. 19 | 20 | You can follow the instructions here: https://docs.ray.io/en/latest/cluster/kubernetes/user-guides/gke-gcs-bucket.html. 21 | * Authenticate to Your GKE Cluster: 22 | 23 | .. code-block:: bash 24 | 25 | gcloud container clusters get-credentials $CLUSTER --zone $ZONE --project $YOUR_PROJECT 26 | * Create a Hugging Face token on https://huggingface.co/docs/hub/en/security-tokens 27 | * Save the Hugging Face token to the Kubernetes cluster: 28 | 29 | .. code-block:: bash 30 | 31 | kubectl create secret generic hf-secret \ 32 | --from-literal=hf_api_token=HUGGING_FACE_TOKEN 33 | 34 | 35 | 36 | 37 | Setting Up a Ray Cluster 38 | ------------------------ 39 | 1. Edit one of the following manifest files: 40 | 41 | - Single-host: https://github.com/AI-Hypercomputer/kithara/blob/main/ray/TPU/GKE/single-host.yaml 42 | 43 | - Multi-host: https://github.com/AI-Hypercomputer/kithara/blob/main/ray/TPU/GKE/multi-host.yaml 44 | 45 | Make sure to replace ``YOUR_GCS_BUCKET`` with the name of the GCS bucket created in previous steps. 46 | 2. Deploy the Ray cluster: 47 | 48 | .. code-block:: bash 49 | 50 | kubectl apply -f $MANIFEST_FILE 51 | 3. Check that the cluster is running with: 52 | 53 | .. code-block:: bash 54 | 55 | kubectl get pods 56 | 57 | If everything works as expected, you should see pods running: 58 | 59 | .. code-block:: bash 60 | 61 | NAME READY STATUS RESTARTS AGE 62 | example-cluster-kuberay-head-kgxkp 2/2 Running 0 1m 63 | example-cluster-kuberay-worker-workergroup-bzrz2 2/2 Running 0 1m 64 | example-cluster-kuberay-worker-workergroup-g7k4t 2/2 Running 0 1m 65 | example-cluster-kuberay-worker-workergroup-h6zsx 2/2 Running 0 1m 66 | example-cluster-kuberay-worker-workergroup-pdf8x 2/2 Running 0 1m 67 | 68 | 69 | Running a Ray Workload 70 | ---------------------- 71 | 1. Set the following environment variable: 72 | 73 | .. code-block:: bash 74 | 75 | export RAY_ADDRESS=http://localhost:8265 76 | 2. Port-forward to the Ray cluster: 77 | 78 | .. code-block:: bash 79 | 80 | kubectl port-forward svc/example-cluster-kuberay-head-svc 8265:8265 & 81 | 3. Submit a Ray job, for example: 82 | 83 | .. code-block:: bash 84 | 85 | ray job submit --working-dir . \ 86 | --runtime-env-json='{"excludes": [".git", "kithara/model/maxtext/maxtext/MaxText/test_assets"]}' \ 87 | -- python examples/multihost/ray/TPU/full_finetuning_example.py 88 | 4. You can visit ``http://localhost:8265`` in your browser to see the Ray dashboard and monitor job status. 89 | 90 | 91 | Clean Up 92 | -------- 93 | 1. When your job is done, you can delete it by running: 94 | 95 | .. code-block:: bash 96 | 97 | kubectl delete -f $MANIFEST_FILE 98 | 99 | 2. The GKE cluster can be deleted with: 100 | 101 | .. code-block:: bash 102 | 103 | gcloud clusters delete $CLUSTER --zone $ZONE 104 | 105 | -------------------------------------------------------------------------------- /docs/source/installation/tpu_qr.rst: -------------------------------------------------------------------------------- 1 | .. _tpu_vm: 2 | 3 | 4 | Setting up Ray Cluster with TPU Queued Resources 5 | ================================================ 6 | 7 | Prerequisites 8 | ------------- 9 | * Clone the Kithara repo. 10 | 11 | .. code-block:: bash 12 | 13 | git clone https://github.com/AI-Hypercomputer/kithara.git 14 | 15 | * Modify ``ray/TPU/cluster.yaml`` with your GCP project, zone, and TPU resource types. 16 | 17 | .. tip:: 18 | Search for "MODIFY" in the YAML file to find required changes 19 | 20 | Setting up the Ray Cluster 21 | ------------------------- 22 | 23 | 1. Create Ray Head Node 24 | ---------------------- 25 | 26 | Create the Ray head node and launch the Ray cluster:: 27 | 28 | ray up -y ray/TPU/QR/cluster.yaml 29 | 30 | 2. Monitor Cluster Creation 31 | -------------------------- 32 | 33 | The ``ray up`` command will take a few minutes. If you lose the terminal, monitor the cluster creation process with:: 34 | 35 | ray monitor ray/TPU/QR/cluster.yaml 36 | 37 | 3. Launch Ray Dashboard 38 | ---------------------- 39 | 40 | Launch Ray Cluster dashboard once the ``ray_head_default`` node is ``Active``. You should see the dashboard on your ``localhost:8265``:: 41 | 42 | ray dashboard ray/TPU/QR/cluster.yaml 43 | 44 | 4. Note Head Node IP 45 | ------------------- 46 | 47 | In the Ray dashboard, note down the IP of your ray head node from the ``Cluster`` panel. 48 | 49 | .. note:: 50 | This should be the internal IP, not the external IP of the head node. It should start with 10.x.x.x 51 | 52 | 5. Set Resource Variables 53 | ------------------------ 54 | 55 | Specify the following variables with your own TPU resources spec:: 56 | 57 | export ZONE="us-central2-b" 58 | export QR_NAME="my_tpu_qr" 59 | export NODE_ID="v4-a" 60 | export PROJECT="gcp_project_name" 61 | export TPU_TYPE="v4-32" 62 | export RAY_CLUSTER_IP="your_cluster_ip" 63 | 64 | 6. Create TPU VMs 65 | ---------------- 66 | 67 | Create TPU VMs via QR:: 68 | 69 | gcloud alpha compute tpus queued-resources create $QR_NAME \ 70 | --node-id $NODE_ID \ 71 | --zone $ZONE \ 72 | --project $PROJECT \ 73 | --accelerator-type $TPU_TYPE \ 74 | --runtime-version tpu-ubuntu2204-base \ 75 | --metadata-from-file='startup-script=ray/TPU/QR/qr_worker_startup_script.sh' 76 | 77 | 7. Monitor QR Status 78 | ------------------- 79 | 80 | Monitor the status of the QR creation:: 81 | 82 | gcloud compute tpus queued-resources describe $QR_NAME --project $PROJECT --zone $ZONE 83 | 84 | Once the status becomes ``ACTIVE``, monitor the logs to verify package installation:: 85 | 86 | gcloud alpha compute tpus queued-resources ssh $QR_NAME \ 87 | --project $PROJECT \ 88 | --zone $ZONE \ 89 | --command="sudo cat /var/log/syslog | grep startup-script" \ 90 | --worker=0 \ 91 | --node=all 92 | 93 | 8. Attach TPU VMs to Ray Cluster 94 | ------------------------------- 95 | 96 | Once QRs are ready, attach the TPU VMs as worker nodes:: 97 | 98 | gcloud alpha compute tpus queued-resources ssh $QR_NAME \ 99 | --project $PROJECT \ 100 | --zone $ZONE \ 101 | --command="ray start --address=$RAY_CLUSTER_IP:6379 --resources='{\"tpu_host\": 1}'" \ 102 | --worker=all \ 103 | --node=all 104 | 105 | Troubleshooting 106 | ~~~~~~~~~~~~~~ 107 | 108 | If you encounter Python or Ray version inconsistencies, check the worker node logs:: 109 | 110 | gcloud alpha compute tpus queued-resources ssh $QR_NAME \ 111 | --project $PROJECT \ 112 | --zone $ZONE \ 113 | --command="sudo cat /var/log/syslog | grep startup-script" \ 114 | --worker=all \ 115 | --node=all 116 | 117 | 9. Run Examples 118 | -------------- 119 | 120 | Your Ray Cluster is now ready. Try examples in the ``examples/multihost/TPU`` folder:: 121 | 122 | python ray/submit_job.py "python3.11 examples/multihost/ray/TPU/sft_lora_example.py" --hf-token your_token 123 | 124 | To early-stop your job:: 125 | 126 | ray job stop ray_job_id 127 | 128 | 10. Remove QRs 129 | ------------- 130 | 131 | To remove QRs from your Ray Cluster:: 132 | 133 | gcloud compute tpus queued-resources delete $QR_NAME --project $PROJECT --zone $ZONE 134 | 135 | 11. Tear Down Cluster 136 | -------------------- 137 | 138 | When finished with your ray cluster, tear it down:: 139 | 140 | ray down ray/TPU/QR/cluster.yaml -------------------------------------------------------------------------------- /docs/source/installation/tpu_vm.rst: -------------------------------------------------------------------------------- 1 | .. _tpu_vm: 2 | 3 | Setting up Ray Cluster with TPU VMs 4 | ================================================ 5 | 6 | Prerequisites 7 | ------------- 8 | * Clone the Kithara repo. 9 | 10 | .. code-block:: bash 11 | 12 | git clone https://github.com/AI-Hypercomputer/kithara.git 13 | 14 | * Modify ``ray/TPU/cluster.yaml`` with your GCP project, zone, and TPU resource types. 15 | 16 | .. tip:: 17 | Search for "MODIFY" in the YAML file to find required changes 18 | 19 | Setting up the Ray Cluster 20 | ------------------------- 21 | 22 | 1. Launch the cluster:: 23 | 24 | ray up -y ray/TPU/cluster.yaml 25 | 26 | 3. Monitor setup process:: 27 | 28 | ray monitor ray/TPU/cluster.yaml 29 | 30 | 4. Launch Ray dashboard:: 31 | 32 | ray dashboard ray/TPU/cluster.yaml 33 | 34 | The dashboard will be available at ``localhost:8265`` 35 | 36 | Troubleshooting 37 | ~~~~~~~~~~~~~~ 38 | * ``update-failed`` errors typically don't affect proper node setup 39 | * Check node status by executing:: 40 | 41 | ray attach cluster.yaml 42 | ray status 43 | 44 | Running Multihost Jobs 45 | --------------------- 46 | 47 | 1. Submit job:: 48 | 49 | python ray/submit_job.py "python3.11 examples/multihost/ray/TPU/sft_lora_example.py" --hf-token your_token 50 | 51 | 2. To stop a job early:: 52 | 53 | export RAY_ADDRESS="http://127.0.0.1:8265" 54 | ray job stop ray_job_id 55 | 56 | Cleanup 57 | ------- 58 | 59 | When finished, tear down the cluster:: 60 | 61 | ray down cluster.yaml -------------------------------------------------------------------------------- /docs/source/lora.rst: -------------------------------------------------------------------------------- 1 | .. _lora: 2 | 3 | LoRA: Low-Rank Adaptation Guide 4 | =============================== 5 | 6 | What is LoRA? 7 | ------------ 8 | LoRA (Low-Rank Adaptation) is a technique that makes fine-tuning large language models much more efficient. Instead of updating all model parameters during fine-tuning, LoRA: 9 | 10 | - Keeps the original model weights frozen 11 | - Adds small trainable matrices to specific model layers 12 | - Drastically reduces the number of parameters that need to be updated 13 | 14 | For example, you can fine-tune a 9B parameter model by training only about 100M parameters (roughly 1% of the original size) or less. 15 | 16 | Simple Usage 17 | ------------ 18 | 19 | To train your model with LoRA, you don't have to worry about changing anything in your training script other specifying the ``lora_rank`` arg. 20 | LoRA will be applied to the q_proj and v_proj layers.:: 21 | 22 | from Kithara import KerasHubModel 23 | 24 | model = KerasHubModel.from_preset( 25 | "hf://google/gemma-2-2b", 26 | lora_rank=16 # <-- One line toggle 27 | ) 28 | 29 | Saving LoRA Models 30 | ---------- 31 | You have three options for saving models trained with LoRA: 32 | 33 | 1. Save Only LoRA Adapters 34 | ~~~~~~~~~~ 35 | Since the base model is left unchanged, you can save just the LoRA Adapters:: 36 | 37 | model.save_in_hf_format( 38 | destination, 39 | only_save_adapters=True 40 | ) 41 | 42 | 2. Save Base Model and Adapters Separately 43 | ~~~~~~~~~~ 44 | In case you want to save the base model as well. :: 45 | 46 | model.save_in_hf_format( 47 | destination, 48 | only_save_adapters=False, 49 | save_adapters_separately=True 50 | ) 51 | 52 | 3. Save Merged Model 53 | ~~~~~~~~~~ 54 | Creates a single model combining base weights and adaptations:: 55 | 56 | model.save_in_hf_format( 57 | destination, 58 | save_adapters_separately=False 59 | ) 60 | 61 | Load LoRA Models back into HuggingFace 62 | ---------- 63 | 64 | To load a model trained with LoRA back into HuggingFace, you can use the following code, where ``lora_dir`` and ``model_dir`` stores the weights saved by Kithara.:: 65 | 66 | # Load adapters separately 67 | hf_model = AutoModelForCausalLM.from_pretrained(model_id) 68 | hf_model.load_adapter(lora_dir) 69 | 70 | # Load merged model 71 | hf_model = AutoModelForCausalLM.from_pretrained(model_dir) 72 | 73 | 74 | Next Steps 75 | ---------- 76 | For a complete example of using LoRA with supervised fine-tuning, see the :doc:`SFT + LoRA guide `. -------------------------------------------------------------------------------- /docs/source/models.rst: -------------------------------------------------------------------------------- 1 | .. _models: 2 | 3 | Supported Models 4 | ================ 5 | 6 | We support all safetensor formatted models on `HuggingFace Hub `_ for the following architectures. 7 | 8 | .. list-table:: 9 | :header-rows: 1 10 | 11 | * - Model 12 | - Size 13 | - HuggingFace Handle Example 14 | * - Gemma 2 15 | - 2B, 9B, 27B 16 | - google/gemma-2-2b 17 | * - Llama 3.1 18 | - 8B, 70B, 405B 19 | - meta-llama/Llama-3.1-8B 20 | -------------------------------------------------------------------------------- /docs/source/observability.rst: -------------------------------------------------------------------------------- 1 | .. _observability: 2 | 3 | Observability 4 | ============= 5 | 6 | Kithara supports Tensorboard and (soon) Weights and Biases for observability. 7 | 8 | Tensorboard 9 | ----------- 10 | 11 | To use Tensorboard, simply specify the ``tensorboard_dir`` arg in the ``Trainer`` class to a local directory or a Google Cloud Storage bucket. 12 | 13 | To track training and evaluation performance, launch the tensorboard server with:: 14 | 15 | tensorboard --logdir=your_tensorboard_dir 16 | -------------------------------------------------------------------------------- /docs/source/optimizers.rst: -------------------------------------------------------------------------------- 1 | .. _optimizers: 2 | 3 | Optimizer Support Guide 4 | ======================= 5 | 6 | Choosing an Optimizer 7 | --------------------- 8 | Effective model fine-tuning relies heavily on the chosen optimization algorithm. Kithara offers flexibility by supporting optimizers from two widely-used libraries: 9 | 10 | * **Keras Optimizers**: Standard, easy-to-use optimizers integrated directly into the Keras framework. 11 | * **Optax Optimizers**: A powerful and composable optimization library designed for JAX (which Keras can leverage), offering fine-grained control over the gradient update process. 12 | 13 | This guide explains how to utilize both types within the Kithara ``Trainer``. 14 | 15 | Using Keras Optimizers 16 | ---------------------- 17 | You can directly use any optimizer available in the ``keras.optimizers`` module. Simply instantiate the desired optimizer with its configuration parameters and pass the instance to the ``Trainer``. 18 | 19 | Here's how you can use the ``AdamW`` optimizer from Keras:: 20 | 21 | import keras 22 | from Kithara import Trainer # Assuming Trainer is in Kithara namespace 23 | 24 | # Assume 'model' and 'train_dataloader' are already defined 25 | # Define hyperparameters 26 | LEARNING_RATE = 1e-4 27 | WEIGHT_DECAY = 0.01 28 | 29 | # Instantiate a Keras optimizer 30 | optimizer = keras.optimizers.AdamW( 31 | learning_rate=LEARNING_RATE, 32 | weight_decay=WEIGHT_DECAY, 33 | # Add other Keras optimizer arguments as needed 34 | ) 35 | 36 | # Initialize the Trainer 37 | trainer = Trainer( 38 | model=model, 39 | optimizer=optimizer, # Pass the Keras optimizer instance 40 | train_dataloader=train_dataloader, 41 | epochs=1, 42 | # ... other Trainer arguments (log_steps_interval, etc.) 43 | ) 44 | 45 | # Start training 46 | trainer.train() 47 | 48 | For a comprehensive list of available Keras optimizers and their specific arguments, consult the official `Keras Optimizers documentation `_. 49 | 50 | Using Optax Optimizers 51 | -------------------- 52 | Optax provides a highly flexible approach where optimizers are often built by chaining together multiple gradient transformations (like learning rate scheduling, gradient clipping, or weight decay) with a core optimizer function. This allows for complex optimization strategies. 53 | 54 | To use Optax, import the ``optax`` library, construct your optimizer (potentially using ``optax.chain``), and pass the resulting object to the ``Trainer``. 55 | 56 | This example demonstrates using ``optax.adamw`` combined with gradient clipping:: 57 | 58 | import optax 59 | from Kithara import Trainer # Assuming Trainer is in Kithara namespace 60 | 61 | # Assume 'model' and 'train_dataloader' are already defined 62 | # Define hyperparameters 63 | LEARNING_RATE = 1e-4 64 | WEIGHT_DECAY = 0.01 65 | CLIP_NORM = 1.0 66 | 67 | # Create an Optax optimizer chain 68 | # Common practice: apply transformations like clipping before the main optimizer step 69 | optimizer = optax.chain( 70 | optax.clip(CLIP_NORM), 71 | optax.adamw(learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY), 72 | # Add other Optax transformations as needed 73 | ) 74 | 75 | # Initialize the Trainer 76 | trainer = Trainer( 77 | model=model, 78 | optimizer=optimizer, # Pass the Optax optimizer instance 79 | train_dataloader=train_dataloader, 80 | epochs=1, 81 | # ... other Trainer arguments (log_steps_interval, etc.) 82 | ) 83 | 84 | # Start training 85 | trainer.train() 86 | 87 | Explore the extensive capabilities of Optax, including its various optimizers and gradient transformations, in the official `Optax API documentation `_. 88 | 89 | Key Takeaway 90 | ------------ 91 | Whether you choose a Keras or an Optax optimizer, the integration with Kithara is straightforward: 92 | 93 | 1. Instantiate your chosen optimizer object with the desired configuration. 94 | 2. Pass this optimizer object to the ``optimizer`` argument when initializing the ``Trainer``. 95 | 96 | The ``Trainer`` will then handle the application of the optimizer during the training loop. 97 | -------------------------------------------------------------------------------- /docs/source/overviewxyz.rst: -------------------------------------------------------------------------------- 1 | .. _overviewxyz: 2 | 3 | Overview 4 | ===== 5 | 6 | 7 | -------------------------------------------------------------------------------- /docs/source/packing.rst: -------------------------------------------------------------------------------- 1 | .. _packing: 2 | 3 | Dataset Packing 4 | =============== 5 | 6 | What is Dataset Packing? 7 | ----------------------- 8 | 9 | Dataset packing combines multiple short sequences into longer ones during training, reducing overall training time when your data consists of sequences shorter than the model's maximum sequence length. 10 | 11 | When to Use? 12 | ------------ 13 | Use Packing will not hurt performance, so you should use it whenever possible. However, Packing is only supported for MaxText models. 14 | 15 | How to Use? 16 | --------- 17 | 18 | Convert your dataset to a packed format using ``.to_packed_dataset()``: 19 | 20 | .. code-block:: python 21 | 22 | dataset = Kithara.TextCompletionDataset( 23 | ray_dataset, 24 | tokenizer="hf://google/gemma-2-2b", 25 | max_seq_len=4096 26 | ).to_packed_dataset() 27 | 28 | Important Notes 29 | ------------- 30 | 31 | * No other changes to your training script are needed 32 | * Loss calculation remains accurate as Kithara uses per-token loss 33 | * Adjust your learning rate and training steps since packing reduces the effective number of samples in a batch -------------------------------------------------------------------------------- /docs/source/quickstart.rst: -------------------------------------------------------------------------------- 1 | .. _quickstart: 2 | 3 | Quick Start 4 | =========== 5 | 6 | This guide demonstrates how to fine-tune a Gemma2 2B model using LoRA with a toy dataset. 7 | 8 | The script can also be found on `Github `_. 9 | 10 | Prefer running in Colab? Check out the `this Colab Notebook instead `_. 11 | 12 | Setup 13 | ----- 14 | Log into HuggingFace and import required packages:: 15 | 16 | from huggingface_hub import login 17 | login(token="your_hf_token", add_to_git_credential=False) 18 | 19 | import os 20 | os.environ["KERAS_BACKEND"] = "jax" 21 | import keras 22 | import ray 23 | from kithara import ( 24 | KerasHubModel, 25 | Dataloader, 26 | Trainer, 27 | SFTDataset, 28 | ) 29 | 30 | .. tip:: 31 | New to HuggingFace? First create an access token, `apply access `_ to the Gemma2 HuggingFace model which will be used in this example. 32 | 33 | Quick Usage 34 | ---------- 35 | 36 | 1. Create the Model:: 37 | 38 | model = KerasHubModel.from_preset( 39 | "hf://google/gemma-2-2b", 40 | precision="mixed_bfloat16", 41 | lora_rank=4, 42 | ) 43 | 44 | 2. Prepare Dataset:: 45 | 46 | dataset_items = [ 47 | { 48 | "prompt": "What is your name?", 49 | "answer": "My name is Kithara", 50 | } 51 | for _ in range(1000) 52 | ] 53 | dataset = ray.data.from_items(dataset_items) 54 | train_ds, eval_ds = dataset.train_test_split(test_size=500) 55 | 56 | 3. Create Dataset and Optimizer:: 57 | 58 | train_dataset = SFTDataset( 59 | train_ds, 60 | tokenizer_handle="hf://google/gemma-2-2b", 61 | max_seq_len=4096, 62 | ) 63 | 64 | eval_dataset = SFTDataset( 65 | eval_ds, 66 | tokenizer_handle="hf://google/gemma-2-2b", 67 | max_seq_len=4096, 68 | ) 69 | 70 | optimizer = keras.optimizers.AdamW( 71 | learning_rate=2e-4, 72 | weight_decay=0.01 73 | ) 74 | 75 | 4. Create Dataloaders:: 76 | 77 | train_dataloader = Dataloader( 78 | train_dataset, 79 | per_device_batch_size=1 80 | ) 81 | 82 | eval_dataloader = Dataloader( 83 | eval_dataset, 84 | per_device_batch_size=1 85 | ) 86 | 87 | 5. Initialize and Run Trainer:: 88 | 89 | trainer = Trainer( 90 | model=model, 91 | optimizer=optimizer, 92 | train_dataloader=train_dataloader, 93 | eval_dataloader=eval_dataloader, 94 | steps=200, # You can also use epochs instead of steps 95 | eval_steps_interval=10, 96 | max_eval_samples=50, 97 | log_steps_interval=10, 98 | ) 99 | 100 | trainer.train() 101 | 102 | 6. Test the Model:: 103 | 104 | pred = model.generate( 105 | "What is your name?", 106 | max_length=30, 107 | tokenizer_handle="hf://google/gemma-2-2b", 108 | return_decoded=True 109 | ) 110 | print("Tuned model generates:", pred) 111 | 112 | Running This Example on Single Host 113 | ------------------------------------------------ 114 | 115 | Simple copy paste this script from the Github repo, and run it on your TPU VM:: 116 | 117 | python examples/singlehost/quick_start.py 118 | 119 | 120 | Running This Example on Multi-host 121 | --------------------------------- 122 | 123 | Kithara works with any accelerator orchestrator. However, if you are new to distributed training, we provide guide for :doc:`multihost training with Ray `. 124 | 125 | Once you set up a Ray cluster, clone the Github Repo, and run this example with your Ray Cluster:: 126 | 127 | python ray/submit_job.py "python3.11 examples/multihost/ray/TPU/quick_start.py" --hf-token your_token 128 | 129 | 130 | Next Steps 131 | ----------- 132 | 133 | Check out the :doc:`Finetuning Guide ` to craft out your own finetuning job. 134 | -------------------------------------------------------------------------------- /docs/source/scaling_with_ray.rst: -------------------------------------------------------------------------------- 1 | .. _scaling_with_ray: 2 | 3 | Running on Multihost? 4 | ===================== 5 | 6 | Kithara works with any accelerator orchestrator. However, if you are new to distributed training, we provide guide for multihost training with `Ray `_. 7 | 8 | .. admonition:: What is Ray? 9 | 10 | Ray is a popular and powerful tool for running distributed TPU and GPU workloads, offering: 11 | 12 | - Dashboard for job queueing and log visualization 13 | - Streamlined environment setup and file syncing 14 | - Simple command interface for multihost workloads 15 | 16 | Follow the instructions below to set up your Ray cluster. 17 | 18 | * :doc:`Set up Ray Cluster with TPU VMs ` 19 | * :doc:`Set up Ray Cluster with TPU QRs ` 20 | * :doc:`Set up Ray Cluster with TPU GKE ` 21 | 22 | .. toctree:: 23 | :maxdepth: 2 24 | :caption: Ray Cluster Setup 25 | :hidden: 26 | 27 | Ray Cluster with TPU VMs 28 | Ray Cluster with TPU QRs 29 | Ray Cluster with TPU GKE 30 | 31 | 32 | After setting up your Ray Cluster, you can scale up your workload to run on multiple hosts using the following recipe. 33 | 34 | .. code-block:: python 35 | :caption: my_multihost_ray_job.py 36 | 37 | import ray 38 | import jax 39 | 40 | ray.init() 41 | 42 | num_chips_per_host = 4 # <--IMPORTANT: Use 4 for v4 and v5, 8 for v4e and v5e TPU 43 | num_tpu_hosts = int(ray.cluster_resources()["TPU"] / num_chips_per_host) 44 | print(f"{num_tpu_hosts=}") 45 | 46 | @ray.remote(resources={"TPU": num_chips_per_host}) 47 | def main(): 48 | 49 | # HuggingFace login 50 | from huggingface_hub import login 51 | import os 52 | 53 | hf_token = os.getenv("HF_TOKEN") 54 | if hf_token: 55 | login(token=hf_token, add_to_git_credential=False) 56 | 57 | # Let JAX know that we are running a distributed job 58 | jax.distributed.initialize() 59 | 60 | # No need to change your single host job script, simply use it as it is. 61 | from examples.singlehost.quick_start import run_workload 62 | 63 | # Run this workload on all hosts. Don't worry, we are handling 64 | # model sharding and batch sharding for you. 65 | run_workload() 66 | 67 | 68 | ray.get([main.remote() for i in range(num_tpu_hosts)]) 69 | 70 | ray.shutdown() 71 | 72 | To launch this job, run the following command:: 73 | 74 | ray job submit \ 75 | --address="http://localhost:8265" \ 76 | --runtime-env-json='{"env_vars": {"HF_TOKEN": "your_token_here", "HF_HUB_ENABLE_HF_TRANSFER": "1"}}' \ 77 | -- "python3.11 my_multihost_ray_job.py" 78 | 79 | Equivalently, use the Kithara `helper script `_:: 80 | 81 | python ray/submit_job.py "python3.11 my_multihost_ray_job.py" --hf-token your_token 82 | 83 | Check out some multihost examples: 84 | 85 | - `Multihost Quickstart Example `_ 86 | - `Multihost Continued Pretraining Example `_ 87 | - `Multihost SFT+LoRA Example `_ 88 | -------------------------------------------------------------------------------- /docs/source/serve_with_vllm.rst: -------------------------------------------------------------------------------- 1 | .. serve_with_vllm: 2 | 3 | Serve with vLLM 4 | ============== 5 | 6 | The fine-tuned model can be served on `vLLM:GPU `_ or `vLLM:TPU `_. 7 | 8 | Exporting your model 9 | -------------------- 10 | In order to serve your finetuned model, first make sure that your output checkpoints are exported to a GCS bucket or persistent volume. You can do this byoverriding the `model_output_dir` parameter in the config yaml: 11 | 12 | .. code-block:: python 13 | 14 | model_output_dir: "gs://bucket_name/ckpt/" 15 | 16 | 17 | Starting vLLM with a Finetuned Model 18 | ------------------------------------ 19 | Next, start your vLLM deployment by attaching the previous output volume as a mounted drive. Please visit our GCS guide (link TBD) for details. 20 | 21 | You can start vLLM with a path to the directory containing the model files. For example if you mounted the volume at `/model` and the model checkpoints are located under the `checkpoint` directory: 22 | 23 | .. code-block:: bash 24 | 25 | vllm serve /model/checkpoint --tensor_parallel_size 8 --max-model-length 4096 26 | 27 | where `tensor_parallel_size` is equal to the number of TPU chips available. You can adjust `max-model-length` depending on the availably HBM. 28 | 29 | If loading the model is successful, you should see the following output in the console: 30 | 31 | .. code-block:: bash 32 | 33 | INFO 03-05 23:24:13 api_server.py:756] Using supplied chat template: 34 | INFO 03-05 23:24:13 api_server.py:756] None 35 | INFO 03-05 23:24:13 launcher.py:21] Available routes are: 36 | INFO 03-05 23:24:13 launcher.py:29] Route: /openapi.json, Methods: GET, HEAD 37 | INFO 03-05 23:24:13 launcher.py:29] Route: /docs, Methods: GET, HEAD 38 | INFO 03-05 23:24:13 launcher.py:29] Route: /docs/oauth2-redirect, Methods: GET, HEAD 39 | INFO 03-05 23:24:13 launcher.py:29] Route: /redoc, Methods: GET, HEAD 40 | INFO 03-05 23:24:13 launcher.py:29] Route: /health, Methods: GET 41 | INFO 03-05 23:24:13 launcher.py:29] Route: /ping, Methods: POST, GET 42 | INFO 03-05 23:24:13 launcher.py:29] Route: /tokenize, Methods: POST 43 | INFO 03-05 23:24:13 launcher.py:29] Route: /detokenize, Methods: POST 44 | INFO 03-05 23:24:13 launcher.py:29] Route: /v1/models, Methods: GET 45 | INFO 03-05 23:24:13 launcher.py:29] Route: /version, Methods: GET 46 | INFO 03-05 23:24:13 launcher.py:29] Route: /v1/chat/completions, Methods: POST 47 | INFO 03-05 23:24:13 launcher.py:29] Route: /v1/completions, Methods: POST 48 | INFO 03-05 23:24:13 launcher.py:29] Route: /v1/embeddings, Methods: POST 49 | INFO 03-05 23:24:13 launcher.py:29] Route: /pooling, Methods: POST 50 | INFO 03-05 23:24:13 launcher.py:29] Route: /score, Methods: POST 51 | INFO 03-05 23:24:13 launcher.py:29] Route: /v1/score, Methods: POST 52 | INFO 03-05 23:24:13 launcher.py:29] Route: /rerank, Methods: POST 53 | INFO 03-05 23:24:13 launcher.py:29] Route: /v1/rerank, Methods: POST 54 | INFO 03-05 23:24:13 launcher.py:29] Route: /v2/rerank, Methods: POST 55 | INFO 03-05 23:24:13 launcher.py:29] Route: /invocations, Methods: POST 56 | INFO: Started server process [67252] 57 | INFO: Waiting for application startup. 58 | INFO: Application startup complete. 59 | INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) 60 | 61 | 62 | Send a request to the vLLM server to verify that it's working: 63 | 64 | .. code-block:: bash 65 | 66 | curl http://localhost:8000/v1/completions \ 67 | -H "Content-Type: application/json" \ 68 | -d '{ 69 | "model": "/model/checkpoint/", 70 | "prompt": "your prompt", 71 | "max_tokens": 256, 72 | "temperature": 0.7 73 | }' 74 | 75 | -------------------------------------------------------------------------------- /docs/source/temp.rst: -------------------------------------------------------------------------------- 1 | #.. _overviewxyz: 2 | 3 | 👋 Welcome to Kithara! 4 | =================================== 5 | .. note:: 6 | 7 | This project is under active development. 8 | 9 | Introduction 10 | ------------ 11 | 12 | What is Kithara? 13 | ~~~~~~~~~~~~~~~~ 14 | 15 | **Kithara** is a lightweight library offering building blocks and recipes for tuning popular open source LLMs like Llama 3 and Gemma 2 on Google TPUs. 16 | 17 | Kithara makes post-training easier to do on TPUs. Post-training algorithms include continued pretraining, fine-tuning, PEFT, RLHF, and quantization. These algorithms adapt a pre-trained language model to a specific task or domain by training it on a custom dataset. Using TPUs provides significant advantages in terms of performance, cost-effectiveness, and scalability, enabling faster training times and the ability to work with larger models and datasets. 18 | 19 | -------------------------------------------------------------------------------- /docs/source/troubleshooting.rst: -------------------------------------------------------------------------------- 1 | .. _troubleshooting: 2 | 3 | Troubleshooting 4 | =============== 5 | 6 | 1. Disk OOM when loading HF model checkpoint 7 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 8 | 9 | First try emptying your cache by running the following code on your Ray Cluster. 10 | 11 | .. code-block:: python 12 | 13 | import shutil 14 | shutil.rmtree("/home/ubuntu/.cache/huggingface/hub/", ignore_errors=True) 15 | shutil.rmtree("/home/ubuntu/.keras/models", ignore_errors=True) 16 | 17 | If you are using a single VM, the path may be different. 18 | 19 | .. code-block:: python 20 | 21 | import shutil 22 | shutil.rmtree("~/.cache/huggingface/hub/", ignore_errors=True) 23 | shutil.rmtree("~/.keras/models", ignore_errors=True) 24 | 25 | If emptying the cache still doesn't help, try attaching a disk to your VM and change HF cache directory using the environment variable ``export HF_HOME=``. 26 | 27 | You may have to copy your HF token to this new cache directory with ``cp .cache/huggingface/token /token``. 28 | 29 | 2. Permission denied error when uploading checkpoint to GCS 30 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 31 | 32 | First make sure that your bucket is in the same project as your TPU VM. 33 | 34 | Otherwise, verify your current authentication: 35 | 36 | .. code-block:: bash 37 | 38 | gcloud auth list 39 | gsutil ls gs://your_bucket 40 | 41 | For your Python code, you likely need to ensure you're using the same credentials. 42 | 43 | .. code-block:: bash 44 | 45 | gcloud auth application-default login 46 | 47 | 3. jaxlib.xla_extension.XlaRuntimeError errors 48 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 49 | 50 | Try uninstall and reinstalling ``jax`` and ``jaxlib`` 51 | 52 | .. code-block:: bash 53 | 54 | pip uninstall jax jaxlib 55 | pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 56 | pip install libtpu-nightly==0.1.dev20250128+nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 57 | 58 | 4. Unable to initialize backend 'tpu': INTERNAL: Failed to get global TPU topology. 59 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 60 | 61 | Try adding ``JAX_PLATFORMS=''`` to your environment variables. -------------------------------------------------------------------------------- /examples/example_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from typing import Tuple 18 | from datasets import load_dataset 19 | import ray 20 | from ray.data import Dataset 21 | 22 | 23 | def example_datasets(option: str) -> Tuple[ray.data.Dataset, ray.data.Dataset]: 24 | """ 25 | Load examples datasets based on the specified option. 26 | 27 | Returns: 28 | Tuple[ray.data.Dataset, ray.data.Dataset]: A tuple of (training_dataset, evaluation_dataset) 29 | 30 | Example usage: 31 | train_ds, eval_ds = example_datasets("hf") 32 | """ 33 | valid_options = {"hf", "finetune_toy", "sft_toy", "files"} 34 | if option not in valid_options: 35 | raise ValueError( 36 | f"Invalid option: {option}. Must be one of {', '.join(valid_options)}" 37 | ) 38 | 39 | if option == "hf": 40 | return _load_huggingface_dataset() 41 | elif option == "finetune_toy": 42 | return _create_finetune_toy_dataset() 43 | elif option == "sft_toy": 44 | return _create_sft_toy_dataset() 45 | elif option == "files": 46 | raise NotImplementedError( 47 | "Loading from files is not yet implemented. Please use another option." 48 | ) 49 | 50 | 51 | def _load_huggingface_dataset() -> Tuple[Dataset, Dataset]: 52 | """Load the C4 dataset from HuggingFace.""" 53 | hf_train_dataset = load_dataset("allenai/c4", "en", split="train", streaming=True) 54 | hf_val_dataset = load_dataset( 55 | "allenai/c4", "en", split="validation", streaming=True 56 | ) 57 | 58 | return ( 59 | ray.data.from_huggingface(hf_train_dataset), 60 | ray.data.from_huggingface(hf_val_dataset), 61 | ) 62 | 63 | 64 | def _create_finetune_toy_dataset() -> Tuple[Dataset, Dataset]: 65 | """Create a toy dataset for finetuning with 1000 examples.""" 66 | dataset_items = [ 67 | {"text": f"{i} What is your name? My name is Mary."} for i in range(1000) 68 | ] 69 | dataset = ray.data.from_items(dataset_items) 70 | return dataset.train_test_split(test_size=500) 71 | 72 | 73 | def _create_sft_toy_dataset() -> Tuple[Dataset, Dataset]: 74 | """Create a toy dataset for supervised fine-tuning with 1000 examples.""" 75 | dataset_items = [{ 76 | "prompt": "What is your name?", 77 | "answer": "My name is Mary", 78 | } for _ in range(1000)] 79 | dataset = ray.data.from_items(dataset_items) 80 | return dataset.train_test_split(test_size=500) 81 | -------------------------------------------------------------------------------- /examples/multihost/ray/GPU/sft_lora_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import ray 18 | from typing import List, Any 19 | from kithara.distributed.data import split_dataset 20 | 21 | ray.init() 22 | 23 | num_chips_per_host = 4 24 | num_gpu_devices = int(ray.cluster_resources()["GPU"]) 25 | print(f"{num_gpu_devices=}") 26 | 27 | 28 | @ray.remote(num_gpus=num_chips_per_host) 29 | def main(train_ds, eval_ds, split_data_across_host): 30 | from huggingface_hub import login 31 | import os 32 | 33 | hf_token = os.getenv("HF_TOKEN") 34 | if hf_token: 35 | login(token=hf_token, add_to_git_credential=False) 36 | 37 | from examples.singlehost.full_finetuning_example import run_workload 38 | 39 | run_workload( 40 | train_source=train_ds, 41 | eval_source=eval_ds, 42 | dataset_is_sharded_per_host=split_data_across_host, 43 | ) 44 | 45 | 46 | # Create mulit-host datasets 47 | dataset_items = [ 48 | { 49 | "prompt": "What is your name?", 50 | "answer": "My name is Mary", 51 | } 52 | for _ in range(1000) 53 | ] 54 | dataset = ray.data.from_items(dataset_items) 55 | train_ds, eval_ds = dataset.train_test_split(test_size=500) 56 | 57 | split_data_across_host = False 58 | if split_data_across_host: 59 | train_ds: List[Any] = split_dataset(train_ds, num_hosts=num_gpu_devices) 60 | eval_ds: List[Any] = split_dataset(eval_ds, num_hosts=num_gpu_devices) 61 | ray.get( 62 | [ 63 | main.remote(train_ds[i], eval_ds[i], split_data_across_host) 64 | for i in range(num_gpu_devices) 65 | ] 66 | ) 67 | else: 68 | ray.get( 69 | [ 70 | main.remote(train_ds, eval_ds, split_data_across_host) 71 | for i in range(num_gpu_devices) 72 | ] 73 | ) 74 | 75 | ray.shutdown() 76 | -------------------------------------------------------------------------------- /examples/multihost/ray/TPU/continued_pretraining_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | """Full parameter continued pretraining a Gemma2 9B model. 18 | 19 | This script demonstrates how to 20 | 21 | 1. Load HuggingFace Gemma2 checkpoint 22 | 2. Loading large HuggingFace dataset 23 | 3. Configure data loading and preprocessing 24 | 4. Run training across TPU/GPU devices 25 | 5. Save checkpoint to GCS periodically 26 | 6. Generate text using the trained model 27 | 7. Save model in HuggingFace format to GCS 28 | 29 | This script can be run on both single-host and multi-host. 30 | For mulit-host set up, please follow https://kithara.readthedocs.io/en/latest/scaling_with_ray.html. 31 | 32 | Singlehost: python examples/singlehost/continued_pretraining_example.py 33 | Multihost: python ray/submit_job.py "python3.11 examples/multihost/ray/TPU/continued_pretraining_example.py" --hf-token your_token 34 | """ 35 | 36 | import ray 37 | import jax 38 | 39 | ray.init() 40 | 41 | # Verify TPU resources 42 | num_chips_per_host = 4 # <-- IMPORTANT: Use 4 for v4 and v5, 8 for v4e and v5e 43 | num_tpu_hosts = int(ray.cluster_resources()["TPU"] / num_chips_per_host) 44 | print(f"{num_tpu_hosts=}") 45 | 46 | 47 | @ray.remote(resources={"TPU": num_chips_per_host}) 48 | def main(): 49 | 50 | import subprocess 51 | 52 | subprocess.run(["rm", "-rf", "/tmp/libtpu_lockfile", "/tmp/tpu_logs"]) 53 | subprocess.run(["rm", "-rf", "/home/ubuntu/.cache/huggingface/hub"]) 54 | 55 | # HuggingFace login 56 | from huggingface_hub import login 57 | import os 58 | 59 | hf_token = os.getenv("HF_TOKEN") 60 | if hf_token: 61 | login(token=hf_token, add_to_git_credential=False) 62 | 63 | jax.distributed.initialize() 64 | 65 | from singlehost.continued_pretraining_example import run_workload 66 | 67 | # Save model outputs to GCS 68 | run_workload( 69 | gemma2_model_size="9b", model_output_dir="gs://wenxindong-vm/model_output" 70 | ) 71 | 72 | 73 | ray.get([main.remote() for i in range(num_tpu_hosts)]) 74 | 75 | ray.shutdown() 76 | -------------------------------------------------------------------------------- /examples/multihost/ray/TPU/full_finetuning_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | """Full parameter finetune a Gemma2 9B model. 18 | 19 | This script demonstrates how to: 20 | 1. Set up a Gemma2 model for full parameter finetuning 21 | 2. Load HuggingFace Gemma2 checkpoint 22 | 3. Configure data loading and preprocessing 23 | 4. Run training across TPU/GPU devices 24 | 5. Save checkpoint to GCS periodically 25 | 6. Generate text using the trained model 26 | 7. Save model in HuggingFace format to GCS 27 | 28 | This script should be run on multihost, since gemma2-9b will not fit on a single host. However, 29 | you can change the model to `gemma2-2b` to run on single host. 30 | 31 | Singlehost: python examples/singlehost/full_finetuning_example.py 32 | Multihost: python ray/submit_job.py "python3.11 examples/multihost/ray/TPU/full_finetuning_example.py" --hf-token your_token 33 | 34 | """ 35 | 36 | import ray 37 | from absl import app 38 | from typing import List, Any 39 | from kithara.config.pyconfig import load_config 40 | from kithara.distributed.data import split_dataset 41 | import jax 42 | 43 | ray.init() 44 | 45 | # Verify TPU resources 46 | num_chips_per_host = 4 # 4 for v4 and v5, 8 for v4e and v5e 47 | num_tpu_devices = int(ray.cluster_resources()["TPU"]) 48 | num_tpu_hosts = num_tpu_devices // num_chips_per_host 49 | print(f"{num_tpu_devices=}") 50 | print(f"{num_tpu_hosts=}") 51 | 52 | 53 | @ray.remote(resources={"TPU": num_chips_per_host}) 54 | def run(config, train_ds, eval_ds, dataset_is_sharded_per_host): 55 | 56 | import subprocess 57 | 58 | subprocess.run(["rm", "-rf", "/tmp/libtpu_lockfile", "/tmp/tpu_logs"]) 59 | 60 | # HuggingFace login 61 | from huggingface_hub import login 62 | import os 63 | 64 | hf_token = os.getenv("HF_TOKEN") 65 | if hf_token: 66 | login(token=hf_token, add_to_git_credential=False) 67 | 68 | jax.distributed.initialize() 69 | 70 | # Run workload in SPMD mode 71 | from examples.singlehost.full_finetuning_example import run_workload 72 | 73 | run_workload( 74 | config=config, 75 | train_source=train_ds, 76 | eval_source=eval_ds, 77 | dataset_is_sharded_per_host=dataset_is_sharded_per_host, 78 | ) 79 | 80 | 81 | def main(argv): 82 | config = load_config(argv) 83 | 84 | # Create multi-host datasets 85 | dataset_items = [ 86 | {"text": f"{i} What is your name? My name is Mary."} for i in range(1000) 87 | ] 88 | dataset = ray.data.from_items(dataset_items) 89 | train_ds, eval_ds = dataset.train_test_split(test_size=500) 90 | 91 | split_data_across_host = False 92 | if split_data_across_host: 93 | train_ds: List[Any] = split_dataset(train_ds, num_hosts=num_tpu_hosts) 94 | eval_ds: List[Any] = split_dataset(eval_ds, num_hosts=num_tpu_hosts) 95 | ray.get( 96 | [ 97 | run.remote(config, train_ds[i], eval_ds[i], split_data_across_host) 98 | for i in range(num_tpu_hosts) 99 | ] 100 | ) 101 | else: 102 | ray.get( 103 | [ 104 | run.remote(config, train_ds, eval_ds, split_data_across_host) 105 | for _ in range(num_tpu_hosts) 106 | ] 107 | ) 108 | 109 | ray.shutdown() 110 | 111 | 112 | if __name__ == "__main__": 113 | app.run(main) 114 | -------------------------------------------------------------------------------- /examples/multihost/ray/TPU/quick_start.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | """Quick Start Example 18 | 19 | This script demonstrates how to run LoRA SFT on a toy dataset 20 | 21 | 1. Load HuggingFace Gemma2 checkpoint 22 | 2. Configure data loading and preprocessing 23 | 3. Run training across TPU/GPU devices 24 | 25 | This script can be run on both single-host and multi-host. 26 | For mulit-host set up, please follow https://kithara.readthedocs.io/en/latest/scaling_with_ray.html. 27 | 28 | 29 | Singlehost: python examples/singlehost/quick_start.py 30 | Multihost: python ray/submit_job.py "python3.11 examples/multihost/ray/TPU/quick_start.py" --hf-token your_token 31 | """ 32 | 33 | import ray 34 | import jax 35 | 36 | ray.init() 37 | 38 | num_chips_per_host = 4 # <--IMPORTANT: Use 4 for v4 and v5, 8 for v4e and v5e 39 | num_tpu_hosts = int(ray.cluster_resources()["TPU"] / num_chips_per_host) 40 | print(f"{num_tpu_hosts=}") 41 | 42 | @ray.remote(resources={"TPU": num_chips_per_host}) 43 | def main(): 44 | 45 | import subprocess 46 | 47 | # This is not strictly necessary, but helps to remove TPU deadlocks if you ever run into them. 48 | subprocess.run(["rm", "-rf", "/tmp/libtpu_lockfile", "/tmp/tpu_logs"]) 49 | 50 | # HuggingFace login 51 | from huggingface_hub import login 52 | import os 53 | 54 | hf_token = os.getenv("HF_TOKEN") 55 | if hf_token: 56 | login(token=hf_token, add_to_git_credential=False) 57 | 58 | # Let JAX know that we are running a distributed job 59 | jax.distributed.initialize() 60 | 61 | # No need to change your single host job script, simply use it as it is. 62 | from examples.singlehost.quick_start import run_workload 63 | 64 | # Run this workload on all hosts. Don't worry, we are handling 65 | # all the model sharding and batch sharding for you. 66 | run_workload() 67 | 68 | ray.get([main.remote() for i in range(num_tpu_hosts)]) 69 | 70 | ray.shutdown() 71 | -------------------------------------------------------------------------------- /examples/multihost/ray/TPU/sft_lora_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | """SFT a Gemma2 2B model using LoRA on TPU or GPU on an Alpaca Dataset 18 | 19 | This script demonstrates how to: 20 | 1. Set up a Gemma2 model for LoRA SFT 21 | 2. Load HuggingFace Gemma2 checkpoint 22 | 3. Load HuggingFace Dataset 23 | 4. Configure data loading and preprocessing 24 | 5. Run training across TPU/GPU devices 25 | 6. Save the LoRA adapters 26 | 27 | This script can be run on both single-host and multi-host. 28 | For mulit-host set up, please follow https://kithara.readthedocs.io/en/latest/scaling_with_ray.html. 29 | 30 | Singlehost: python examples/singlehost/sft_lora_example.py 31 | Multihost: python ray/submit_job.py "python3.11 examples/multihost/ray/TPU/sft_lora_example.py" --hf-token your_token 32 | """ 33 | 34 | import ray 35 | import jax 36 | 37 | ray.init() 38 | 39 | # Verify TPU resources 40 | num_chips_per_host = 4 # <-- IMPORTANT: Use 4 for v4 and v5, 8 for v4e and v5e 41 | num_tpu_hosts = int(ray.cluster_resources()["TPU"] / num_chips_per_host) 42 | print(f"{num_tpu_hosts=}") 43 | 44 | 45 | @ray.remote(resources={"TPU": num_chips_per_host}) 46 | def main(): 47 | 48 | import subprocess 49 | 50 | subprocess.run(["rm", "-rf", "/tmp/libtpu_lockfile", "/tmp/tpu_logs"]) 51 | 52 | 53 | # HuggingFace login 54 | from huggingface_hub import login 55 | import os 56 | 57 | hf_token = os.getenv("HF_TOKEN") 58 | if hf_token: 59 | login(token=hf_token, add_to_git_credential=False) 60 | 61 | jax.distributed.initialize() 62 | 63 | # Run workload in SPMD mode 64 | from singlehost.sft_lora_example import run_workload 65 | 66 | # Save your model in cloud storage. Use None to skip model saving. 67 | run_workload(model_output_dir=None) 68 | 69 | 70 | ray.get([main.remote() for i in range(num_tpu_hosts)]) 71 | ray.shutdown() 72 | -------------------------------------------------------------------------------- /examples/multihost/ray/TPU/sft_lora_example.yaml: -------------------------------------------------------------------------------- 1 | # Launch this example using the following command: 2 | 3 | # Important: Replace v5e with your tpu generation 4 | # python ray/submit_job.py "python3.11 kithara/config/launcher.py --override_config=examples/multihost/ray/TPU/sft_lora_example.yaml --override tpu_generation=v5e" --hf-token your_token 5 | 6 | # ------------------------- 7 | # Orchestration Configuration 8 | # ------------------------- 9 | single_host: False 10 | 11 | # ------------------------- 12 | # Model Configuration 13 | # ------------------------- 14 | seq_len: 8192 15 | model_handle: "hf://google/gemma-2-9b" 16 | 17 | # ------------------------- 18 | # Training Settings 19 | # ------------------------- 20 | training_steps: 100 21 | eval_steps_interval: 20 22 | log_steps_interval: 10 23 | 24 | # ------------------------- 25 | # Checkpointing 26 | # ------------------------- 27 | save_checkpoint_interval: 20 28 | max_checkpoints_to_keep: 3 29 | checkpoint_dir: "gs://wenxindong-tpu-prod-env-multipod-bucket/checkpoints/" 30 | 31 | # ------------------------- 32 | # Logging 33 | # ------------------------- 34 | tensorboard_dir: "gs://wenxindong-tpu-prod-env-multipod-bucket/tensorboard_dir/" 35 | 36 | # ------------------------- 37 | # Train Dataset 38 | # ------------------------- 39 | train_dataset_column_mapping: {"prompt":"input", "answer":"output"} 40 | train_eval_split: 50 # Use 50 test samples 41 | 42 | # ------------------------- 43 | # Output Configuration 44 | # ------------------------- 45 | model_output_dir: "gs://wenxindong-tpu-prod-env-multipod-bucket/model_output/" -------------------------------------------------------------------------------- /examples/singlehost/quick_start.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | """Quick Start Example 18 | 19 | This script demonstrates how to run LoRA SFT on a toy dataset 20 | 21 | 1. Load HuggingFace Gemma2 checkpoint 22 | 2. Configure data loading and preprocessing 23 | 3. Run training across TPU/GPU devices 24 | 25 | This script can be run on both single-host and multi-host. 26 | For mulit-host set up, please follow https://kithara.readthedocs.io/en/latest/scaling_with_ray.html. 27 | 28 | Singlehost: python examples/singlehost/quick_start.py 29 | Multihost: python ray/submit_job.py "python3.11 examples/multihost/ray/TPU/quick_start.py" --hf-token your_token 30 | """ 31 | 32 | import os 33 | 34 | os.environ["KERAS_BACKEND"] = "jax" 35 | import keras 36 | import ray 37 | from kithara import ( 38 | KerasHubModel, 39 | Dataloader, 40 | Trainer, 41 | SFTDataset, 42 | ) 43 | 44 | config = { 45 | "model_handle": "hf://google/gemma-2-2b", 46 | "tokenizer_handle": "hf://google/gemma-2-2b", 47 | "seq_len": 4096, 48 | "lora_rank": 16, 49 | "precision": "mixed_bfloat16", 50 | "training_steps": 60, 51 | "eval_steps_interval": 10, 52 | "log_steps_interval": 1, 53 | "per_device_batch_size": 1, 54 | "max_eval_samples": 50, 55 | "learning_rate": 2e-4, 56 | } 57 | 58 | 59 | def run_workload(): 60 | 61 | # Create a toy dataset 62 | dataset_items = [ 63 | { 64 | "prompt": "What is your name?", 65 | "answer": "My name is Kithara", 66 | } 67 | for _ in range(1000) 68 | ] 69 | dataset = ray.data.from_items(dataset_items) 70 | train_source, eval_source = dataset.train_test_split(test_size=500) 71 | 72 | # Create model 73 | model = KerasHubModel.from_preset( 74 | config["model_handle"], 75 | precision=config["precision"], 76 | lora_rank=config["lora_rank"], 77 | ) 78 | 79 | # Creates datasets 80 | train_dataset = SFTDataset( 81 | train_source, 82 | tokenizer_handle=config["tokenizer_handle"], 83 | max_seq_len=config["seq_len"], 84 | ) 85 | eval_dataset = SFTDataset( 86 | eval_source, 87 | tokenizer_handle=config["tokenizer_handle"], 88 | max_seq_len=config["seq_len"], 89 | ) 90 | 91 | # Create optimizer 92 | optimizer = keras.optimizers.AdamW( 93 | learning_rate=config["learning_rate"], weight_decay=0.01 94 | ) 95 | 96 | # Create data loaders 97 | train_dataloader = Dataloader( 98 | train_dataset, 99 | per_device_batch_size=config["per_device_batch_size"], 100 | ) 101 | eval_dataloader = Dataloader( 102 | eval_dataset, 103 | per_device_batch_size=config["per_device_batch_size"], 104 | ) 105 | 106 | # Initialize trainer 107 | trainer = Trainer( 108 | model=model, 109 | optimizer=optimizer, 110 | train_dataloader=train_dataloader, 111 | eval_dataloader=eval_dataloader, 112 | steps=config["training_steps"], 113 | eval_steps_interval=config["eval_steps_interval"], 114 | max_eval_samples=config["max_eval_samples"], 115 | log_steps_interval=config["log_steps_interval"], 116 | ) 117 | 118 | # Start training 119 | trainer.train() 120 | 121 | print("Finished training. Prompting model...") 122 | 123 | # Test after tuning 124 | pred = model.generate( 125 | "What is your name?", 126 | max_length=30, 127 | tokenizer_handle=config["tokenizer_handle"], 128 | return_decoded=True, 129 | ) 130 | print("Tuned model generates:", pred) 131 | 132 | 133 | if __name__ == "__main__": 134 | 135 | run_workload() 136 | -------------------------------------------------------------------------------- /kithara/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from kithara.callbacks.profiler import * 2 | from kithara.callbacks.checkpointer import * -------------------------------------------------------------------------------- /kithara/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/kithara/config/__init__.py -------------------------------------------------------------------------------- /kithara/config/default.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Model parameters 16 | model_handle: "hf://google/gemma-2-2b" 17 | tokenizer_handle: "hf://google/gemma-2-2b" 18 | seq_len: 4096 19 | precision: "mixed_bfloat16" 20 | 21 | # Training parameters 22 | training_steps: 200 23 | eval_steps_interval: 100 24 | log_steps_interval: 1 25 | per_device_batch_size: 1 26 | max_eval_samples: 50 27 | learning_rate: 5e-05 28 | 29 | # Output parameters 30 | model_output_dir: "gs://bucket_name/ckpt/" 31 | -------------------------------------------------------------------------------- /kithara/config/pyconfig.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import sys 16 | from typing import Any 17 | import yaml 18 | 19 | 20 | DEFAULT_CONFIG = "kithara/config/default.yaml" 21 | 22 | 23 | def load_config(argv) -> dict[str, Any]: 24 | """Loads the YAML config from a file with a given name.""" 25 | config_name = DEFAULT_CONFIG 26 | if len(argv) > 1: 27 | config_name = argv[1] 28 | 29 | with open(config_name, "r", encoding="utf-8") as yaml_file: 30 | config = yaml.safe_load(yaml_file) 31 | 32 | # Some float parameters like learning_rate could be parsed 33 | # as strings instead of floats, so we need to explicitly 34 | # convert them. 35 | if 'learning_rate' in config: 36 | config['learning_rate'] = float(config['learning_rate']) 37 | 38 | 39 | print(f"Parsed config: {config}") 40 | return config 41 | -------------------------------------------------------------------------------- /kithara/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from kithara.dataset.dataloader import Dataloader 2 | from kithara.dataset.sft import SFTDataset 3 | from kithara.dataset.text_completion import TextCompletionDataset 4 | -------------------------------------------------------------------------------- /kithara/dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from typing import Iterator, Any, Iterable, Dict 18 | import jax 19 | from kithara.dataset.dataset import Dataset 20 | from jax.tree_util import tree_map 21 | import numpy as np 22 | 23 | class Dataloader(Iterable): 24 | """Kithara Dataloader class. This dataloader class supports distributed 25 | data loading for multi-host training. 26 | 27 | Attributes: 28 | dataset: A Kithara Dataset instance, or any Iterable instance 29 | that that returns individual, tokenized samples formatted 30 | as expected by the model. 31 | per_device_batch_size: Number of samples per batch per device. 32 | If you experience HBM OOM errors, try reducing this value. 33 | dataset_is_sharded_per_host: True if dataset is already sharded 34 | and each host is provided with a local dataset shard. False if 35 | every host is loading from the same dataset. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | dataset: Dataset, 41 | per_device_batch_size: int, 42 | dataset_is_sharded_per_host: bool = False, 43 | ): 44 | self.dataset = dataset 45 | self.per_device_batch_size = per_device_batch_size 46 | self.dataset_is_sharded_per_host = dataset_is_sharded_per_host 47 | self._iterator = None 48 | self.num_hosts = jax.process_count() 49 | self.host_id = jax.process_index() 50 | 51 | def __iter__(self) -> Iterator[Any]: 52 | """Return iterator over batches in the dataset. 53 | 54 | Yields: 55 | Per-host batch input 56 | """ 57 | self._iterator = iter(self.dataset) 58 | return self 59 | 60 | def __next__(self) -> Dict[str, str]: 61 | """Get next batch of data from the dataset. 62 | 63 | Returns: 64 | Per-host batch input 65 | """ 66 | if self.dataset_is_sharded_per_host: 67 | # For sharded datasets, each host processes per_host_batch_size samples 68 | samples = [next(self._iterator) for _ in range(self.per_host_batch_size)] 69 | else: 70 | # For non-sharded datasets: 71 | # 1. Each host loads global_batch_size samples 72 | # 2. Only keeps samples corresponding to its host_id 73 | samples = [] 74 | for i in range(self.global_batch_size): 75 | sample = next(self._iterator) 76 | if i % self.num_hosts == self.host_id: 77 | samples.append(sample) 78 | 79 | samples = tree_map( 80 | lambda *arrays: np.concatenate(arrays), *samples 81 | ) 82 | return samples 83 | 84 | def __len__(self) -> int: 85 | """Get the total number of batches in the dataset. 86 | 87 | Returns: 88 | Total number of batches that will be yielded by this dataloader. 89 | For non-sharded datasets, this is the same across all hosts. 90 | For sharded datasets, each host has its own length. 91 | """ 92 | if not hasattr(self.dataset, '__len__'): 93 | raise TypeError("Dataset must implement __len__ to support length calculation") 94 | 95 | total_samples = len(self.dataset) 96 | if self.dataset_is_sharded_per_host: 97 | batches = total_samples // self.per_host_batch_size 98 | else: 99 | batches = total_samples // (self.num_hosts * self.per_host_batch_size) 100 | return batches 101 | 102 | @property 103 | def global_batch_size(self): 104 | num_devices = jax.device_count() 105 | return self.per_device_batch_size * num_devices 106 | @property 107 | def per_host_batch_size(self): 108 | num_devices_per_host = jax.local_device_count() 109 | return self.per_device_batch_size * num_devices_per_host 110 | 111 | -------------------------------------------------------------------------------- /kithara/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import ray 18 | from typing import Any, Iterator, Iterable 19 | from abc import ABC 20 | from datasets import Dataset as HF_Dataset, IterableDataset as HF_IterableDataset 21 | 22 | 23 | class Dataset(Iterable, ABC): 24 | """A base dataset class that serves as a base for other dataset 25 | implementations, including SFTDataset and TextCompletionDataset. 26 | It supports both regular and streaming Ray datasets. 27 | 28 | Args: 29 | source (ray.data.Dataset): The underlying Ray dataset to wrap. 30 | 31 | """ 32 | 33 | def __init__(self, source: ray.data.Dataset): 34 | self.source = self._maybe_convert_to_ray_dataset(source) 35 | self._length = None 36 | self._iterator = None 37 | 38 | def _maybe_convert_to_ray_dataset(self, source): 39 | # HuggingFace datasets 40 | if isinstance(source, (HF_Dataset, HF_IterableDataset)): 41 | return ray.data.from_huggingface(source) 42 | # TODO: Add adapters for other dataset formats 43 | return source 44 | 45 | def process_sample(self, sample): 46 | return sample 47 | 48 | def __len__(self) -> int: 49 | if self._length is None: 50 | try: 51 | print( 52 | ( 53 | "Warning: If your dataset is a streaming dataset, " 54 | "this operation (__len__) might trigger its lazy executation." 55 | ) 56 | ) 57 | self._length = self.source.count() 58 | except: 59 | # For streamed datasets where count() might not work 60 | print( 61 | ( 62 | "Unable to get __len__ of dataset. Dataset may be a streaming " 63 | "dataset. Setting length to be 0." 64 | ) 65 | ) 66 | self._length = 0 67 | return self._length 68 | 69 | def __next__(self) -> Any: 70 | """Returns the next element in the dataset.""" 71 | if self._iterator is None: 72 | self._iterator = iter(self.source.iter_rows()) 73 | sample = next(self._iterator) 74 | return self.process_sample(sample) 75 | 76 | def __iter__(self) -> Iterator[Any]: 77 | """Return an iterator over the dataset.""" 78 | self._iterator = iter(self.source.iter_rows()) 79 | return self 80 | -------------------------------------------------------------------------------- /kithara/dataset/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from typing import List, Dict 18 | from transformers import AutoTokenizer as HFTokenizer 19 | import numpy as np 20 | from functools import lru_cache 21 | 22 | @lru_cache(maxsize=5) 23 | def initialize_tokenizer(tokenizer_handle, tokenizer=None): 24 | """ 25 | Creates an HuggingFace AutoTokenizer with the tokenizer_handle. 26 | Try to specify a pad_token for the tokenizer if not specified. 27 | """ 28 | if tokenizer is None: 29 | if tokenizer_handle.startswith("hf://"): 30 | tokenizer_handle = tokenizer_handle.removeprefix("hf://") 31 | try: 32 | tokenizer = HFTokenizer.from_pretrained(tokenizer_handle) 33 | except ValueError as e: 34 | print(f"Tokenizer handle {tokenizer_handle} is not a valid HuggingFace tokenizer handle.") 35 | 36 | # Llama tokenizers don't have a default pad_token, we must add it here 37 | if tokenizer.pad_token is None: 38 | # Gemma 2 39 | if tokenizer.get_vocab().get("") is not None: 40 | tokenizer.add_special_tokens({'pad_token': ''}) 41 | # Llama 3 uses "<|finetune_right_pad_id|>" 42 | elif tokenizer.get_vocab().get("<|finetune_right_pad_id|>") is not None: 43 | tokenizer.add_special_tokens({'pad_token': '<|finetune_right_pad_id|>'}) 44 | # Llama 2 doesn't have a padding token, fall back to "" 45 | elif tokenizer.get_vocab().get("") is not None: 46 | tokenizer.add_special_tokens({'pad_token': ''}) 47 | else: 48 | print("WARNING: Tokenizer doesn't have the attribute pad_token") 49 | return tokenizer 50 | 51 | 52 | def HFtokenize( 53 | text: List[str], 54 | tokenizer: HFTokenizer, 55 | seq_len: int, 56 | padding: str = "max_length", 57 | ) -> Dict[str, np.ndarray]: 58 | """Tokenizes text using a HuggingFace tokenizer with specific parameters. 59 | 60 | This function handles the tokenization of input text, applying padding and 61 | truncation as needed. It converts text into a format suitable for model input. 62 | 63 | Args: 64 | text (Union[str, List[str]]): Input text or list of texts to tokenize 65 | tokenizer (HFTokenizer): HuggingFace tokenizer instance 66 | seq_len (int): Maximum sequence length for tokenization 67 | padding (str, optional): Padding strategy. Defaults to "max_length". 68 | Options include: 69 | - "max_length": Pad to seq_len 70 | - "longest": Pad to longest sequence in batch 71 | - "do_not_pad": No padding 72 | 73 | Returns: 74 | Dict[str, np.ndarray]: Dictionary containing: 75 | - 'input_ids': Token IDs (shape: [batch_size, seq_len]) 76 | - 'attention_mask': Mask indicating real vs padded tokens 77 | (shape: [batch_size, seq_len]) 78 | 79 | Example: 80 | >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 81 | >>> result = HFtokenize(["hello world"], tokenizer, seq_len=10) 82 | >>> print(result['input_ids'].shape) 83 | (1, 10) 84 | """ 85 | 86 | return tokenizer( 87 | text, 88 | max_length=seq_len, 89 | padding=padding, 90 | padding_side="right", 91 | truncation=True, 92 | add_special_tokens=False, 93 | return_tensors="np", 94 | ) 95 | -------------------------------------------------------------------------------- /kithara/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | from kithara.distributed.sharding import * 2 | from kithara.distributed.data.split import split_dataset -------------------------------------------------------------------------------- /kithara/distributed/data/__init__.py: -------------------------------------------------------------------------------- 1 | from kithara.distributed.data.split import split_dataset -------------------------------------------------------------------------------- /kithara/distributed/data/split.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import ray 18 | 19 | def split_dataset(ds: ray.data.Dataset, num_hosts: int): 20 | if ds is None: 21 | return [None for _ in range(num_hosts)] 22 | return ds.streaming_split(num_hosts, equal=True) 23 | -------------------------------------------------------------------------------- /kithara/distributed/sharding/__init__.py: -------------------------------------------------------------------------------- 1 | from kithara.distributed.sharding.strategy import ShardingStrategy, PredefinedShardingStrategy 2 | -------------------------------------------------------------------------------- /kithara/distributed/sharding/_data_sharding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from kithara.distributed.sharding._mesh import Axis 18 | from jax.sharding import NamedSharding, Mesh 19 | import jax 20 | from jax.sharding import PartitionSpec as P 21 | from dataclasses import dataclass 22 | from typing import ClassVar 23 | 24 | 25 | @dataclass 26 | class DataSharding: 27 | 28 | # Class-level dictionary to store mesh types 29 | _mesh_types: ClassVar[dict] = { 30 | "tp": lambda: DataSharding.tp(), 31 | "fsdp": lambda: DataSharding.fsdp(), 32 | "fully_replicated": lambda: DataSharding.fully_replicated() 33 | } 34 | 35 | def __class_getitem__(cls, key: str): 36 | if key not in cls._mesh_types: 37 | raise KeyError(f"Unknown mesh type: {key}") 38 | return cls._mesh_types[key]() 39 | 40 | @classmethod 41 | def fsdp(cls): 42 | return NamedSharding( 43 | Mesh(jax.devices(), (Axis.FSDP,)), 44 | P(Axis.FSDP, None), 45 | ) 46 | 47 | @classmethod 48 | def tp(cls): 49 | return NamedSharding( 50 | Mesh(jax.devices(), (Axis.TP,)), 51 | P(None, Axis.TP), 52 | ) 53 | 54 | @classmethod 55 | def fully_replicated(cls): 56 | return NamedSharding( 57 | Mesh(jax.devices(), ("devices",)), 58 | P(None), 59 | ) 60 | -------------------------------------------------------------------------------- /kithara/distributed/sharding/_layout.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from kithara.distributed.sharding.models import GEMMA_LAYOUT, LLAMA_LAYOUT 18 | from dataclasses import dataclass 19 | from typing import ClassVar 20 | 21 | 22 | # Layout configurations for different model architectures 23 | @dataclass 24 | class Layout: 25 | 26 | _mesh_types: ClassVar[dict] = {} 27 | 28 | @classmethod 29 | def initialize_mesh_types(cls): 30 | # Avoid circular import 31 | from kithara.model import supported_models 32 | cls._mesh_types = { 33 | supported_models.GEMMA2_2B: lambda: Layout.gemma(), 34 | supported_models.GEMMA2_9B: lambda: Layout.gemma(), 35 | supported_models.GEMMA2_27B: lambda: Layout.gemma(), 36 | supported_models.LLAMA31_8B: lambda: Layout.llama(), 37 | supported_models.LLAMA31_70B: lambda: Layout.llama(), 38 | supported_models.LLAMA31_405B: lambda: Layout.llama(), 39 | supported_models.LLAMA32_1B: lambda: Layout.llama(), 40 | supported_models.LLAMA32_3B: lambda: Layout.llama(), 41 | } 42 | 43 | def __class_getitem__(cls, key: str): 44 | if not cls._mesh_types: 45 | cls.initialize_mesh_types() 46 | if key not in cls._mesh_types: 47 | raise KeyError(f"Unknown mesh type: {key}") 48 | return cls._mesh_types[key]() 49 | 50 | @classmethod 51 | def gemma(cls): 52 | return GEMMA_LAYOUT 53 | 54 | @classmethod 55 | def llama(cls): 56 | return LLAMA_LAYOUT 57 | -------------------------------------------------------------------------------- /kithara/distributed/sharding/_mesh.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from keras.distribution import DeviceMesh 18 | import jax 19 | from enum import Enum 20 | from dataclasses import dataclass 21 | from typing import ClassVar 22 | 23 | 24 | class Axis(str, Enum): 25 | """Enumeration of supported sharding axes in predefined shardings. 26 | 27 | Attributes: 28 | FSDP: Fully Sharded Data Parallel axis 29 | TP: Tensor Parallel axis 30 | """ 31 | 32 | FSDP = "fsdp" 33 | TP = "tp" 34 | 35 | 36 | @dataclass 37 | class Mesh: 38 | # Class-level dictionary to store mesh types 39 | _mesh_types: ClassVar[dict] = {"tp": lambda: Mesh.tp(), "fsdp": lambda: Mesh.fsdp()} 40 | 41 | def __class_getitem__(cls, key: str): 42 | if key not in cls._mesh_types: 43 | raise KeyError(f"Unknown mesh type: {key}") 44 | return cls._mesh_types[key]() 45 | 46 | @classmethod 47 | def fsdp(cls): 48 | return DeviceMesh( 49 | shape=(len(jax.devices()),), 50 | axis_names=(Axis.FSDP,), 51 | devices=jax.devices(), 52 | ) 53 | 54 | @classmethod 55 | def tp(cls): 56 | return DeviceMesh( 57 | shape=(len(jax.devices()),), 58 | axis_names=(Axis.TP,), 59 | devices=jax.devices(), 60 | ) 61 | -------------------------------------------------------------------------------- /kithara/distributed/sharding/maxtext.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from kithara.distributed.sharding.strategy import ShardingStrategy 18 | from dataclasses import dataclass 19 | from keras.distribution import DeviceMesh, LayoutMap 20 | from keras.src.distribution.distribution_lib import Distribution 21 | from jax.sharding import Sharding, PartitionSpec as P 22 | from typing import Any 23 | from jax.tree_util import tree_flatten_with_path 24 | from keras.src.distribution.distribution_lib import TensorLayout 25 | import keras 26 | from kithara.distributed.sharding.utils import convert_jax_mesh_to_keras_mesh 27 | from jax.sharding import Mesh 28 | import jax 29 | 30 | # Hacky monkey patching for now. Keras validate_axes currently does not accept nested tuples as the ParitionSpec value. 31 | # TODO: Patch Keras and remove this logic 32 | TensorLayout._validate_axes = lambda x: x 33 | 34 | 35 | @dataclass 36 | class MaxTextSharding(ShardingStrategy): 37 | """Sharding strategy from MaxText config. 38 | # TODO: Add docstring 39 | """ 40 | 41 | jax_mesh: Mesh 42 | state_shardings: Any 43 | maxtext_config: 'pyconfig.Hyperparameter' 44 | 45 | def __post_init__(self) -> None: 46 | # """Initialize the sharding strategy configuration.""" 47 | 48 | self._jax_mesh = self.jax_mesh 49 | self.jax_devices = self.jax_mesh.devices 50 | self.maxtext_config = self.maxtext_config 51 | 52 | self._mesh = self._configure_mesh(self.jax_mesh) 53 | self._layout_map = self._configure_layout_map(self.state_shardings) 54 | self._data_sharding = self._configure_data_sharding() 55 | self._distribution = self._configure_distribution() 56 | 57 | @property 58 | def mesh(self) -> DeviceMesh: 59 | return self._mesh 60 | 61 | @property 62 | def layout_map(self) -> LayoutMap: 63 | return self._layout_map 64 | 65 | @property 66 | def data_sharding(self) -> Sharding: 67 | return self._data_sharding 68 | 69 | @property 70 | def distribution(self) -> Sharding: 71 | return self._distribution 72 | 73 | def _configure_mesh(self, jax_mesh) -> DeviceMesh: 74 | return convert_jax_mesh_to_keras_mesh(jax_mesh) 75 | 76 | def _configure_layout_map(self, state_shardings) -> LayoutMap: 77 | layout_map = LayoutMap(self._mesh) 78 | 79 | # Maps regex string to tuple 80 | # E.g. .*params.decoder.decoder.norm.scale.* -> ('tensor',) 81 | var_path_and_sharding, * \ 82 | _ = tree_flatten_with_path(state_shardings.params) 83 | mappings = { 84 | ".*" 85 | + ".".join(str(k.key) for k in var_path) 86 | + ".*": tuple(var_sharding.spec) 87 | for var_path, var_sharding in var_path_and_sharding 88 | } 89 | 90 | for pattern, layout in mappings.items(): 91 | layout_map[pattern] = layout 92 | 93 | return layout_map 94 | 95 | def _configure_data_sharding(self) -> Sharding: 96 | data_pspec = P(*self.maxtext_config.data_sharding) 97 | data_sharding = jax.tree_util.tree_map( 98 | lambda p: jax.sharding.NamedSharding(self._jax_mesh, p), data_pspec) 99 | return data_sharding 100 | 101 | def _configure_distribution(self) -> Distribution: 102 | return keras.distribution.ModelParallel(layout_map=self._layout_map) 103 | -------------------------------------------------------------------------------- /kithara/distributed/sharding/models/__init__.py: -------------------------------------------------------------------------------- 1 | from kithara.distributed.sharding.models.gemma import GEMMA_LAYOUT 2 | from kithara.distributed.sharding.models.llama import LLAMA_LAYOUT 3 | -------------------------------------------------------------------------------- /kithara/distributed/sharding/models/gemma.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from kithara.distributed.sharding._mesh import Axis 18 | 19 | GEMMA_FSDP = { 20 | ".*token_embedding.embeddings.*": (None, Axis.FSDP), 21 | ".*decoder_block.*attention.*(query|key|value).kernel.*": ( 22 | None, 23 | Axis.FSDP, 24 | ), 25 | ".*decoder_block.*attention_output.kernel.*": (None, None, Axis.FSDP), 26 | ".*decoder_block.*ffw_gating.kernel.*": (None, Axis.FSDP), 27 | ".*decoder_block.*ffw_gating_2.kernel.*": (None, Axis.FSDP), 28 | ".*decoder_block.*ffw_linear.kernel.*": (Axis.FSDP, None), 29 | # Lora layers 30 | ".*decoder_block.*attention.*(query|key|value).lora_kernel.*": ( 31 | None, 32 | Axis.FSDP, 33 | ), 34 | } 35 | 36 | GEMMA_LAYOUT = { 37 | "fsdp": GEMMA_FSDP, 38 | } 39 | -------------------------------------------------------------------------------- /kithara/distributed/sharding/models/llama.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from kithara.distributed.sharding._mesh import Axis 18 | 19 | LLAMA_FSDP = { 20 | ".*token_embedding.embeddings.*": (None, Axis.FSDP), 21 | ".*token_embedding.reverse_embeddings.*": (Axis.FSDP, None), 22 | ".*transformer_layer.*self_attention.*(query|key|value).kernel*": ( 23 | None, 24 | Axis.FSDP, 25 | ), 26 | ".*transformer_layer.*self_attention.attention_output.kernel*": ( 27 | None, 28 | None, 29 | Axis.FSDP, 30 | ), 31 | ".*transformer_layer.*feedforward_gate_dense.kernel*": (None, Axis.FSDP), 32 | ".*transformer_layer.*feedforward_intermediate_dense.kernel*": (None, Axis.FSDP), 33 | ".*transformer_layer.*feedforward_output_dense.kernel*": (None, Axis.FSDP), 34 | # Lora layers 35 | ".*transformer_layer.*self_attention.*(query|key|value).lora_kernel.*": ( 36 | None, 37 | Axis.FSDP, 38 | ), 39 | } 40 | 41 | LLAMA_LAYOUT = { 42 | "fsdp": LLAMA_FSDP, 43 | } 44 | -------------------------------------------------------------------------------- /kithara/model/README.md: -------------------------------------------------------------------------------- 1 | # Kithara Model Support Guide 2 | 3 | ## Overview 4 | 5 | Kithara is designed as a model-agnostic framework that focuses on model tuning rather than implementation. It provides integration with existing model implementations through supported providers: 6 | 7 | - MaxText 8 | - KerasHub 9 | 10 | Additional model implementation providers may be added in the future, community contributions are always welcomed. 11 | 12 | ## Supported Models 13 | 14 | The complete list of supported models is maintained in `supported_models.py`. For the most up-to-date information, please refer to this file. 15 | 16 | ## Adding Support for New Models 17 | 18 | This guide focuses on adding support for new KerasHub model architectures. The process involves several key steps: 19 | 20 | ### 1. Configure Model Sharding 21 | 22 | 1. Navigate to `distributed/sharding/models/` 23 | 2. Create or modify the FSDP (Fully Sharded Data Parallel) model sharding strategy for your model 24 | 3. Update the `PredefinedShardingStrategy` class to include support for your model 25 | 26 | ### 2. Model Implementation 27 | 28 | First, create a KerasHub model instance to examine its structure: 29 | 30 | ```python 31 | model = KerasHubModel.from_preset( 32 | your_model_handle 33 | ) 34 | 35 | # Examine model weights and shapes 36 | for weight in model.weights: 37 | print(f"Path: {weight.path}, Shape: {weight.value.shape}") 38 | ``` 39 | 40 | ### 3. HuggingFace Compatibility 41 | 42 | #### 3.1 Analyze HuggingFace Model Structure 43 | 44 | Create and examine the equivalent HuggingFace model: 45 | 46 | ```python 47 | from transformers import AutoModelForCausalLM 48 | 49 | model = AutoModelForCausalLM.from_pretrained(preset_handle) 50 | 51 | # Examine model state dictionary 52 | for key, value in model.state_dict().items(): 53 | print(f"Key: {key}, Shape: {value.shape}") 54 | ``` 55 | 56 | #### 3.2 Update Configuration Files 57 | 58 | 1. Add HuggingFace model configuration: 59 | - File: `hf_compatibility/model_configs.py` 60 | - Add the specific configuration for your model 61 | 62 | 2. Define weight shapes: 63 | - File: `hf_compatibility/shape_mapping.py` 64 | - Document the shapes of all weights in the HuggingFace model 65 | 66 | 3. Create weight mappings: 67 | - File: `model/kerashub/ckpt_compatibility/param_mapping.py` 68 | - Define the mapping between HuggingFace and Kithara weights 69 | - Implement weight translation hook functions 70 | 71 | ### 4. Testing and Validation 72 | 73 | 1. Create unit tests in `tests/model/models/kerashub/ckpt_compatibility/` 74 | 2. Follow the pattern established in `gemma2-2b.py` 75 | 3. Verify numerical correctness of checkpoint translations 76 | 77 | ### Debugging Tips 78 | 79 | When working on weight translations: 80 | 81 | 1. Fork the transformers library for debugging 82 | 2. Add logging statements to compare intermediate tensor outputs 83 | 3. Verify tensor values at each later 84 | 4. To verify MaxText model parameters and shapes 85 | 86 | ```python 87 | from kithara import MaxTextModel 88 | model = MaxTextModel.from_random( 89 | "llama3.1-8b" 90 | ) 91 | 92 | # Examine model weights and shapes 93 | for weight in model.weights: 94 | print(f"Path: {weight.path}, Shape: {weight.value.shape}") 95 | ``` 96 | 97 | 98 | ## Best Practices 99 | 100 | 1. Support all sizes of the supported model family 101 | 2. Document all weight mappings thoroughly and include examples of expected tensor shapes 102 | -------------------------------------------------------------------------------- /kithara/model/__init__.py: -------------------------------------------------------------------------------- 1 | from kithara.model.model import * 2 | from kithara.model.maxtext.maxtext_model import * 3 | from kithara.model.kerashub.keras_hub_model import * -------------------------------------------------------------------------------- /kithara/model/hf_compatibility/__init__.py: -------------------------------------------------------------------------------- 1 | from kithara.model.hf_compatibility.to_huggingface import * 2 | from kithara.model.hf_compatibility.model_configs import * 3 | from kithara.model.hf_compatibility.shape_mapping import * 4 | 5 | -------------------------------------------------------------------------------- /kithara/model/kerashub/__init__.py: -------------------------------------------------------------------------------- 1 | from kithara.model.kerashub.keras_hub_model import * -------------------------------------------------------------------------------- /kithara/model/maxtext/__init__.py: -------------------------------------------------------------------------------- 1 | from kithara.model.maxtext.maxtext_model import * -------------------------------------------------------------------------------- /kithara/model/maxtext/ckpt_compatibility/__init__.py: -------------------------------------------------------------------------------- 1 | from kithara.model.maxtext.ckpt_compatibility.from_huggingface import * 2 | from kithara.model.maxtext.ckpt_compatibility.to_huggingface import * 3 | from kithara.model.maxtext.ckpt_compatibility.utils import * 4 | from kithara.model.maxtext.ckpt_compatibility.param_mapping import * -------------------------------------------------------------------------------- /kithara/model/maxtext/ckpt_compatibility/to_huggingface.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | """ 18 | Module to convert MaxText model weights to HuggingFace format. 19 | """ 20 | 21 | import torch 22 | import contextlib 23 | import time 24 | from kithara.model.hf_compatibility import ( 25 | MODEL_CONFIGS, 26 | ) 27 | from kithara.model.hf_compatibility import process_weight, save_model_files 28 | from kithara.model.maxtext.ckpt_compatibility.param_mapping import ( 29 | HOOK_FNS, 30 | PARAM_MAPPING, 31 | ) 32 | from kithara.model.hf_compatibility import SHAPE_MAPPING 33 | import jax 34 | 35 | def _get_model_mappings(model_name: str, scan_layers: bool, config: dict): 36 | """Retrieves parameter, shape, and hook function mappings for the model.""" 37 | return { 38 | "param_mapping": PARAM_MAPPING[model_name](config.to_dict(), scan_layers), 39 | "shape_mapping": SHAPE_MAPPING[model_name](config.to_dict()), 40 | "hook_fn_mapping": HOOK_FNS[model_name]( 41 | config.to_dict(), scan_layers, saving_to_hf=True 42 | ), 43 | } 44 | 45 | 46 | def _save_checkpoint( 47 | maxtext_model: "kithara.MaxTextModel", output_dir: str, parallel_threads=8 48 | ): 49 | """Main function to save a MaxText model checkpoint in HuggingFace format.""" 50 | if maxtext_model.model_name not in MODEL_CONFIGS: 51 | raise ValueError( 52 | f"Model {maxtext_model.model_name} is not supported. " 53 | f"Supported models are {list(MODEL_CONFIGS.keys())}" 54 | ) 55 | 56 | config = MODEL_CONFIGS[maxtext_model.model_name] 57 | mappings = _get_model_mappings( 58 | maxtext_model.model_name, maxtext_model.scan_layers, config 59 | ) 60 | 61 | # Process weights 62 | start_time = time.time() 63 | weight_arrays = {} 64 | for variable in maxtext_model.weights: 65 | weight_dict = process_weight(variable, mappings) 66 | weight_arrays.update(weight_dict) 67 | print( 68 | f"\n✅ Weights converted into HuggingFace format in {time.time() - start_time}s" 69 | ) 70 | 71 | # Save all model files 72 | save_model_files(weight_arrays, config, output_dir, parallel_threads) 73 | 74 | 75 | def save_maxtext_model_in_hf_format( 76 | model: "MaxTextModel", output_dir: str, dtype: str = "auto", parallel_threads=8 77 | ): 78 | """Convert and save a MaxText model in HuggingFace format. 79 | 80 | Args: 81 | model: MaxTextModel instance to save 82 | output_dir: Directory to save the HuggingFace checkpoint 83 | dtype: dtype for the converted model ("auto" uses source model's dtype) 84 | """ 85 | 86 | @contextlib.contextmanager 87 | def _set_default_tensor_type(dtype: torch.dtype): 88 | """Sets the default torch dtype to the given dtype.""" 89 | torch.set_default_dtype(dtype) 90 | yield 91 | torch.set_default_dtype(torch.float) 92 | 93 | if dtype == "auto": 94 | dtype = model.weight_dtype 95 | 96 | print(f"-> Saving model with {dtype=}...") 97 | with _set_default_tensor_type(getattr(torch, dtype)): 98 | _save_checkpoint(model, output_dir, parallel_threads) 99 | 100 | jax.experimental.multihost_utils.sync_global_devices("saving_completed") -------------------------------------------------------------------------------- /kithara/model/maxtext/ckpt_compatibility/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader 18 | from transformers import AutoModelForCausalLM 19 | 20 | def get_hf_safetensor_weight_keys(preset_handle): 21 | with SafetensorLoader(preset_handle) as loader: 22 | weight_map_keys = loader.safetensor_config["weight_map"].keys() 23 | return list(weight_map_keys) 24 | 25 | def get_hf_model_weight_shapes(preset_handle): 26 | from transformers import AutoModelForCausalLM 27 | weight_keys = get_hf_safetensor_weight_keys(preset_handle) 28 | model = AutoModelForCausalLM.from_pretrained(preset_handle) 29 | 30 | weight_to_shape = {} 31 | for path in weight_keys: 32 | path = path.removesuffix(".weight") 33 | hf_module = model.get_submodule(path) 34 | target_shape = hf_module.state_dict()["weight"].shape 35 | weight_to_shape[path] = target_shape 36 | return weight_to_shape 37 | -------------------------------------------------------------------------------- /kithara/model/maxtext/inference_engine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from kithara.model.maxtext.maxtext.MaxText.maxengine import MaxEngine 16 | import jax 17 | import max_utils 18 | 19 | class MaxtextInferenceEngine(MaxEngine): 20 | """This is a patched version of MaxEngine 21 | 22 | Changes: 23 | - Added the `load_existing_params` function to allow 24 | the engine to run with an exisitng model instead of initiating 25 | a new one. 26 | """ 27 | def load_existing_params(self, params, rng=None) -> "Params": 28 | 29 | if rng is None: 30 | rng = jax.random.PRNGKey(0) 31 | 32 | self.abstract_params = jax.tree_util.tree_map( 33 | lambda x: ( 34 | jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding) 35 | if isinstance(x, jax.Array) 36 | else None 37 | ), 38 | params, 39 | ) 40 | 41 | self.prefill_kv_cache_annotations = max_utils.get_prefill_kv_cache_annotations( 42 | self.model, self.config, rng, self._mesh 43 | ) 44 | self.prefill_kv_cache_shardings = jax.tree_util.tree_map( 45 | lambda x: jax.sharding.NamedSharding(self._mesh, x), 46 | self.prefill_kv_cache_annotations, 47 | ) 48 | 49 | if self.config.stack_prefill_result_cache: 50 | # Add extra axis for the axis generated by the stack. 51 | self.prefill_kv_cache_shardings = jax.tree_util.tree_map( 52 | lambda x: jax.sharding.NamedSharding( 53 | self._mesh, jax.sharding.PartitionSpec(None, *x.spec) 54 | ), 55 | self.prefill_kv_cache_shardings, 56 | ) 57 | self.prefill_kv_cache_shardings = self.prefill_kv_cache_shardings[ 58 | "decoder" 59 | ]["layers_0"] 60 | 61 | self.kv_cache_annotations = max_utils.get_kv_cache_annotations( 62 | self.model, self.config, rng, self._mesh 63 | ) 64 | self.kv_cache_shardings = jax.tree_util.tree_map( 65 | lambda x: jax.sharding.NamedSharding(self._mesh, x), 66 | self.kv_cache_annotations, 67 | ) 68 | return params 69 | -------------------------------------------------------------------------------- /kithara/model/supported_models.py: -------------------------------------------------------------------------------- 1 | GEMMA2_2B = "gemma2-2b" 2 | GEMMA2_9B = "gemma2-9b" 3 | GEMMA2_27B = "gemma2-27b" 4 | LLAMA31_8B = "llama3.1-8b" 5 | LLAMA31_70B = "llama3.1-70b" 6 | LLAMA31_405B = "llama3.1-405b" 7 | LLAMA32_1B = "llama3.2-1b" 8 | LLAMA32_3B = "llama3.2-3b" 9 | -------------------------------------------------------------------------------- /kithara/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from jaxtyping import PyTree 2 | from kithara.optimizers.protocol import KitharaOptimizer 3 | from kithara.optimizers.optax_optimizer import OptaxOptimizerInterface 4 | import optax 5 | 6 | def convert_to_kithara_optimizer( 7 | optimizer: optax.GradientTransformation, trainable_variables: PyTree) -> KitharaOptimizer: 8 | if isinstance(optimizer, optax.GradientTransformation): 9 | return OptaxOptimizerInterface(optimizer, trainable_variables) 10 | raise NotImplementedError(f"Unsupported optimizer type: {type(optimizer)}") 11 | -------------------------------------------------------------------------------- /kithara/optimizers/optax_optimizer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax 4 | from jaxtyping import PyTree 5 | from keras.src import backend 6 | import optax 7 | 8 | from kithara.optimizers.protocol import KitharaOptimizer, TrainableVariable, OptimizerVariable 9 | 10 | 11 | class OptaxOptimizerInterface(KitharaOptimizer): 12 | def __init__(self, optimizer: optax.GradientTransformation, params: PyTree[backend.Variable]): 13 | self.optimizer, state = optimizer, optimizer.init(params) 14 | self.state_in_keras_variable = self._convert_state_to_keras_variables(state, params) 15 | 16 | with backend.name_scope("OptaxOptimizerInterface"): 17 | iterations = backend.Variable(0, name="iterations") 18 | 19 | self._iterations = iterations 20 | 21 | @property 22 | def learning_rate(self) -> None: 23 | return 24 | 25 | @property 26 | def variables(self) -> PyTree[backend.Variable]: 27 | return self.state_in_keras_variable 28 | 29 | @property 30 | def iterations(self) -> backend.Variable: 31 | return self._iterations 32 | 33 | def stateless_apply(self, optimizer_variable: OptimizerVariable, gradients: PyTree, 34 | trainable_variables: TrainableVariable 35 | ) -> Tuple[TrainableVariable, OptimizerVariable]: 36 | 37 | updates, opt_state = self.optimizer.update( 38 | gradients, optimizer_variable, trainable_variables) 39 | new_trainable_variables = optax.apply_updates(trainable_variables, updates) 40 | 41 | return new_trainable_variables, opt_state 42 | 43 | def get_optimizer_memory_usage(self): 44 | return sum([leaf.nbytes for leaf in jax.tree_util.tree_leaves(self.variables)]) 45 | 46 | def _convert_state_to_keras_variables( 47 | self, opt_state: PyTree, trainable_vars: TrainableVariable) -> PyTree[backend.Variable]: 48 | """Convert optax State to a `PyTree` of `keras.Variable` 49 | 50 | Creates a structure (PyTree) of `keras.Variable` objects that mirrors 51 | the structure of the optimizer state (`opt_state`). 52 | 53 | For each subtree (e.g., momentum, variance) within the `opt_state` PyTree: 54 | - Corresponding `keras.Variable` objects are created, one for each 55 | associated trainable variable in the `trainable_vars` PyTree. 56 | - Each created `keras.Variable` is named based on the path of its associated node in `trainable_vars`: 57 | 1. Take the trainable variable's path. 58 | 2. Replace all forward slashes (`/`) with underscores (`_`). 59 | 3. Append the suffix `_optstate`. 60 | 61 | For Example: 62 | - If a trainable variable's path in `trainable_vars` is 63 | `"decoder-block-0/layer-1/kernel"`. 64 | - The corresponding `keras.Variable` holding its optimizer state component 65 | will be named `"decoder-block-0_layer-1_kernel_optstate"`. 66 | 67 | Args: 68 | opt_state (PyTree): the optax optimizer state 69 | trainable_vars (PyTree[backend.Variable]): trainable variables 70 | 71 | Returns: 72 | PyTree[backend.Variable]: the transformed optax state. 73 | """ 74 | 75 | _, trainable_tree_def = jax.tree.flatten(trainable_vars) 76 | 77 | def is_same_shape_with_trainable_subtree(opt_state_subtree: PyTree): 78 | _, opt_state_tree_def = jax.tree.flatten(opt_state_subtree) 79 | return opt_state_tree_def == trainable_tree_def 80 | 81 | def create_keras_variable_pytree( 82 | subtree, trainable_vars) -> PyTree[backend.Variable]: 83 | 84 | if not is_same_shape_with_trainable_subtree(subtree): 85 | return jax.tree.map( 86 | lambda leaf: backend.Variable(leaf,), subtree 87 | ) 88 | 89 | return jax.tree.map( 90 | lambda leaf, keras_var: backend.Variable( 91 | leaf, name=f"{keras_var.path.replace('/', '_')}_optstate",), 92 | subtree, 93 | trainable_vars) 94 | 95 | return jax.tree.map( 96 | lambda subtree: create_keras_variable_pytree(subtree, trainable_vars), 97 | opt_state, 98 | is_leaf=is_same_shape_with_trainable_subtree 99 | ) 100 | -------------------------------------------------------------------------------- /kithara/optimizers/protocol.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Protocol, Tuple, TypeAlias, List, 3 | runtime_checkable) 4 | 5 | from jaxtyping import PyTree 6 | from keras.src.backend import Variable 7 | 8 | TrainableVariable: TypeAlias = List[Variable] 9 | OptimizerVariable: TypeAlias = PyTree[Variable] 10 | 11 | @runtime_checkable 12 | class KitharaOptimizer(Protocol): 13 | 14 | @property 15 | def learning_rate(self) -> float | None: 16 | ... 17 | @property 18 | def variables(self) -> OptimizerVariable: 19 | ... 20 | @variables.setter 21 | def variables(self, state: PyTree): 22 | ... 23 | @property 24 | def iterations(self) -> Variable: 25 | ... 26 | def stateless_apply(self, optimizer_variable: OptimizerVariable, gradients: PyTree, 27 | trainable_variables: TrainableVariable) -> Tuple[TrainableVariable, OptimizerVariable]: 28 | """Apply the optimizer to the trainable variables and gradients. 29 | 30 | Apply the optimizer to the trainable variables and gradients. This method should be stateless, i.e. it should not 31 | modify the internal state of the optimizer. If the optimizer has state, it should be passed as `optimizer_variable`. 32 | If `optimizer_variable` is not provided, the optimizer should use the internal state tracked by `KitharaOptimizer`, 33 | but not optimizer itself. 34 | 35 | Args: 36 | optimizer_variable (OptimizerVariable): The explicit optimizer state. 37 | gradients (PyTree): The gradients to apply to the trainable variables. 38 | The shape of the gradients should match the shape of the `trainable variables`. 39 | trainable_variables (TrainableVariable): The trainable variables to apply the optimizer to 40 | 41 | Returns: 42 | Tuple[TrainableVariable, OptimizerVariable]: The tuple of the updated trainable variables and the updated optimizer state. 43 | """ 44 | def get_optimizer_memory_usage(self) -> int: 45 | """Get the memory usage of the state/variables of the optimizer in bytes. 46 | 47 | Returns: 48 | int: bytes used by the optimizer state/variables. 49 | """ 50 | -------------------------------------------------------------------------------- /kithara/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from kithara.trainer.trainer import * -------------------------------------------------------------------------------- /kithara/utils/gcs_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from google.cloud.storage import Client, transfer_manager 18 | import os 19 | import time 20 | 21 | 22 | def upload_file_to_gcs( 23 | local_file: str, gs_bucket_path: str, remove_local_file_after_upload=False 24 | ): 25 | """Uploads a single file to Google Cloud Storage. 26 | 27 | Args: 28 | local_file: Path to local file 29 | gs_bucket_path: GCS destination (e.g. "gs://my-bucket/path/file.txt" or "my-bucket/path/file.txt") 30 | """ 31 | # Standardize bucket path format 32 | gs_bucket_path = gs_bucket_path.removeprefix("gs://") 33 | bucket_name = gs_bucket_path.split("/")[0] 34 | blob_name = gs_bucket_path[len(bucket_name) :].lstrip("/") 35 | 36 | print(f"-> Uploading {local_file} to {gs_bucket_path}...") 37 | # Upload file 38 | storage_client = Client() 39 | bucket = storage_client.bucket(bucket_name) 40 | blob = bucket.blob(blob_name) 41 | blob.upload_from_filename(local_file) 42 | 43 | print(f"✅ Uploaded {local_file} to {bucket.name}/{blob_name}") 44 | 45 | if remove_local_file_after_upload: 46 | os.remove(local_file) 47 | print(f"✅ Deleted {local_file}") 48 | 49 | 50 | def upload_folder_to_gcs(local_folder: str, gs_bucket_path: str, num_workers: int = 8): 51 | """Uploads all files from a local folder to Google Cloud Storage. 52 | 53 | Args: 54 | local_folder: Path to local folder (e.g. "data/images") 55 | gs_bucket_path: GCS destination (e.g. "gs://my-bucket/images" or "my-bucket/images") 56 | num_workers: Number of parallel upload workers 57 | """ 58 | start_time = time.time() 59 | 60 | # Standardize bucket path format 61 | gs_bucket_path = gs_bucket_path.removeprefix("gs://") 62 | bucket_name = gs_bucket_path.split("/")[0] 63 | # Ensure destination ends with "/" 64 | destination_dir = gs_bucket_path[len(bucket_name) :] 65 | if destination_dir.startswith("/"): 66 | destination_dir = destination_dir[1:] 67 | if destination_dir != "" and not destination_dir.endswith("/"): 68 | destination_dir += "/" 69 | 70 | # Get files to upload 71 | files_in_local_folder = os.listdir(local_folder) 72 | # Set up GCS client 73 | storage_client = Client() 74 | bucket = storage_client.bucket(bucket_name) 75 | 76 | # Upload files in parallel 77 | results = transfer_manager.upload_many_from_filenames( 78 | bucket, 79 | files_in_local_folder, 80 | source_directory=local_folder, 81 | max_workers=num_workers, 82 | blob_name_prefix=destination_dir, 83 | ) 84 | 85 | # Report results 86 | for name, result in zip(files_in_local_folder, results): 87 | if isinstance(result, Exception): 88 | print(f"Failed to upload {name}: {result}") 89 | else: 90 | print(f"✅ Uploaded {name} to {bucket.name}/{destination_dir}{name}") 91 | 92 | print(f"Upload completed in {time.time() - start_time}s") 93 | 94 | 95 | def find_cache_root_dir(): 96 | if "KERAS_HOME" in os.environ: 97 | cachdir = os.environ.get("KERAS_HOME") 98 | else: 99 | cachdir = os.path.expanduser(os.path.join("~", ".keras")) 100 | if not os.access(cachdir, os.W_OK): 101 | cachdir = os.path.join("/tmp", ".keras") 102 | return cachdir 103 | -------------------------------------------------------------------------------- /kithara/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import platform 3 | 4 | 5 | def get_device_stats(): 6 | device_count = jax.device_count() 7 | device_kind = jax.devices()[0].device_kind 8 | memory_info = jax.devices()[0].memory_stats() 9 | memory_per_device = memory_info["bytes_limit"] / (1024**3) # Convert to GB 10 | total_memory = memory_per_device * device_count 11 | return { 12 | "device_count": device_count, 13 | "device_kind": device_kind, 14 | "memory_per_device_gb": round(memory_per_device, 2), 15 | "total_memory_gb": round(total_memory, 2), 16 | } 17 | 18 | 19 | def print_kithara_logo_and_platform_info(): 20 | platform_system = platform.system() 21 | tpu_stats = get_device_stats() 22 | statistics = ( 23 | f" '==='\n" 24 | f" |||\n" 25 | f" '- ||| -'\n" 26 | f" / ||||| \\ Kithara. Platform: {platform_system}. JAX: {jax.__version__}\n" 27 | f" | (|||) | Hardware: {tpu_stats['device_kind']}. Device count: {tpu_stats['device_count']}.\n" 28 | f" | |◕‿◕| | HBM Per Device: {tpu_stats['memory_per_device_gb']} GB. Total HBM Memory: {tpu_stats['total_memory_gb']} GB\n" 29 | f" \\ ||||| / Free Apache license: http://github.com/ai-hypercomputer/kithara\n" 30 | f" --|===|--" 31 | ) 32 | print(statistics) 33 | -------------------------------------------------------------------------------- /kithara/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import torch 18 | import numpy as np 19 | from jax.experimental import multihost_utils 20 | from typing import Optional 21 | 22 | 23 | def convert_jax_weight_to_torch( 24 | weight: "jax.Array", dtype: Optional[str] = None 25 | ) -> torch.Tensor: 26 | expected_dtype = str(weight.dtype) if dtype is None else dtype 27 | expected_shape = weight.shape 28 | weight = multihost_utils.process_allgather(weight) 29 | weight = np.array(weight, dtype="float32") 30 | torch_dtype = getattr(torch, expected_dtype) 31 | torch_array = torch.from_numpy(weight).to(torch_dtype).reshape(expected_shape) 32 | return torch_array 33 | -------------------------------------------------------------------------------- /kithara/utils/tree_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import jax 18 | 19 | def tree_path_to_string(path, sep=None): 20 | keys = [] 21 | for key in path: 22 | if isinstance(key, jax.tree_util.SequenceKey): 23 | keys.append(str(key.idx)) 24 | elif isinstance(key, jax.tree_util.DictKey): 25 | keys.append(str(key.key)) 26 | elif isinstance(key, jax.tree_util.GetAttrKey): 27 | keys.append(str(key.name)) 28 | elif isinstance(key, jax.tree_util.FlattenedIndexKey): 29 | keys.append(str(key.key)) 30 | else: 31 | keys.append(str(key)) 32 | if sep is None: 33 | return tuple(keys) 34 | return sep.join(keys) 35 | 36 | 37 | def named_tree_map(f, tree, *rest, is_leaf=None, sep=None): 38 | """An extended version of jax.tree_util.tree_map, where the mapped function 39 | f takes both the name (path) and the tree leaf as input. 40 | """ 41 | return jax.tree_util.tree_map_with_path( 42 | lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r), 43 | tree, 44 | *rest, 45 | is_leaf=is_leaf, 46 | ) 47 | -------------------------------------------------------------------------------- /perf/kerashub.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | """ 19 | Benchmark for training MaxText models via Kithara. 20 | 21 | This benchmark script runs multi-host training with the specified MaxText model 22 | 23 | Metrics: step time, Tokens/s/device 24 | Artifact: Tensorboard, Xplane (Uploaded to BASE_OUTPUT_DIR) 25 | 26 | Purpose: Compare native MaxText performance against performance of MaxText via Kithara. 27 | 28 | Launch Script: python ray/submit_job.py "python perf/kerashub.py" 29 | 30 | TODO: Launch benchmarks via YAML config. 31 | """ 32 | 33 | import ray 34 | 35 | def run_benchmark(): 36 | import os 37 | os.environ["KERAS_BACKEND"] = "jax" 38 | import keras 39 | from examples.example_datasets import example_datasets 40 | from kithara import KerasHubModel 41 | from kithara.dataset import Dataloader, TextCompletionDataset 42 | from kithara.trainer import Trainer 43 | from kithara.distributed import PredefinedShardingStrategy 44 | from kithara.callbacks import Profiler 45 | 46 | # Run parameters 47 | BASE_OUTPUT_DIR = "gs://wenxindong-vm/jan24/perf" # MODIFY with your GS bucket 48 | MODEL_HANDLE = "hf://google/gemma-2-9b" 49 | SEQ_LEN = 2048 50 | PER_DEVICE_BATCH_SIZE = 1 51 | 52 | train_data, eval_data = example_datasets(option="finetune_toy") 53 | 54 | model = KerasHubModel.from_preset( 55 | MODEL_HANDLE, 56 | precision="mixed_bfloat16", 57 | sharding_strategy=PredefinedShardingStrategy(parallelism="fsdp", model="gemma2-9b"), 58 | remat_mode=None, 59 | flash_attention=True, 60 | ) 61 | 62 | # Create Keras optimizer 63 | optimizer = keras.optimizers.AdamW( 64 | learning_rate=5e-5, 65 | weight_decay=0.01, 66 | ) 67 | 68 | # Create Dataset 69 | train_ds = TextCompletionDataset( 70 | source = train_data, 71 | tokenizer_handle=MODEL_HANDLE, 72 | max_seq_len=SEQ_LEN, 73 | ) 74 | 75 | # Create Dataloader 76 | train_dataloader = Dataloader(train_ds, per_device_batch_size=PER_DEVICE_BATCH_SIZE) 77 | 78 | # Create Xprof Profiler 79 | profiler = Profiler( 80 | mode="xplane", 81 | output_path=BASE_OUTPUT_DIR, 82 | max_profile_steps=5, 83 | skip_first_n_steps=5, 84 | optional_postfix="kerashub", 85 | ) 86 | 87 | # Initialize trainer 88 | trainer = Trainer( 89 | model=model, 90 | optimizer=optimizer, 91 | train_dataloader=train_dataloader, 92 | steps=10, 93 | log_steps_interval=1, 94 | tensorboard_dir=BASE_OUTPUT_DIR, 95 | profiler=profiler, 96 | ) 97 | 98 | # Start training 99 | trainer.train() 100 | 101 | 102 | if __name__ == "__main__": 103 | ray.init() 104 | 105 | num_chips_per_host = 4 # 4 for v4 and v5, 8 for v4e and v5e 106 | num_tpu_devices = int(ray.cluster_resources()["TPU"]) 107 | num_tpu_hosts = num_tpu_devices // num_chips_per_host 108 | 109 | @ray.remote(resources={"TPU": num_chips_per_host}) 110 | def main(): 111 | import sys 112 | 113 | # Add the MaxText directory to the Python path 114 | maxtext_dir = "maxtext/MaxText" 115 | sys.path.append(maxtext_dir) 116 | 117 | import subprocess 118 | subprocess.run(["pip install -e .[tpu] --no-deps"], shell=True) 119 | subprocess.run(["pip uninstall keras keras-hub keras-nlp"], shell=True) 120 | subprocess.run(["pip install keras-hub"], shell=True) 121 | subprocess.run(["pip install --upgrade keras-nightly"], shell=True) 122 | subprocess.run(["pip install git+https://github.com/divyashreepathihalli/keras-nlp.git@gemma_fs_bool"], shell=True) 123 | run_benchmark() 124 | 125 | ray.get([main.remote() for _ in range(num_tpu_hosts)]) 126 | 127 | ray.shutdown() 128 | -------------------------------------------------------------------------------- /perf/maxtext_native.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | """ 18 | Benchmark for training MaxText models via MaxText. 19 | 20 | This benchmark script runs multi-host training with the specified MaxText model. 21 | 22 | Metrics: step time, TFLOP/s/device, Tokens/s/device 23 | Artifact: Tensorboard, Xplane (Uploaded to BASE_OUTPUT_DIR) 24 | 25 | Purpose: Compare native MaxText performance against performance of MaxText via Kithara. 26 | 27 | Launch Script: python ray/submit_job.py "python perf/maxtext_native.py" 28 | 29 | TODO: Launch benchmarks via YAML config. 30 | """ 31 | 32 | import ray 33 | 34 | if __name__ == "__main__": 35 | ray.init() 36 | 37 | num_chips_per_host = 4 # 4 for v4 and v5, 8 for v4e and v5e 38 | num_tpu_devices = int(ray.cluster_resources()["TPU"]) 39 | num_tpu_hosts = num_tpu_devices // num_chips_per_host 40 | 41 | print(f"{num_tpu_devices=}") 42 | print(f"{num_tpu_hosts=}") 43 | 44 | @ray.remote(resources={"TPU": num_chips_per_host}) 45 | def main(): 46 | import sys 47 | import subprocess 48 | 49 | # Add the MaxText directory to the Python path 50 | maxtext_dir = "maxtext/MaxText" 51 | sys.path.append(maxtext_dir) 52 | 53 | # Run parameters 54 | BASE_OUTPUT_DIR = "GS_BUCKET" # MODIFY with your GS bucket 55 | MODEL_NAME = "gemma2-9b" 56 | SEQ_LEN = 2048 57 | PER_DEVICE_BATCH_SIZE = 1 58 | 59 | subprocess.call( 60 | [ 61 | "python", 62 | "maxtext/MaxText/train.py", 63 | "kithara/model/maxtext/maxtext/MaxText/configs/base.yml", 64 | f"model_name={MODEL_NAME}", 65 | "run_name=maxtext_native", 66 | f"max_target_length={SEQ_LEN}", 67 | f"per_device_batch_size={PER_DEVICE_BATCH_SIZE}", 68 | "steps=10", 69 | "enable_checkpointing=false", 70 | "dataset_type=synthetic", 71 | "profiler=xplane", 72 | "scan_layers=false", 73 | "skip_first_n_steps_for_profiler=5", 74 | "gcs_metrics=true", 75 | "profiler_steps=5", 76 | "remat_policy=minimal", 77 | "attention=flash", 78 | f"base_output_directory={BASE_OUTPUT_DIR}", 79 | "dataset_path=gs://max-datasets-rogue/", 80 | ] 81 | ) 82 | 83 | ray.get([main.remote() for _ in range(num_tpu_hosts)]) 84 | 85 | ray.shutdown() 86 | -------------------------------------------------------------------------------- /perf/maxtext_via_kithara.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | """ 19 | Benchmark for training MaxText models via Kithara. 20 | 21 | This benchmark script runs multi-host training with the specified MaxText model 22 | 23 | Metrics: step time, Tokens/s/device 24 | Artifact: Tensorboard, Xplane (Uploaded to BASE_OUTPUT_DIR) 25 | 26 | Purpose: Compare native MaxText performance against performance of MaxText via Kithara. 27 | 28 | Launch Script: python ray/submit_job.py "python perf/maxtext_via_kithara.py" 29 | 30 | TODO: Launch benchmarks via YAML config. 31 | """ 32 | 33 | import ray 34 | 35 | def run_benchmark(): 36 | 37 | import os 38 | 39 | os.environ["KERAS_BACKEND"] = "jax" 40 | import keras 41 | from examples.example_datasets import example_datasets 42 | from kithara import MaxTextModel, Trainer, Dataloader, TextCompletionDataset 43 | from kithara.callbacks import Profiler 44 | 45 | # Run parameters 46 | BASE_OUTPUT_DIR = "GS_BUCKET" # MODIFY with your GS bucket 47 | MODEL_NAME = "gemma2-9b" 48 | SEQ_LEN = 2048 49 | PER_DEVICE_BATCH_SIZE = 1 50 | 51 | train_data, eval_data = example_datasets(option="finetune_toy") 52 | 53 | # Create a randomly initialized MaxText Model 54 | model = MaxTextModel.from_random( 55 | model_name=MODEL_NAME, 56 | seq_len=SEQ_LEN, 57 | per_device_batch_size=PER_DEVICE_BATCH_SIZE, 58 | scan_layers=False, 59 | maxtext_config_args={"remat_policy": "minimal"}, 60 | ) 61 | 62 | # Create Keras optimizer 63 | optimizer = keras.optimizers.AdamW( 64 | learning_rate=5e-5, 65 | weight_decay=0.01, 66 | ) 67 | 68 | # Create Dataset 69 | train_ds = TextCompletionDataset( 70 | source=train_data, 71 | tokenizer_handle="hf://google/gemma-2-2b", 72 | max_seq_len=SEQ_LEN, 73 | ) 74 | 75 | # Create Dataloader 76 | train_dataloader = Dataloader(train_ds, per_device_batch_size=PER_DEVICE_BATCH_SIZE) 77 | 78 | # Create Xprof Profiler 79 | profiler = Profiler( 80 | mode="xplane", 81 | output_path=BASE_OUTPUT_DIR, 82 | max_profile_steps=5, 83 | skip_first_n_steps=5, 84 | optional_postfix="maxtext_via_kithara", 85 | ) 86 | 87 | # Initialize trainer 88 | trainer = Trainer( 89 | model=model, 90 | optimizer=optimizer, 91 | train_dataloader=train_dataloader, 92 | steps=10, 93 | log_steps_interval=1, 94 | tensorboard_dir=BASE_OUTPUT_DIR, 95 | profiler=profiler, 96 | ) 97 | 98 | # Start training 99 | trainer.train() 100 | 101 | 102 | if __name__ == "__main__": 103 | ray.init() 104 | 105 | num_chips_per_host = 4 # 4 for v4 and v5, 8 for v4e and v5e 106 | num_tpu_devices = int(ray.cluster_resources()["TPU"]) 107 | num_tpu_hosts = num_tpu_devices // num_chips_per_host 108 | 109 | @ray.remote(resources={"TPU": num_chips_per_host}) 110 | def main(): 111 | import sys 112 | 113 | # Add the MaxText directory to the Python path 114 | maxtext_dir = "maxtext/MaxText" 115 | sys.path.append(maxtext_dir) 116 | 117 | run_benchmark() 118 | 119 | ray.get([main.remote() for _ in range(num_tpu_hosts)]) 120 | 121 | ray.shutdown() 122 | -------------------------------------------------------------------------------- /perf/sft_unsloth.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | """ 18 | Benchmark for SFT via unsloth. 19 | 20 | This benchmark script runs supervised fine tuning using LoRA with the 21 | specified model. 22 | 23 | Metrics: train_samples_per_second, eval_samples_per_second 24 | 25 | Purpose: Compare performance of SFT using unsloth. 26 | 27 | set up: 28 | - Create 2 X A100 40Gb GPU 29 | - pip install unsloth tensorboardx 30 | 31 | Launch Script: python kithara/perf/sft_unsloth.py 32 | """ 33 | 34 | from unsloth import FastLanguageModel 35 | import torch 36 | 37 | max_seq_length = 4096 38 | 39 | model, tokenizer = FastLanguageModel.from_pretrained( 40 | model_name="unsloth/gemma-2-2b", 41 | max_seq_length=max_seq_length, 42 | dtype=torch.bfloat16, 43 | load_in_4bit=False, 44 | ) 45 | 46 | model = FastLanguageModel.get_peft_model( 47 | model, 48 | r = 16, 49 | target_modules = ["q_proj", "v_proj"], 50 | lora_alpha = 16, 51 | lora_dropout = 0, 52 | bias = "none", 53 | # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes! 54 | use_gradient_checkpointing = False, # True or "unsloth" for very long context 55 | random_state = 3407 56 | ) 57 | 58 | alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 59 | 60 | ### Instruction: 61 | {} 62 | 63 | ### Input: 64 | {} 65 | 66 | ### Response: 67 | {}""" 68 | 69 | def formatting_prompts_func(examples): 70 | EOS_TOKEN = tokenizer.eos_token 71 | instructions = examples["instruction"] 72 | inputs = examples["input"] 73 | outputs = examples["output"] 74 | texts = [] 75 | for instruction, input, output in zip(instructions, inputs, outputs): 76 | # Must add EOS_TOKEN, otherwise your generation will go on forever! 77 | text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN 78 | texts.append(text) 79 | return { "text" : texts, } 80 | 81 | from datasets import load_dataset 82 | dataset = load_dataset("yahma/alpaca-cleaned", split = "train") 83 | dataset = dataset.map(formatting_prompts_func, batched = True,) 84 | datasets = dataset.train_test_split(test_size=64, shuffle=False) 85 | train_dataset, test_dataset = datasets["train"], datasets["test"] 86 | 87 | from trl import SFTTrainer 88 | from transformers import TrainingArguments 89 | from unsloth import is_bfloat16_supported 90 | 91 | 92 | trainer = SFTTrainer( 93 | model=model, 94 | tokenizer=tokenizer, 95 | train_dataset=train_dataset, 96 | eval_dataset=test_dataset, 97 | dataset_text_field="text", 98 | max_seq_length=max_seq_length, 99 | dataset_num_proc=2, 100 | packing=True, # Can make training 5x faster for short sequences. 101 | args=TrainingArguments( 102 | per_device_train_batch_size=2, 103 | do_predict=True, 104 | per_device_eval_batch_size=2, 105 | max_steps=500, 106 | eval_strategy="steps", 107 | eval_steps=100, 108 | learning_rate=2e-4, 109 | fp16=not is_bfloat16_supported(), 110 | bf16=is_bfloat16_supported(), #This is True 111 | logging_steps=10, 112 | optim="adamw_torch", 113 | weight_decay=0.01, 114 | lr_scheduler_type="constant", 115 | seed=3407, 116 | output_dir="unsloth_outputs", 117 | report_to="tensorboard", 118 | ), 119 | ) 120 | 121 | trainer_stats = trainer.train() 122 | 123 | FastLanguageModel.for_inference(model) 124 | inputs = tokenizer( 125 | [ 126 | alpaca_prompt.format( 127 | "Continue the fibonnaci sequence.", # instruction 128 | "1, 1, 2, 3, 5, 8", # input 129 | "", # output - leave this blank for generation! 130 | ) 131 | ], return_tensors = "pt").to("cuda") 132 | 133 | outputs = model.generate(**inputs, max_new_tokens = 64, use_cache = True) 134 | print("generated", tokenizer.batch_decode(outputs)) 135 | 136 | -------------------------------------------------------------------------------- /perf/trl_packing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | """ 18 | Benchmark for full parameter finetuning via trl. 19 | 20 | Metrics: train_samples_per_second 21 | 22 | Purpose: Compare perforance of full parameter finetuning using trl. 23 | 24 | Launch Script: python kithara/perf/trl_packing.py 25 | """ 26 | 27 | 28 | """ 29 | Set up: 30 | pip install trl peft transformers datasets hf_transfer huggingface_hub[cli] 31 | 32 | Run this script on single host: HF_HUB_ENABLE_HF_TRANSFER=1 python kithara_trl_comparison/trl_train.py 33 | """ 34 | 35 | import torch._dynamo 36 | torch._dynamo.config.suppress_errors = True 37 | 38 | from datasets import load_dataset 39 | import torch 40 | from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments 41 | from peft import LoraConfig 42 | from trl import SFTTrainer, SFTConfig 43 | 44 | model_id = "google/gemma-2-2b" 45 | tokenizer = AutoTokenizer.from_pretrained(model_id) 46 | alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 47 | 48 | ### Instruction: 49 | {} 50 | 51 | ### Input: 52 | {} 53 | 54 | ### Response: 55 | {}""" 56 | 57 | def formatting_prompts_func(examples): 58 | EOS_TOKEN = tokenizer.eos_token 59 | instructions = examples["instruction"] 60 | inputs = examples["input"] 61 | outputs = examples["output"] 62 | texts = [] 63 | for instruction, input, output in zip(instructions, inputs, outputs): 64 | # Must add EOS_TOKEN, otherwise your generation will go on forever! 65 | text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN 66 | texts.append(text) 67 | return { "text" : texts, } 68 | 69 | from datasets import load_dataset 70 | dataset = load_dataset("yahma/alpaca-cleaned", split = "train") 71 | dataset = dataset.map(formatting_prompts_func, batched = True,) 72 | datasets = dataset.train_test_split(test_size=200, shuffle=False) 73 | train_dataset, test_dataset = datasets["train"], datasets["test"] 74 | 75 | model = AutoModelForCausalLM.from_pretrained( 76 | model_id, torch_dtype=torch.bfloat16 77 | ) 78 | 79 | trainer = SFTTrainer( 80 | model = model, 81 | tokenizer = tokenizer, 82 | train_dataset=dataset, 83 | eval_dataset = test_dataset, 84 | args = SFTConfig( 85 | dataset_text_field="text", 86 | per_device_train_batch_size=2, 87 | max_steps=1000, 88 | logging_steps=10, 89 | # eval_strategy="steps", 90 | # eval_steps=100, 91 | learning_rate = 2e-4, 92 | optim = "adamw_torch", 93 | lr_scheduler_type="constant", 94 | bf16=True, 95 | seed = 3407, 96 | report_to = "tensorboard", 97 | max_seq_length = 8192, 98 | dataset_num_proc = 1, 99 | packing = False, 100 | output_dir = "trl_packing/no_packing" 101 | ), 102 | ) 103 | 104 | trainer.train() 105 | 106 | inputs = tokenizer( 107 | [ 108 | alpaca_prompt.format( 109 | "Continue the fibonnaci sequence.", # instruction 110 | "1, 1, 2, 3, 5, 8", # input 111 | "", # output - leave this blank for generation! 112 | ) 113 | ], return_tensors = "pt").to("cuda") 114 | 115 | outputs = model.generate(**inputs, max_new_tokens = 64, use_cache = True) 116 | print("generated", tokenizer.batch_decode(outputs)) 117 | 118 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["flit_core >=3.2,<4"] 3 | build-backend = "flit_core.buildapi" 4 | 5 | [project] 6 | name = "kithara" 7 | description = "LLM post-training library" 8 | version = "0.0.10" 9 | readme = "README.md" 10 | license = { file = 'LICENSE' } 11 | authors = [{ name = 'Kithara Authors' }] 12 | requires-python = ">= 3.11" 13 | dependencies = [ 14 | "flax>=0.7.0", 15 | "datasets", 16 | "huggingface-hub", 17 | "keras>=3.8.0", 18 | "transformers>=4.45.1", 19 | "keras-hub>=0.18.1", 20 | "google-api-python-client", 21 | "google-auth-httplib2", 22 | "google-auth-oauthlib", 23 | "ray[default]==2.40.0", 24 | "jax[cpu]", 25 | "peft", 26 | "hf_transfer", 27 | "tabulate", 28 | "cryptography", 29 | "huggingface_hub", 30 | # MaxText dependencies 31 | "aqtp", 32 | "grain-nightly", 33 | "orbax-checkpoint>=0.10.3", 34 | "google-cloud-logging", 35 | "tensorboardx", 36 | "ml-collections", 37 | "tensorflow_datasets", 38 | "sentencepiece", 39 | "tiktoken", 40 | "cloud-accelerator-diagnostics", 41 | "cloud-tpu-diagnostics", 42 | "ml-goodput-measurement", 43 | "google-cloud-monitoring", 44 | "omegaconf", 45 | # Jetstream dependencies 46 | "setuptools==61.0", 47 | "jaxtyping", 48 | "clu", 49 | "editdistance", 50 | "pyglove", 51 | "tensorflow_datasets", 52 | "tfds-nightly" #"tfds-nightly==4.9.8.dev202503240044" 53 | ] 54 | 55 | [project.optional-dependencies] 56 | tpu = ["jax[tpu]", "torch==2.4.0+cpu"] 57 | gpu = ["jax[cuda]", "torch==2.4.0"] 58 | cpu = ["jax[cpu]", "torch==2.4.0"] 59 | dev = [ 60 | "twine", 61 | "flit", 62 | "sphinx==8.2.0", 63 | "sphinx-autobuild", 64 | "sphinxawesome-theme>=5.3.2", 65 | "sphinx_design", 66 | ] 67 | 68 | [tool.flit.module] 69 | name = "kithara" 70 | 71 | [project.urls] 72 | Homepage = 'https://github.com/AI-Hypercomputer/kithara' 73 | Repository = 'https://github.com/AI-Hypercomputer/kithara' 74 | Documentation = "https://kithara.readthedocs.io/en/latest/index.html" 75 | 76 | [tool.flit.sdist] 77 | exclude = [ 78 | "kithara/model/maxtext/maxtext/MaxText/test_assets/", 79 | "**/*.Dockerfile", 80 | ] 81 | -------------------------------------------------------------------------------- /ray/GPU/cluster.yaml: -------------------------------------------------------------------------------- 1 | cluster_name: my-gpu-cluster # MODIFY choose your cluster name 2 | max_workers: 100 3 | 4 | provider: 5 | type: gcp 6 | region: us-central1 #MODIFY this with your region 7 | availability_zone: us-central1-a #MODIFY this with your availability zone 8 | project_id: #MODIFY this with your project id 9 | 10 | head_node_type: ray_head_node 11 | 12 | available_node_types: 13 | ray_head_node: 14 | node_config: 15 | machineType: n1-standard-8 16 | disks: 17 | - boot: true 18 | autoDelete: true 19 | type: PERSISTENT 20 | initializeParams: 21 | diskSizeGb: 50 22 | sourceImage: projects/ubuntu-os-cloud/global/images/family/ubuntu-2204-lts 23 | metadata: 24 | install-nvidia-driver: 'false' 25 | serviceAccounts: 26 | - email: # MODIFY this with your service account email 27 | scopes: 28 | - https://www.googleapis.com/auth/cloud-platform 29 | resources: {"CPU": 8} 30 | min_workers: 0 31 | max_workers: 0 32 | 33 | ray_worker_node: 34 | node_config: 35 | machineType: g2-standard-48 # Required for L4 GPUs 36 | guestAccelerators: 37 | - acceleratorType: nvidia-l4 # Use L4 GPUs 38 | acceleratorCount: 4 # Other GPU generations may have different accelerator counts 39 | disks: 40 | - boot: true 41 | autoDelete: true 42 | type: PERSISTENT 43 | initializeParams: 44 | diskSizeGb: 50 45 | sourceImage: projects/ubuntu-os-cloud/global/images/family/ubuntu-2204-lts 46 | metadata: 47 | install-nvidia-driver: 'true' 48 | serviceAccounts: 49 | - email: # MODIFY this with your service account email 50 | scopes: 51 | - https://www.googleapis.com/auth/cloud-platform 52 | scheduling: 53 | onHostMaintenance: "TERMINATE" # Disable live migration 54 | resources: {"CPU": 48, "GPU": 4} # Should match number of CPUs and GPUs of worker node 55 | min_workers: 1 56 | max_workers: 4 57 | 58 | initialization_commands: 59 | # Don't stall on ubuntu graphic... 60 | - sudo sed -i 's/#$nrconf{restart} = '"'"'i'"'"';/$nrconf{restart} = '"'"'a'"'"';/g' /etc/needrestart/needrestart.conf 61 | - sudo add-apt-repository -y ppa:deadsnakes/ppa 62 | - sudo apt-get update 63 | - sudo apt-get install -y python3.11 64 | - sudo apt-get install -y python3-pip python-is-python3 65 | - sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1 66 | - sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1 67 | - echo 'export PATH="$HOME/.local/bin:$PATH"' >> ~/.bashrc 68 | - source ~/.bashrc 69 | - python -m pip install --upgrade pip 70 | 71 | # Install dependencies needed by both the head and worker nodes 72 | # If you missed some dependencies during the set up process, you can 73 | # install them later during runtime. 74 | setup_commands: 75 | - pip install "ray[default]==2.40.0" 76 | 77 | # Install dependecies needed only by the head node 78 | head_setup_commands: 79 | - pip install google-api-python-client 80 | - pip install -U kithara[cpu] 81 | 82 | # Ensure Python, CUDA, and PyTorch are installed on the worker nodes 83 | worker_setup_commands: 84 | # Install CUDA 85 | - wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-ubuntu2204.pin 86 | - sudo mv cuda-ubuntu2204.pin /etc/apt/preferences.d/cuda-repository-pin-600 87 | - wget https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb 88 | - sudo dpkg -i cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb 89 | - sudo cp /var/cuda-repo-ubuntu2204-12-4-local/cuda-*-keyring.gpg /usr/share/keyrings/ 90 | - sudo apt-get update 91 | - sudo apt-get -y install cuda-toolkit-12-4 92 | - sudo apt-get install -y cuda-drivers-550 93 | - sudo nvidia-smi # allows for sudoless nvidia-smi 94 | - pip install -U kithara[gpu] -------------------------------------------------------------------------------- /ray/GPU/find_devices.py: -------------------------------------------------------------------------------- 1 | # Test your GPU Ray cluster set up with 2 | # ray job submit --working-dir . -- python3 ray/GPU/find_devices.py 3 | 4 | import ray 5 | 6 | ray.init() 7 | 8 | num_chips_per_host = 4 9 | num_gpu_devices = int(ray.cluster_resources()["GPU"]) 10 | print(f"{num_gpu_devices=}") 11 | 12 | @ray.remote(num_gpus=num_chips_per_host) 13 | def main(): 14 | import jax 15 | print(jax.devices()) 16 | 17 | ray.get([main.remote() for _ in range(num_gpu_devices)]) 18 | 19 | ray.shutdown() 20 | -------------------------------------------------------------------------------- /ray/TPU/GCE/cluster.yaml: -------------------------------------------------------------------------------- 1 | # A unique identifier for the head node and workers of this cluster. 2 | cluster_name: ray-cluster # MODIFY choose your cluster name 3 | max_workers: 1 4 | available_node_types: 5 | # The ray head node is a CPU node 6 | ray_head_default: 7 | min_workers: 0 8 | max_workers: 0 9 | resources: { "CPU": 0 } # Don't change this 10 | # Provider-specific config for this node type, e.g. instance type. By default 11 | # Ray will auto-configure unspecified fields such as subnets and ssh-keys. 12 | # For more documentation on available fields, see: 13 | # https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert 14 | node_config: 15 | machineType: n1-standard-16 16 | disks: 17 | - boot: true 18 | autoDelete: true 19 | type: PERSISTENT 20 | initializeParams: 21 | diskSizeGb: 50 # Increase this if you need more disk space 22 | # See https://cloud.google.com/compute/docs/images for more images 23 | sourceImage: projects/ubuntu-os-cloud/global/images/family/ubuntu-2204-lts 24 | # Uncomment and modify the following if you want to use a custom VPC network 25 | 26 | # networkInterfaces: 27 | # - kind: compute#networkInterface 28 | # subnetwork: https://www.googleapis.com/compute/v1/projects/{{project_name}}/regions/{{zone}}/subnetworks/{{subnet_name}} 29 | # accessConfigs: 30 | # - kind: compute#accessConfig 31 | # networkTier: STANDARD 32 | # name: "External NAT" 33 | # type: "ONE_TO_ONE_NAT" 34 | 35 | # Additional supported GCP configs: https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert 36 | 37 | # They ray worker nodes are TPU nodes 38 | ray_tpu: 39 | min_workers: 1 40 | max_workers: 1 41 | resources: { "TPU": 4 } # MODIFY this to be the number of chips per host for your TPU type. Should be 4 for v4 and v5 and 8 for v4e and v5e 42 | node_config: 43 | acceleratorType: v4-16 # MODIFY this to be your TPU type 44 | runtimeVersion: tpu-ubuntu2204-base 45 | provider: # MODIFY this section with your provider spec 46 | type: gcp 47 | region: us-central2 #MODIFY this with your region 48 | availability_zone: us-central2-b #MODIFY this with your availability_zone 49 | project_id: project_id #MODIFY this with your project id 50 | 51 | initialization_commands: 52 | # Don't stall on ubuntu graphic... 53 | - sudo sed -i 's/#$nrconf{restart} = '"'"'i'"'"';/$nrconf{restart} = '"'"'a'"'"';/g' /etc/needrestart/needrestart.conf 54 | - sudo add-apt-repository -y ppa:deadsnakes/ppa 55 | - sudo apt-get update 56 | - sudo apt-get install -y python3.11 57 | - sudo apt-get install -y python3-pip python-is-python3 58 | - sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1 59 | - sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1 60 | - echo 'export PATH="$HOME/.local/bin:$PATH"' >> ~/.bashrc 61 | - source ~/.bashrc 62 | - python -m pip install --upgrade pip 63 | 64 | 65 | # Install dependencies needed by both the head and worker nodes 66 | # If you missed some dependencies during the set up process, you can 67 | # install them later during runtime. 68 | setup_commands: 69 | - pip install "ray[default]==2.40.0" 70 | 71 | # Install dependecies needed only by the head node 72 | head_setup_commands: 73 | - pip install google-api-python-client 74 | - pip install -U kithara[cpu] 75 | 76 | # Install dependencies needed only by the worker node 77 | worker_setup_commands: 78 | - pip install -U kithara[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://download.pytorch.org/whl/cpu 79 | 80 | # Specify the node type of the head node (as configured above). 81 | head_node_type: ray_head_default -------------------------------------------------------------------------------- /ray/TPU/GKE/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM rayproject/ray:2.42.0-py311 2 | 3 | RUN python -m pip install --upgrade pip 4 | 5 | RUN pip install -U kithara[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://download.pytorch.org/whl/cpu 6 | 7 | USER root 8 | -------------------------------------------------------------------------------- /ray/TPU/QR/cluster.yaml: -------------------------------------------------------------------------------- 1 | # A unique identifier for the head node and workers of this cluster. 2 | cluster_name: wenxindong-qr-cluster-dws-4 # MODIFY choose your cluster name 3 | max_workers: 1 4 | available_node_types: 5 | # The ray head node is a CPU node 6 | ray_head_default: 7 | min_workers: 0 8 | max_workers: 0 9 | resources: { "CPU": 0 } # Don't change this 10 | # Provider-specific config for this node type, e.g. instance type. By default 11 | # Ray will auto-configure unspecified fields such as subnets and ssh-keys. 12 | # For more documentation on available fields, see: 13 | # https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert 14 | node_config: 15 | machineType: n1-standard-16 16 | disks: 17 | - boot: true 18 | autoDelete: true 19 | type: PERSISTENT 20 | initializeParams: 21 | diskSizeGb: 50 # Increase this if you need more disk space 22 | # See https://cloud.google.com/compute/docs/images for more images 23 | sourceImage: projects/ubuntu-os-cloud/global/images/family/ubuntu-2204-lts 24 | 25 | # Uncomment and modify the following if you want to use a custom VPC network 26 | 27 | # networkInterfaces: 28 | # - kind: compute#networkInterface 29 | # subnetwork: https://www.googleapis.com/compute/v1/projects/{{project_name}}/regions/{{zone}}/subnetworks/{{subnet_name}} 30 | # accessConfigs: 31 | # - kind: compute#accessConfig 32 | # networkTier: STANDARD 33 | # name: "External NAT" 34 | # type: "ONE_TO_ONE_NAT" 35 | 36 | # Additional supported GCP configs: https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert 37 | 38 | provider: # MODIFY this section with your provider spec 39 | type: gcp 40 | region: us-west4 #MODIFY this with your region 41 | availability_zone: us-west4-a #MODIFY this with your availability_zone 42 | project_id: tpu-prod-env-multipod #MODIFY this with your project id 43 | 44 | initialization_commands: 45 | # Don't stall on ubuntu graphic... 46 | - sudo sed -i 's/#$nrconf{restart} = '"'"'i'"'"';/$nrconf{restart} = '"'"'a'"'"';/g' /etc/needrestart/needrestart.conf 47 | - sudo add-apt-repository -y ppa:deadsnakes/ppa 48 | - sudo apt-get update 49 | - sudo apt-get install -y python3.11 50 | - sudo apt-get install -y python3-pip python-is-python3 51 | - sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1 52 | - sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1 53 | - echo 'export PATH="$HOME/.local/bin:$PATH"' >> ~/.bashrc 54 | - source ~/.bashrc 55 | - python -m pip install --upgrade pip 56 | 57 | # Install dependencies needed by both the head and worker nodes 58 | # If you missed some dependencies during the set up process, you can 59 | # install them later during runtime. 60 | setup_commands: 61 | - pip install "ray[default]==2.40.0" 62 | 63 | # Install dependecies needed only by the head node 64 | head_setup_commands: 65 | - pip install google-api-python-client 66 | - pip install -U kithara[cpu]==0.0.8 67 | 68 | # Specify the node type of the head node (as configured above). 69 | head_node_type: ray_head_default -------------------------------------------------------------------------------- /ray/TPU/QR/qr_worker_startup_script.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | # Function to check and release dpkg lock 5 | check_dpkg_lock() { 6 | if sudo fuser -s /var/lib/dpkg/lock-frontend ; then 7 | sudo systemctl stop unattended-upgrades 8 | fi 9 | } 10 | 11 | check_dpkg_lock 12 | sudo sed -i 's/#$nrconf{restart} = '"'"'i'"'"';/$nrconf{restart} = '"'"'a'"'"';/g' /etc/needrestart/needrestart.conf 13 | 14 | check_dpkg_lock 15 | sudo add-apt-repository -y ppa:deadsnakes/ppa 16 | 17 | check_dpkg_lock 18 | sudo apt-get update 19 | 20 | check_dpkg_lock 21 | sudo apt-get install -y python3.11 22 | 23 | check_dpkg_lock 24 | sudo apt-get install -y python3-pip python-is-python3 25 | 26 | check_dpkg_lock 27 | sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1 28 | sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1 29 | echo 'export PATH="$HOME/.local/bin:$PATH"' >> ~/.bashrc 30 | source ~/.bashrc 31 | 32 | check_dpkg_lock 33 | sudo python -m pip install --upgrade pip 34 | 35 | check_dpkg_lock 36 | sudo pip install -U kithara[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://download.pytorch.org/whl/cpu 37 | -------------------------------------------------------------------------------- /ray/submit_job.py: -------------------------------------------------------------------------------- 1 | from ray.job_submission import JobSubmissionClient, JobStatus 2 | import time 3 | import argparse 4 | 5 | 6 | def wait_until_status(client, job_id, status_to_wait_for, timeout_seconds=60): 7 | start = time.time() 8 | while time.time() - start <= timeout_seconds: 9 | status = client.get_job_status(job_id) 10 | print(f"status: {status}") 11 | if status in status_to_wait_for: 12 | break 13 | time.sleep(1) 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser( 18 | description="Submit a Ray job with specified entrypoint" 19 | ) 20 | parser.add_argument( 21 | "job_file", 22 | type=str, 23 | help='The entrypoint script to run (e.g., "python examples/multihost/ray/TPU/sft_lora_example.py")', 24 | ) 25 | parser.add_argument( 26 | "--ray-address", 27 | type=str, 28 | default="http://localhost:8265", 29 | help="Ray cluster address (default: http://localhost:8265)", 30 | ) 31 | parser.add_argument( 32 | "--hf-token", 33 | type=str, 34 | default="", 35 | help="Optional. Your Huggingface token if HuggingFace login is need. ", 36 | ) 37 | 38 | return parser.parse_args() 39 | 40 | 41 | def main(): 42 | args = parse_args() 43 | client = JobSubmissionClient(args.ray_address) 44 | print("Initialized job client") 45 | 46 | job_id = client.submit_job( 47 | entrypoint=args.job_file, 48 | runtime_env={ 49 | "working_dir": "./", 50 | # These files are too large and are safe to exclude 51 | "excludes": ["**/test_assets/", ".git/", "*.Dockerfile", "*.pkl"], 52 | "pip": [], # Any missing dependencies to run job goes here 53 | "env_vars": {"HF_TOKEN": args.hf_token, "HF_HUB_ENABLE_HF_TRANSFER": "1"}, 54 | }, 55 | ) 56 | 57 | print(f"Submitted job with {job_id=}") 58 | 59 | wait_until_status( 60 | client, job_id, {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED} 61 | ) 62 | 63 | # See full logs in the Ray Dashboard 64 | logs = client.get_job_logs(job_id) 65 | print(logs) 66 | 67 | if __name__ == "__main__": 68 | 69 | main() 70 | 71 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Testing Guide 2 | 3 | ## Overview 4 | This guide explains how to add and run unit tests. All tests are located in the `/tests` directory, which mirrors the structure of the `/kithara` directory. Test files are excluded from the distributed wheel package. 5 | 6 | ## Test Discovery 7 | Tests are automatically discovered by `unittest` when they meet these criteria: 8 | - Files are named with the pattern `test_*.py` 9 | - Test functions start with `test_` 10 | 11 | ## Test Categories 12 | We use three categories of tests, controlled by environment variables: 13 | 14 | ### 1. Light Tests 15 | - Default test category 16 | - Quick to execute (under 3 minutes) 17 | - Sanity checks for basic correctness 18 | - No special decorator needed 19 | 20 | ### 2. Heavy Tests 21 | - Takes longer than 3 minutes to execute 22 | - Should run before merging to main branch 23 | - Requires following decorator: 24 | ```python 25 | @unittest.skipIf(int(os.getenv('RUN_LIGHT_TESTS_ONLY', 0)) == 1, "Heavy Test") 26 | ``` 27 | 28 | ### 3. Manual Tests 29 | - Tests model/task correctness 30 | - Results unlikely to change 31 | - Run on-demand only 32 | - Requires decorator: 33 | ```python 34 | @unittest.skipIf(int(os.getenv('RUN_SKIPPED_TESTS', 0)) != 1, "Manual Test") 35 | ``` 36 | 37 | ## Choosing Test Categories 38 | 39 | Use this decision flow to categorize your tests: 40 | 41 | 1. **Manual Test** if: 42 | - Tests numerical correctness, such as checkpoint conversion correctness. 43 | - Results should be stable across runs 44 | - Requires significant computation time 45 | 46 | 2. **Heavy Test** if: 47 | - Must run before main branch merges 48 | - Takes >3 minutes to execute 49 | - Tests critical functionality 50 | 51 | 3. **Light Test** if: 52 | - Quick to execute 53 | - Tests syntax, basic functionality 54 | - Should run frequently during development 55 | 56 | ## Environment Variables 57 | 58 | | Variable | Purpose | Values | 59 | |----------|----------|---------| 60 | | `RUN_LIGHT_TESTS_ONLY` | Run only light tests | `1` = light tests only, `0` = include heavy tests | 61 | | `RUN_SKIPPED_TESTS` | Include manual tests | `1` = run all tests, `0` = skip manual tests | 62 | 63 | ## Running Tests 64 | 65 | ### Quick Development Check (Light Tests Only) 66 | ```bash 67 | RUN_LIGHT_TESTS_ONLY=1 python -m unittest 68 | ``` 69 | 70 | ### Pre-merge Check (Light + Heavy Tests) 71 | ```bash 72 | python -m unittest 73 | ``` 74 | 75 | ### Full Test Suite (All Tests) 76 | ```bash 77 | RUN_SKIPPED_TESTS=1 python -m unittest 78 | ``` 79 | 80 | ## Best Practices 81 | 82 | 1. **Test Organization**: Group related tests in the same file, following the source code structure 83 | 84 | 2. **Documentation**: Include docstrings at the top of every test file to describe the purpose the tests and how to run the tests in a standlone manner. E.g. 85 | ``` 86 | These tests validate that the MaxText implementations produces logits that are 87 | numerically comparable to the reference HuggingFace implementation. 88 | 89 | Usage: 90 | Run script on single host VM: RUN_SKIPPED_TESTS=1 python -m unittest tests/model/maxtext/ckpt_compatibility/test_loading_models.py 91 | ``` 92 | 93 | ## Contributing Tests 94 | 95 | 1. Create test file in appropriate `/tests` subdirectory 96 | 2. Choose appropriate test category 97 | 3. Add necessary decorators 98 | 4. Include clear docstrings and comments 99 | 5. Verify tests pass 100 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/tests/__init__.py -------------------------------------------------------------------------------- /tests/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/tests/callbacks/__init__.py -------------------------------------------------------------------------------- /tests/callbacks/test_orbax_checkpointer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | """End to end test for model checkpointing correctness 18 | 19 | Run script on single host VM: python -m unittest tests/callbacks/test_orbax_checkpointer.py 20 | """ 21 | 22 | import os 23 | 24 | os.environ["KERAS_BACKEND"] = "jax" 25 | 26 | import shutil 27 | import unittest 28 | import numpy as np 29 | from kithara import KerasHubModel, MaxTextModel 30 | from kithara.callbacks.checkpointer import Checkpointer 31 | from kithara.utils.gcs_utils import find_cache_root_dir 32 | from tests.model.utils import check_arrays_match 33 | import jax 34 | 35 | 36 | class TestOrbaxCheckpointing(unittest.TestCase): 37 | @classmethod 38 | def setUpClass(cls): 39 | cls.TMP_DIR = os.path.join(find_cache_root_dir(), "test/ckpt") 40 | 41 | def setUp(self): 42 | shutil.rmtree(self.TMP_DIR, ignore_errors=True) 43 | 44 | def tearDown(self): 45 | shutil.rmtree(self.TMP_DIR, ignore_errors=True) 46 | 47 | def _test_checkpoint_save_load(self, model, model_input): 48 | """Test that model state is preserved after save and load operations.""" 49 | # Get initial model output 50 | logits_before, _ = model.stateless_call( 51 | model.trainable_variables, 52 | model.non_trainable_variables, 53 | model_input, 54 | ) 55 | 56 | # Save and load model 57 | checkpointer = Checkpointer(self.TMP_DIR, model=model) 58 | checkpointer.save(0, blocking=True) 59 | checkpointer.load() 60 | 61 | # Get model output after loading 62 | logits_after, _ = model.stateless_call( 63 | model.trainable_variables, 64 | model.non_trainable_variables, 65 | model_input, 66 | ) 67 | 68 | # Verify outputs match within tolerance 69 | try: 70 | check_arrays_match(logits_before, logits_after, atol=0.0001) 71 | except AssertionError as e: 72 | self.fail(f"Model outputs don't match after save/load: {str(e)}") 73 | 74 | def test_kerashub_models(self): 75 | # Initialize model 76 | model = KerasHubModel.from_preset( 77 | "hf://google/gemma-2-2b", 78 | lora_rank=6 79 | ) 80 | 81 | # Prepare test input 82 | model_input = { 83 | "token_ids": np.array([[1, 2, 3]]), 84 | "padding_mask": np.array([[1, 1, 1]]), 85 | } 86 | self._test_checkpoint_save_load(model, model_input) 87 | 88 | def test_maxtext_model(self): 89 | model = MaxTextModel.from_random( 90 | model_name="default", seq_len=100, per_device_batch_size=1, scan_layers=True 91 | ) 92 | 93 | model_input = { 94 | "tokens": np.array([[i for i in range(100)]] * jax.device_count()), 95 | "positions": np.array([[i for i in range(100)]] * jax.device_count()), 96 | "segment_ids": np.array([[1 for _ in range(100)]] * jax.device_count()), 97 | } 98 | self._test_checkpoint_save_load(model, model_input) 99 | 100 | 101 | if __name__ == "__main__": 102 | unittest.main(verbosity=2) 103 | -------------------------------------------------------------------------------- /tests/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/tests/dataset/__init__.py -------------------------------------------------------------------------------- /tests/dataset/test_dataloader_creation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | """Unit tests for creating Kithara Dataloaders 18 | 19 | Run test on a TPU VM: python -m unittest tests/dataset/test_dataloader_creation.py 20 | """ 21 | 22 | import unittest 23 | from kithara import TextCompletionDataset, Dataloader 24 | import time 25 | import unittest.result 26 | import tests.dataset.utils as dataset_utils 27 | import numpy as np 28 | import os 29 | 30 | class TestDataloaderCreation(unittest.TestCase): 31 | 32 | def setUp(self): 33 | print(f"\nStarting test: {self._testMethodName}") 34 | self.start_time = time.time() 35 | 36 | def tearDown(self): 37 | duration = time.time() - self.start_time 38 | print(f"Completed test: {self._testMethodName} in {duration:.2f} seconds\n") 39 | 40 | def _run_test(self, data_source): 41 | dataset = TextCompletionDataset( 42 | data_source, 43 | tokenizer_handle="hf://google/gemma-2-2b", 44 | model_type="KerasHub", 45 | max_seq_len=100, 46 | ) 47 | dataloader = Dataloader(dataset, per_device_batch_size=2) 48 | for batch in dataloader: 49 | self.assertTrue(isinstance(batch, dict)) 50 | self.assertTrue("x" in batch) 51 | self.assertTrue("y" in batch) 52 | self.assertTrue("token_ids" in batch["x"]) 53 | self.assertTrue("padding_mask" in batch["x"]) 54 | self.assertTrue(isinstance(batch["x"]["token_ids"], np.ndarray)) 55 | self.assertTrue(isinstance(batch["x"]["padding_mask"], np.ndarray)) 56 | self.assertTrue(isinstance(batch["y"], np.ndarray)) 57 | self.assertEqual(len(batch["x"]["token_ids"].shape), 2) 58 | break 59 | 60 | def test_dataloader_drops_last_batch(self): 61 | n = 500 62 | dict_dataset = dataset_utils.create_dict_ray_dataset(n) 63 | dataset = TextCompletionDataset( 64 | dict_dataset, 65 | tokenizer_handle="hf://google/gemma-2-2b", 66 | ) 67 | dataloader = Dataloader(dataset, per_device_batch_size=n+1) 68 | self.assertEqual(len(dataloader), 0) 69 | 70 | dataloader = Dataloader(dataset, per_device_batch_size=3) 71 | batch_size = dataloader.per_host_batch_size 72 | 73 | self.assertEqual(len(dataloader), n//batch_size) 74 | total = 0 75 | for batch in dataloader: 76 | total += 1 77 | self.assertEqual(total, n//batch_size) 78 | 79 | 80 | def test_creating_dataloader_for_fixed_data(self): 81 | dict_dataset = dataset_utils.create_dict_ray_dataset() 82 | self._run_test(dict_dataset) 83 | 84 | @unittest.skipIf(int(os.getenv('RUN_LIGHT_TESTS_ONLY', 0)) == 1, "Heavy Test") 85 | def test_creating_dataloader_for_streaming_data(self): 86 | streaming_dataset = dataset_utils.create_hf_streaming_ray_dataset() 87 | self._run_test(streaming_dataset) 88 | 89 | 90 | if __name__ == "__main__": 91 | unittest.main(verbosity=2) 92 | -------------------------------------------------------------------------------- /tests/dataset/test_dataset_creation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | """Unit tests for creating Kithara Datasets 18 | 19 | Run test on a TPU VM: python -m unittest tests/dataset/test_dataset_creation.py 20 | """ 21 | 22 | import unittest 23 | from kithara import SFTDataset, TextCompletionDataset 24 | import time 25 | import unittest.result 26 | import tests.dataset.utils as dataset_utils 27 | import numpy as np 28 | import os 29 | 30 | 31 | class TestDatasetCreation(unittest.TestCase): 32 | 33 | def setUp(self): 34 | print(f"\nStarting test: {self._testMethodName}") 35 | self.start_time = time.time() 36 | 37 | def tearDown(self): 38 | duration = time.time() - self.start_time 39 | print(f"Completed test: {self._testMethodName} in {duration:.2f} seconds\n") 40 | 41 | def _check_dataset_item_format(self, element): 42 | self.assertTrue(isinstance(element, dict)) 43 | self.assertTrue("x" in element) 44 | self.assertTrue("y" in element) 45 | self.assertTrue("token_ids" in element["x"]) 46 | self.assertTrue("padding_mask" in element["x"]) 47 | self.assertTrue(isinstance(element["x"]["token_ids"], np.ndarray)) 48 | self.assertTrue(isinstance(element["x"]["padding_mask"], np.ndarray)) 49 | self.assertTrue(isinstance(element["y"], np.ndarray)) 50 | 51 | def test_creating_sft_dataset_from_dict(self): 52 | dict_dataset = dataset_utils.create_dict_ray_dataset() 53 | dataset = SFTDataset( 54 | dict_dataset, 55 | tokenizer_handle="hf://google/gemma-2-2b", 56 | column_mapping={"prompt": "text", "answer": "text"}, 57 | model_type="KerasHub", 58 | ) 59 | self.assertIsNotNone(len(dataset)) 60 | for element in dataset: 61 | self._check_dataset_item_format(element) 62 | break 63 | 64 | def test_creating_sft_dataset_fail_when_tokenizer_is_not_provided(self): 65 | dict_dataset = dataset_utils.create_dict_ray_dataset() 66 | with self.assertRaises(AssertionError): 67 | SFTDataset(dict_dataset, model_type="KerasHub") 68 | 69 | def test_creating_sft_dataset_from_csv(self): 70 | csv_dataset = dataset_utils.create_csv_ray_dataset() 71 | dataset = SFTDataset( 72 | csv_dataset, 73 | tokenizer_handle="hf://google/gemma-2-2b", 74 | column_mapping={ 75 | "prompt": "sepal length (cm)", 76 | "answer": "sepal length (cm)", 77 | }, 78 | model_type="KerasHub", 79 | ) 80 | self.assertIsNotNone(len(dataset)) 81 | for element in dataset: 82 | self._check_dataset_item_format(element) 83 | break 84 | 85 | @unittest.skipIf(int(os.getenv("RUN_LIGHT_TESTS_ONLY", 0)) == 1, "Heavy Test") 86 | def test_creating_text_completion_dataset_from_streaming_hf(self): 87 | streaming_dataset = dataset_utils.create_hf_streaming_ray_dataset() 88 | dataset = TextCompletionDataset( 89 | streaming_dataset, 90 | tokenizer_handle="hf://google/gemma-2-2b", 91 | model_type="KerasHub", 92 | ) 93 | self.assertTrue(len(dataset) == 364608) 94 | for element in dataset: 95 | self._check_dataset_item_format(element) 96 | break 97 | 98 | @unittest.skipIf(int(os.getenv("RUN_LIGHT_TESTS_ONLY", 0)) == 1, "Heavy Test") 99 | def test_creating_dataset_from_raw_hf_dataset(self): 100 | from datasets import load_dataset 101 | 102 | hf_dataset = load_dataset( 103 | "allenai/c4", "en", split="validation", streaming=True 104 | ) 105 | dataset = TextCompletionDataset( 106 | hf_dataset, 107 | tokenizer_handle="hf://google/gemma-2-2b", 108 | model_type="KerasHub", 109 | ) 110 | self.assertTrue(len(dataset) == 364608) 111 | for element in dataset: 112 | self._check_dataset_item_format(element) 113 | break 114 | 115 | 116 | if __name__ == "__main__": 117 | unittest.main(verbosity=2) 118 | -------------------------------------------------------------------------------- /tests/dataset/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import ray 18 | from datasets import load_dataset 19 | 20 | def create_dict_ray_dataset(n=1000): 21 | dataset_items = [ 22 | {"text": f"{i} What is your name? My name is Kithara."} for i in range(n) 23 | ] 24 | return ray.data.from_items(dataset_items) 25 | 26 | def create_hf_streaming_ray_dataset(): 27 | hf_val_dataset = load_dataset( 28 | "allenai/c4", "en", split="validation", streaming=True 29 | ) 30 | return ray.data.from_huggingface(hf_val_dataset) 31 | 32 | def create_json_ray_dataset(): 33 | ds = ray.data.read_json("s3:/import ray/anonymous@ray-example-data/log.json") 34 | return ds 35 | 36 | def create_csv_ray_dataset(): 37 | ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv") 38 | return ds -------------------------------------------------------------------------------- /tests/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/tests/model/__init__.py -------------------------------------------------------------------------------- /tests/model/kerashub/ckpt_compatibility/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/tests/model/kerashub/ckpt_compatibility/__init__.py -------------------------------------------------------------------------------- /tests/model/kerashub/test_creation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | """Unit tests for creating a KerasHub model 18 | 19 | Run test on a TPU VM: python -m unittest tests/model/kerashub/test_creation.py 20 | """ 21 | import unittest 22 | from kithara import KerasHubModel 23 | import time 24 | import unittest.result 25 | 26 | class TestModelCreation(unittest.TestCase): 27 | 28 | def setUp(self): 29 | print(f"\nStarting test: {self._testMethodName}") 30 | self.start_time = time.time() 31 | 32 | def tearDown(self): 33 | duration = time.time() - self.start_time 34 | print(f"Completed test: {self._testMethodName} in {duration:.2f} seconds\n") 35 | 36 | def test_creating_gemma2_2b_model(self): 37 | KerasHubModel.from_preset( 38 | "hf://google/gemma-2-2b" 39 | ) 40 | def test_creating_gemma2_2b_model_with_lora(self): 41 | KerasHubModel.from_preset( 42 | "hf://google/gemma-2-2b", 43 | lora_rank=6 44 | ) 45 | 46 | def test_creating_model_with_precision(self): 47 | model = KerasHubModel.from_preset( 48 | "hf://google/gemma-2-2b", 49 | precision="float32" 50 | ) 51 | self.assertTrue(str(model.trainable_variables[0].value.dtype) == "float32") 52 | 53 | def test_creating_gemma2_9b_model(self): 54 | KerasHubModel.from_preset( 55 | "hf://google/gemma-2-9b" 56 | ) 57 | def test_creating_gemma2_2b_model_with_lora(self): 58 | KerasHubModel.from_preset( 59 | "hf://google/gemma-2-9b", 60 | lora_rank=16 61 | ) 62 | 63 | if __name__ == '__main__': 64 | unittest.main(verbosity=2) 65 | -------------------------------------------------------------------------------- /tests/model/maxtext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/tests/model/maxtext/__init__.py -------------------------------------------------------------------------------- /tests/model/maxtext/ckpt_compatibility/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/tests/model/maxtext/ckpt_compatibility/__init__.py -------------------------------------------------------------------------------- /tests/model/maxtext/test_creation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | """Unit tests for creating a MaxText model 18 | 19 | Run test on a TPU VM: python -m unittest tests/model/maxtext/test_creation.py 20 | """ 21 | import unittest 22 | from kithara import MaxTextModel 23 | import time 24 | import unittest.result 25 | import os 26 | 27 | class TestModelCreation(unittest.TestCase): 28 | 29 | def setUp(self): 30 | print(f"\nStarting test: {self._testMethodName}") 31 | self.start_time = time.time() 32 | 33 | def tearDown(self): 34 | duration = time.time() - self.start_time 35 | print(f"Completed test: {self._testMethodName} in {duration:.2f} seconds\n") 36 | 37 | def test_creating_random_gemma2_2b_model(self): 38 | MaxTextModel.from_random( 39 | "gemma2-2b", 40 | seq_len=100, 41 | per_device_batch_size=1, 42 | ) 43 | 44 | def test_creating_preset_gemma2_2b_model(self): 45 | MaxTextModel.from_preset( 46 | "hf://google/gemma-2-2b", 47 | seq_len=100, 48 | per_device_batch_size=1, 49 | ) 50 | 51 | def test_creating_preset_gemma2_2b_model_with_scan(self): 52 | MaxTextModel.from_preset( 53 | "hf://google/gemma-2-2b", 54 | seq_len=100, 55 | per_device_batch_size=1, 56 | scan_layers=True 57 | ) 58 | 59 | def test_creating_random_model_with_precision(self): 60 | model = MaxTextModel.from_random( 61 | "gemma2-2b", 62 | seq_len=100, 63 | per_device_batch_size=1, 64 | precision="bfloat16", 65 | ) 66 | self.assertTrue(str(model.trainable_variables[0].value.dtype) == "bfloat16") 67 | 68 | def test_creating_preset_model_with_precision(self): 69 | model = MaxTextModel.from_preset( 70 | "hf://google/gemma-2-2b", 71 | seq_len=100, 72 | per_device_batch_size=1, 73 | precision="bfloat16", 74 | ) 75 | self.assertTrue(str(model.trainable_variables[0].value.dtype) == "bfloat16") 76 | 77 | def test_creating_random_gemma2_2b_model(self): 78 | MaxTextModel.from_random( 79 | "gemma2-9b", 80 | seq_len=100, 81 | per_device_batch_size=1, 82 | ) 83 | 84 | @unittest.skipIf(int(os.getenv('RUN_LIGHT_TESTS_ONLY', 0)) == 1, "Heavy Test") 85 | def test_creating_preset_gemma2_9b_model(self): 86 | MaxTextModel.from_preset( 87 | "hf://google/gemma-2-9b", 88 | seq_len=100, 89 | per_device_batch_size=1, 90 | ) 91 | 92 | @unittest.skipIf(int(os.getenv('RUN_LIGHT_TESTS_ONLY', 0)) == 1, "Heavy Test") 93 | def test_creating_preset_gemma2_9b_model_with_scan(self): 94 | MaxTextModel.from_preset( 95 | "hf://google/gemma-2-9b", 96 | seq_len=100, 97 | per_device_batch_size=1, 98 | scan_layers=True 99 | ) 100 | 101 | if __name__ == '__main__': 102 | unittest.main(verbosity=2) 103 | -------------------------------------------------------------------------------- /tests/model/test_prompt.py: -------------------------------------------------------------------------------- 1 | TEST_PROMPT="""# Activation dtypes. 2 | dtype: "bfloat16" 3 | # Used to configure quantization in the transformer layers, defaults to null implying bf16. 4 | # Possible alternative settings are as follows: 5 | # 'int8' for dynamic range quantization using 8-bits 6 | # 'int8w' for weights only quantization using 8-bits 7 | # 'int4w' for weights only quantization using 4-bits 8 | # 'intmp' for mixed precision weight only quantization based on config file 9 | # 'fp8' for 8-bit floating-point GeMMs on NVIDIA GPUs. 10 | quantization: "" 11 | # Choose one of default, high, and highest. 12 | # https://kolonist26-jax-kr.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision 13 | matmul_precision: "default" 14 | activations_in_float32: False # Sets activations to float32 before nonlinearity it true, else dtype 15 | # Path to file with quantization config - only used for mixed precision. 16 | # Example configs in ../Maxtext/configs/quantization 17 | # Allows us to configure different bits, tiling and scale for quantizing selected weights. 18 | # Bits represents number of bits to quantize to, 19 | # tile-size represents the tiling sized used in AQT tiled_dot_general, 20 | # Value of scale is used to scale the abs_max value used for AQT quantization 21 | # Defaults values are 8 bits, tile-size=-1 (no tiling) and scale=1. 22 | quant_cfg_path: "" 23 | quantize_kvcache: False # Set to True to quantize KV Cache values, defaults to False 24 | # Valid kv_quant_axis values: 25 | # - "" is valid only when quantize_kvcache is False 26 | # - "dkv" indicates quantize kv cache over the cache_kv, i.e. kv dimension axis 27 | # - "heads_and_dkv" indicates quantize kv cache over cache_heads and cache_kv axes 28 | # Default to "heads_and_dkv" for faster compution, kv_quant_axis is not used when quantize_kvcache is False 29 | # - "dkv" is expected with better accuracy but degraded computation 30 | kv_quant_axis: "heads_and_dkv" 31 | kv_quant_dtype: "int8" 32 | checkpoint_is_quantized: False # Set to True if reading from a saved aqt quantized checkpoint 33 | # Saves params quantized on fly at following path 34 | save_quantized_params_path: "" 35 | # Shard the range finding operation for quantization. By default this is set to number of slices. 36 | quantization_local_shard_count: -1 37 | decoder_block: "llama2" # which style of DecoderBlock to use. 38 | # Global parameter scale needs to be a power of 2. If you want finer grained control of the model sizes 39 | # then you should explicitly set base_embed_dim, base_num_query_heads, base_num_kv_heads, 40 | # base_mlp_dim, base_num_decoder_layers and/or head_dim. 41 | weight_dtype: float32 42 | global_parameter_scale: 1 43 | base_emb_dim: 2048 44 | base_num_query_heads: 16 45 | base_num_kv_heads: 16 46 | base_mlp_dim: 7168 47 | base_num_decoder_layers: 16 48 | head_dim: 128 49 | mlp_activations: ["silu", "linear"] 50 | dropout_rate: 0 51 | logits_via_embedding: False 52 | normalize_embedding_logits: True # whether to normlize pre-softmax logits if logits_via_embedding is true 53 | logits_dot_in_fp32: True # whether to use fp32 in logits_dense or shared_embedding dot product for stability""" -------------------------------------------------------------------------------- /tests/optimizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/tests/optimizers/__init__.py -------------------------------------------------------------------------------- /tests/optimizers/test_optax_optimizer.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import unittest.mock 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import optax 7 | from keras.src import backend 8 | 9 | from kithara.optimizers import OptaxOptimizerInterface 10 | 11 | class TestOptaxOptimizerInterface(unittest.TestCase): 12 | 13 | def setUp(self): 14 | self.dummy_keras_variables = [ 15 | backend.Variable(jnp.ones((1024, 1024)), name="dummy_keras_variable") 16 | for _ in range(3)] 17 | 18 | return super().setUp() 19 | 20 | def test_learning_rate_is_none(self): 21 | optax_optimizer_interface = OptaxOptimizerInterface( 22 | optax.adam(1e-3), self.dummy_keras_variables) 23 | self.assertIsNone(optax_optimizer_interface.learning_rate) 24 | 25 | def test_optimizer_initialization(self): 26 | optax_optimizer_interface = OptaxOptimizerInterface( 27 | optax.adam(1e-3), self.dummy_keras_variables) 28 | self.assertTrue( 29 | jax.tree.reduce( 30 | lambda agg, leaf: agg and isinstance(leaf, backend.Variable), 31 | optax_optimizer_interface.variables, True)) 32 | 33 | def test_stateless_apply(self): 34 | optax_optimizer_interface = OptaxOptimizerInterface( 35 | optax.adam(1e-3), self.dummy_keras_variables) 36 | new_trainables, new_opt_variables = optax_optimizer_interface.stateless_apply( 37 | optax_optimizer_interface.variables, self.dummy_keras_variables, self.dummy_keras_variables,) 38 | print(new_trainables, new_opt_variables) -------------------------------------------------------------------------------- /tests/optimizers/test_protocol.py: -------------------------------------------------------------------------------- 1 | """Unit tests for testing the protocols integrity and the conversion of optimizers to Kithara optimizers. 2 | 3 | Run test on a TPU VM: python -m unittest tests/optimizers/test_protocol.py 4 | """ 5 | import unittest 6 | import unittest.mock 7 | 8 | import keras 9 | import optax 10 | 11 | from kithara.optimizers import KitharaOptimizer, convert_to_kithara_optimizer 12 | 13 | class TestKitharaOptimizer(unittest.TestCase): 14 | 15 | def test_kithara_optimizer_interface_integrity(self): 16 | 17 | with self.assertRaises(TypeError): 18 | KitharaOptimizer() 19 | self.assertTrue(callable(KitharaOptimizer.stateless_apply)) 20 | self.assertTrue(callable(KitharaOptimizer.get_optimizer_memory_usage)) 21 | self.assertTrue(hasattr(KitharaOptimizer, "learning_rate")) 22 | self.assertTrue(hasattr(KitharaOptimizer, "iterations")) 23 | self.assertTrue(hasattr(KitharaOptimizer, "variables")) 24 | 25 | @unittest.mock.patch("kithara.optimizers.OptaxOptimizerInterface") 26 | def test_convert_to_kithara_optimizer( 27 | self, mocked_optax_optimizer_interface: unittest.mock.Mock): 28 | 29 | convert_to_kithara_optimizer(optax.adam(1e-3), 1) 30 | mocked_optax_optimizer_interface.assert_called_once() 31 | 32 | with self.assertRaises(NotImplementedError): 33 | convert_to_kithara_optimizer(keras.Optimizer(.01), 1.0) 34 | 35 | if __name__ == "__main__": 36 | unittest.main() -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Google LLC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import signal 18 | from functools import wraps 19 | 20 | def timeout(seconds): 21 | def decorator(func): 22 | @wraps(func) 23 | def wrapper(*args, **kwargs): 24 | def handler(signum, frame): 25 | raise TimeoutError(f"Test timed out after {seconds} seconds") 26 | 27 | # Set the timeout handler 28 | original_handler = signal.signal(signal.SIGALRM, handler) 29 | signal.alarm(seconds) 30 | 31 | try: 32 | result = func(*args, **kwargs) 33 | finally: 34 | # Restore the original handler and disable the alarm 35 | signal.alarm(0) 36 | signal.signal(signal.SIGALRM, original_handler) 37 | return result 38 | return wrapper 39 | return decorator 40 | -------------------------------------------------------------------------------- /tests/trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/kithara/5ce848d574eb0fe8284a7027fb573216edeeabec/tests/trainer/__init__.py --------------------------------------------------------------------------------