├── version.txt ├── torchtune ├── version.txt ├── docs │ ├── source │ │ ├── deep_dives │ │ │ └── README.txt │ │ ├── tutorials │ │ │ └── README.txt │ │ ├── _static │ │ │ └── img │ │ │ │ ├── qlora_exp.png │ │ │ │ ├── kd-qwen2-res.png │ │ │ │ ├── kd-simplified.png │ │ │ │ ├── lora_diagram.png │ │ │ │ ├── qat_diagram.png │ │ │ │ ├── kd-hyperparam-lr.png │ │ │ │ ├── pytorch-logo-dark.png │ │ │ │ ├── pytorch-logo-flame.png │ │ │ │ ├── generic-pytorch-logo.png │ │ │ │ ├── kd-finetune-student.png │ │ │ │ ├── kd-finetune-teacher.png │ │ │ │ ├── torchtune_workspace.png │ │ │ │ ├── comet_torchtune_project.png │ │ │ │ ├── kd-hyperparam-kd-ratio.png │ │ │ │ ├── lora_experiment_loss_curves.png │ │ │ │ ├── card-background.svg │ │ │ │ ├── pytorch-logo-flame.svg │ │ │ │ └── pytorch-logo-dark.svg │ │ ├── _templates │ │ │ ├── autosummary │ │ │ │ ├── function.rst │ │ │ │ └── class.rst │ │ │ └── layout.html │ │ ├── api_ref_config.rst │ │ ├── api_ref_generation.rst │ │ ├── api_ref_utilities.rst │ │ ├── api_ref_rlhf.rst │ │ ├── api_ref_datasets.rst │ │ ├── api_ref_data.rst │ │ └── basics │ │ │ └── datasets_overview.rst │ ├── requirements.txt │ ├── license_header.txt │ └── Makefile ├── MANIFEST.in ├── tests │ ├── assets │ │ ├── generation_config.json │ │ ├── m.model │ │ ├── valid_dummy_config.yaml │ │ ├── rgb_pytorch.png │ │ ├── sentencepiece.model │ │ ├── dog_on_skateboard.jpg │ │ ├── vqa_tiny.json │ │ ├── invalid_dummy_config.yaml │ │ ├── hh_rlhf_tiny.json │ │ ├── tokenizer_config.json │ │ ├── instruct_tiny.json │ │ ├── README.md │ │ └── chat_tiny.json │ ├── recipes │ │ ├── __init__.py │ │ ├── common.py │ │ └── test_configs.py │ ├── torchtune │ │ ├── __init__.py │ │ ├── _cli │ │ │ ├── __init__.py │ │ │ ├── test_tune.py │ │ │ ├── test_ls.py │ │ │ └── test_validate.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── clip │ │ │ │ ├── __init__.py │ │ │ │ ├── test_clip_text_encoder.py │ │ │ │ └── test_clip_tokenizer.py │ │ │ ├── flux │ │ │ │ └── __init__.py │ │ │ ├── phi3 │ │ │ │ ├── __init__.py │ │ │ │ └── test_phi3.py │ │ │ ├── phi4 │ │ │ │ └── __init__.py │ │ │ ├── qwen2 │ │ │ │ ├── __init__.py │ │ │ │ └── test_qwen2.py │ │ │ ├── t5 │ │ │ │ ├── __init__.py │ │ │ │ └── test_t5_tokenizer.py │ │ │ ├── qwen2_5 │ │ │ │ └── __init__.py │ │ │ ├── llama2 │ │ │ │ ├── scripts │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── README.md │ │ │ │ └── test_llama2_prompt_template.py │ │ │ ├── mistral │ │ │ │ ├── scripts │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── mistral_test_config.py │ │ │ │ │ └── README.md │ │ │ │ ├── test_mistral.py │ │ │ │ ├── test_mistral_classifier.py │ │ │ │ └── test_mistral_prompt_template.py │ │ │ ├── llama3 │ │ │ │ └── test_llama3.py │ │ │ └── llama4 │ │ │ │ └── test_llama4_transform.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── peft │ │ │ │ └── __init__.py │ │ │ ├── low_precision │ │ │ │ ├── __init__.py │ │ │ │ └── test_nf4_dispatch_registration.py │ │ │ ├── model_fusion │ │ │ │ ├── __init__.py │ │ │ │ └── test_fusion_utils.py │ │ │ ├── transforms │ │ │ │ ├── test_pad_dim_to_size.py │ │ │ │ └── tokenizers │ │ │ │ │ └── test_utils.py │ │ │ ├── test_layernorm.py │ │ │ ├── test_feed_forward.py │ │ │ └── loss │ │ │ │ └── test_ce_chunked_output_loss.py │ │ ├── rlhf │ │ │ ├── __init__.py │ │ │ └── loss │ │ │ │ └── __init__.py │ │ ├── utils │ │ │ └── __init__.py │ │ ├── datasets │ │ │ ├── __init__.py │ │ │ ├── multimodal │ │ │ │ ├── test_multimodal_chat_dataset.py │ │ │ │ └── test_vqa_dataset.py │ │ │ └── test_wikitext_dataset.py │ │ ├── generation │ │ │ └── __init__.py │ │ ├── dev │ │ │ └── rl │ │ │ │ └── rewards │ │ │ │ ├── __init__.py │ │ │ │ └── test_rewards.py │ │ └── config │ │ │ └── test_validate.py │ ├── common.py │ ├── test_import_recipes.py │ └── __init__.py ├── torchtune │ ├── dev │ │ ├── __init__.py │ │ ├── grpo │ │ │ └── __init__.py │ │ ├── rl │ │ │ ├── __init__.py │ │ │ ├── workers │ │ │ │ ├── weight_updaters │ │ │ │ │ └── __init__.py │ │ │ │ ├── trainers │ │ │ │ │ └── __init__.py │ │ │ │ ├── parameter_servers │ │ │ │ │ └── __init__.py │ │ │ │ ├── datacollectors │ │ │ │ │ └── __init__.py │ │ │ │ ├── __init__.py │ │ │ │ └── metric_logger.py │ │ │ ├── utils │ │ │ │ ├── __init__.py │ │ │ │ └── dist.py │ │ │ └── datatypes │ │ │ │ ├── vllm_completion_output.py │ │ │ │ ├── __init__.py │ │ │ │ └── trajectory.py │ │ └── README.md │ ├── _cli │ │ ├── __init__.py │ │ ├── subcommand.py │ │ └── tune.py │ ├── models │ │ ├── __init__.py │ │ ├── flux │ │ │ └── __init__.py │ │ ├── llama3_3 │ │ │ ├── __init__.py │ │ │ └── _model_builders.py │ │ ├── t5 │ │ │ ├── __init__.py │ │ │ └── _model_builders.py │ │ ├── phi4 │ │ │ └── __init__.py │ │ ├── llama3_2 │ │ │ └── __init__.py │ │ ├── gemma │ │ │ ├── __init__.py │ │ │ └── rms_norm.py │ │ ├── code_llama2 │ │ │ └── __init__.py │ │ ├── clip │ │ │ └── __init__.py │ │ ├── llama3 │ │ │ ├── __init__.py │ │ │ └── _model_utils.py │ │ ├── phi3 │ │ │ └── __init__.py │ │ ├── llama2 │ │ │ ├── _model_utils.py │ │ │ └── __init__.py │ │ ├── llama3_1 │ │ │ └── __init__.py │ │ ├── gemma2 │ │ │ └── __init__.py │ │ ├── qwen2 │ │ │ └── __init__.py │ │ ├── mistral │ │ │ └── __init__.py │ │ ├── llama3_2_vision │ │ │ └── __init__.py │ │ ├── llama4 │ │ │ ├── __init__.py │ │ │ └── _chunked_attention.py │ │ └── qwen2_5 │ │ │ └── __init__.py │ ├── modules │ │ ├── transforms │ │ │ ├── vision_utils │ │ │ │ ├── __init__.py │ │ │ │ └── pad_dim_to_size.py │ │ │ ├── __init__.py │ │ │ └── tokenizers │ │ │ │ └── __init__.py │ │ ├── low_precision │ │ │ ├── __init__.py │ │ │ └── _register_nf4_dispatch_ops.py │ │ ├── _export │ │ │ ├── install_requirements.sh │ │ │ └── README.md │ │ ├── moe │ │ │ └── __init__.py │ │ ├── model_fusion │ │ │ └── __init__.py │ │ ├── tanh_gate.py │ │ ├── loss │ │ │ └── __init__.py │ │ ├── peft │ │ │ └── __init__.py │ │ ├── tokenizers │ │ │ └── __init__.py │ │ ├── layer_norm.py │ │ ├── feed_forward.py │ │ └── rms_norm.py │ ├── data │ │ ├── _common.py │ │ ├── _torchdata.py │ │ └── __init__.py │ ├── rlhf │ │ ├── loss │ │ │ └── __init__.py │ │ ├── utils │ │ │ └── __init__.py │ │ └── __init__.py │ ├── config │ │ ├── __init__.py │ │ ├── _errors.py │ │ └── _validate.py │ ├── generation │ │ └── __init__.py │ ├── datasets │ │ ├── multimodal │ │ │ └── __init__.py │ │ └── __init__.py │ ├── utils │ │ ├── _import_guard.py │ │ ├── __init__.py │ │ └── _version.py │ ├── training │ │ ├── _model_util.py │ │ ├── pooling.py │ │ └── checkpointing │ │ │ └── __init__.py │ └── __init__.py ├── CITATION.cff ├── recipes │ ├── dev │ │ ├── gsm8k_sft.sbatch │ │ └── multinode_grpo.sbatch │ ├── configs │ │ ├── quantization.yaml │ │ ├── gemma │ │ │ └── evaluation.yaml │ │ ├── qwen3 │ │ │ └── evaluation.yaml │ │ ├── qwen2_5 │ │ │ └── evaluation.yaml │ │ ├── mistral │ │ │ └── evaluation.yaml │ │ ├── eleuther_evaluation.yaml │ │ ├── phi3 │ │ │ └── evaluation.yaml │ │ ├── llama3_2 │ │ │ └── evaluation.yaml │ │ ├── code_llama2 │ │ │ └── evaluation.yaml │ │ ├── qwen2 │ │ │ └── evaluation.yaml │ │ ├── phi4 │ │ │ └── evaluation.yaml │ │ ├── llama2 │ │ │ └── generation_v2.yaml │ │ ├── generation.yaml │ │ ├── llama3 │ │ │ └── 70B_generation_distributed.yaml │ │ ├── llama3_3 │ │ │ └── 70B_generation_distributed.yaml │ │ ├── llama3_1 │ │ │ └── 70B_generation_distributed.yaml │ │ ├── llama3_2_vision │ │ │ └── 11B_generation_v2.yaml │ │ └── llama4 │ │ │ └── scout_17B_16E_generation_distributed.yaml │ ├── __init__.py │ └── full_finetune_multinode.slurm └── LICENSE ├── MANIFEST.in ├── tests └── models │ └── qwen3 │ └── __init__.py ├── mirotrain ├── data │ └── __init__.py ├── __init__.py ├── modules │ ├── loss │ │ └── __init__.py │ ├── moe │ │ ├── __init__.py │ │ ├── grouped_gemm_util.py │ │ └── expert_parallel.py │ └── __init__.py ├── models │ ├── qwen3_moe │ │ └── __init__.py │ ├── convert_weights.py │ └── qwen3 │ │ ├── _checkpointing_utils.py │ │ └── _validate.py ├── datasets │ └── __init__.py ├── training │ ├── checkpointing │ │ ├── _utils.py │ │ └── __init__.py │ ├── __init__.py │ └── lr_schedulers.py └── monkey │ ├── __init__.py │ └── _errors.py ├── .flake8 └── .pre-commit-config.yaml /version.txt: -------------------------------------------------------------------------------- 1 | 0.1.0 2 | -------------------------------------------------------------------------------- /torchtune/version.txt: -------------------------------------------------------------------------------- 1 | 0.7.0 2 | -------------------------------------------------------------------------------- /torchtune/docs/source/deep_dives/README.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchtune/docs/source/tutorials/README.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | prune tests # Remove all testing files from final dist/ 2 | -------------------------------------------------------------------------------- /torchtune/MANIFEST.in: -------------------------------------------------------------------------------- 1 | prune tests # Remove all testing files from final dist/ 2 | -------------------------------------------------------------------------------- /torchtune/tests/assets/generation_config.json: -------------------------------------------------------------------------------- 1 | {"bos_token_id": 0, "eos_token_id": -1} 2 | -------------------------------------------------------------------------------- /torchtune/tests/assets/m.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/tests/assets/m.model -------------------------------------------------------------------------------- /torchtune/tests/assets/valid_dummy_config.yaml: -------------------------------------------------------------------------------- 1 | test: 2 | _component_: torchtune.training.get_dtype 3 | dtype: fp32 4 | -------------------------------------------------------------------------------- /tests/models/qwen3/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2025 MiromindAI 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | -------------------------------------------------------------------------------- /torchtune/tests/assets/rgb_pytorch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/tests/assets/rgb_pytorch.png -------------------------------------------------------------------------------- /torchtune/tests/assets/sentencepiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/tests/assets/sentencepiece.model -------------------------------------------------------------------------------- /torchtune/tests/assets/dog_on_skateboard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/tests/assets/dog_on_skateboard.jpg -------------------------------------------------------------------------------- /torchtune/docs/source/_static/img/qlora_exp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/qlora_exp.png -------------------------------------------------------------------------------- /torchtune/docs/source/_static/img/kd-qwen2-res.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/kd-qwen2-res.png -------------------------------------------------------------------------------- /torchtune/docs/source/_static/img/kd-simplified.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/kd-simplified.png -------------------------------------------------------------------------------- /torchtune/docs/source/_static/img/lora_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/lora_diagram.png -------------------------------------------------------------------------------- /torchtune/docs/source/_static/img/qat_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/qat_diagram.png -------------------------------------------------------------------------------- /torchtune/docs/source/_static/img/kd-hyperparam-lr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/kd-hyperparam-lr.png -------------------------------------------------------------------------------- /torchtune/docs/source/_static/img/pytorch-logo-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/pytorch-logo-dark.png -------------------------------------------------------------------------------- /torchtune/docs/source/_static/img/pytorch-logo-flame.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/pytorch-logo-flame.png -------------------------------------------------------------------------------- /torchtune/docs/source/_static/img/generic-pytorch-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/generic-pytorch-logo.png -------------------------------------------------------------------------------- /torchtune/docs/source/_static/img/kd-finetune-student.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/kd-finetune-student.png -------------------------------------------------------------------------------- /torchtune/docs/source/_static/img/kd-finetune-teacher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/kd-finetune-teacher.png -------------------------------------------------------------------------------- /torchtune/docs/source/_static/img/torchtune_workspace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/torchtune_workspace.png -------------------------------------------------------------------------------- /torchtune/docs/source/_static/img/comet_torchtune_project.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/comet_torchtune_project.png -------------------------------------------------------------------------------- /torchtune/docs/source/_static/img/kd-hyperparam-kd-ratio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/kd-hyperparam-kd-ratio.png -------------------------------------------------------------------------------- /torchtune/docs/source/_static/img/lora_experiment_loss_curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/lora_experiment_loss_curves.png -------------------------------------------------------------------------------- /mirotrain/data/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2025 MiromindAI 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from ._messages import TracesToMessages 6 | 7 | __all__ = ["TracesToMessages"] 8 | -------------------------------------------------------------------------------- /torchtune/tests/assets/vqa_tiny.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "input": "What is presented on image?", 4 | "output": "PyTorch logo.", 5 | "image": "tests/assets/rgb_pytorch.png" 6 | } 7 | ] 8 | -------------------------------------------------------------------------------- /torchtune/docs/source/_templates/autosummary/function.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | {{ name | underline}} 7 | 8 | .. autofunction:: {{ name }} 9 | -------------------------------------------------------------------------------- /torchtune/docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx-gallery>0.11 2 | sphinx==5.0.0 3 | sphinx_design 4 | sphinx_copybutton 5 | sphinx-tabs 6 | matplotlib 7 | -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 8 | -------------------------------------------------------------------------------- /torchtune/tests/assets/invalid_dummy_config.yaml: -------------------------------------------------------------------------------- 1 | test1: 2 | _component_: torchtune.training.get_dtype 3 | dtype: fp32 4 | dummy: 3 5 | test2: 6 | _component_: torchtune.training.get_dtype 7 | dtype: fp32 8 | dummy: 3 9 | -------------------------------------------------------------------------------- /torchtune/docs/source/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | {{ name | underline}} 7 | 8 | .. autoclass:: {{ name }} 9 | :members: 10 | -------------------------------------------------------------------------------- /torchtune/docs/license_header.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) Meta Platforms, Inc. and affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the BSD-style license found in the 5 | LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/recipes/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/torchtune/dev/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/torchtune/_cli/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/torchtune/dev/grpo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/torchtune/dev/rl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/_cli/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/rlhf/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/generation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/clip/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/flux/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/phi3/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/phi4/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/qwen2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/t5/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/modules/peft/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/rlhf/loss/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/dev/rl/rewards/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/qwen2_5/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/llama2/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/mistral/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/modules/low_precision/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/modules/model_fusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/torchtune/modules/transforms/vision_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtune/tests/assets/hh_rlhf_tiny.json: -------------------------------------------------------------------------------- 1 | [{"chosen":[{"content":"What do I do when I have a hole in my trousers?","role":"user"},{"content":"Fix the hole.","role":"assistant"}],"rejected":[{"content":"What do I do when I have a hole in my trousers?","role":"user"},{"content":"Take them off.","role":"assistant"}]}] 2 | -------------------------------------------------------------------------------- /torchtune/docs/source/api_ref_config.rst: -------------------------------------------------------------------------------- 1 | .. _config: 2 | 3 | ================ 4 | torchtune.config 5 | ================ 6 | 7 | .. currentmodule:: torchtune.config 8 | 9 | .. autosummary:: 10 | :toctree: generated/ 11 | :nosignatures: 12 | 13 | instantiate 14 | parse 15 | validate 16 | log_config 17 | -------------------------------------------------------------------------------- /torchtune/tests/assets/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"bos_token": {"__type": "AddedToken", "content": "<|begin_of_sentence|>", "lstrip": false, "normalized": true, "rstrip": false, "single_word": false}, "eos_token": {"__type": "AddedToken", "content": "<|end_of_sentence|>", "lstrip": false, "normalized": true, "rstrip": false, "single_word": false}} 2 | -------------------------------------------------------------------------------- /torchtune/tests/recipes/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from pathlib import Path 8 | 9 | RECIPE_TESTS_DIR = Path(__file__).parent 10 | -------------------------------------------------------------------------------- /torchtune/torchtune/dev/rl/workers/weight_updaters/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .weight_updater import VLLMHFWeightUpdateReceiver # noqa: F401 8 | -------------------------------------------------------------------------------- /torchtune/tests/assets/instruct_tiny.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instruction": "What time is it in London?", 4 | "response": "It is 10:00 AM in London" 5 | }, 6 | { 7 | "instruction": "Is is Istanbul or Constantinople?", 8 | "response": "Istanbul was Constantinople. Now it's Istanbul, not Constantinople." 9 | } 10 | ] 11 | -------------------------------------------------------------------------------- /torchtune/torchtune/dev/README.md: -------------------------------------------------------------------------------- 1 | ## Torchtune dev components 2 | 3 | This directory houses experimental components. 4 | The purpose of `torchtune/dev` is to support bleeding-edge APIs 5 | with less stringent requirements on testing, documentation, or stability. 6 | The APIs in this folder are not public and are not guaranteed to adhere to backwards compatibility. 7 | -------------------------------------------------------------------------------- /mirotrain/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2025 MiromindAI 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | 6 | __version__ = "0.1.0" 7 | 8 | # Import main modules (lazy import to avoid dependency issues) 9 | __all__ = [ 10 | "__version__", 11 | "data", 12 | "datasets", 13 | "models", 14 | "modules", 15 | "training", 16 | ] 17 | -------------------------------------------------------------------------------- /torchtune/CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: "torchtune: PyTorch's post-training library" 3 | message: "If you use this software, please cite it as below." 4 | type: software 5 | authors: 6 | - given-names: "torchtune maintainers and contributors" 7 | url: "https//github.com/pytorch/torchtune" 8 | license: "BSD-3-Clause" 9 | date-released: "2024-04-14" 10 | -------------------------------------------------------------------------------- /torchtune/torchtune/dev/rl/workers/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .training import TrainingWorker 8 | 9 | __all__ = ["TrainingWorker"] 10 | -------------------------------------------------------------------------------- /torchtune/tests/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from pathlib import Path 7 | 8 | TUNE_PATH = "torchtune/_cli/tune.py" 9 | 10 | ASSETS = Path(__file__).parent / "assets" 11 | -------------------------------------------------------------------------------- /torchtune/torchtune/dev/rl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .dist import stateless_init_process_group 8 | 9 | __all__ = ["stateless_init_process_group"] 10 | -------------------------------------------------------------------------------- /torchtune/torchtune/dev/rl/workers/parameter_servers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .vllm import VLLMParameterServer 8 | 9 | __all__ = ["VLLMParameterServer"] 10 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/flux/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from ._model_builders import flux_1_autoencoder 7 | 8 | __all__ = [ 9 | "flux_1_autoencoder", 10 | ] 11 | -------------------------------------------------------------------------------- /mirotrain/modules/loss/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2025 MiromindAI 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from .cross_entropy_loss import LigerLinearCrossEntropyLoss, LinearSqrtCrossEntropyLoss 6 | from .dpo_loss import DPOLoss 7 | 8 | __all__ = [ 9 | "LigerLinearCrossEntropyLoss", 10 | "LinearSqrtCrossEntropyLoss", 11 | "DPOLoss", 12 | ] 13 | -------------------------------------------------------------------------------- /torchtune/torchtune/modules/low_precision/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .nf4_linear import FrozenNF4Linear 8 | 9 | __all__ = [ 10 | "FrozenNF4Linear", 11 | ] 12 | -------------------------------------------------------------------------------- /torchtune/torchtune/dev/rl/workers/datacollectors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sync import SyncLLMCollector, VLLMWorkerWrapper 8 | 9 | __all__ = ["SyncLLMCollector", "VLLMWorkerWrapper"] 10 | -------------------------------------------------------------------------------- /mirotrain/models/qwen3_moe/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2025 MiromindAI 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from ._convert_weights import qwen3_moe_hf_to_tune, qwen3_moe_tune_to_hf 6 | from ._model_builders import qwen3_235b_a22b, qwen3_30b_a3b 7 | 8 | __all__ = [ 9 | "qwen3_moe_hf_to_tune", 10 | "qwen3_moe_tune_to_hf", 11 | "qwen3_30b_a3b", 12 | "qwen3_235b_a22b", 13 | ] 14 | -------------------------------------------------------------------------------- /torchtune/torchtune/data/_common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from typing import Dict, List, Union 7 | 8 | import torch 9 | 10 | CROSS_ENTROPY_IGNORE_IDX = -100 11 | PACK_TYPE = Dict[str, Union[torch.Tensor, List[int]]] 12 | -------------------------------------------------------------------------------- /torchtune/docs/source/api_ref_generation.rst: -------------------------------------------------------------------------------- 1 | .. _generation: 2 | 3 | ==================== 4 | torchtune.generation 5 | ==================== 6 | 7 | .. currentmodule:: torchtune.generation 8 | 9 | .. autosummary:: 10 | :toctree: generated/ 11 | :nosignatures: 12 | 13 | generate 14 | generate_next_token 15 | sample 16 | get_causal_mask_from_padding_mask 17 | get_position_ids_from_padding_mask 18 | -------------------------------------------------------------------------------- /torchtune/docs/source/api_ref_utilities.rst: -------------------------------------------------------------------------------- 1 | =============== 2 | torchtune.utils 3 | =============== 4 | 5 | .. currentmodule:: torchtune.utils 6 | 7 | 8 | .. _gen_label: 9 | 10 | Miscellaneous 11 | ------------- 12 | 13 | .. autosummary:: 14 | :toctree: generated/ 15 | :nosignatures: 16 | 17 | batch_to_device 18 | get_device 19 | get_logger 20 | torch_version_ge 21 | get_world_size_and_rank 22 | -------------------------------------------------------------------------------- /torchtune/torchtune/dev/rl/datatypes/vllm_completion_output.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import vllm 8 | from tensordict import from_dataclass 9 | 10 | VllmCompletionOutput = from_dataclass(vllm.outputs.CompletionOutput) 11 | -------------------------------------------------------------------------------- /torchtune/torchtune/rlhf/loss/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from .dpo import DPOLoss, RSOLoss 9 | from .ppo import PPOLoss 10 | 11 | __all__ = [ 12 | "DPOLoss", 13 | "RSOLoss", 14 | "PPOLoss", 15 | ] 16 | -------------------------------------------------------------------------------- /torchtune/torchtune/modules/_export/install_requirements.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # Install pytorch 9 | pip install torch==2.6.0 torchvision==0.21.0 --extra-index-url https://download.pytorch.org/whl/nightly/cpu 10 | -------------------------------------------------------------------------------- /torchtune/torchtune/rlhf/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._convert_weights import reward_hf_to_tune, reward_tune_to_hf # noqa 8 | 9 | __all__ = [ 10 | "reward_hf_to_tune", 11 | "reward_tune_to_hf", 12 | ] 13 | -------------------------------------------------------------------------------- /torchtune/docs/source/_static/img/card-background.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /torchtune/docs/source/api_ref_rlhf.rst: -------------------------------------------------------------------------------- 1 | =============== 2 | torchtune.rlhf 3 | =============== 4 | 5 | .. currentmodule:: torchtune.rlhf 6 | 7 | Components and losses for RLHF algorithms like PPO and DPO. 8 | 9 | .. autosummary:: 10 | :toctree: generated/ 11 | :nosignatures: 12 | 13 | estimate_advantages 14 | get_rewards_ppo 15 | truncate_sequence_at_first_stop_token 16 | loss.PPOLoss 17 | loss.DPOLoss 18 | loss.RSOLoss 19 | -------------------------------------------------------------------------------- /torchtune/torchtune/modules/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torchtune.modules.transforms._transforms import Transform, VisionCrossAttentionMask 8 | 9 | 10 | __all__ = [ 11 | "Transform", 12 | "VisionCrossAttentionMask", 13 | ] 14 | -------------------------------------------------------------------------------- /mirotrain/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2025 MiromindAI 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from ._chat import odr_chat_dataset 6 | from ._collate import padded_collate_dpo, padded_collate_packed 7 | from ._packed import StatefulDistributedStreamingPackedDataset 8 | 9 | __all__ = [ 10 | "odr_chat_dataset", 11 | "padded_collate_packed", 12 | "StatefulDistributedStreamingPackedDataset", 13 | "padded_collate_dpo", 14 | ] 15 | -------------------------------------------------------------------------------- /torchtune/tests/assets/README.md: -------------------------------------------------------------------------------- 1 | # Details on the assets in this folder 2 | 3 | ## `m.model` 4 | 5 | **Description**: 6 | **Creation**: 7 | **Usage**: 8 | 9 | 10 | ## `tiny_fair_checkpoint.pt` 11 | 12 | **Description**: 13 | **Creation**: 14 | **Usage**: 15 | 16 | ## `tiny_llama2_checkpoint.pt` 17 | 18 | **Description**: 19 | **Creation**: 20 | **Usage**: 21 | 22 | ## `tiny_state_dict_with_one_key.pt` 23 | 24 | **Description**: 25 | **Creation**: 26 | **Usage**: 27 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/llama3_3/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._model_builders import llama3_3_70b, lora_llama3_3_70b, qlora_llama3_3_70b # noqa 8 | 9 | __all__ = [ 10 | "llama3_3_70b", 11 | "lora_llama3_3_70b", 12 | "qlora_llama3_3_70b", 13 | ] 14 | -------------------------------------------------------------------------------- /torchtune/tests/test_import_recipes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | 9 | 10 | def test_import_recipes(): 11 | with pytest.raises( 12 | ModuleNotFoundError, match="The torchtune recipes directory isn't a package" 13 | ): 14 | import recipes # noqa 15 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/t5/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._component_builders import t5_encoder 8 | from ._model_builders import t5_tokenizer, t5_v1_1_xxl_encoder 9 | 10 | __all__ = [ 11 | "t5_encoder", 12 | "t5_tokenizer", 13 | "t5_v1_1_xxl_encoder", 14 | ] 15 | -------------------------------------------------------------------------------- /mirotrain/training/checkpointing/_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2025 MiromindAI 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from enum import Enum 6 | 7 | from torchtune.training.checkpointing._utils import ModelType as original_ModelType 8 | 9 | STEP_KEY = "global_step" 10 | 11 | ModelType = Enum( 12 | "ModelType", 13 | { 14 | **{member.name: member.value for member in original_ModelType}, 15 | "QWEN3": "qwen3", 16 | "QWEN3_MoE": "qwen3_moe", 17 | }, 18 | ) 19 | -------------------------------------------------------------------------------- /torchtune/torchtune/modules/moe/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .experts import GroupedExperts, LoRAGroupedExperts 8 | from .moe import MoE, TokenChoiceTopKRouter 9 | 10 | __all__ = [ 11 | "MoE", 12 | "GroupedExperts", 13 | "LoRAGroupedExperts", 14 | "TokenChoiceTopKRouter", 15 | ] 16 | -------------------------------------------------------------------------------- /torchtune/torchtune/_cli/subcommand.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | class Subcommand: 9 | def __init__(self, *args, **kwargs): 10 | pass 11 | 12 | @classmethod 13 | def create(cls, *args, **kwargs): 14 | return cls(*args, **kwargs) 15 | 16 | def _add_arguments(self): 17 | pass 18 | -------------------------------------------------------------------------------- /torchtune/torchtune/config/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._instantiate import instantiate 8 | from ._parse import parse 9 | from ._utils import log_config 10 | from ._validate import validate 11 | 12 | __all__ = [ 13 | "instantiate", 14 | "parse", 15 | "log_config", 16 | "validate", 17 | ] 18 | -------------------------------------------------------------------------------- /torchtune/torchtune/dev/rl/datatypes/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .request_output import RequestOutput 8 | from .trajectory import Trajectory 9 | from .vllm_completion_output import VllmCompletionOutput 10 | 11 | __all__ = [ 12 | "RequestOutput", 13 | "Trajectory", 14 | "VllmCompletionOutput", 15 | ] 16 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/phi4/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._model_builders import ( # noqa 8 | lora_phi4_14b, 9 | phi4_14b, 10 | phi4_tokenizer, 11 | qlora_phi4_14b, 12 | ) 13 | 14 | __all__ = [ 15 | "phi4_14b", 16 | "phi4_tokenizer", 17 | "lora_phi4_14b", 18 | "qlora_phi4_14b", 19 | ] 20 | -------------------------------------------------------------------------------- /mirotrain/training/checkpointing/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2025 MiromindAI 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from ._checkpoint_client import ( 6 | SDCheckpointClient, 7 | StepCheckpointClient, 8 | StepTrainingProgress, 9 | ) 10 | from ._checkpointer import FullModelHFCheckpointer 11 | from ._utils import STEP_KEY 12 | 13 | __all__ = [ 14 | "SDCheckpointClient", 15 | "FullModelHFCheckpointer", 16 | "STEP_KEY", 17 | "StepCheckpointClient", 18 | "StepTrainingProgress", 19 | ] 20 | -------------------------------------------------------------------------------- /mirotrain/modules/moe/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2025 MiromindAI 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from .dropless_layer import DroplessMoELayer 6 | from .expert_parallel import ( 7 | get_expert_parallel_group, 8 | get_expert_parallel_rank, 9 | get_expert_parallel_world_size, 10 | set_expert_parallel_group, 11 | ) 12 | 13 | __all__ = [ 14 | "DroplessMoELayer", 15 | "get_expert_parallel_group", 16 | "get_expert_parallel_rank", 17 | "get_expert_parallel_world_size", 18 | "set_expert_parallel_group", 19 | ] 20 | -------------------------------------------------------------------------------- /mirotrain/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2025 MiromindAI 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from .attention import MultiHeadAttentionWithUlysses 6 | from .transformer import ( 7 | MoETransformerDecoder, 8 | MoETransformerSelfAttentionLayer, 9 | SDTransformerDecoder, 10 | SDTransformerSelfAttentionLayer, 11 | ) 12 | 13 | __all__ = [ 14 | "MultiHeadAttentionWithUlysses", 15 | "MoETransformerDecoder", 16 | "MoETransformerSelfAttentionLayer", 17 | "SDTransformerDecoder", 18 | "SDTransformerSelfAttentionLayer", 19 | ] 20 | -------------------------------------------------------------------------------- /torchtune/torchtune/generation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._generation import ( 8 | generate, 9 | generate_next_token, 10 | get_causal_mask_from_padding_mask, 11 | get_position_ids_from_padding_mask, 12 | sample, 13 | ) 14 | 15 | __all__ = [ 16 | "generate", 17 | "generate_next_token", 18 | "get_causal_mask_from_padding_mask", 19 | "get_position_ids_from_padding_mask", 20 | "sample", 21 | ] 22 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/llama2/scripts/README.md: -------------------------------------------------------------------------------- 1 | # Verifying Correctness against Reference Implementations 2 | 3 | This repository puts a high bar on correctness and testing. To make sure our model and 4 | module implementations are correct, we compare our implementation against reference implementations 5 | where possible. This folder contains scripts used for these comparisons. 6 | 7 | 8 | ## Running the scripts 9 | 10 | You can run the scripts using the following command as an example. 11 | Each script should print out the value being used in the associated unit tests. 12 | 13 | ``` 14 | python3 -m tests.llm.llama2.scripts.compare_attention 15 | ``` 16 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/mistral/scripts/mistral_test_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass 8 | 9 | 10 | @dataclass 11 | class MistralTestConfig: 12 | BSZ = 2 13 | SEQ_LEN = 128 14 | EMBED_DIM = 64 15 | VOCAB_SIZE = 512 16 | NUM_LAYERS = 4 17 | NUM_HEADS = 4 18 | NUM_KV_HEADS = 2 19 | INTERMEDIATE_DIM = 512 20 | MAX_SEQ_LEN = 256 21 | ROPE_BASE = 10000 22 | NORM_EPS = 1e-5 23 | SEED = 16 24 | -------------------------------------------------------------------------------- /torchtune/torchtune/datasets/multimodal/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._llava_instruct import llava_instruct_dataset 8 | from ._multimodal import multimodal_chat_dataset 9 | from ._the_cauldron import the_cauldron_dataset, the_cauldron_transform 10 | from ._vqa import vqa_dataset 11 | 12 | __all__ = [ 13 | "the_cauldron_dataset", 14 | "the_cauldron_transform", 15 | "llava_instruct_dataset", 16 | "multimodal_chat_dataset", 17 | "vqa_dataset", 18 | ] 19 | -------------------------------------------------------------------------------- /torchtune/recipes/dev/gsm8k_sft.sbatch: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --time=01:00:00 3 | #SBATCH --constraint=volta32gb 4 | #SBATCH --nodes=1 5 | #SBATCH --ntasks-per-node=1 6 | #SBATCH --gpus-per-node=8 7 | #SBATCH --no-requeue 8 | #SBATCH --exclusive 9 | 10 | #SBATCH --job-name=torchtune 11 | #SBATCH --output=slurm_logs/%j.out 12 | #SBATCH --error=slurm_logs/%j.err 13 | 14 | # /\ Customize SBATCH directives to custommize your hardware 15 | 16 | # \/ Customize the virtual env/module load - this assumes a virtual env in root of torchtune 17 | source ../../.venv/bin/activate 18 | 19 | srun tune run \ 20 | --nnodes 1 \ 21 | --nproc_per_node 8 \ 22 | full_finetune_distributed --config dev/3B_sft_for_grpo "$@" 23 | -------------------------------------------------------------------------------- /torchtune/torchtune/modules/model_fusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._deep_fusion import DeepFusionModel 8 | from ._early_fusion import EarlyFusionModel 9 | from ._fusion_layers import FusionEmbedding, FusionLayer 10 | from ._fusion_utils import get_fusion_params, register_fusion_module 11 | 12 | __all__ = [ 13 | "DeepFusionModel", 14 | "FusionLayer", 15 | "FusionEmbedding", 16 | "register_fusion_module", 17 | "get_fusion_params", 18 | "EarlyFusionModel", 19 | ] 20 | -------------------------------------------------------------------------------- /torchtune/tests/assets/chat_tiny.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "conversations": [ 4 | { 5 | "from": "system", 6 | "value": "You are an AI assistant." 7 | }, 8 | { 9 | "from": "human", 10 | "value": "What is the meaning of life?" 11 | }, 12 | { 13 | "from": "gpt", 14 | "value": "The meaning of life is 42." 15 | }, 16 | { 17 | "from": "human", 18 | "value": "That's ridiculous." 19 | }, 20 | { 21 | "from": "gpt", 22 | "value": "I agree." 23 | } 24 | ] 25 | } 26 | ] 27 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/llama3_2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._component_builders import llama3_2, lora_llama3_2 8 | 9 | from ._model_builders import ( # noqa 10 | llama3_2_1b, 11 | llama3_2_3b, 12 | lora_llama3_2_1b, 13 | lora_llama3_2_3b, 14 | qlora_llama3_2_1b, 15 | qlora_llama3_2_3b, 16 | ) 17 | 18 | __all__ = [ 19 | "llama3_2", 20 | "llama3_2_1b", 21 | "llama3_2_3b", 22 | "lora_llama3_2", 23 | "lora_llama3_2_1b", 24 | "lora_llama3_2_3b", 25 | "qlora_llama3_2_1b", 26 | "qlora_llama3_2_3b", 27 | ] 28 | -------------------------------------------------------------------------------- /torchtune/torchtune/dev/rl/workers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .datacollectors import SyncLLMCollector 8 | from .metric_logger import MetricLoggerWorker 9 | from .parameter_servers import VLLMParameterServer 10 | from .postprocessing import PostProcessingWorker 11 | from .trainers import TrainingWorker 12 | from .weight_updaters import VLLMHFWeightUpdateReceiver 13 | 14 | __all__ = [ 15 | "SyncLLMCollector", 16 | "MetricLoggerWorker", 17 | "VLLMParameterServer", 18 | "PostProcessingWorker", 19 | "TrainingWorker", 20 | "VLLMHFWeightUpdateReceiver", 21 | ] 22 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/_cli/test_tune.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import runpy 9 | import sys 10 | 11 | from tests.common import TUNE_PATH 12 | 13 | 14 | class TestTuneCLI: 15 | def test_tune_without_args_returns_help(self, capsys, monkeypatch): 16 | testargs = ["tune"] 17 | 18 | monkeypatch.setattr(sys, "argv", testargs) 19 | runpy.run_path(TUNE_PATH, run_name="__main__") 20 | 21 | captured = capsys.readouterr() 22 | output = captured.out.rstrip("\n") 23 | 24 | assert "Welcome to the torchtune CLI!" in output 25 | -------------------------------------------------------------------------------- /torchtune/torchtune/utils/_import_guard.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import importlib 8 | 9 | import torch 10 | 11 | # We can only use flex attention / BlockMask if torch version >= 2.5.0 and GPU is Turing / SM75 and above 12 | _SUPPORTS_FLEX_ATTENTION = ( 13 | torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 5) 14 | ) 15 | 16 | _TORCHDATA_MIN_VERSION = "0.10.0" 17 | if ( 18 | importlib.util.find_spec("torchdata") is not None 19 | and importlib.util.find_spec("torchdata.nodes") is not None 20 | ): 21 | _TORCHDATA_INSTALLED = True 22 | else: 23 | _TORCHDATA_INSTALLED = False 24 | -------------------------------------------------------------------------------- /torchtune/torchtune/modules/_export/README.md: -------------------------------------------------------------------------------- 1 | # Export 2 | 3 | This directory provides [exportable](https://pytorch.org/docs/stable/export.html) variants of torchtune modules. 4 | 5 | Modules in this directory: 6 | 7 | * Take the same arguments to `__init__()` and `forward()` as the corresponding reference modules in torchtune. 8 | * Give the output as the reference module in torchtune (unless stated otherwise in the docstring). 9 | * Are guaranteed to work out of the box with torch.export.export(). 10 | * Should work out of the box with torch.aot_compile(). 11 | 12 | All modules should be covered by unit tests (under `tests/torchtune/modules/_export/`) that runs daily and on PRs touching this directory. 13 | 14 | These modules are subject to change so proceed with caution. 15 | 16 | Contributors: @larryliu0820, @Jack-Khuu, @dvorjackz 17 | -------------------------------------------------------------------------------- /torchtune/torchtune/modules/low_precision/_register_nf4_dispatch_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torchao.dtypes.nf4tensor import implements as nf4_tensor_impl, to_nf4 9 | 10 | 11 | @nf4_tensor_impl([torch.ops.aten.clone.default]) 12 | def clone(func, *args, **kwargs): 13 | """ 14 | __torch_dispatch__ override that is called when cloning an NF4Tensor. 15 | This is implemented by creating a new NF4Tensor with the unquantized weight 16 | of the input tensor. Note that this is not an exact "clone" due to the loss 17 | in precision. 18 | """ 19 | return to_nf4(args[0][0].get_original_weight()) 20 | -------------------------------------------------------------------------------- /torchtune/torchtune/training/_model_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import warnings 7 | 8 | import torch 9 | 10 | 11 | def disable_dropout(model: torch.nn.Module) -> None: 12 | """ 13 | Disables dropout layers in the given model. 14 | 15 | Args: 16 | model (torch.nn.Module): The model in which dropout layers should be disabled. 17 | """ 18 | for module in model.modules(): 19 | if isinstance(module, torch.nn.Dropout) and module.p != 0: 20 | warnings.warn( 21 | f"Found Dropout with value {module.p} in module {module}. Setting to zero." 22 | ) 23 | module.p = 0 24 | -------------------------------------------------------------------------------- /torchtune/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Check at the top-level that torchao is installed. 8 | # This is better than doing it at every import site. 9 | # We have to do this because it is not currently possible to 10 | # properly support both nightly and stable installs of PyTorch + torchao 11 | # in pyproject.toml. 12 | try: 13 | import torchao # noqa 14 | except ImportError as e: 15 | raise ImportError( 16 | """ 17 | torchao not installed. 18 | Please follow the instructions at https://pytorch.org/torchtune/main/install.html#pre-requisites 19 | to install torchao. 20 | """ 21 | ) from e 22 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/gemma/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._component_builders import gemma, lora_gemma # noqa 8 | from ._model_builders import ( # noqa 9 | gemma_2b, 10 | gemma_7b, 11 | gemma_tokenizer, 12 | lora_gemma_2b, 13 | lora_gemma_7b, 14 | qlora_gemma_2b, 15 | qlora_gemma_7b, 16 | ) 17 | from ._tokenizer import GemmaTokenizer # noqa 18 | 19 | __all__ = [ 20 | "GemmaTokenizer", 21 | "gemma", 22 | "gemma_2b", 23 | "gemma_7b", 24 | "gemma_tokenizer", 25 | "lora_gemma", 26 | "lora_gemma_2b", 27 | "lora_gemma_7b", 28 | "qlora_gemma_2b", 29 | "qlora_gemma_7b", 30 | ] 31 | -------------------------------------------------------------------------------- /torchtune/torchtune/dev/rl/datatypes/trajectory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List 8 | 9 | import torch 10 | from tensordict import TensorClass 11 | from torchtune.dev.rl.rewards import RewardOutput 12 | 13 | 14 | class Trajectory(TensorClass["nocast"]): 15 | query_responses: torch.Tensor 16 | responses: torch.Tensor 17 | logprobs: torch.Tensor 18 | ref_logprobs: torch.Tensor 19 | query_response_padding_masks: torch.Tensor 20 | seq_lens: torch.Tensor 21 | answers: torch.Tensor 22 | policy_version: int 23 | advantages: torch.Tensor 24 | reward_outputs: List[RewardOutput] 25 | sequence_ids: List[str] 26 | -------------------------------------------------------------------------------- /mirotrain/modules/moe/grouped_gemm_util.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2025 MiromindAI 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | # Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/grouped_gemm_util.py 6 | 7 | try: 8 | import grouped_gemm 9 | except ImportError: 10 | grouped_gemm = None 11 | 12 | 13 | def grouped_gemm_is_available(): 14 | """Check if grouped_gemm is available.""" 15 | return grouped_gemm is not None 16 | 17 | 18 | def assert_grouped_gemm_is_available(): 19 | """Assert that grouped_gemm is available.""" 20 | assert grouped_gemm_is_available(), ( 21 | "Grouped GEMM is not available. Please run " 22 | "`pip install git+https://github.com/fanshiqing/grouped_gemm@v1.1.4`." 23 | ) 24 | 25 | 26 | ops = grouped_gemm.ops if grouped_gemm_is_available() else None 27 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/modules/transforms/test_pad_dim_to_size.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | 9 | import torch 10 | 11 | from torchtune.modules.transforms.vision_utils.pad_dim_to_size import pad_dim_to_size 12 | 13 | 14 | def test_pad_dim_to_size(): 15 | image = torch.ones(2, 2, 2, 2, dtype=torch.float16) 16 | image = pad_dim_to_size(image, 4, 1) 17 | assert image.shape == (2, 4, 2, 2) 18 | assert image.mean() == 0.5, "Expected mean to be 0.5 after padding" 19 | assert image.dtype == torch.float16, "Expected dtype to be float16 after padding" 20 | 21 | with pytest.raises(Exception): 22 | pad_dim_to_size(image, 2, 1) 23 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/mistral/scripts/README.md: -------------------------------------------------------------------------------- 1 | ## Verifying correctness 2 | This directory compares the current implementation of `mistral` to the reference implementation at https://github.com/mistralai/mistral-src/blob/main/one_file_ref.py. Additionally, `torchtune.models.mistral._component_builders.mistral_mlp` is compared in `tests.torchtune.models.mistral.scripts.compare_feed_forward.py` 3 | 4 | Since `torchtune.models.mistral` shares nearly all components with `torchtune.models.llama2`, please see `tests.torchtune.models.llama2.scripts` for comparison scripts for individual components. 5 | 6 | ## Running the scripts 7 | 8 | You can run the scripts using the following command as an example. 9 | Each script should print out the value being used in the associated unit tests. 10 | 11 | ``` 12 | python3 -m tests.torchtune.models.mistral.scripts.compare_mistral 13 | ``` 14 | -------------------------------------------------------------------------------- /mirotrain/training/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2025 MiromindAI 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from ._distributed import gather_cpu_state_dict, ParallelDims, shard_model 6 | from ._grad_scaler import scale_grads_ 7 | from .checkpointing import ( 8 | FullModelHFCheckpointer, 9 | SDCheckpointClient, 10 | STEP_KEY, 11 | StepCheckpointClient, 12 | StepTrainingProgress, 13 | ) 14 | from .clip_grad import clip_grad_norm_ 15 | from .lr_schedulers import get_cosine_schedule_with_warmup 16 | 17 | __all__ = [ 18 | "FullModelHFCheckpointer", 19 | "ParallelDims", 20 | "gather_cpu_state_dict", 21 | "shard_model", 22 | "scale_grads_", 23 | "clip_grad_norm_", 24 | "STEP_KEY", 25 | "StepCheckpointClient", 26 | "StepTrainingProgress", 27 | "get_cosine_schedule_with_warmup", 28 | "SDCheckpointClient", 29 | ] 30 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/code_llama2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._model_builders import ( # noqa 8 | code_llama2_13b, 9 | code_llama2_70b, 10 | code_llama2_7b, 11 | lora_code_llama2_13b, 12 | lora_code_llama2_70b, 13 | lora_code_llama2_7b, 14 | qlora_code_llama2_13b, 15 | qlora_code_llama2_70b, 16 | qlora_code_llama2_7b, 17 | ) 18 | 19 | __all__ = [ 20 | "code_llama2_13b", 21 | "code_llama2_70b", 22 | "code_llama2_7b", 23 | "lora_code_llama2_13b", 24 | "lora_code_llama2_70b", 25 | "lora_code_llama2_7b", 26 | "qlora_code_llama2_13b", 27 | "qlora_code_llama2_70b", 28 | "qlora_code_llama2_7b", 29 | ] 30 | -------------------------------------------------------------------------------- /torchtune/torchtune/modules/tanh_gate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from torch import nn 10 | 11 | 12 | class TanhGate(nn.Module): 13 | """Implements a basic learnable gate to scale layer outputs""" 14 | 15 | def __init__(self) -> None: 16 | super().__init__() 17 | self.scale = nn.Parameter(torch.zeros(1)) 18 | 19 | def forward(self, x: torch.Tensor) -> torch.Tensor: 20 | """ 21 | Args: 22 | x (torch.Tensor): input tensor to gate 23 | 24 | Returns: 25 | torch.Tensor: The output tensor after gating. Has the same shape as ``x``. 26 | """ 27 | return x * self.scale.tanh() 28 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/datasets/multimodal/test_multimodal_chat_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | from tests.test_utils import DummyTokenizer 9 | 10 | from torchtune.datasets.multimodal import multimodal_chat_dataset 11 | 12 | 13 | class TestMultimodalChatDataset: 14 | @pytest.fixture 15 | def tokenizer(self): 16 | return DummyTokenizer() 17 | 18 | def test_dataset_fails_with_packed(self, tokenizer): 19 | with pytest.raises( 20 | ValueError, match="Multimodal datasets don't support packing yet." 21 | ): 22 | multimodal_chat_dataset( 23 | model_transform=tokenizer, source="json", packed=True 24 | ) 25 | -------------------------------------------------------------------------------- /torchtune/torchtune/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._device import ( 8 | batch_to_device, 9 | DeviceSupport, 10 | get_device, 11 | get_device_support, 12 | get_torch_device_namespace, 13 | get_world_size_and_rank, 14 | ) 15 | from ._logging import deprecated, get_logger, log_once, log_rank_zero 16 | 17 | from ._version import torch_version_ge 18 | 19 | __all__ = [ 20 | "get_world_size_and_rank", 21 | "batch_to_device", 22 | "get_device", 23 | "get_logger", 24 | "torch_version_ge", 25 | "get_device_support", 26 | "get_torch_device_namespace", 27 | "DeviceSupport", 28 | "log_rank_zero", 29 | "deprecated", 30 | "log_once", 31 | ] 32 | -------------------------------------------------------------------------------- /torchtune/recipes/configs/quantization.yaml: -------------------------------------------------------------------------------- 1 | # Config for QuantizationRecipe in quantize.py 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run quantize --config quantization 5 | 6 | output_dir: /tmp/torchtune/llama2_7B/quantized # /tmp may be deleted by your system. Change it to your preference. 7 | 8 | # 9 | # Model arguments 10 | model: 11 | _component_: torchtune.models.llama2.llama2_7b 12 | 13 | checkpointer: 14 | _component_: torchtune.training.FullModelHFCheckpointer 15 | checkpoint_dir: /tmp/Llama-2-7b-hf 16 | checkpoint_files: [ 17 | pytorch_model-00001-of-00002.bin, 18 | pytorch_model-00002-of-00002.bin, 19 | ] 20 | recipe_checkpoint: null 21 | output_dir: ${output_dir} 22 | model_type: LLAMA2 23 | 24 | device: cuda 25 | dtype: bf16 26 | seed: 1234 27 | 28 | quantizer: 29 | _component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer 30 | groupsize: 256 31 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/clip/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._component_builders import clip_mlp, clip_text_encoder, clip_vision_encoder 8 | from ._model_builders import clip_text_vit_large_patch14, clip_tokenizer 9 | from ._position_embeddings import ( 10 | TiledTokenPositionalEmbedding, 11 | TilePositionalEmbedding, 12 | TokenPositionalEmbedding, 13 | ) 14 | from ._transform import CLIPImageTransform 15 | 16 | __all__ = [ 17 | "clip_mlp", 18 | "clip_text_encoder", 19 | "clip_vision_encoder", 20 | "clip_text_vit_large_patch14", 21 | "clip_tokenizer", 22 | "CLIPImageTransform", 23 | "TokenPositionalEmbedding", 24 | "TiledTokenPositionalEmbedding", 25 | "TilePositionalEmbedding", 26 | ] 27 | -------------------------------------------------------------------------------- /torchtune/torchtune/modules/transforms/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._gpt2 import GPT2BaseTokenizer 8 | from ._hf_tokenizer import HuggingFaceBaseTokenizer 9 | from ._sentencepiece import SentencePieceBaseTokenizer 10 | from ._tiktoken import TikTokenBaseTokenizer 11 | from ._utils import ( 12 | BaseTokenizer, 13 | ModelTokenizer, 14 | parse_hf_tokenizer_json, 15 | tokenize_messages_no_special_tokens, 16 | ) 17 | 18 | __all__ = [ 19 | "SentencePieceBaseTokenizer", 20 | "TikTokenBaseTokenizer", 21 | "ModelTokenizer", 22 | "GPT2BaseTokenizer", 23 | "BaseTokenizer", 24 | "tokenize_messages_no_special_tokens", 25 | "parse_hf_tokenizer_json", 26 | "HuggingFaceBaseTokenizer", 27 | ] 28 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/llama3/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._component_builders import llama3, lora_llama3 8 | 9 | from ._model_builders import ( # noqa 10 | llama3_70b, 11 | llama3_8b, 12 | llama3_tokenizer, 13 | lora_llama3_70b, 14 | lora_llama3_8b, 15 | qlora_llama3_70b, 16 | qlora_llama3_8b, 17 | ) 18 | from ._parallelism import base_llama_tp_plan 19 | from ._tokenizer import Llama3Tokenizer 20 | 21 | __all__ = [ 22 | "Llama3Tokenizer", 23 | "llama3", 24 | "llama3_8b", 25 | "llama3_70b", 26 | "llama3_tokenizer", 27 | "lora_llama3", 28 | "lora_llama3_8b", 29 | "lora_llama3_70b", 30 | "qlora_llama3_8b", 31 | "qlora_llama3_70b", 32 | "base_llama_tp_plan", 33 | ] 34 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/phi3/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._component_builders import lora_phi3, phi3 # noqa 8 | from ._convert_weights import phi3_hf_to_tune, phi3_tune_to_hf # noqa 9 | from ._model_builders import ( # noqa 10 | lora_phi3_mini, 11 | phi3_mini, 12 | phi3_mini_tokenizer, 13 | qlora_phi3_mini, 14 | ) 15 | from ._position_embeddings import Phi3RotaryPositionalEmbeddings # noqa 16 | from ._tokenizer import Phi3MiniTokenizer # noqa 17 | 18 | __all__ = [ 19 | "phi3_mini", 20 | "phi3_mini_tokenizer", 21 | "lora_phi3_mini", 22 | "qlora_phi3_mini", 23 | "Phi3RotaryPositionalEmbeddings", 24 | "Phi3MiniTokenizer", 25 | "phi3_hf_to_tune", 26 | "phi3_tune_to_hf", 27 | "phi3", 28 | "lora_phi3", 29 | ] 30 | -------------------------------------------------------------------------------- /torchtune/docs/source/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | 3 | {% block sidebartitle %} 4 |
5 | {{ version }} ▼ 6 |
7 | {% include "searchbox.html" %} 8 | {% endblock %} 9 | 10 | 11 | {% block footer %} 12 | 13 | 17 | 18 | 21 | {{ super() }} 22 | 27 | {% endblock %} 28 | -------------------------------------------------------------------------------- /torchtune/torchtune/modules/loss/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .ce_chunked_output_loss import CEWithChunkedOutputLoss 8 | 9 | from .cross_entropy_loss import LinearCrossEntropyLoss 10 | from .kd_losses import ( 11 | ForwardKLLoss, 12 | ForwardKLWithChunkedOutputLoss, 13 | ReverseKLLoss, 14 | ReverseKLWithChunkedOutputLoss, 15 | SymmetricKLLoss, 16 | SymmetricKLWithChunkedOutputLoss, 17 | ) 18 | from .loss_types import RLLoss, SFTLoss 19 | 20 | __all__ = [ 21 | "CEWithChunkedOutputLoss", 22 | "ForwardKLLoss", 23 | "ForwardKLWithChunkedOutputLoss", 24 | "ReverseKLLoss", 25 | "ReverseKLWithChunkedOutputLoss", 26 | "SymmetricKLLoss", 27 | "SymmetricKLWithChunkedOutputLoss", 28 | "LinearCrossEntropyLoss", 29 | "SFTLoss", 30 | "RLLoss", 31 | ] 32 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/_cli/test_ls.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import runpy 7 | import sys 8 | 9 | from tests.common import TUNE_PATH 10 | 11 | from torchtune._recipe_registry import get_all_recipes 12 | 13 | 14 | class TestTuneListCommand: 15 | """This class tests the `tune ls` command.""" 16 | 17 | def test_ls_lists_all_recipes_and_configs(self, capsys, monkeypatch): 18 | testargs = "tune ls".split() 19 | 20 | monkeypatch.setattr(sys, "argv", testargs) 21 | runpy.run_path(TUNE_PATH, run_name="__main__") 22 | 23 | captured = capsys.readouterr() 24 | output = captured.out.rstrip("\n") 25 | 26 | for recipe in get_all_recipes(): 27 | assert recipe.name in output 28 | for config in recipe.configs: 29 | assert config.name in output 30 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/llama2/_model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | def scale_hidden_dim_for_mlp(dim: int, multiple_of: int = 256) -> int: 9 | """Scale hidden dimension for MLP to keep number of parameters and computation constant. 10 | 11 | Args: 12 | dim (int): Input dimension. 13 | multiple_of (int): Round scaled dimension to nearest multiple of `multiple_of` for clean computation. 14 | 15 | Returns: 16 | Scaled hidden dimension. 17 | """ 18 | # Scale hidden dimension by (2/3)4d for SwiGLU to keep number of 19 | # parameters and computation constant 20 | hidden_dim = 4 * int(2 * dim / 3) 21 | # Round hidden dimension to nearest multiple of `multiple_of` 22 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 23 | return hidden_dim 24 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/llama3/_model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | def scale_hidden_dim_for_mlp(dim: int, multiple_of: int = 256) -> int: 9 | """Scale hidden dimension for MLP to keep number of parameters and computation constant. 10 | 11 | Args: 12 | dim (int): Input dimension. 13 | multiple_of (int): Round scaled dimension to nearest multiple of `multiple_of` for clean computation. 14 | 15 | Returns: 16 | Scaled hidden dimension. 17 | """ 18 | # Scale hidden dimension by (2/3)4d for SwiGLU to keep number of 19 | # parameters and computation constant 20 | hidden_dim = 4 * int(2 * dim / 3) 21 | # Round hidden dimension to nearest multiple of `multiple_of` 22 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 23 | return hidden_dim 24 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/llama3_1/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._component_builders import llama3_1, lora_llama3_1 8 | 9 | from ._model_builders import ( # noqa 10 | llama3_1_405b, 11 | llama3_1_70b, 12 | llama3_1_8b, 13 | lora_llama3_1_405b, 14 | lora_llama3_1_70b, 15 | lora_llama3_1_8b, 16 | qlora_llama3_1_405b, 17 | qlora_llama3_1_70b, 18 | qlora_llama3_1_8b, 19 | ) 20 | from ._position_embeddings import Llama3ScaledRoPE 21 | 22 | __all__ = [ 23 | "llama3_1", 24 | "llama3_1_8b", 25 | "llama3_1_70b", 26 | "llama3_1_405b", 27 | "lora_llama3_1", 28 | "lora_llama3_1_8b", 29 | "lora_llama3_1_70b", 30 | "lora_llama3_1_405b", 31 | "qlora_llama3_1_8b", 32 | "qlora_llama3_1_70b", 33 | "qlora_llama3_1_405b", 34 | "Llama3ScaledRoPE", 35 | ] 36 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/gemma/rms_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | 10 | 11 | class GemmaRMSNorm(nn.Module): 12 | # Copied from https://github.com/google/gemma_pytorch/blob/main/gemma/model.py 13 | def __init__(self, dim: int, eps: float = 1e-6): 14 | super().__init__() 15 | self.eps = eps 16 | self.scale = nn.Parameter(torch.zeros(dim)) 17 | 18 | def _norm(self, x): 19 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 20 | 21 | def forward(self, x): 22 | output = self._norm(x.float()) 23 | # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) 24 | # See https://github.com/huggingface/transformers/pull/29402 25 | output = output * (1.0 + self.scale.float()) 26 | return output.type_as(x) 27 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/gemma2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ..gemma._model_builders import gemma_tokenizer 8 | from ..gemma._tokenizer import GemmaTokenizer # noqa 9 | from ._component_builders import gemma2, lora_gemma2 # noqa 10 | from ._model_builders import ( # noqa 11 | gemma2_27b, 12 | gemma2_2b, 13 | gemma2_9b, 14 | lora_gemma2_27b, 15 | lora_gemma2_2b, 16 | lora_gemma2_9b, 17 | qlora_gemma2_27b, 18 | qlora_gemma2_2b, 19 | qlora_gemma2_9b, 20 | ) 21 | 22 | __all__ = [ 23 | "GemmaTokenizer", 24 | "gemma2", 25 | "gemma2_2b", 26 | "gemma2_9b", 27 | "gemma2_27b", 28 | "gemma_tokenizer", 29 | "lora_gemma2", 30 | "lora_gemma2_2b", 31 | "lora_gemma2_9b", 32 | "lora_gemma2_27b", 33 | "qlora_gemma2_2b", 34 | "qlora_gemma2_9b", 35 | "qlora_gemma2_27b", 36 | ] 37 | -------------------------------------------------------------------------------- /torchtune/torchtune/modules/peft/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._utils import ( # noqa 8 | AdapterModule, 9 | disable_adapter, 10 | get_adapter_params, 11 | get_adapter_state_dict, 12 | get_lora_module_names, 13 | get_merged_lora_ckpt, 14 | LORA_ATTN_MODULES, 15 | set_trainable_params, 16 | validate_missing_and_unexpected_for_lora, 17 | ) 18 | from .dora import DoRALinear 19 | from .lora import LoRALinear, QATLoRALinear, TrainableParams 20 | 21 | 22 | __all__ = [ 23 | "AdapterModule", 24 | "DoRALinear", 25 | "LoRALinear", 26 | "QATLoRALinear", 27 | "get_adapter_params", 28 | "set_trainable_params", 29 | "validate_missing_and_unexpected_for_lora", 30 | "disable_adapter", 31 | "get_adapter_state_dict", 32 | "get_merged_lora_ckpt", 33 | "get_lora_module_names", 34 | "TrainableParams", 35 | ] 36 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/qwen2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._component_builders import lora_qwen2, qwen2 # noqa 8 | from ._convert_weights import qwen2_hf_to_tune, qwen2_tune_to_hf # noqa 9 | from ._model_builders import ( 10 | lora_qwen2_0_5b, 11 | lora_qwen2_1_5b, 12 | lora_qwen2_7b, 13 | qwen2_0_5b, 14 | qwen2_1_5b, 15 | qwen2_7b, 16 | qwen2_tokenizer, 17 | ) 18 | from ._positional_embeddings import Qwen2RotaryPositionalEmbeddings 19 | from ._tokenizer import Qwen2Tokenizer 20 | 21 | __all__ = [ 22 | "lora_qwen2", 23 | "qwen2", 24 | "qwen2_hf_to_tune", 25 | "qwen2_tune_to_hf", 26 | "lora_qwen2_0_5b", 27 | "lora_qwen2_1_5b", 28 | "lora_qwen2_7b", 29 | "qwen2_0_5b", 30 | "qwen2_1_5b", 31 | "qwen2_7b", 32 | "qwen2_tokenizer", 33 | "Qwen2RotaryPositionalEmbeddings", 34 | "Qwen2Tokenizer", 35 | ] 36 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # Suggested config from pytorch that we can adapt 3 | select = B,C,E,F,N,P,T4,W,B9,TOR0,TOR1,TOR2 4 | max-line-length = 120 5 | # C408 ignored because we like the dict keyword argument syntax 6 | # E501 is not flexible enough, we're using B950 instead 7 | # N812 ignored because import torch.nn.functional as F is PyTorch convention 8 | # N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP) 9 | # E731 allow usage of assigning lambda expressions 10 | ignore = 11 | E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731 12 | # shebang has extra meaning in fbcode lints, so I think it's not worth trying 13 | # to line this up with executable bit 14 | EXE001, 15 | # these ignores are from flake8-bugbear; please fix! 16 | B007,B008, 17 | optional-ascii-coding = True 18 | exclude = 19 | ./.git, 20 | ./docs 21 | ./build 22 | ./scripts, 23 | ./venv, 24 | *.pyi 25 | .pre-commit-config.yaml 26 | *.md 27 | .flake8 28 | tests/torchtune/models/llama2/scripts/*.py 29 | -------------------------------------------------------------------------------- /torchtune/torchtune/dev/rl/workers/metric_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import ray 8 | from torchtune import config 9 | 10 | 11 | @ray.remote(num_cpus=1, num_gpus=0) 12 | class MetricLoggerWorker: 13 | def __init__(self, cfg): 14 | self.logger = config.instantiate(cfg.metric_logger) 15 | self.logger.log_config(cfg) 16 | 17 | def log_dict(self, log_dict, step=None): 18 | # allowing actors to use their own step counters 19 | self.logger.log_dict(log_dict, step=step) 20 | 21 | def log_table(self, table_data, columns, table_name, step=None): 22 | """Log a table to WandB.""" 23 | import wandb 24 | 25 | table = wandb.Table(columns=columns, data=table_data) 26 | self.logger.log_dict({table_name: table}, step=step) 27 | 28 | def close(self): 29 | if hasattr(self.logger, "close"): 30 | self.logger.close() 31 | -------------------------------------------------------------------------------- /torchtune/tests/recipes/test_configs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | from pathlib import Path 8 | 9 | import torchtune 10 | 11 | from omegaconf import OmegaConf 12 | from torchao.utils import TORCH_VERSION_AFTER_2_4 13 | from torchtune import config 14 | 15 | CONFIG_DIR = Path(torchtune.__file__).parent.parent / "recipes" / "configs" 16 | 17 | 18 | class TestConfigs: 19 | def test_instantiate(self) -> None: 20 | all_configs = [ 21 | os.path.join(CONFIG_DIR, f) 22 | for f in os.listdir(CONFIG_DIR) 23 | if f.endswith(".yaml") 24 | ] 25 | for config_path in all_configs: 26 | # QAT config is only compatible with PyTorch 2.4+ 27 | if config_path.endswith("qat_full.yaml") and not TORCH_VERSION_AFTER_2_4: 28 | continue 29 | cfg = OmegaConf.load(config_path) 30 | config.validate(cfg) 31 | -------------------------------------------------------------------------------- /mirotrain/models/convert_weights.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2025 MiromindAI 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | # Adapted from torchtune.models.convert_weights 6 | 7 | import re 8 | 9 | from typing import Dict 10 | 11 | 12 | def get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str: 13 | try: 14 | # Checks if there is a layer # in the key 15 | if any(k.isdigit() for k in key.split(".")): 16 | # Replace layer number with "{}" to create key for lookup 17 | abstract_key = re.sub(r"(\.\d+)", ".{}", key) 18 | # search all layer number 19 | layer_nums = re.findall(r"\d+", key) 20 | new_key = mapping_dict[abstract_key] 21 | new_key = new_key.format(*layer_nums) 22 | else: 23 | new_key = mapping_dict[key] 24 | except KeyError as e: 25 | raise Exception( 26 | f'Error converting the state dict. Found unexpected key: "{key}". ' 27 | "Please make sure you're loading a checkpoint with the right format. " 28 | ) from e 29 | 30 | return new_key 31 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/modules/model_fusion/test_fusion_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torch import nn 8 | from torchtune.modules.model_fusion import get_fusion_params, register_fusion_module 9 | 10 | 11 | def test_register_fusion_module(): 12 | """ 13 | Test that all parameters are returned as fusion_params. 14 | """ 15 | model = nn.Linear(1, 1) 16 | register_fusion_module(model) 17 | 18 | fusion_params = set(model.fusion_params()) 19 | assert fusion_params == {"weight", "bias"} 20 | 21 | 22 | def test_get_fusion_params(): 23 | """ 24 | Test that the correct parameters are returned as fusion_params. 25 | """ 26 | layer1 = nn.Linear(1, 1) 27 | layer2 = nn.Linear(1, 1) 28 | register_fusion_module(layer2) 29 | model = nn.Sequential(layer1, layer2) 30 | 31 | fusion_params = set(get_fusion_params(model)) 32 | assert fusion_params == {"1.weight", "1.bias"} 33 | -------------------------------------------------------------------------------- /torchtune/torchtune/config/_errors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List 8 | 9 | 10 | class InstantiationError(Exception): 11 | """ 12 | Raised when a `_component_` field in a config is unable to be instantiated. 13 | """ 14 | 15 | pass 16 | 17 | 18 | class ConfigError(Exception): 19 | """ 20 | Raised when the yaml config is not well-formed. Prints all the collected 21 | errors at once. 22 | 23 | Args: 24 | errors (List[Exception]): exceptions found when validating `_component_` 25 | fields in the config 26 | """ 27 | 28 | def __init__(self, errors: List[Exception]): 29 | self.errors = errors 30 | 31 | def __str__(self): 32 | error_messages = [f"{type(e).__name__}: {str(e)}" for e in self.errors] 33 | return "Config is not well-formed, found the following errors: \n" + "\n".join( 34 | error_messages 35 | ) 36 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/mistral/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._component_builders import ( 8 | lora_mistral, 9 | lora_mistral_classifier, 10 | mistral, 11 | mistral_classifier, 12 | ) 13 | from ._model_builders import ( 14 | lora_mistral_7b, 15 | lora_mistral_reward_7b, 16 | mistral_7b, 17 | mistral_reward_7b, 18 | mistral_tokenizer, 19 | qlora_mistral_7b, 20 | qlora_mistral_reward_7b, 21 | ) 22 | from ._prompt_template import MistralChatTemplate 23 | from ._tokenizer import MistralTokenizer 24 | 25 | __all__ = [ 26 | "MistralTokenizer", 27 | "MistralChatTemplate", 28 | "lora_mistral", 29 | "lora_mistral_classifier", 30 | "mistral", 31 | "mistral_classifier", 32 | "lora_mistral_7b", 33 | "lora_mistral_reward_7b", 34 | "mistral_7b", 35 | "mistral_reward_7b", 36 | "mistral_tokenizer", 37 | "qlora_mistral_7b", 38 | "qlora_mistral_reward_7b", 39 | ] 40 | -------------------------------------------------------------------------------- /torchtune/recipes/configs/gemma/evaluation.yaml: -------------------------------------------------------------------------------- 1 | # Config for EleutherEvalRecipe in eleuther_eval.py 2 | # 3 | # To launch, run the following command: 4 | # tune run eleuther_eval --config gemma/evaluation 5 | 6 | output_dir: ./ # Not needed 7 | 8 | # Model Arguments 9 | model: 10 | _component_: torchtune.models.gemma.gemma_2b 11 | 12 | # Checkpointer 13 | checkpointer: 14 | _component_: torchtune.training.FullModelHFCheckpointer 15 | checkpoint_dir: /tmp/gemma-2b 16 | checkpoint_files: [ 17 | model-00001-of-00002.safetensors, 18 | model-00002-of-00002.safetensors, 19 | ] 20 | output_dir: ${output_dir} 21 | model_type: GEMMA 22 | 23 | # Tokenizer 24 | tokenizer: 25 | _component_: torchtune.models.gemma.gemma_tokenizer 26 | path: /tmp/gemma-2b/tokenizer.model 27 | 28 | # Environment 29 | device: cuda 30 | dtype: bf16 31 | seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed 32 | 33 | # EleutherAI specific eval args 34 | tasks: ["truthfulqa_mc2"] 35 | limit: null 36 | max_seq_length: 4096 37 | batch_size: 8 38 | enable_kv_cache: True 39 | 40 | # Quantization specific args 41 | quantizer: null 42 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: 'build' 2 | 3 | default_language_version: 4 | python: python3 5 | 6 | repos: 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v5.0.0 9 | hooks: 10 | - id: trailing-whitespace 11 | - id: check-ast 12 | - id: check-merge-conflict 13 | - id: no-commit-to-branch 14 | args: ['--branch=main'] 15 | - id: check-added-large-files 16 | args: ['--maxkb=1000'] 17 | - id: end-of-file-fixer 18 | exclude: '^(.*\.svg)$' 19 | 20 | - repo: https://github.com/pycqa/flake8 21 | rev: 7.1.1 22 | hooks: 23 | - id: flake8 24 | additional_dependencies: 25 | - flake8-bugbear == 22.4.25 26 | - pep8-naming == 0.12.1 27 | - torchfix 28 | args: ['--config=.flake8'] 29 | 30 | - repo: https://github.com/omnilib/ufmt 31 | rev: v2.8.0 32 | hooks: 33 | - id: ufmt 34 | additional_dependencies: 35 | - black == 22.12.0 36 | - usort == 1.0.5 37 | 38 | - repo: https://github.com/jsh9/pydoclint 39 | rev: 0.5.12 40 | hooks: 41 | - id: pydoclint 42 | args: [--config=pyproject.toml] 43 | -------------------------------------------------------------------------------- /mirotrain/monkey/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2025 MiromindAI 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | # -*- coding: utf-8 -*- 6 | """ 7 | author: lei.lei@shanda.com 8 | time: 2025/05/29 10:38 9 | description: gevent style monkey patching: 10 | 1. explicit: user calls `mirotrain.monkey.patch_common()`. 11 | 2. declarative: module under `mirotrain` uses `__targets__` and `__implements__` to inform WHAT to patch. 12 | 3. silent error: when patching failed, gives a warning and continue. no exception will be raised. 13 | """ 14 | import logging 15 | 16 | from ._state import is_anything_patched, is_module_patched, is_object_patched 17 | from .api import patch_by_source_module_fqn 18 | 19 | __all__ = [ 20 | "patch_common", 21 | "is_object_patched", 22 | "is_module_patched", 23 | "is_anything_patched", 24 | ] 25 | 26 | logger = logging.getLogger("mirotrain") 27 | 28 | 29 | def patch_common(): 30 | """TODO: add other patchese here""" 31 | patch_by_source_module_fqn("mirotrain.data._messages") 32 | patch_by_source_module_fqn("mirotrain.datasets._packed") 33 | patch_by_source_module_fqn("mirotrain.training.checkpointing._checkpointer") 34 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/config/test_validate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | from omegaconf import OmegaConf 9 | from torchtune import config 10 | from torchtune.config._errors import ConfigError 11 | 12 | VALID_CONFIG_PATH = "tests/assets/valid_dummy_config.yaml" 13 | INVALID_CONFIG_PATH = "tests/assets/invalid_dummy_config.yaml" 14 | 15 | 16 | class TestValidate: 17 | def test_validate(self): 18 | conf = OmegaConf.load(VALID_CONFIG_PATH) 19 | # Test a valid component 20 | config.validate(conf) 21 | # Test an invalid component 22 | conf = OmegaConf.load(INVALID_CONFIG_PATH) 23 | with pytest.raises(ConfigError) as excinfo: 24 | config.validate(conf) 25 | exc_config = excinfo.value 26 | assert len(exc_config.errors) == 2 27 | for e in exc_config.errors: 28 | assert isinstance(e, TypeError) 29 | assert str(e) == "get_dtype got an unexpected keyword argument 'dummy'" 30 | -------------------------------------------------------------------------------- /torchtune/recipes/configs/qwen3/evaluation.yaml: -------------------------------------------------------------------------------- 1 | # Config for EleutherEvalRecipe in eleuther_eval.py 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run eleuther_eval --config qwen3/evaluation 5 | 6 | output_dir: ./ # Not needed 7 | 8 | # Model Arguments 9 | model: 10 | _component_: torchtune.models.qwen3.qwen3_0_6b_instruct 11 | 12 | checkpointer: 13 | _component_: torchtune.training.FullModelHFCheckpointer 14 | checkpoint_dir: /tmp/Qwen3-0.6B 15 | checkpoint_files: [ 16 | model.safetensors, 17 | ] 18 | output_dir: ${output_dir} 19 | model_type: QWEN2 20 | 21 | # Tokenizer 22 | tokenizer: 23 | _component_: torchtune.models.qwen3.qwen3_tokenizer 24 | path: /tmp/Qwen3-0.6B/vocab.json 25 | merges_file: /tmp/Qwen3-0.6B/merges.txt 26 | max_seq_len: null 27 | 28 | # Environment 29 | device: cuda 30 | dtype: bf16 31 | seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed 32 | 33 | # EleutherAI specific eval args 34 | tasks: ["truthfulqa_mc2"] 35 | limit: null 36 | max_seq_length: 4096 37 | batch_size: 8 38 | enable_kv_cache: True 39 | 40 | # Quantization specific args 41 | quantizer: null 42 | -------------------------------------------------------------------------------- /torchtune/torchtune/dev/rl/utils/dist.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | 10 | def stateless_init_process_group( 11 | master_address: str, 12 | master_port: int, 13 | rank: int, 14 | world_size: int, 15 | device: torch.device, 16 | ): 17 | """ 18 | vLLM provides `StatelessProcessGroup` to create a process group 19 | without considering the global process group in torch.distributed. 20 | It is recommended to create `StatelessProcessGroup`, and then initialize 21 | the data-plane communication (NCCL) between external (train processes) 22 | and vLLM workers. 23 | """ 24 | from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator 25 | from vllm.distributed.utils import StatelessProcessGroup 26 | 27 | pg = StatelessProcessGroup.create( 28 | host=master_address, port=master_port, rank=rank, world_size=world_size 29 | ) 30 | 31 | pynccl = PyNcclCommunicator(pg, device=device) 32 | return pynccl 33 | -------------------------------------------------------------------------------- /torchtune/torchtune/modules/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # flake8: noqa: F401 8 | 9 | # NOTE: This file is maintained for backward compatibility purposes. 10 | # The imports below point to the new location in `torchtune.modules.transforms.tokenizers`. 11 | # The import paths will be removed in v0.7. Please update your code to use the new path 12 | # (torchtune.modules.transforms.tokenizers) to avoid breaking changes in future releases. 13 | 14 | 15 | import warnings 16 | 17 | from torchtune.modules.transforms.tokenizers import ( 18 | BaseTokenizer, 19 | ModelTokenizer, 20 | parse_hf_tokenizer_json, 21 | SentencePieceBaseTokenizer, 22 | TikTokenBaseTokenizer, 23 | tokenize_messages_no_special_tokens, 24 | ) 25 | 26 | warnings.warn( 27 | "The import path 'torchtune.modules.tokenizers' is deprecated and will be removed in v0.7. " 28 | "Please update your imports to 'torchtune.modules.transforms.tokenizers'.", 29 | DeprecationWarning, 30 | stacklevel=2, 31 | ) 32 | -------------------------------------------------------------------------------- /torchtune/recipes/configs/qwen2_5/evaluation.yaml: -------------------------------------------------------------------------------- 1 | # Config for EleutherEvalRecipe in eleuther_eval.py 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run eleuther_eval --config qwen2_5/evaluation 5 | 6 | output_dir: ./ # Not needed 7 | 8 | # Model Arguments 9 | model: 10 | _component_: torchtune.models.qwen2_5.qwen2_5_0_5b 11 | 12 | checkpointer: 13 | _component_: torchtune.training.FullModelHFCheckpointer 14 | checkpoint_dir: /tmp/Qwen2.5-0.5B-Instruct 15 | checkpoint_files: [ 16 | model.safetensors, 17 | ] 18 | output_dir: ${output_dir} 19 | model_type: QWEN2 20 | 21 | # Tokenizer 22 | tokenizer: 23 | _component_: torchtune.models.qwen2_5.qwen2_5_tokenizer 24 | path: /tmp/Qwen2.5-0.5B-Instruct/vocab.json 25 | merges_file: /tmp/Qwen2.5-0.5B-Instruct/merges.txt 26 | max_seq_len: null 27 | 28 | # Environment 29 | device: cuda 30 | dtype: bf16 31 | seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed 32 | 33 | # EleutherAI specific eval args 34 | tasks: ["truthfulqa_mc2"] 35 | limit: null 36 | max_seq_length: 4096 37 | batch_size: 8 38 | enable_kv_cache: True 39 | 40 | # Quantization specific args 41 | quantizer: null 42 | -------------------------------------------------------------------------------- /mirotrain/modules/moe/expert_parallel.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2025 MiromindAI 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from typing import Optional 6 | 7 | import torch.distributed as dist 8 | 9 | _EXPERT_PARALLEL_GROUP = None 10 | 11 | 12 | def set_expert_parallel_group(group: dist.ProcessGroup): 13 | """ 14 | Set expert parallel process group. 15 | """ 16 | global _EXPERT_PARALLEL_GROUP 17 | _EXPERT_PARALLEL_GROUP = group 18 | 19 | 20 | def get_expert_parallel_group() -> Optional[dist.ProcessGroup]: 21 | """ 22 | Get expert parallel process group. 23 | """ 24 | global _EXPERT_PARALLEL_GROUP 25 | return _EXPERT_PARALLEL_GROUP 26 | 27 | 28 | def get_expert_parallel_world_size(group: dist.ProcessGroup = None) -> int: 29 | """ 30 | Get expert parallel world size. 31 | """ 32 | group = get_expert_parallel_group() if group is None else group 33 | return dist.get_world_size(group) if group else 1 34 | 35 | 36 | def get_expert_parallel_rank(group: dist.ProcessGroup = None) -> int: 37 | """ 38 | Get expert parallel rank. 39 | """ 40 | group = get_expert_parallel_group() if group is None else group 41 | return dist.get_rank(group) if group else 0 42 | -------------------------------------------------------------------------------- /torchtune/torchtune/modules/layer_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from typing import Any 9 | 10 | import torch 11 | from torch import nn 12 | 13 | 14 | class Fp32LayerNorm(nn.LayerNorm): 15 | """ 16 | Wrapper around :class:`~torch.nn.LayerNorm` to support mixed-precision training. 17 | """ 18 | 19 | def __init__(self, *args: Any, **kwargs: Any) -> None: 20 | super().__init__(*args, **kwargs) 21 | 22 | def forward(self, x: torch.Tensor) -> torch.Tensor: 23 | """ 24 | Args: 25 | x (torch.Tensor): Input tensor. 26 | 27 | Returns: 28 | torch.Tensor: The normalized output tensor having the same shape as ``x``. 29 | """ 30 | output = nn.functional.layer_norm( 31 | x.float(), 32 | self.normalized_shape, 33 | self.weight.float() if self.weight is not None else None, 34 | self.bias.float() if self.bias is not None else None, 35 | self.eps, 36 | ) 37 | return output.type_as(x) 38 | -------------------------------------------------------------------------------- /torchtune/torchtune/rlhf/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from ._types import ChosenRejectedOutputs, PPOStats, Trajectory 9 | 10 | from .rewards import ( 11 | estimate_advantages, 12 | get_reward_penalty_mask, 13 | get_rewards_ppo, 14 | masked_mean, 15 | masked_sum, 16 | masked_var, 17 | whiten, 18 | ) 19 | from .sequence_processing import ( 20 | batched_logits_to_logprobs, 21 | get_batch_log_probs, 22 | logits_to_logprobs, 23 | truncate_sequence_at_first_stop_token, 24 | truncate_sequence_for_logprobs, 25 | ) 26 | 27 | __all__ = [ 28 | "truncate_sequence_at_first_stop_token", 29 | "logits_to_logprobs", 30 | "batched_logits_to_logprobs", 31 | "truncate_sequence_for_logprobs", 32 | "get_reward_penalty_mask", 33 | "estimate_advantages", 34 | "get_rewards_ppo", 35 | "whiten", 36 | "masked_mean", 37 | "masked_sum", 38 | "masked_var", 39 | "PPOStats", 40 | "get_batch_log_probs", 41 | "Trajectory", 42 | "ChosenRejectedOutputs", 43 | ] 44 | -------------------------------------------------------------------------------- /torchtune/recipes/configs/mistral/evaluation.yaml: -------------------------------------------------------------------------------- 1 | # Config for EleutherEvalRecipe in eleuther_eval.py 2 | # 3 | # To launch, run the following command: 4 | # tune run eleuther_eval --config mistral/evaluation 5 | 6 | output_dir: ./ # Not needed 7 | 8 | # Model Arguments 9 | model: 10 | _component_: torchtune.models.mistral.mistral_7b 11 | 12 | # Checkpointer 13 | checkpointer: 14 | _component_: torchtune.training.FullModelHFCheckpointer 15 | checkpoint_dir: /tmp/Mistral-7B-v0.1/ 16 | checkpoint_files: [ 17 | pytorch_model-00001-of-00002.bin, 18 | pytorch_model-00002-of-00002.bin 19 | ] 20 | output_dir: ${output_dir} 21 | model_type: MISTRAL 22 | resume_from_checkpoint: False 23 | 24 | # Tokenizer 25 | tokenizer: 26 | _component_: torchtune.models.mistral.mistral_tokenizer 27 | path: /tmp/Mistral-7B-v0.1/tokenizer.model 28 | max_seq_len: null 29 | 30 | # Environment 31 | device: cuda 32 | dtype: bf16 33 | seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed 34 | 35 | # EleutherAI specific eval args 36 | tasks: ["truthfulqa_mc2"] 37 | limit: null 38 | max_seq_length: 4096 39 | batch_size: 8 40 | enable_kv_cache: True 41 | 42 | # Quantization specific args 43 | quantizer: null 44 | -------------------------------------------------------------------------------- /torchtune/recipes/configs/eleuther_evaluation.yaml: -------------------------------------------------------------------------------- 1 | # Config for EleutherEvalRecipe in eleuther_eval.py 2 | # 3 | # To launch, run the following command from root torchtune directory: 4 | # tune run eleuther_eval --config eleuther_evaluation tasks=["truthfulqa_mc2","hellaswag"] 5 | 6 | output_dir: ./ # Not needed 7 | 8 | # Model Arguments 9 | model: 10 | _component_: torchtune.models.llama2.llama2_7b 11 | 12 | checkpointer: 13 | _component_: torchtune.training.FullModelHFCheckpointer 14 | checkpoint_dir: /tmp/Llama-2-7b-hf 15 | checkpoint_files: [ 16 | pytorch_model-00001-of-00002.bin, 17 | pytorch_model-00002-of-00002.bin, 18 | ] 19 | output_dir: ${output_dir} 20 | model_type: LLAMA2 21 | 22 | # Tokenizer 23 | tokenizer: 24 | _component_: torchtune.models.llama2.llama2_tokenizer 25 | path: /tmp/Llama-2-7b-hf/tokenizer.model 26 | max_seq_len: null 27 | 28 | # Environment 29 | device: cuda 30 | dtype: bf16 31 | seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed 32 | 33 | # EleutherAI specific eval args 34 | tasks: ["truthfulqa_mc2"] 35 | limit: null 36 | max_seq_length: 4096 37 | batch_size: 8 38 | enable_kv_cache: True 39 | 40 | # Quantization specific args 41 | quantizer: null 42 | -------------------------------------------------------------------------------- /torchtune/recipes/configs/phi3/evaluation.yaml: -------------------------------------------------------------------------------- 1 | # Config for EleutherEvalRecipe in eleuther_eval.py 2 | # 3 | # To launch, run the following command: 4 | # tune run eleuther_eval --config phi3/evaluation 5 | 6 | output_dir: ./ # Not needed 7 | 8 | # Model Arguments 9 | model: 10 | _component_: torchtune.models.phi3.phi3_mini 11 | 12 | # Checkpointer 13 | checkpointer: 14 | _component_: torchtune.training.FullModelHFCheckpointer 15 | checkpoint_dir: /tmp/phi-3 16 | checkpoint_files: [ 17 | model-00001-of-00002.safetensors, 18 | model-00002-of-00002.safetensors 19 | ] 20 | recipe_checkpoint: null 21 | output_dir: ${output_dir} 22 | model_type: PHI3_MINI 23 | resume_from_checkpoint: False 24 | 25 | # Tokenizer 26 | tokenizer: 27 | _component_: torchtune.models.phi3.phi3_mini_tokenizer 28 | path: /tmp/phi-3/tokenizer.model 29 | max_seq_len: null 30 | 31 | # Environment 32 | device: cuda 33 | dtype: bf16 34 | seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed 35 | 36 | # EleutherAI specific eval args 37 | tasks: ["truthfulqa_mc2"] 38 | limit: null 39 | max_seq_length: 4096 40 | batch_size: 8 41 | enable_kv_cache: True 42 | 43 | # Quantization specific args 44 | quantizer: null 45 | -------------------------------------------------------------------------------- /torchtune/recipes/configs/llama3_2/evaluation.yaml: -------------------------------------------------------------------------------- 1 | # Config for EleutherEvalRecipe in eleuther_eval.py 2 | # 3 | # To launch, run the following command: 4 | # tune run eleuther_eval --config llama3_2/evaluation 5 | 6 | output_dir: ./ # Not needed 7 | 8 | # Model Arguments 9 | model: 10 | _component_: torchtune.models.llama3_2.llama3_2_3b 11 | 12 | # Checkpointer 13 | checkpointer: 14 | _component_: torchtune.training.FullModelHFCheckpointer 15 | checkpoint_dir: /tmp/Llama-3.2-3B-Instruct 16 | checkpoint_files: [ 17 | model-00001-of-00002.safetensors, 18 | model-00002-of-00002.safetensors, 19 | ] 20 | recipe_checkpoint: null 21 | output_dir: ${output_dir} 22 | model_type: LLAMA3_2 23 | resume_from_checkpoint: False 24 | 25 | # Tokenizer 26 | tokenizer: 27 | _component_: torchtune.models.llama3.llama3_tokenizer 28 | path: /tmp/Llama-3.2-3B-Instruct/original/tokenizer.model 29 | max_seq_len: null 30 | 31 | # Environment 32 | device: cpu 33 | dtype: bf16 34 | seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed 35 | 36 | # EleutherAI specific eval args 37 | tasks: ["truthfulqa_mc2"] 38 | limit: null 39 | max_seq_length: 4096 40 | batch_size: 8 41 | enable_kv_cache: True 42 | 43 | # Quantization specific args 44 | quantizer: null 45 | -------------------------------------------------------------------------------- /torchtune/recipes/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # This file mainly exists because we want to ensure that `recipes` aren't 8 | # importable *from the tests*. 9 | # We're using the `prepend` pytest import mode which adds the root dir (i.e. the 10 | # parent of torchtune/, tests/, recipes/) to the pythonpath during pytest 11 | # sessions 12 | # (https://docs.pytest.org/en/7.1.x/explanation/pythonpath.html#import-modes). 13 | # This has the positive effect that the `tests` folder becomes importable when 14 | # testing (we need that, considering how tests are currently set up) but ALSO 15 | # has the negative effect of making the `recipes/` importable when testing. 16 | # Since we don't want the tests to to incorrectly assume that recipes are 17 | # importable, we have to explicitly raise an error here. 18 | 19 | raise ModuleNotFoundError( 20 | "The torchtune recipes directory isn't a package and you should not import anything from here. " 21 | "Refer to our docs for detailed instructions on how to use recipes: " 22 | "https://pytorch.org/torchtune/main/deep_dives/recipe_deepdive.html" 23 | ) 24 | -------------------------------------------------------------------------------- /torchtune/recipes/configs/code_llama2/evaluation.yaml: -------------------------------------------------------------------------------- 1 | # Config for EleutherEvalRecipe in eleuther_eval.py 2 | # 3 | # To launch, run the following command: 4 | # tune run eleuther_eval --config code_llama2/evaluation 5 | 6 | output_dir: ./ # Not needed 7 | 8 | # Model arguments 9 | model: 10 | _component_: torchtune.models.code_llama2.code_llama2_7b 11 | 12 | # Tokenizer 13 | tokenizer: 14 | _component_: torchtune.models.llama2.llama2_tokenizer 15 | path: /tmp/CodeLlama-7b-hf/tokenizer.model 16 | max_seq_len: null 17 | 18 | # Checkpointer 19 | checkpointer: 20 | _component_: torchtune.training.FullModelHFCheckpointer 21 | checkpoint_dir: /tmp/CodeLlama-7b-hf 22 | checkpoint_files: [ 23 | pytorch_model-00001-of-00003.bin, 24 | pytorch_model-00002-of-00003.bin, 25 | pytorch_model-00003-of-00003.bin 26 | ] 27 | recipe_checkpoint: null 28 | output_dir: ${output_dir} 29 | model_type: LLAMA2 30 | resume_from_checkpoint: False 31 | 32 | # Environment 33 | device: cpu 34 | dtype: bf16 35 | seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed 36 | 37 | # EleutherAI specific eval args 38 | tasks: ["truthfulqa_mc2"] 39 | limit: null 40 | max_seq_length: 4096 41 | batch_size: 8 42 | enable_kv_cache: True 43 | 44 | # Quantization specific args 45 | quantizer: null 46 | -------------------------------------------------------------------------------- /torchtune/recipes/configs/qwen2/evaluation.yaml: -------------------------------------------------------------------------------- 1 | # Config for EleutherEvalRecipe in eleuther_eval.py 2 | # 3 | # To launch, run the following command: 4 | # tune run eleuther_eval --config qwen2/evaluation 5 | 6 | output_dir: ./ # Not needed 7 | 8 | # Model Arguments 9 | model: 10 | _component_: torchtune.models.qwen2.qwen2_7b 11 | 12 | # Checkpointer 13 | checkpointer: 14 | _component_: torchtune.training.FullModelHFCheckpointer 15 | checkpoint_dir: /tmp/Qwen2-7B-Instruct 16 | checkpoint_files: [ 17 | model-00001-of-00004.safetensors, 18 | model-00002-of-00004.safetensors, 19 | model-00003-of-00004.safetensors, 20 | model-00004-of-00004.safetensors 21 | ] 22 | output_dir: ${output_dir} 23 | model_type: QWEN2 24 | 25 | # Tokenizer 26 | tokenizer: 27 | _component_: torchtune.models.qwen2.qwen2_tokenizer 28 | path: /tmp/Qwen2-7B-Instruct/vocab.json 29 | merges_file: /tmp/Qwen2-7B-Instruct/merges.txt 30 | max_seq_len: null 31 | 32 | # Environment 33 | device: cuda 34 | dtype: bf16 35 | seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed 36 | 37 | # EleutherAI specific eval args 38 | tasks: ["truthfulqa_mc2"] 39 | limit: null 40 | max_seq_length: 4096 41 | batch_size: 8 42 | enable_kv_cache: True 43 | 44 | # Quantization specific args 45 | quantizer: null 46 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/modules/low_precision/test_nf4_dispatch_registration.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torchao.dtypes import to_nf4 9 | 10 | 11 | class TestNF4DispatchRegistration: 12 | """ 13 | Class for testing NF4Tensor dispatch ops. 14 | """ 15 | 16 | def test_inplace_copy_copies_expected_attributes(self): 17 | """ 18 | This test ensures that we're copying over all relevant attributes when implementing 19 | torch.ops.aten.copy_.default. If this test fails, we would need to update our implementation 20 | in _register_nf4_dispatch_ops to cover the newly added attributes. 21 | """ 22 | expected_inplace_copy_attrs = [ 23 | "block_size", 24 | "n_blocks", 25 | "scaler_block_size", 26 | "quantized_scalers", 27 | "quantization_factor", 28 | "scaler_mean", 29 | "quantized_data", 30 | "nf4", 31 | ] 32 | 33 | z = to_nf4(torch.rand(512, 512, dtype=torch.bfloat16)) 34 | inplace_copy_attr_set = set(z.__dict__.keys()) 35 | assert set(expected_inplace_copy_attrs) == inplace_copy_attr_set 36 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/llama3_2_vision/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._component_builders import ( # noqa 8 | llama3_2_vision_decoder, 9 | llama3_2_vision_encoder, 10 | lora_llama3_2_vision_decoder, 11 | lora_llama3_2_vision_encoder, 12 | ) 13 | from ._encoder import Llama3VisionEncoder, Llama3VisionProjectionHead 14 | 15 | from ._model_builders import ( # noqa 16 | llama3_2_vision_11b, 17 | llama3_2_vision_90b, 18 | llama3_2_vision_transform, 19 | lora_llama3_2_vision_11b, 20 | lora_llama3_2_vision_90b, 21 | qlora_llama3_2_vision_11b, 22 | qlora_llama3_2_vision_90b, 23 | ) 24 | from ._transform import Llama3VisionTransform 25 | 26 | __all__ = [ 27 | "llama3_2_vision_11b", 28 | "llama3_2_vision_transform", 29 | "lora_llama3_2_vision_11b", 30 | "qlora_llama3_2_vision_11b", 31 | "llama3_2_vision_90b", 32 | "lora_llama3_2_vision_90b", 33 | "qlora_llama3_2_vision_90b", 34 | "llama3_2_vision_decoder", 35 | "llama3_2_vision_encoder", 36 | "lora_llama3_2_vision_decoder", 37 | "lora_llama3_2_vision_encoder", 38 | "Llama3VisionEncoder", 39 | "Llama3VisionProjectionHead", 40 | "Llama3VisionTransform", 41 | ] 42 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/t5/test_t5_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import pytest 7 | 8 | from tests.common import ASSETS 9 | from torchtune.models.t5._model_builders import t5_tokenizer 10 | 11 | 12 | class TestT5Tokenizer: 13 | @pytest.fixture 14 | def tokenizer(self): 15 | return t5_tokenizer(str(ASSETS / "sentencepiece.model")) 16 | 17 | def test_encoding(self, tokenizer): 18 | texts = [ 19 | "a cow jumping over the moon", 20 | "a helpful AI assistant", 21 | ] 22 | correct_tokens = [ 23 | [3, 9, 9321, 15539, 147, 8, 8114, 1], 24 | [3, 9, 2690, 7833, 6165, 1], 25 | ] 26 | for text, correct in zip(texts, correct_tokens): 27 | tokens = tokenizer.encode(text) 28 | print(tokens) 29 | assert tokens == correct 30 | 31 | def test_decoding(self, tokenizer): 32 | text = "this is torchtune" 33 | assert text == tokenizer.decode(tokenizer.encode(text)) 34 | 35 | def test_call(self, tokenizer): 36 | sample = {"text": "hello world"} 37 | sample = tokenizer(sample) 38 | assert "text" not in sample 39 | assert "tokens" in sample 40 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/llama3_3/_model_builders.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from torchtune.models.llama3_1._model_builders import ( 7 | llama3_1_70b, 8 | lora_llama3_1_70b, 9 | qlora_llama3_1_70b, 10 | ) 11 | 12 | """ 13 | Model builders build specific instantiations using component builders. The Llama3.3 model 14 | builders all call the Llama3.1 models as they're identical models apart from the checkpoints. 15 | """ 16 | 17 | llama3_3_70b = llama3_1_70b 18 | 19 | llama3_3_70b.__doc__ = """ 20 | Builder for creating a Llama3.3 model initialized w/ the default 70B parameter values. 21 | Please see `llama3_1_70b` for full API arguments. 22 | """ 23 | 24 | lora_llama3_3_70b = lora_llama3_1_70b 25 | 26 | lora_llama3_3_70b.__doc__ = """ 27 | Builder for creating a Llama3.3 70B model with LoRA enabled. 28 | Please see `lora_llama3_1_70b` for full API arguments. 29 | """ 30 | 31 | qlora_llama3_3_70b = qlora_llama3_1_70b 32 | 33 | qlora_llama3_1_70b.__doc__ = """ 34 | Builder for creating a Llama3.3 70B model with QLoRA enabled. Base model weights in linear layers 35 | that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. 36 | Please see `lora_llama3_1_70b` for full API arguments. 37 | """ 38 | -------------------------------------------------------------------------------- /torchtune/recipes/configs/phi4/evaluation.yaml: -------------------------------------------------------------------------------- 1 | # Config for EleutherEvalRecipe in eleuther_eval.py 2 | # 3 | # To launch, run the following command: 4 | # tune run eleuther_eval --config phi4/evaluation 5 | 6 | output_dir: ./ # Not needed 7 | 8 | # Model Arguments 9 | model: 10 | _component_: torchtune.models.phi4.phi4_14b 11 | 12 | # Checkpointer 13 | checkpointer: 14 | _component_: torchtune.training.FullModelHFCheckpointer 15 | checkpoint_dir: /tmp/phi-4 16 | checkpoint_files: [ 17 | model-00001-of-00006.safetensors, 18 | model-00002-of-00006.safetensors, 19 | model-00003-of-00006.safetensors, 20 | model-00004-of-00006.safetensors, 21 | model-00005-of-00006.safetensors, 22 | model-00006-of-00006.safetensors, 23 | ] 24 | recipe_checkpoint: null 25 | output_dir: ${output_dir} 26 | model_type: PHI4 27 | resume_from_checkpoint: False 28 | 29 | # Tokenizer 30 | tokenizer: 31 | _component_: torchtune.models.phi4.phi4_tokenizer 32 | vocab_path: /tmp/phi-4/vocab.json 33 | merges_path: /tmp/phi-4/merges.txt 34 | max_seq_len: null 35 | 36 | # Environment 37 | device: cuda 38 | dtype: bf16 39 | seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed 40 | 41 | # EleutherAI specific eval args 42 | tasks: ["truthfulqa_mc2"] 43 | limit: null 44 | max_seq_length: 4096 45 | batch_size: 8 46 | enable_kv_cache: True 47 | 48 | # Quantization specific args 49 | quantizer: null 50 | -------------------------------------------------------------------------------- /torchtune/torchtune/utils/_version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from datetime import datetime 8 | 9 | import torch 10 | 11 | 12 | def torch_version_ge(version: str) -> bool: 13 | """ 14 | Check if torch version is greater than or equal to the given version. 15 | 16 | Args: 17 | version (str): The torch version to compare against 18 | 19 | Returns: 20 | bool: True if torch version is greater than or equal to the given version. 21 | 22 | Example: 23 | >>> print(torch.__version__) 24 | 2.4.0 25 | >>> torch_version_ge("2.0") 26 | True 27 | """ 28 | return version in torch.__version__ or torch.__version__ >= version 29 | 30 | 31 | def _is_fbcode(): 32 | return not hasattr(torch.version, "git_version") 33 | 34 | 35 | def _nightly_version_ge(ao_version_str: str, date: str) -> bool: 36 | """ 37 | Compare a torchao nightly version to a date of the form 38 | %Y-%m-%d. 39 | 40 | Returns True if the nightly version is greater than or equal to 41 | the date, False otherwise 42 | """ 43 | ao_datetime = datetime.strptime( 44 | ao_version_str.split("+")[0].split("dev")[1], "%Y%m%d" 45 | ) 46 | return ao_datetime >= datetime.strptime(date, "%Y-%m-%d") 47 | -------------------------------------------------------------------------------- /torchtune/recipes/configs/llama2/generation_v2.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate_V2.py to generate output from an LLM 2 | # 3 | # This config assumes that you've run the following command before launching: 4 | # tune download meta-llama/Llama-2-7b-chat-hf --output-dir /tmp/Llama-2-7b-chat-hf --ignore-patterns "*.bin" --hf-token 5 | # 6 | # To launch, run the following command: 7 | # tune run dev/generate_v2 --config llama2/generation_v2 8 | 9 | output_dir: ./ # Not needed 10 | 11 | # Model arguments 12 | model: 13 | _component_: torchtune.models.llama2.llama2_7b 14 | 15 | # Transform arguments 16 | tokenizer: 17 | _component_: torchtune.models.llama2.llama2_tokenizer 18 | path: /tmp/Llama-2-7b-chat-hf/tokenizer.model 19 | max_seq_len: 2048 20 | 21 | # Checkpointer 22 | checkpointer: 23 | _component_: torchtune.training.FullModelHFCheckpointer 24 | checkpoint_dir: /tmp/Llama-2-7b-chat-hf 25 | checkpoint_files: [ 26 | model-00001-of-00002.safetensors, 27 | model-00002-of-00002.safetensors 28 | ] 29 | output_dir: ${output_dir} 30 | model_type: LLAMA2 31 | 32 | # Device 33 | device: cuda 34 | dtype: bf16 35 | seed: 1234 36 | log_level: INFO # DEBUG, WARN, etc. 37 | 38 | # Generation arguments 39 | prompt: 40 | system: You are a helpful and creative AI assistant. 41 | user: What is the capital of France? 42 | max_new_tokens: 200 43 | temperature: 0.6 # 0.8 and 0.6 are popular values to try 44 | top_k: 300 45 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/llama3/test_llama3.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | from tests.test_utils import fixed_init_model 10 | from torchtune.models.llama3 import llama3 11 | from torchtune.training.seed import set_seed 12 | 13 | EMBED_DIM = 128 14 | NUM_LAYERS = 4 15 | NUM_HEADS = 16 16 | NUM_KV_HEADS = 8 17 | VOCAB_SIZE = 32000 18 | MAX_SEQ_LEN = 2048 19 | BSZ = 2 20 | SEQ_LEN = 100 21 | 22 | 23 | @pytest.fixture(autouse=True) 24 | def random(): 25 | set_seed(16) 26 | 27 | 28 | class TestLlama3: 29 | @pytest.fixture 30 | def inputs(self): 31 | return torch.randint(0, VOCAB_SIZE, (BSZ, SEQ_LEN)) 32 | 33 | def test_forward(self, inputs): 34 | model = llama3( 35 | vocab_size=VOCAB_SIZE, 36 | num_layers=NUM_LAYERS, 37 | num_heads=NUM_HEADS, 38 | num_kv_heads=NUM_KV_HEADS, 39 | embed_dim=EMBED_DIM, 40 | max_seq_len=MAX_SEQ_LEN, 41 | ) 42 | fixed_init_model(model, min_val=-0.25, max_val=0.5) 43 | actual = model(inputs) 44 | expected = torch.tensor(3.9763) 45 | assert actual.shape == (BSZ, SEQ_LEN, VOCAB_SIZE) 46 | torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-4) 47 | -------------------------------------------------------------------------------- /mirotrain/monkey/_errors.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2025 MiromindAI 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | # -*- coding: utf-8 -*- 6 | 7 | # This file is adapted from 8 | # https://github.com/gevent/gevent/blob/73025a8837b3bff19c106e877fa2374889c59dd3/src/gevent/monkey/_errors.py 9 | 10 | """ 11 | Exception classes and errors that this package may raise. 12 | """ 13 | 14 | 15 | class _BadModule(ImportError): 16 | """ 17 | Raised when a module is not importable. 18 | """ 19 | 20 | def __init__(self, name): 21 | ImportError.__init__(self, "Module %r is not importable" % (name,)) 22 | 23 | 24 | class _BadTargets(AttributeError): 25 | """ 26 | Raised when ``__targets__`` is incorrect. 27 | """ 28 | 29 | def __init__(self, module, target): 30 | AttributeError.__init__( 31 | self, 32 | "Module %r has a bad or missing value %r for __targets__" 33 | % ( 34 | target, 35 | module, 36 | ), 37 | ) 38 | 39 | 40 | class _BadImplements(AttributeError): 41 | """ 42 | Raised when ``__implements__`` is incorrect. 43 | """ 44 | 45 | def __init__(self, module, implements): 46 | AttributeError.__init__( 47 | self, 48 | "Module %r has a bad or missing value %r for __implements__" 49 | % ( 50 | implements, 51 | module, 52 | ), 53 | ) 54 | -------------------------------------------------------------------------------- /torchtune/torchtune/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | __version__ = "" 8 | 9 | 10 | # Check at the top-level that torchao is installed. 11 | # This is better than doing it at every import site. 12 | # We have to do this because it is not currently possible to 13 | # properly support both nightly and stable installs of PyTorch + torchao 14 | # in pyproject.toml. 15 | try: 16 | import torchao # noqa 17 | except ImportError as e: 18 | raise ImportError( 19 | """ 20 | torchao not installed. 21 | Please follow the instructions at https://pytorch.org/torchtune/main/install.html#pre-requisites 22 | to install torchao. 23 | """ 24 | ) from e 25 | 26 | # Enables faster downloading. For more info: https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads 27 | # To disable, run `HF_HUB_ENABLE_HF_TRANSFER=0 tune download ` 28 | try: 29 | import os 30 | 31 | import hf_transfer # noqa 32 | 33 | if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") is None: 34 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" 35 | except ImportError: 36 | pass 37 | 38 | from torchtune import datasets, generation, models, modules, utils 39 | 40 | __all__ = [datasets, models, modules, utils, generation] 41 | -------------------------------------------------------------------------------- /torchtune/recipes/configs/generation.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in generate.py to generate output 2 | # from Llama2 7B model 3 | # 4 | # This config assumes that you've run the following command before launching 5 | # this run: 6 | # tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --ignore-patterns "*.safetensors" --hf-token 7 | # 8 | # To launch, run the following command from root torchtune directory: 9 | # tune run generate --config generation 10 | 11 | output_dir: ./ # Not needed 12 | 13 | # Model arguments 14 | model: 15 | _component_: torchtune.models.llama2.llama2_7b 16 | 17 | checkpointer: 18 | _component_: torchtune.training.FullModelHFCheckpointer 19 | checkpoint_dir: /tmp/Llama-2-7b-hf/ 20 | checkpoint_files: [ 21 | pytorch_model-00001-of-00002.bin, 22 | pytorch_model-00002-of-00002.bin, 23 | ] 24 | output_dir: ${output_dir} 25 | model_type: LLAMA2 26 | 27 | device: cuda 28 | dtype: bf16 29 | 30 | seed: 1234 31 | 32 | # Tokenizer arguments 33 | tokenizer: 34 | _component_: torchtune.models.llama2.llama2_tokenizer 35 | path: /tmp/Llama-2-7b-hf/tokenizer.model 36 | max_seq_len: null 37 | prompt_template: null 38 | 39 | # Generation arguments; defaults taken from gpt-fast 40 | prompt: 41 | system: null 42 | user: "Tell me a joke." 43 | max_new_tokens: 300 44 | temperature: 0.6 # 0.8 and 0.6 are popular values to try 45 | top_k: 300 46 | 47 | enable_kv_cache: True 48 | 49 | quantizer: null 50 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/_cli/test_validate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import runpy 8 | import sys 9 | 10 | import pytest 11 | from tests.common import ASSETS, TUNE_PATH 12 | 13 | 14 | class TestTuneValidateCommand: 15 | """This class tests the `tune validate` command.""" 16 | 17 | VALID_CONFIG_PATH = ASSETS / "valid_dummy_config.yaml" 18 | INVALID_CONFIG_PATH = ASSETS / "invalid_dummy_config.yaml" 19 | 20 | def test_validate_good_config(self, capsys, monkeypatch): 21 | args = f"tune validate {self.VALID_CONFIG_PATH}".split() 22 | 23 | monkeypatch.setattr(sys, "argv", args) 24 | runpy.run_path(TUNE_PATH, run_name="__main__") 25 | 26 | captured = capsys.readouterr() 27 | out = captured.out.rstrip("\n") 28 | 29 | assert out == "Config is well-formed!" 30 | 31 | def test_validate_bad_config(self, monkeypatch, capsys): 32 | args = f"tune validate {self.INVALID_CONFIG_PATH}".split() 33 | 34 | monkeypatch.setattr(sys, "argv", args) 35 | with pytest.raises(SystemExit): 36 | runpy.run_path(TUNE_PATH, run_name="__main__") 37 | 38 | captured = capsys.readouterr() 39 | err = captured.err.rstrip("\n") 40 | 41 | assert "got an unexpected keyword argument 'dummy'" in err 42 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/llama2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._component_builders import ( 8 | llama2, 9 | llama2_classifier, 10 | lora_llama2, 11 | lora_llama2_classifier, 12 | ) 13 | 14 | from ._model_builders import ( # noqa 15 | llama2_13b, 16 | llama2_70b, 17 | llama2_7b, 18 | llama2_reward_7b, 19 | llama2_tokenizer, 20 | lora_llama2_13b, 21 | lora_llama2_70b, 22 | lora_llama2_7b, 23 | lora_llama2_reward_7b, 24 | qlora_llama2_13b, 25 | qlora_llama2_70b, 26 | qlora_llama2_7b, 27 | qlora_llama2_reward_7b, 28 | ) 29 | from ._prompt_template import Llama2ChatTemplate 30 | from ._tokenizer import Llama2Tokenizer 31 | 32 | __all__ = [ 33 | "Llama2Tokenizer", 34 | "Llama2ChatTemplate", 35 | "llama2", 36 | "llama2_classifier", 37 | "lora_llama2_classifier", 38 | "llama2_reward_7b", 39 | "lora_llama2_reward_7b", 40 | "qlora_llama2_reward_7b", 41 | "lora_llama2", 42 | "llama2_13b", 43 | "llama2_70b", 44 | "llama2_7b", 45 | "llama2_tokenizer", 46 | "lora_llama2", 47 | "llama2_classifier", 48 | "lora_llama2_13b", 49 | "lora_llama2_70b", 50 | "lora_llama2_7b", 51 | "qlora_llama2_13b", 52 | "qlora_llama2_70b", 53 | "qlora_llama2_7b", 54 | ] 55 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/phi3/test_phi3.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | from tests.test_utils import fixed_init_model 10 | from torchtune.models.phi3 import phi3 11 | from torchtune.training.seed import set_seed 12 | 13 | EMBED_DIM = 128 14 | INTER_DIM = 256 15 | NUM_LAYERS = 4 16 | NUM_HEADS = 16 17 | NUM_KV_HEADS = 8 18 | VOCAB_SIZE = 32000 19 | MAX_SEQ_LEN = 2048 20 | BSZ = 2 21 | SEQ_LEN = 100 22 | 23 | 24 | @pytest.fixture(autouse=True) 25 | def random(): 26 | set_seed(16) 27 | 28 | 29 | class TestPhi3: 30 | @pytest.fixture 31 | def inputs(self): 32 | return torch.randint(0, VOCAB_SIZE, (BSZ, SEQ_LEN)) 33 | 34 | def test_forward(self, inputs): 35 | model = phi3( 36 | vocab_size=VOCAB_SIZE, 37 | num_layers=NUM_LAYERS, 38 | num_heads=NUM_HEADS, 39 | num_kv_heads=NUM_KV_HEADS, 40 | embed_dim=EMBED_DIM, 41 | intermediate_dim=INTER_DIM, 42 | max_seq_len=MAX_SEQ_LEN, 43 | ) 44 | fixed_init_model(model, min_val=-0.25, max_val=0.5) 45 | actual = model(inputs) 46 | expected = torch.tensor(3.9763) 47 | assert actual.shape == (BSZ, SEQ_LEN, VOCAB_SIZE) 48 | torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-4) 49 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/llama4/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._component_builders import ( 8 | llama4_decoder, 9 | llama4_vision_encoder, 10 | llama4_vision_projection_head, 11 | lora_llama4_decoder, 12 | lora_llama4_vision_encoder, 13 | lora_llama4_vision_projection_head, 14 | ) 15 | from ._encoder import Llama4VisionEncoder, Llama4VisionProjectionHead 16 | 17 | from ._model_builders import ( 18 | llama4_maverick_17b_128e, 19 | llama4_scout_17b_16e, 20 | llama4_transform, 21 | lora_llama4_scout_17b_16e, 22 | ) 23 | from ._parallelism import decoder_only_tp_plan 24 | from ._position_embeddings import Llama4ScaledRoPE 25 | from ._tokenizer import Llama4Tokenizer 26 | from ._transform import Llama4Transform 27 | 28 | __all__ = [ 29 | "llama4_vision_encoder", 30 | "decoder_only_tp_plan", 31 | "llama4_vision_projection_head", 32 | "Llama4VisionEncoder", 33 | "Llama4VisionProjectionHead", 34 | "llama4_decoder", 35 | "Llama4Tokenizer", 36 | "llama4_scout_17b_16e", 37 | "llama4_maverick_17b_128e", 38 | "lora_llama4_vision_encoder", 39 | "lora_llama4_vision_projection_head", 40 | "lora_llama4_decoder", 41 | "lora_llama4_scout_17b_16e", 42 | "Llama4Transform", 43 | "llama4_transform", 44 | "Llama4ScaledRoPE", 45 | ] 46 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/qwen2/test_qwen2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | from tests.test_utils import fixed_init_model 10 | from torchtune.models.qwen2 import qwen2 11 | from torchtune.training.seed import set_seed 12 | 13 | EMBED_DIM = 128 14 | INTER_DIM = 256 15 | NUM_LAYERS = 4 16 | NUM_HEADS = 16 17 | NUM_KV_HEADS = 8 18 | VOCAB_SIZE = 32000 19 | MAX_SEQ_LEN = 2048 20 | BSZ = 2 21 | SEQ_LEN = 100 22 | 23 | 24 | @pytest.fixture(autouse=True) 25 | def random(): 26 | set_seed(16) 27 | 28 | 29 | class TestQwen2: 30 | @pytest.fixture 31 | def inputs(self): 32 | return torch.randint(0, VOCAB_SIZE, (BSZ, SEQ_LEN)) 33 | 34 | def test_forward(self, inputs): 35 | model = qwen2( 36 | vocab_size=VOCAB_SIZE, 37 | num_layers=NUM_LAYERS, 38 | num_heads=NUM_HEADS, 39 | num_kv_heads=NUM_KV_HEADS, 40 | embed_dim=EMBED_DIM, 41 | intermediate_dim=INTER_DIM, 42 | max_seq_len=MAX_SEQ_LEN, 43 | ) 44 | fixed_init_model(model, min_val=-0.25, max_val=0.5) 45 | actual = model(inputs) 46 | expected = torch.tensor(3.9763) 47 | assert actual.shape == (BSZ, SEQ_LEN, VOCAB_SIZE) 48 | torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-4) 49 | -------------------------------------------------------------------------------- /torchtune/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright 2024 Meta 4 | 5 | Redistribution and use in source and binary forms, with or without modification, 6 | are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice,this list 9 | of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, this 12 | list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its contributors may 16 | be used to endorse or promote products derived from this software without specific 17 | prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY 20 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 21 | OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT 22 | SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 23 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 24 | TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR 25 | BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 27 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH 28 | DAMAGE. 29 | -------------------------------------------------------------------------------- /torchtune/torchtune/training/pooling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | 8 | 9 | def get_unmasked_sequence_lengths(mask: torch.Tensor) -> torch.Tensor: 10 | """ 11 | Returns the sequence lengths (0-indexed) for each batch element, excluding masked tokens. 12 | 13 | Args: 14 | mask (torch.Tensor): Boolean mask with shape [b x s], where True indicates a value to be masked out 15 | This is usually a mask for padding tokens, where True indicates a padding token. 16 | 17 | Returns: 18 | Tensor: Sequence indices logits with shape [b] 19 | 20 | Shape notation: 21 | - b = batch size 22 | - s = sequence length 23 | 24 | Example: 25 | >>> input_ids = torch.tensor([ 26 | ... [2, 4, 0, 0], 27 | ... [2, 4, 6, 0], 28 | ... [2, 4, 6, 9] 29 | ... ]) 30 | >>> mask = input_ids == 0 31 | >>> mask 32 | tensor([[False, False, True, True], 33 | [False, False, False, True], 34 | [False, False, False, False]]) 35 | >>> get_unmasked_sequence_lengths(mask) 36 | tensor([1, 2, 3]) 37 | 38 | """ 39 | # calculate per-batch-element sequence lengths by finding last valid tokens 40 | sequence_lengths = (~mask).cumsum(dim=-1).argmax(dim=-1).to(dtype=torch.long) 41 | 42 | return sequence_lengths.clip(0, mask.shape[1] - 1) 43 | -------------------------------------------------------------------------------- /mirotrain/training/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2025 MiromindAI 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import math 6 | 7 | import torch 8 | from torch.optim.lr_scheduler import LambdaLR 9 | 10 | 11 | def get_cosine_schedule_with_warmup( 12 | optimizer: torch.optim.Optimizer, 13 | num_training_steps: int, 14 | warmup_ratio: float = 0.1, 15 | num_cycles: float = 0.5, 16 | last_epoch: int = -1, 17 | ) -> LambdaLR: 18 | """ 19 | Create a learning rate schedule that linearly increases the learning rate from 20 | 0.0 to lr over `` warmup_ratio * num_training_steps``, then decreases to 0.0 on a cosine schedule over 21 | the remaining ``num_training_steps - warmup_ratio * num_training_steps`` (assuming ``num_cycles`` = 0.5). 22 | 23 | This is based on the Hugging Face implementation 24 | https://github.com/huggingface/transformers/blob/v4.23.1/src/transformers/optimization.py#L104. 25 | 26 | """ 27 | num_warmup_steps = int(warmup_ratio * num_training_steps) 28 | 29 | def lr_lambda(current_step: int) -> float: 30 | # linear warmup phase 31 | if current_step < num_warmup_steps: 32 | return current_step / max(1, num_warmup_steps) 33 | 34 | # cosine 35 | progress = (current_step - num_warmup_steps) / max( 36 | 1, num_training_steps - num_warmup_steps 37 | ) 38 | 39 | cosine_lr_multiple = 0.5 * ( 40 | 1.0 + math.cos(math.pi * num_cycles * 2.0 * progress) 41 | ) 42 | return max(0.0, cosine_lr_multiple) 43 | 44 | return LambdaLR(optimizer, lr_lambda, last_epoch) 45 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/llama4/test_llama4_transform.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | from tests.common import ASSETS 10 | from tests.test_utils import assert_expected 11 | from torchtune.models.llama4._transform import Llama4Transform 12 | from torchtune.training.seed import set_seed 13 | from torchvision.transforms.v2 import functional as F 14 | 15 | EMBED_DIM = 128 16 | BSZ = 2 17 | N_IMG = 1 18 | N_TILES = 4 19 | N_PATCHES = 17 # 16 + 1 for CLS token 20 | 21 | 22 | @pytest.fixture(autouse=True) 23 | def random(): 24 | set_seed(16) 25 | 26 | 27 | def random_image(): 28 | tensor = torch.randn(3, 4, 4) 29 | tensor = torch.clamp(tensor, 0, 1) 30 | pil_image = F.to_pil_image(tensor) 31 | return pil_image 32 | 33 | 34 | class TestImageThumbnailTransform: 35 | @pytest.fixture 36 | def transform(self): 37 | return Llama4Transform( 38 | path=str(ASSETS / "tiktoken_small_llama4.model"), 39 | tile_size=16, 40 | patch_size=4, 41 | max_num_tiles=4, 42 | image_mean=[0.5, 0.5, 0.5], 43 | image_std=[0.5, 0.5, 0.5], 44 | dtype=torch.float32, 45 | ) 46 | 47 | def test_call(self, transform): 48 | image = random_image() 49 | actual = transform.thumbnail_transform(image) 50 | assert actual.shape == (3, 16, 16) 51 | assert_expected(actual.sum(), torch.tensor(-123.3412)) 52 | -------------------------------------------------------------------------------- /torchtune/docs/source/_static/img/pytorch-logo-flame.svg: -------------------------------------------------------------------------------- 1 | 2 | image/svg+xml 34 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/clip/test_clip_text_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | 10 | from torchtune.models.clip._component_builders import clip_text_encoder 11 | from torchtune.training.seed import set_seed 12 | 13 | VOCAB_SIZE = 512 14 | MAX_SEQ_LEN = 77 15 | BSZ = 2 16 | EMBED_DIM = 4 17 | 18 | 19 | @pytest.fixture(autouse=True) 20 | def random(): 21 | set_seed(0) 22 | 23 | 24 | class TestClipTextEncoder: 25 | @pytest.fixture 26 | def model(self): 27 | model = clip_text_encoder( 28 | vocab_size=VOCAB_SIZE, 29 | max_seq_len=MAX_SEQ_LEN, 30 | embed_dim=EMBED_DIM, 31 | num_heads=2, 32 | num_layers=2, 33 | ) 34 | 35 | for param in model.parameters(): 36 | param.data.uniform_(0, 1) 37 | 38 | return model 39 | 40 | @pytest.fixture 41 | def inputs(self): 42 | return torch.randint(0, VOCAB_SIZE, (BSZ, MAX_SEQ_LEN)) 43 | 44 | def test_forward(self, model, inputs): 45 | actual = model(inputs) 46 | expected = torch.tensor( 47 | [[0.1915, 1.3982, 0.6298, -0.0966], [0.2276, 1.3785, 0.6309, -0.1066]] 48 | ) 49 | assert actual.shape == (BSZ, EMBED_DIM) 50 | torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) 51 | 52 | def test_backward(self, model, inputs): 53 | y = model(inputs) 54 | loss = y.mean() 55 | loss.backward() 56 | -------------------------------------------------------------------------------- /torchtune/torchtune/training/checkpointing/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from typing import Union 7 | 8 | from torchtune.training.checkpointing._checkpointer import ( 9 | DistributedCheckpointer, 10 | FullModelHFCheckpointer, 11 | FullModelMetaCheckpointer, 12 | FullModelTorchTuneCheckpointer, 13 | ) 14 | from torchtune.training.checkpointing._utils import ( 15 | ADAPTER_CONFIG, 16 | ADAPTER_KEY, 17 | DATALOADER_KEY, 18 | EPOCHS_KEY, 19 | FormattedCheckpointFiles, 20 | get_largest_iter_folder, 21 | MAX_STEPS_KEY, 22 | MODEL_KEY, 23 | ModelType, 24 | OPT_KEY, 25 | RNG_KEY, 26 | SEED_KEY, 27 | STEPS_KEY, 28 | TOTAL_EPOCHS_KEY, 29 | update_state_dict_for_classifier, 30 | ) 31 | 32 | Checkpointer = Union[ 33 | DistributedCheckpointer, 34 | FullModelHFCheckpointer, 35 | FullModelMetaCheckpointer, 36 | FullModelTorchTuneCheckpointer, 37 | ] 38 | 39 | __all__ = [ 40 | "FullModelHFCheckpointer", 41 | "FullModelMetaCheckpointer", 42 | "FullModelTorchTuneCheckpointer", 43 | "DistributedCheckpointer", 44 | "ModelType", 45 | "Checkpointer", 46 | "update_state_dict_for_classifier", 47 | "ADAPTER_CONFIG", 48 | "get_largest_iter_folder", 49 | "ADAPTER_KEY", 50 | "EPOCHS_KEY", 51 | "MAX_STEPS_KEY", 52 | "MODEL_KEY", 53 | "OPT_KEY", 54 | "RNG_KEY", 55 | "SEED_KEY", 56 | "STEPS_KEY", 57 | "TOTAL_EPOCHS_KEY", 58 | "FormattedCheckpointFiles", 59 | "DATALOADER_KEY", 60 | ] 61 | -------------------------------------------------------------------------------- /torchtune/recipes/configs/llama3/70B_generation_distributed.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in dev/generate_v2.py to generate output 2 | # using a Llama3 70B Instruct model 3 | # 4 | # This config assumes that you've run the following command before launching: 5 | # tune download meta-llama/Meta-Llama-3-70B-Instruct --output-dir /tmp/Meta-Llama-3-70B-Instruct --ignore-patterns "original/consolidated*" --hf-token 6 | # 7 | # To launch, run the following command from root torchtune directory: 8 | # tune run --nproc_per_node 8 dev/generate_v2_distributed --config llama3/70B_generation_distributed 9 | 10 | output_dir: ./ 11 | 12 | # Model arguments 13 | model: 14 | _component_: torchtune.models.llama3.llama3_70b 15 | 16 | tensor_parallel_plan: 17 | _component_: torchtune.models.llama3.base_llama_tp_plan 18 | 19 | # Transform arguments 20 | tokenizer: 21 | _component_: torchtune.models.llama3.llama3_tokenizer 22 | path: /tmp/Meta-Llama-3-70B-Instruct/original/tokenizer.model 23 | prompt_template: null 24 | max_seq_len: 8192 25 | 26 | # Checkpointer 27 | checkpointer: 28 | _component_: torchtune.training.FullModelHFCheckpointer 29 | checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct 30 | checkpoint_files: 31 | filename_format: model-{}-of-{}.safetensors 32 | max_filename: "00030" 33 | recipe_checkpoint: null 34 | output_dir: ${output_dir} 35 | model_type: LLAMA3 36 | 37 | # Device 38 | device: cuda 39 | dtype: bf16 40 | seed: 1234 41 | log_level: INFO # DEBUG, WARN, etc. 42 | 43 | # Generation arguments 44 | prompt: 45 | system: null 46 | user: 47 | text: Tell a joke. 48 | max_new_tokens: 200 49 | temperature: 0.6 # 0.8 and 0.6 are popular values to try 50 | top_k: 300 51 | -------------------------------------------------------------------------------- /torchtune/recipes/configs/llama3_3/70B_generation_distributed.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in dev/generate_v2.py to generate output 2 | # using a Llama3.1 70B Instruct model 3 | # 4 | # This config assumes that you've run the following command before launching: 5 | # tune download meta-llama/Llama-3.3-70B-Instruct --ignore-patterns "original/consolidated*" --output-dir /tmp/Llama-3.3-70B-Instruct --hf-token 6 | # 7 | # To launch, run the following command from root torchtune directory: 8 | # tune run --nproc_per_node 8 dev/generate_v2_distributed --config llama3_3/70B_generation_distributed 9 | 10 | output_dir: ./ 11 | 12 | # Model arguments 13 | model: 14 | _component_: torchtune.models.llama3_3.llama3_3_70b 15 | 16 | tensor_parallel_plan: 17 | _component_: torchtune.models.llama3.base_llama_tp_plan 18 | 19 | # Transform arguments 20 | tokenizer: 21 | _component_: torchtune.models.llama3.llama3_tokenizer 22 | path: /tmp/Llama-3.3-70B-Instruct/original/tokenizer.model 23 | prompt_template: null 24 | max_seq_len: 8192 25 | 26 | # Checkpointer 27 | checkpointer: 28 | _component_: torchtune.training.FullModelHFCheckpointer 29 | checkpoint_dir: /tmp/Llama-3.3-70B-Instruct/ 30 | checkpoint_files: 31 | filename_format: model-{}-of-{}.safetensors 32 | max_filename: "00030" 33 | recipe_checkpoint: null 34 | output_dir: ${output_dir} 35 | model_type: LLAMA3 36 | 37 | # Device 38 | device: cuda 39 | dtype: bf16 40 | seed: 1234 41 | log_level: INFO # DEBUG, WARN, etc. 42 | 43 | # Generation arguments 44 | prompt: 45 | system: null 46 | user: 47 | text: Tell a joke. 48 | max_new_tokens: 200 49 | temperature: 0.6 # 0.8 and 0.6 are popular values to try 50 | top_k: 300 51 | -------------------------------------------------------------------------------- /torchtune/torchtune/data/_torchdata.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import functools 8 | from typing import Any, Callable, Iterable, Iterator, Mapping, TypeVar 9 | 10 | from torchtune.utils._import_guard import _TORCHDATA_INSTALLED, _TORCHDATA_MIN_VERSION 11 | 12 | from typing_extensions import TypeAlias 13 | 14 | 15 | if _TORCHDATA_INSTALLED: 16 | from torchdata.nodes import BaseNode, Loader # noqa 17 | else: 18 | # If we fail to import torchdata, define stubs to make typechecker happy 19 | T = TypeVar("T") 20 | 21 | class BaseNode(Iterator[T]): 22 | def __init__(self, *args, **kwargs): 23 | pass 24 | 25 | class Loader(Iterable): 26 | def __init__(self, *args, **kwargs): 27 | assert_torchdata_installed() 28 | 29 | 30 | DatasetType: TypeAlias = BaseNode[Mapping[str, Any]] # type: ignore 31 | 32 | 33 | def assert_torchdata_installed(): 34 | if not _TORCHDATA_INSTALLED: 35 | raise ImportError( 36 | f"torchdata is not installed, or the current version is too old. " 37 | f"Please (re-)install it with `pip install torchdata>={_TORCHDATA_MIN_VERSION}`. " 38 | ) 39 | 40 | 41 | def requires_torchdata(func: Callable) -> Callable: 42 | """ 43 | Decorator to check if torchdata is installed and raise an ImportError if not. 44 | """ 45 | 46 | @functools.wraps(func) 47 | def wrapper(*args, **kwargs): 48 | assert_torchdata_installed() 49 | return func(*args, **kwargs) 50 | 51 | return wrapper 52 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/datasets/test_wikitext_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from unittest.mock import patch 7 | 8 | import pytest 9 | 10 | from tests.test_utils import DummyTokenizer 11 | 12 | from torchtune.datasets import wikitext_dataset 13 | 14 | 15 | class TestWikiTextDataset: 16 | @pytest.fixture 17 | def tokenizer(self): 18 | return DummyTokenizer() 19 | 20 | @patch("torchtune.datasets._text_completion.load_dataset") 21 | @pytest.mark.parametrize("max_seq_len", [128, 512, 1024, 4096]) 22 | def test_dataset_get_item(self, load_dataset, tokenizer, max_seq_len): 23 | # Sample data from wikitext dataset 24 | load_dataset.return_value = [ 25 | { 26 | "page": "Bart , like the rest of his family , has yellow skin . " 27 | "Bart usually wears a red T @-@ shirt , blue shorts and blue trainers . " 28 | "When the Simpson family goes to church in the episodes , or to school " 29 | "events or shows , Bart wears a blue suit with a white shirt , a purple " 30 | "tie , blue shorts and a blue jacket .", 31 | } 32 | ] 33 | ds = wikitext_dataset( 34 | tokenizer=tokenizer, 35 | max_seq_len=max_seq_len, 36 | ) 37 | input, label = ds[0]["tokens"], ds[0]["labels"] 38 | assert len(input) <= max_seq_len 39 | assert len(label) <= max_seq_len 40 | assert len(input) == len(label) 41 | assert input[0] == tokenizer.bos_id 42 | -------------------------------------------------------------------------------- /torchtune/torchtune/modules/transforms/vision_utils/pad_dim_to_size.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | 11 | def pad_dim_to_size( 12 | input: torch.Tensor, size: int, dim: int, *, fill: float = 0.0 13 | ) -> torch.Tensor: 14 | """ 15 | Pads the given dimension of the input to the given size. 16 | 17 | Example: 18 | >>> image = torch.rand(1, 4, 4) 19 | >>> padded_image = pad_to_size(image, 3, dim=0) 20 | >>> padded_image.shape 21 | torch.Size([3, 4, 4]) 22 | 23 | Args: 24 | input (torch.Tensor): Tensor to pad. 25 | size (int): Size to pad to. 26 | dim (int): Dimension to pad. 27 | fill (float): Value to fill the padded region with. Default: 0.0 28 | 29 | Returns: 30 | torch.Tensor: Padded input. 31 | """ 32 | pad_size = size - input.shape[dim] 33 | assert ( 34 | pad_size >= 0 35 | ), f"Tensor input shape {input.shape[dim]} is larger than given size {size}" 36 | 37 | # Set up 0 padding for the entire tensor. 38 | # Padding is in order W*H*C*N, with front and back for each dim. 39 | # https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html 40 | padding = [0] * 2 * input.dim() 41 | # Find the pad_index: convert NCHW to WHCN, and only pad the back 42 | # (not both sides). 43 | pad_index = (input.dim() - dim) * 2 - 1 44 | padding[pad_index] = pad_size 45 | # Pad dim to size. 46 | return F.pad(input, padding, value=fill) 47 | -------------------------------------------------------------------------------- /torchtune/recipes/configs/llama3_1/70B_generation_distributed.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in dev/generate_v2.py to generate output 2 | # using a Llama3.1 70B Instruct model 3 | # 4 | # This config assumes that you've run the following command before launching: 5 | # tune download meta-llama/Meta-Llama-3.1-70B-Instruct --output-dir /tmp/Meta-Llama-3.1-70B-Instruct --ignore-patterns "original/consolidated*" --hf-token 6 | # 7 | # To launch, run the following command from root torchtune directory: 8 | # tune run --nproc_per_node 8 dev/generate_v2_distributed --config llama3_1/70B_generation_distributed 9 | 10 | output_dir: ./ 11 | 12 | # Model arguments 13 | model: 14 | _component_: torchtune.models.llama3_1.llama3_1_70b 15 | 16 | tensor_parallel_plan: 17 | _component_: torchtune.models.llama3.base_llama_tp_plan 18 | 19 | # Transform arguments 20 | tokenizer: 21 | _component_: torchtune.models.llama3.llama3_tokenizer 22 | path: /tmp/Meta-Llama-3.1-70B-Instruct/original/tokenizer.model 23 | prompt_template: null 24 | max_seq_len: 8192 25 | 26 | # Checkpointer 27 | checkpointer: 28 | _component_: torchtune.training.FullModelHFCheckpointer 29 | checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ 30 | checkpoint_files: 31 | filename_format: model-{}-of-{}.safetensors 32 | max_filename: "00030" 33 | recipe_checkpoint: null 34 | output_dir: ${output_dir} 35 | model_type: LLAMA3 36 | 37 | # Device 38 | device: cuda 39 | dtype: bf16 40 | seed: 1234 41 | log_level: INFO # DEBUG, WARN, etc. 42 | 43 | # Generation arguments 44 | prompt: 45 | system: null 46 | user: 47 | text: Tell a joke. 48 | max_new_tokens: 200 49 | temperature: 0.6 # 0.8 and 0.6 are popular values to try 50 | top_k: 300 51 | -------------------------------------------------------------------------------- /mirotrain/models/qwen3/_checkpointing_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2025 MiromindAI 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import json 6 | from functools import wraps 7 | from pathlib import Path 8 | 9 | from ._positional_embeddings import RopeScaling 10 | 11 | 12 | def patch_copy_files(max_position_embeddings: int, rope_scaling: RopeScaling): 13 | import mirotrain.training.checkpointing._checkpointer as _checkpointer 14 | 15 | original_copy_files = _checkpointer.copy_files 16 | 17 | @wraps(original_copy_files) 18 | def patched_copy_files( 19 | input_dir: str | Path, output_dir: str | Path, *args, **kwargs 20 | ): 21 | """ 22 | Find the config.json in output_dir and update it with custom model_extra parameters. 23 | 24 | Modifications applied to config.json: 25 | - max_position_embeddings: Updated if specified in model_extra 26 | - rope_scaling: Added/replaced if specified in model_extra 27 | """ 28 | 29 | original_copy_files(input_dir, output_dir, *args, **kwargs) 30 | 31 | if isinstance(output_dir, str): 32 | output_dir = Path(output_dir) 33 | config_path = output_dir / "config.json" 34 | if not config_path.exists(): 35 | return 36 | 37 | with open(config_path, "r") as f: 38 | config = json.load(f) 39 | 40 | config["max_position_embeddings"] = max_position_embeddings 41 | config["rope_scaling"] = rope_scaling 42 | 43 | with open(config_path, "w") as f: 44 | json.dump(config, f, indent=2, ensure_ascii=False) 45 | f.write("\n") 46 | 47 | return patched_copy_files 48 | 49 | _checkpointer.copy_files = patched_copy_files 50 | -------------------------------------------------------------------------------- /mirotrain/models/qwen3/_validate.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2025 MiromindAI 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from ._positional_embeddings import RopeScaling 6 | 7 | 8 | def validate_yarn_cfg(max_position_embeddings: int, rope_scaling: RopeScaling): 9 | """ 10 | Validate YARN-specific configs 11 | """ 12 | 13 | rope_type = rope_scaling.get("rope_type") 14 | if rope_type is None: 15 | raise KeyError("'rope_type' must be specified in 'rope_scaling' configuration.") 16 | if rope_type != "yarn": 17 | raise ValueError( 18 | f"Unsupported 'rope_type': {rope_type}. " 19 | "Currently, only 'yarn' is supported for Qwen patching." 20 | ) 21 | 22 | factor = rope_scaling.get("factor") 23 | if factor is None: 24 | raise KeyError("'factor' must be specified in 'rope_scaling' for YARN.") 25 | 26 | original_max_position_embeddings = rope_scaling.get( 27 | "original_max_position_embeddings" 28 | ) 29 | if original_max_position_embeddings is None: 30 | raise KeyError( 31 | "'original_max_position_embeddings' must be specified in 'rope_scaling' for YARN." 32 | ) 33 | 34 | if max_position_embeddings is None: 35 | raise KeyError( 36 | "'max_position_embeddings' must be specified in model_extra when rope_scaling is enabled." 37 | ) 38 | 39 | if max_position_embeddings != int(original_max_position_embeddings * factor): 40 | raise ValueError( 41 | f"Configured 'max_position_embeddings' ({max_position_embeddings}) is not equal to " 42 | f"'original_max_position_embeddings' ({original_max_position_embeddings}) * " 43 | f"'rope_scaling.factor' ({factor})." 44 | ) 45 | -------------------------------------------------------------------------------- /torchtune/recipes/configs/llama3_2_vision/11B_generation_v2.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in dev/generate_v2.py to generate output 2 | # from a Llama3.2 11B Vision Instruct model 3 | # 4 | # This config assumes that you've run the following command before launching: 5 | # tune download meta-llama/Llama-3.2-11B-Vision-Instruct --output-dir /tmp/Llama-3.2-11B-Vision-Instruct --ignore-patterns "original/consolidated*" 6 | # 7 | # To launch, run the following command from root torchtune directory: 8 | # tune run dev/generate_v2 --config llama3_2_vision/11B_generation_v2 9 | 10 | output_dir: ./ 11 | 12 | # Model arguments 13 | model: 14 | _component_: torchtune.models.llama3_2_vision.llama3_2_vision_11b 15 | 16 | # Transform arguments 17 | tokenizer: 18 | _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform 19 | path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model 20 | prompt_template: null 21 | max_seq_len: 8192 22 | 23 | # Checkpointer 24 | checkpointer: 25 | _component_: torchtune.training.FullModelHFCheckpointer 26 | checkpoint_dir: /tmp/Llama-3.2-11B-Vision-Instruct/ 27 | checkpoint_files: 28 | filename_format: model-{}-of-{}.safetensors 29 | max_filename: "00005" 30 | output_dir: ${output_dir} 31 | model_type: LLAMA3_VISION 32 | 33 | # Device 34 | device: cuda 35 | dtype: bf16 36 | seed: 1234 37 | log_level: INFO # DEBUG, WARN, etc. 38 | 39 | # Generation arguments 40 | prompt: 41 | system: You are a helpful assistant who responds like the author Shakespeare. 42 | user: 43 | image: https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg 44 | text: What is in this image? 45 | max_new_tokens: 200 46 | temperature: 0.6 # 0.8 and 0.6 are popular values to try 47 | top_k: 300 48 | -------------------------------------------------------------------------------- /torchtune/torchtune/_cli/tune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | 9 | from torchtune._cli.cat import Cat 10 | 11 | from torchtune._cli.cp import Copy 12 | from torchtune._cli.download import Download 13 | from torchtune._cli.ls import List 14 | from torchtune._cli.run import Run 15 | from torchtune._cli.validate import Validate 16 | 17 | 18 | class TuneCLIParser: 19 | """Holds all information related to running the CLI""" 20 | 21 | def __init__(self): 22 | # Initialize the top-level parser 23 | self._parser = argparse.ArgumentParser( 24 | prog="tune", 25 | description="Welcome to the torchtune CLI!", 26 | add_help=True, 27 | ) 28 | # Default command is to print help 29 | self._parser.set_defaults(func=lambda args: self._parser.print_help()) 30 | 31 | # Add subcommands 32 | subparsers = self._parser.add_subparsers(title="subcommands") 33 | Download.create(subparsers) 34 | List.create(subparsers) 35 | Copy.create(subparsers) 36 | Run.create(subparsers) 37 | Validate.create(subparsers) 38 | Cat.create(subparsers) 39 | 40 | def parse_args(self) -> argparse.Namespace: 41 | """Parse CLI arguments""" 42 | return self._parser.parse_args() 43 | 44 | def run(self, args: argparse.Namespace) -> None: 45 | """Execute CLI""" 46 | args.func(args) 47 | 48 | 49 | def main(): 50 | parser = TuneCLIParser() 51 | args = parser.parse_args() 52 | parser.run(args) 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /torchtune/docs/source/api_ref_datasets.rst: -------------------------------------------------------------------------------- 1 | .. _datasets: 2 | 3 | ================== 4 | torchtune.datasets 5 | ================== 6 | 7 | .. currentmodule:: torchtune.datasets 8 | 9 | For a detailed general usage guide, please see :ref:`datasets_overview`. 10 | 11 | 12 | Text datasets 13 | ------------- 14 | 15 | torchtune supports several widely used text-only datasets to help quickly bootstrap your fine-tuning. 16 | 17 | .. autosummary:: 18 | :toctree: generated/ 19 | :nosignatures: 20 | 21 | alpaca_dataset 22 | alpaca_cleaned_dataset 23 | grammar_dataset 24 | hh_rlhf_helpful_dataset 25 | samsum_dataset 26 | slimorca_dataset 27 | stack_exchange_paired_dataset 28 | cnn_dailymail_articles_dataset 29 | wikitext_dataset 30 | 31 | Image + Text datasets 32 | --------------------- 33 | 34 | .. autosummary:: 35 | :toctree: generated/ 36 | :nosignatures: 37 | 38 | multimodal.llava_instruct_dataset 39 | multimodal.the_cauldron_dataset 40 | multimodal.vqa_dataset 41 | 42 | .. _dataset_builders: 43 | 44 | Generic dataset builders 45 | ------------------------ 46 | 47 | torchtune also supports generic dataset builders for common formats like chat models and instruct models. 48 | These are especially useful for specifying from a YAML config. 49 | 50 | .. autosummary:: 51 | :toctree: generated/ 52 | :nosignatures: 53 | 54 | instruct_dataset 55 | chat_dataset 56 | preference_dataset 57 | text_completion_dataset 58 | 59 | Generic dataset classes 60 | ----------------------- 61 | 62 | Class representations for the above dataset builders. 63 | 64 | .. autosummary:: 65 | :toctree: generated/ 66 | :nosignatures: 67 | 68 | TextCompletionDataset 69 | ConcatDataset 70 | PackedDataset 71 | PreferenceDataset 72 | SFTDataset 73 | -------------------------------------------------------------------------------- /torchtune/torchtune/config/_validate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import inspect 8 | 9 | from omegaconf import DictConfig 10 | from torchtune.config._errors import ConfigError 11 | from torchtune.config._utils import _get_component_from_path, _has_component 12 | 13 | 14 | def validate(cfg: DictConfig) -> None: 15 | """ 16 | Ensure that all components in the config can be instantiated correctly 17 | 18 | Args: 19 | cfg (DictConfig): The config to validate 20 | 21 | Raises: 22 | ConfigError: If any component cannot be instantiated 23 | """ 24 | 25 | errors = [] 26 | for node, nodedict in cfg.items(): 27 | if _has_component(nodedict): 28 | try: 29 | _component_ = _get_component_from_path(nodedict.get("_component_")) 30 | kwargs = {k: v for k, v in nodedict.items() if k != "_component_"} 31 | sig = inspect.signature(_component_) 32 | sig.bind(**kwargs) 33 | # Some objects require other objects as arguments, like optimizers, 34 | # lr_schedulers, datasets, etc. Try doing partial instantiation 35 | except TypeError as e: 36 | if "missing a required argument" in str(e): 37 | sig.bind_partial(**kwargs) 38 | else: 39 | # inspect.signature does not retain the function name in the 40 | # exception, so we manually add it back in 41 | e = TypeError(f"{_component_.__name__} {str(e)}") 42 | errors.append(e) 43 | 44 | if errors: 45 | raise ConfigError(errors) 46 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/modules/test_layernorm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | 9 | import torch 10 | 11 | from tests.test_utils import assert_expected 12 | 13 | from torchtune.modules.layer_norm import Fp32LayerNorm 14 | from torchtune.training.seed import set_seed 15 | 16 | 17 | @pytest.fixture(autouse=True) 18 | def random(): 19 | set_seed(0) 20 | 21 | 22 | class TestLayerNorm: 23 | """ 24 | Class for testing our LayerNorm, which is just a wrapper around torch.nn.LayerNorm 25 | to support fp16 training. 26 | """ 27 | 28 | @pytest.fixture 29 | def dim(self) -> int: 30 | return 8 31 | 32 | @pytest.fixture 33 | def eps(self) -> float: 34 | return 1e-6 35 | 36 | @pytest.fixture 37 | def input_random_fp16(self, dim) -> torch.Tensor: 38 | return torch.randn(dim, dtype=torch.float16) 39 | 40 | @pytest.fixture 41 | def layer_norm(self, dim, eps) -> Fp32LayerNorm: 42 | return Fp32LayerNorm(dim, eps=eps) 43 | 44 | def test_forward_fp16(self, layer_norm, input_random_fp16, eps, dim) -> None: 45 | output_fp16 = layer_norm(input_random_fp16) 46 | 47 | # assert dtype as fp16 48 | assert ( 49 | output_fp16.dtype == torch.float16 50 | ), "Expected output to be fp16, but got {output_fp16.dtype=}" 51 | 52 | # assert value as fp32 53 | expected_output = torch.nn.LayerNorm(dim, eps=eps)(input_random_fp16.float()) 54 | output_fp32 = layer_norm(input_random_fp16.float()) 55 | assert_expected( 56 | output_fp32.mean(), expected_output.mean(), atol=1e-8, rtol=1e-8 57 | ) 58 | -------------------------------------------------------------------------------- /torchtune/docs/source/_static/img/pytorch-logo-dark.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 10 | 13 | 14 | 16 | 17 | 18 | 20 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /torchtune/docs/source/api_ref_data.rst: -------------------------------------------------------------------------------- 1 | .. _data: 2 | 3 | ============== 4 | torchtune.data 5 | ============== 6 | 7 | .. currentmodule:: torchtune.data 8 | 9 | Text templates 10 | -------------- 11 | 12 | Templates for instruct prompts and chat prompts. Includes some specific formatting for difference datasets 13 | and models. 14 | 15 | .. autosummary:: 16 | :toctree: generated/ 17 | :nosignatures: 18 | 19 | GrammarErrorCorrectionTemplate 20 | SummarizeTemplate 21 | QuestionAnswerTemplate 22 | PromptTemplate 23 | PromptTemplateInterface 24 | ChatMLTemplate 25 | 26 | Types 27 | ----- 28 | 29 | .. autosummary:: 30 | :toctree: generated/ 31 | :nosignatures: 32 | 33 | Message 34 | Role 35 | 36 | .. _message_transforms_ref: 37 | 38 | Message transforms 39 | ------------------ 40 | 41 | Converts data from common schema and conversation JSON formats into a list of torchtune :class:`Message`. 42 | 43 | .. autosummary:: 44 | :toctree: generated/ 45 | :nosignatures: 46 | 47 | InputOutputToMessages 48 | ShareGPTToMessages 49 | OpenAIToMessages 50 | ChosenRejectedToMessages 51 | AlpacaToMessages 52 | 53 | Collaters 54 | --------- 55 | 56 | Collaters used to collect samples into batches and handle any padding. 57 | 58 | .. autosummary:: 59 | :toctree: generated/ 60 | :nosignatures: 61 | 62 | padded_collate 63 | padded_collate_dpo 64 | padded_collate_packed 65 | padded_collate_sft 66 | padded_collate_tiled_images_and_mask 67 | left_pad_sequence 68 | 69 | Helper functions 70 | ---------------- 71 | 72 | Miscellaneous helper functions used in modifying data. 73 | 74 | .. autosummary:: 75 | :toctree: generated/ 76 | :nosignatures: 77 | 78 | validate_messages 79 | truncate 80 | load_image 81 | format_content_with_images 82 | mask_messages 83 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/qwen2_5/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._model_builders import ( 8 | lora_qwen2_5_0_5b, 9 | lora_qwen2_5_14b_base, 10 | lora_qwen2_5_14b_instruct, 11 | lora_qwen2_5_1_5b_base, 12 | lora_qwen2_5_1_5b_instruct, 13 | lora_qwen2_5_32b_base, 14 | lora_qwen2_5_32b_instruct, 15 | lora_qwen2_5_3b, 16 | lora_qwen2_5_72b_base, 17 | lora_qwen2_5_72b_instruct, 18 | lora_qwen2_5_7b_base, 19 | lora_qwen2_5_7b_instruct, 20 | qwen2_5_0_5b, 21 | qwen2_5_14b_base, 22 | qwen2_5_14b_instruct, 23 | qwen2_5_1_5b_base, 24 | qwen2_5_1_5b_instruct, 25 | qwen2_5_32b_base, 26 | qwen2_5_32b_instruct, 27 | qwen2_5_3b, 28 | qwen2_5_72b_base, 29 | qwen2_5_72b_instruct, 30 | qwen2_5_7b_base, 31 | qwen2_5_7b_instruct, 32 | qwen2_5_tokenizer, 33 | ) 34 | 35 | __all__ = [ 36 | "lora_qwen2_5_0_5b", 37 | "lora_qwen2_5_14b_base", 38 | "lora_qwen2_5_14b_instruct", 39 | "lora_qwen2_5_1_5b_base", 40 | "lora_qwen2_5_1_5b_instruct", 41 | "lora_qwen2_5_32b_base", 42 | "lora_qwen2_5_32b_instruct", 43 | "lora_qwen2_5_3b", 44 | "lora_qwen2_5_72b_base", 45 | "lora_qwen2_5_72b_instruct", 46 | "lora_qwen2_5_7b_base", 47 | "lora_qwen2_5_7b_instruct", 48 | "qwen2_5_0_5b", 49 | "qwen2_5_14b_base", 50 | "qwen2_5_14b_instruct", 51 | "qwen2_5_1_5b_base", 52 | "qwen2_5_1_5b_instruct", 53 | "qwen2_5_32b_base", 54 | "qwen2_5_32b_instruct", 55 | "qwen2_5_3b", 56 | "qwen2_5_72b_base", 57 | "qwen2_5_72b_instruct", 58 | "qwen2_5_7b_base", 59 | "qwen2_5_7b_instruct", 60 | "qwen2_5_tokenizer", 61 | ] 62 | -------------------------------------------------------------------------------- /torchtune/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | ifneq ($(EXAMPLES_PATTERN),) 5 | EXAMPLES_PATTERN_OPTS := -D sphinx_gallery_conf.filename_pattern="$(EXAMPLES_PATTERN)" 6 | endif 7 | 8 | # You can set these variables from the command line. 9 | SPHINXOPTS = -W -j auto -T -v $(EXAMPLES_PATTERN_OPTS) 10 | SPHINXBUILD = sphinx-build 11 | SPHINXPROJ = torchtune 12 | SOURCEDIR = source 13 | BUILDDIR = build 14 | 15 | # Put it first so that "make" without argument is like "make help". 16 | help: 17 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 18 | 19 | docset: html 20 | doc2dash --name $(SPHINXPROJ) --icon $(SOURCEDIR)/_static/img/pytorch-logo-flame.png --enable-js --online-redirect-url http://pytorch.org/vision/ --force $(BUILDDIR)/html/ 21 | 22 | # Manually fix because Zeal doesn't deal well with `icon.png`-only at 2x resolution. 23 | cp $(SPHINXPROJ).docset/icon.png $(SPHINXPROJ).docset/icon@2x.png 24 | convert $(SPHINXPROJ).docset/icon@2x.png -resize 16x16 $(SPHINXPROJ).docset/icon.png 25 | 26 | html-noplot: # Avoids running the gallery examples, which may take time 27 | $(SPHINXBUILD) -D plot_gallery=0 -b html "${SOURCEDIR}" "$(BUILDDIR)"/html 28 | @echo 29 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 30 | 31 | clean: 32 | rm -rf $(BUILDDIR)/* 33 | rm -rf $(SOURCEDIR)/generated_examples/ # sphinx-gallery 34 | rm -rf $(SOURCEDIR)/gen_modules/ # sphinx-gallery 35 | rm -rf $(SOURCEDIR)/sg_execution_times.rst # sphinx-gallery 36 | rm -rf $(SOURCEDIR)/generated/ # autosummary 37 | 38 | .PHONY: help Makefile docset 39 | 40 | # Catch-all target: route all unknown targets to Sphinx using the new 41 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 42 | %: Makefile 43 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 44 | -------------------------------------------------------------------------------- /torchtune/recipes/full_finetune_multinode.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # ---------- SBATCH commands ---------- # 9 | #SBATCH --job-name=torchtune-multi-node 10 | #SBATCH --ntasks=2 11 | #SBATCH --nodes=2 12 | #SBATCH --gpus-per-task=8 13 | #SBATCH --cpus-per-task=96 14 | #SBATCH --partition=train 15 | 16 | # ---------- Set env variables ---------- # 17 | # Grab the IP for head node: 18 | # You may need to set this to the fully qualified domain name of your head node 19 | nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) 20 | nodes_array=($nodes) 21 | head_node=${nodes_array[0]} 22 | head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) 23 | echo Node IP: $head_node_ip 24 | 25 | # You might need to explicitly set the network interface for distributed backends: 26 | # export NCCL_SOCKET_IFNAME=... 27 | # export GLOO_SOCKET_IFNAME=... 28 | 29 | export TORCH_DIST_INIT_BARRIER=1 30 | export LOGLEVEL=INFO 31 | 32 | # ---------- Launch training ---------- # 33 | # You probably want to load in a virtual env w/ conda... 34 | # module load conda 35 | # conda activate torchtune 36 | # ...or venv 37 | # source torchtune/bin/activate 38 | 39 | SHARED_FS=/mnt/slurm # <-- Replace w/ your filesystem 40 | CHECKPOINT_DIR="$SHARED_FS/Llama-3.3-70B-Instruct" 41 | OUTPUT_DIR="$SHARED_FS/Llama3.3-70B-fft-output" 42 | 43 | # Adjust sbatch --ntasks and sbatch --nodes above and --nnodes below to your specific node count 44 | srun tune run --nnodes 2 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" \ 45 | full_finetune_distributed --config llama3_3/70B_full_multinode checkpoint_dir=$CHECKPOINT_DIR output_dir=$OUTPUT_DIR 46 | -------------------------------------------------------------------------------- /torchtune/recipes/dev/multinode_grpo.sbatch: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # ---------- SBATCH commands ---------- # 9 | #SBATCH --time=72:00:00 10 | #SBATCH --job-name=torchtune-multi-node 11 | #SBATCH --constraint=volta32gb 12 | #SBATCH --ntasks-per-node=1 13 | #SBATCH --nodes=2 14 | #SBATCH --exclusive 15 | #SBATCH --gpus-per-task=8 16 | #SBATCH --cpus-per-task=80 17 | #SBATCH --output=slurm_logs/%j/%N.out 18 | #SBATCH --error=slurm_logs/%j/%N.err 19 | 20 | 21 | # ---------- Set env variables ---------- # 22 | # Grab the IP for head node: 23 | # You may need to set this to the fully qualified domain name of your head node 24 | nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) 25 | nodes_array=($nodes) 26 | head_node=${nodes_array[0]} 27 | head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) 28 | echo Node IP: $head_node_ip 29 | 30 | # You might need to explicitly set the network interface for distributed backends: 31 | # export NCCL_SOCKET_IFNAME=... 32 | # export GLOO_SOCKET_IFNAME=... 33 | 34 | export TORCH_DIST_INIT_BARRIER=1 35 | export LOGLEVEL=INFO 36 | 37 | # ---------- Launch training ---------- # 38 | # You probably want to load in a virtual env w/ conda... 39 | # module load conda 40 | # conda activate torchtune 41 | # ...or venv 42 | # source torchtune/bin/activate 43 | 44 | source ../../.venv/bin/activate 45 | 46 | # Adjust sbatch --ntasks and sbatch --nodes above and --nnodes below to your specific node count 47 | srun --export=ALL,OMP_NUM_THREADS=8 tune run --nnodes 2 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" \ 48 | dev/grpo_full_finetune_distributed --config dev/3B_full_grpo "$@" 49 | -------------------------------------------------------------------------------- /torchtune/recipes/configs/llama4/scout_17B_16E_generation_distributed.yaml: -------------------------------------------------------------------------------- 1 | # Config for running the InferenceRecipe in dev/generate_v2.py to generate output 2 | # from a Llama4 17Bx16E MoE model 3 | # 4 | # This config assumes that you've run the following command before launching 5 | # tune download meta-llama/Llama-4-Scout-17B-16E-Instruct 6 | # 7 | # To launch, run the following command: 8 | # tune run --nproc_per_node 4 dev/generate_v2_distributed --config llama4/scout_17B_16E_generation_distributed 9 | 10 | # Model arguments 11 | model: 12 | _component_: torchtune.models.llama4.llama4_scout_17b_16e 13 | 14 | tensor_parallel_plan: 15 | _component_: torchtune.models.llama4.decoder_only_tp_plan 16 | 17 | tokenizer: 18 | _component_: torchtune.models.llama4.llama4_transform 19 | path: /tmp/Llama-4-Scout-17B-16E-Instruct/tokenizer.model 20 | max_seq_len: null 21 | max_num_tiles: 16 22 | 23 | checkpointer: 24 | _component_: torchtune.training.FullModelHFCheckpointer 25 | checkpoint_dir: /tmp/Llama-4-Scout-17B-16E-Instruct # You can also point this to your finetuned model! 26 | checkpoint_files: 27 | filename_format: model-{}-of-{}.safetensors 28 | max_filename: "00050" 29 | output_dir: ./ # No need for an output dir 30 | model_type: LLAMA4 31 | 32 | use_distributed_state_dict: True 33 | use_flex: True # Use PyTorch's FlexAttention for construction of attention masks 34 | 35 | # Environment 36 | device: cuda 37 | dtype: bf16 38 | seed: 1234 39 | log_level: INFO # DEBUG, WARN, etc. 40 | 41 | # Generation arguments 42 | prompt: 43 | system: You are a helpful assistant who responds like the author Shakespeare. 44 | user: 45 | image: https://upload.wikimedia.org/wikipedia/commons/thumb/4/4c/Olive_baboon.jpg/330px-Olive_baboon.jpg 46 | text: What is in this image? 47 | max_new_tokens: 200 48 | temperature: 0.6 # 0.8 and 0.6 are popular values to try 49 | top_k: 300 50 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/mistral/test_mistral.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | from tests.test_utils import fixed_init_model 10 | from tests.torchtune.models.mistral.scripts.mistral_test_config import MistralTestConfig 11 | from torchtune.models.mistral import mistral 12 | from torchtune.training.seed import set_seed 13 | 14 | 15 | @pytest.fixture(autouse=True) 16 | def random(): 17 | set_seed(MistralTestConfig.SEED) 18 | 19 | 20 | class TestMistral: 21 | @pytest.fixture 22 | def inputs(self): 23 | return torch.randint( 24 | 0, 25 | MistralTestConfig.VOCAB_SIZE, 26 | (MistralTestConfig.BSZ, MistralTestConfig.SEQ_LEN), 27 | ) 28 | 29 | def test_forward(self, inputs): 30 | model = mistral( 31 | vocab_size=MistralTestConfig.VOCAB_SIZE, 32 | embed_dim=MistralTestConfig.EMBED_DIM, 33 | num_heads=MistralTestConfig.NUM_HEADS, 34 | num_layers=MistralTestConfig.NUM_LAYERS, 35 | num_kv_heads=MistralTestConfig.NUM_KV_HEADS, 36 | max_seq_len=MistralTestConfig.MAX_SEQ_LEN, 37 | intermediate_dim=MistralTestConfig.INTERMEDIATE_DIM, 38 | norm_eps=MistralTestConfig.NORM_EPS, 39 | rope_base=MistralTestConfig.ROPE_BASE, 40 | ) 41 | fixed_init_model(model) 42 | actual = model(inputs) 43 | expected = torch.tensor(18.2749) 44 | assert actual.shape == ( 45 | MistralTestConfig.BSZ, 46 | MistralTestConfig.SEQ_LEN, 47 | MistralTestConfig.VOCAB_SIZE, 48 | ) 49 | torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-4) 50 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/dev/rl/rewards/test_rewards.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | from torchtune.dev.rl.rewards import RewardOutput 10 | 11 | 12 | class TestRewardOutput: 13 | @pytest.fixture 14 | def sample_reward_output(self): 15 | return RewardOutput( 16 | reward_base_name="test_reward", 17 | total_reward=torch.tensor([1.0, 2.0, 3.0]), 18 | successes=torch.tensor([1.0, 0.0, 1.0]), 19 | rewards={ 20 | "sub_reward_1": torch.tensor([0.5, 1.5, 2.5]), 21 | "sub_reward_2": torch.tensor([10.0, 20.0, 30.0]), 22 | }, 23 | ) 24 | 25 | def test_log(self, sample_reward_output): 26 | log_dict = sample_reward_output.log(prefix="train") 27 | expected_log = { 28 | "train/test_reward/sub_reward_1": 1.5, 29 | "train/test_reward/sub_reward_2": 20.0, 30 | "train/test_reward": 2.0, 31 | "train/test_reward/successes": 2.0 / 3.0, 32 | } 33 | assert log_dict.keys() == expected_log.keys() 34 | for key in expected_log: 35 | assert log_dict[key] == pytest.approx(expected_log[key]) 36 | 37 | def test_log_no_prefix(self, sample_reward_output): 38 | log_dict = sample_reward_output.log() 39 | expected_log = { 40 | "test_reward/sub_reward_1": 1.5, 41 | "test_reward/sub_reward_2": 20.0, 42 | "test_reward": 2.0, 43 | "test_reward/successes": 2.0 / 3.0, 44 | } 45 | assert log_dict.keys() == expected_log.keys() 46 | for key in expected_log: 47 | assert log_dict[key] == pytest.approx(expected_log[key]) 48 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/t5/_model_builders.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torchtune.models.t5._component_builders import t5_encoder 8 | from torchtune.models.t5._encoder import T5Encoder 9 | from torchtune.models.t5._tokenizer import T5Tokenizer 10 | 11 | 12 | def t5_v1_1_xxl_encoder(max_seq_len: int = 512) -> T5Encoder: 13 | """ 14 | Builder for the T5 v1.1 XXL (11B parameters) encoder. 15 | 16 | T5 paper: https://arxiv.org/abs/1910.10683 17 | 18 | 1.1 release: 19 | https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511 20 | 21 | Args: 22 | max_seq_len (int): The maximum sequence length (context length) of the model. 23 | Default: 512 24 | 25 | Returns: 26 | T5Encoder: Instantiation of the T5 encoder 27 | """ 28 | return t5_encoder( 29 | embed_dim=4096, 30 | mlp_dim=10240, 31 | num_heads=64, 32 | head_dim=64, 33 | num_layers=24, 34 | rel_pos_num_buckets=32, 35 | rel_pos_max_dist=128, 36 | vocab_size=32128, 37 | norm_eps=1e-6, 38 | max_seq_len=max_seq_len, 39 | ) 40 | 41 | 42 | def t5_tokenizer(path: str, max_seq_len: int = 512, truncate: bool = True): 43 | """ 44 | Builder for the T5 tokenizer. 45 | 46 | Args: 47 | path (str): the path to the T5 sentencepiece tokenizer file 48 | max_seq_len (int): the context length 49 | truncate (bool): whether to truncate the token sequence when longer than max_seq_len 50 | 51 | Returns: 52 | T5Tokenizer: Instantiation of the T5 tokenizer 53 | """ 54 | return T5Tokenizer(path, max_seq_len=max_seq_len, truncate=truncate) 55 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/llama2/test_llama2_prompt_template.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from tests.test_utils import assert_dialogue_equal, MESSAGE_SAMPLE 8 | from torchtune.data import Message 9 | from torchtune.models.llama2 import Llama2ChatTemplate 10 | 11 | 12 | class TestLlama2ChatTemplate: 13 | expected_dialogue = [ 14 | Message( 15 | role="user", 16 | content="[INST] <>\nYou are an AI assistant. User will you give you a task. " 17 | "Your goal is to complete the task as faithfully as you can. While performing " 18 | "the task think step-by-step and justify your steps.\n<>\n\nPlease " 19 | "briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old " 20 | "Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? " 21 | "How about on an icy road? Well one father in Russia did just that, and recorded " 22 | "the entire thing. To her credit, the child seemed to be doing a great job. " 23 | "(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n" 24 | "Summary: [/INST] ", 25 | ), 26 | Message( 27 | role="assistant", 28 | content="A father in Russia allowed his 8-year-old child to drive his car on an " 29 | "icy road and recorded the event. The child appeared to be handling the situation well, " 30 | "showcasing their driving skills despite the challenging conditions.", 31 | ), 32 | ] 33 | 34 | def test_call(self): 35 | actual = Llama2ChatTemplate()(MESSAGE_SAMPLE) 36 | assert_dialogue_equal(actual, self.expected_dialogue) 37 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/modules/transforms/tokenizers/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | 9 | from tests.test_utils import DummyTokenizer 10 | from torchtune.data import Message 11 | 12 | from torchtune.modules.transforms.tokenizers import tokenize_messages_no_special_tokens 13 | 14 | 15 | class TestTokenizerUtils: 16 | @pytest.fixture 17 | def tokenizer(self): 18 | return DummyTokenizer(max_seq_len=100) 19 | 20 | @pytest.fixture 21 | def messages(self): 22 | return [ 23 | Message(role="user", content="hello world!", masked=True), 24 | Message(role="assistant", content="hello back!"), 25 | ] 26 | 27 | @pytest.mark.parametrize( 28 | "add_bos, add_eos", 29 | [ 30 | (True, True), 31 | (False, False), 32 | ], 33 | ) 34 | def test_tokenize_no_special_tokens(self, tokenizer, messages, add_bos, add_eos): 35 | tokens, mask = tokenize_messages_no_special_tokens( 36 | tokenizer, 37 | messages, 38 | bos_id=tokenizer.bos_id if add_bos else None, 39 | eos_id=tokenizer.eos_id if add_eos else None, 40 | ) 41 | 42 | assert len(tokens) == len(mask) 43 | 44 | # User message should be masked 45 | assert mask[0] is True 46 | # Assistant message should not be masked 47 | assert mask[-1] is False 48 | 49 | if add_bos: 50 | assert tokens[0] == tokenizer.bos_id 51 | else: 52 | assert tokens[0] != tokenizer.bos_id 53 | 54 | if add_eos: 55 | assert tokens[-1] == tokenizer.eos_id 56 | else: 57 | assert tokens[-1] != tokenizer.eos_id 58 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/mistral/test_mistral_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | from tests.test_utils import fixed_init_model 10 | from torchtune.models.mistral import mistral_classifier 11 | from torchtune.training.seed import set_seed 12 | 13 | NUM_LAYERS = 4 14 | NUM_HEADS = 16 15 | NUM_KV_HEADS = 8 16 | VOCAB_SIZE = 32000 17 | MAX_SEQ_LEN = 2048 18 | INTERMEDIATE_DIM = 512 19 | 20 | 21 | @pytest.fixture(autouse=True) 22 | def random(): 23 | set_seed(16) 24 | 25 | 26 | class TestMistralClassifier: 27 | # expected values are calculated using 28 | # `tests.torchtune.models.scripts.compare_mistral_classifier` 29 | @pytest.mark.parametrize( 30 | "bsz, embed_dim, seq_len, n_classes, expected", 31 | [ 32 | (2, 64, 64, 2, 22.6879), 33 | (1, 256, 256, 1, 110.2561), 34 | ], 35 | ) 36 | def test_forward( 37 | self, bsz: int, embed_dim: int, seq_len: int, n_classes: int, expected: float 38 | ): 39 | inputs = torch.randint(low=0, high=VOCAB_SIZE, size=(bsz, seq_len)) 40 | model = mistral_classifier( 41 | num_classes=n_classes, 42 | vocab_size=VOCAB_SIZE, 43 | num_layers=n_classes, 44 | num_heads=NUM_HEADS, 45 | num_kv_heads=NUM_KV_HEADS, 46 | embed_dim=embed_dim, 47 | intermediate_dim=INTERMEDIATE_DIM, 48 | max_seq_len=MAX_SEQ_LEN, 49 | ) 50 | fixed_init_model(model) 51 | actual = model(inputs) 52 | expected = torch.tensor(expected) 53 | assert actual.shape == (bsz, seq_len, n_classes) 54 | torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-4) 55 | -------------------------------------------------------------------------------- /torchtune/torchtune/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torchtune.data._collate import ( 8 | left_pad_sequence, 9 | padded_collate, 10 | padded_collate_dpo, 11 | padded_collate_packed, 12 | padded_collate_sft, 13 | padded_collate_tiled_images_and_mask, 14 | ) 15 | from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX 16 | from torchtune.data._messages import ( 17 | AlpacaToMessages, 18 | ChosenRejectedToMessages, 19 | InputOutputToMessages, 20 | mask_messages, 21 | Message, 22 | OpenAIToMessages, 23 | Role, 24 | ShareGPTToMessages, 25 | validate_messages, 26 | ) 27 | from torchtune.data._prompt_templates import ( 28 | ChatMLTemplate, 29 | GrammarErrorCorrectionTemplate, 30 | PromptTemplate, 31 | PromptTemplateInterface, 32 | QuestionAnswerTemplate, 33 | SummarizeTemplate, 34 | ) 35 | from torchtune.data._utils import format_content_with_images, load_image, truncate 36 | 37 | __all__ = [ 38 | "CROSS_ENTROPY_IGNORE_IDX", 39 | "GrammarErrorCorrectionTemplate", 40 | "SummarizeTemplate", 41 | "OpenAIToMessages", 42 | "ShareGPTToMessages", 43 | "AlpacaToMessages", 44 | "truncate", 45 | "Message", 46 | "validate_messages", 47 | "mask_messages", 48 | "Role", 49 | "format_content_with_images", 50 | "PromptTemplateInterface", 51 | "PromptTemplate", 52 | "InputOutputToMessages", 53 | "ChosenRejectedToMessages", 54 | "QuestionAnswerTemplate", 55 | "ChatMLTemplate", 56 | "padded_collate_sft", 57 | "padded_collate_dpo", 58 | "left_pad_sequence", 59 | "padded_collate", 60 | "padded_collate_tiled_images_and_mask", 61 | "padded_collate_packed", 62 | "load_image", 63 | ] 64 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/modules/test_feed_forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Tuple 8 | 9 | import pytest 10 | 11 | import torch 12 | 13 | from tests.test_utils import assert_expected, fixed_init_model 14 | from torch import nn 15 | 16 | from torchtune.modules import FeedForward 17 | from torchtune.training.seed import set_seed 18 | 19 | 20 | @pytest.fixture(autouse=True) 21 | def random(): 22 | set_seed(0) 23 | 24 | 25 | class TestFeedForward: 26 | """Class for testing FFN implementation.""" 27 | 28 | @pytest.fixture 29 | def input_params(self) -> Tuple[int, int]: 30 | dim = 4096 31 | hidden_dim = 11008 # Scaled for SwiGLU 32 | return dim, hidden_dim 33 | 34 | @pytest.fixture 35 | def input(self, input_params: Tuple[int, int]) -> torch.Tensor: 36 | dim, _ = input_params 37 | return torch.randn(1, dim) 38 | 39 | @pytest.fixture 40 | def ffn(self, input_params: Tuple[int, int]) -> FeedForward: 41 | dim, hidden_dim = input_params 42 | gate_proj = nn.Linear(dim, hidden_dim, bias=False) 43 | down_proj = nn.Linear(hidden_dim, dim, bias=False) 44 | up_proj = nn.Linear(dim, hidden_dim, bias=False) 45 | ff = FeedForward( 46 | gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj 47 | ).eval() 48 | fixed_init_model(ff) 49 | ff.eval() 50 | return ff 51 | 52 | def test_forward(self, input: torch.Tensor, ffn: FeedForward) -> None: 53 | with torch.no_grad(): 54 | x_out = ffn(input) 55 | assert_expected(x_out.mean(), torch.tensor(251.5356), atol=1e-7, rtol=1e-3) 56 | assert_expected(x_out.max(), torch.tensor(503.0614), atol=1e-7, rtol=1e-3) 57 | -------------------------------------------------------------------------------- /torchtune/torchtune/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torchtune.datasets import multimodal 8 | from torchtune.datasets._alpaca import alpaca_cleaned_dataset, alpaca_dataset 9 | from torchtune.datasets._chat import chat_dataset 10 | from torchtune.datasets._cnn_dailymail import cnn_dailymail_articles_dataset 11 | from torchtune.datasets._concat import ConcatDataset 12 | from torchtune.datasets._grammar import grammar_dataset 13 | from torchtune.datasets._hh_rlhf_helpful import hh_rlhf_helpful_dataset 14 | from torchtune.datasets._instruct import instruct_dataset 15 | from torchtune.datasets._packed import PackedDataset 16 | from torchtune.datasets._preference import preference_dataset, PreferenceDataset 17 | from torchtune.datasets._samsum import samsum_dataset 18 | from torchtune.datasets._sft import SFTDataset 19 | from torchtune.datasets._slimorca import slimorca_dataset 20 | from torchtune.datasets._stack_exchange_paired import stack_exchange_paired_dataset 21 | from torchtune.datasets._text_completion import ( 22 | text_completion_dataset, 23 | TextCompletionDataset, 24 | ) 25 | from torchtune.datasets._wikitext import wikitext_dataset 26 | 27 | __all__ = [ 28 | "alpaca_dataset", 29 | "alpaca_cleaned_dataset", 30 | "grammar_dataset", 31 | "samsum_dataset", 32 | "stack_exchange_paired_dataset", 33 | "slimorca_dataset", 34 | "instruct_dataset", 35 | "preference_dataset", 36 | "chat_dataset", 37 | "text_completion_dataset", 38 | "TextCompletionDataset", 39 | "cnn_dailymail_articles_dataset", 40 | "PackedDataset", 41 | "ConcatDataset", 42 | "wikitext_dataset", 43 | "PreferenceDataset", 44 | "SFTDataset", 45 | "hh_rlhf_helpful_dataset", 46 | "multimodal", 47 | ] 48 | -------------------------------------------------------------------------------- /torchtune/docs/source/basics/datasets_overview.rst: -------------------------------------------------------------------------------- 1 | .. _datasets_overview: 2 | 3 | ================= 4 | Datasets Overview 5 | ================= 6 | torchtune lets you fine-tune LLMs and VLMs using any dataset found on Hugging Face Hub, downloaded locally, 7 | or on a remote url. We provide built-in dataset builders to help you quickly bootstrap your fine-tuning project 8 | for workflows including instruct tuning, preference alignment, continued pretraining, and more. Beyond those, torchtune 9 | enables full customizability on your dataset pipeline, letting you train on any data format or schema. 10 | 11 | The following tasks are supported: 12 | 13 | - Text supervised fine-tuning 14 | - :ref:`instruct_dataset_usage_label` 15 | - :ref:`chat_dataset_usage_label` 16 | - Multimodal supervised fine-tuning 17 | - :ref:`multimodal_dataset_usage_label` 18 | - RLHF 19 | - :ref:`preference_dataset_usage_label` 20 | - Continued pre-training 21 | - :ref:`text_completion_dataset_usage_label` 22 | 23 | Data pipeline 24 | ------------- 25 | .. image:: /_static/img/torchtune_datasets.svg 26 | 27 | From raw data samples to the model inputs in the training recipe, all torchtune datasets follow 28 | the same pipeline: 29 | 30 | 1. Raw data is queried one sample at a time from a Hugging Face dataset, local file, or remote file 31 | 2. :ref:`message_transform_usage_label` convert the raw sample which can take any format into a list of torchtune 32 | :ref:`messages_usage_label`. Images are contained in the message object they are associated with. 33 | 3. :ref:`model_transform_usage_label` applies model-specific transforms to the messages, including tokenization (see :ref:`tokenizers_usage_label`), 34 | prompt templating (see :ref:`prompt_templates_usage_label`), image transforms, and anything else required for that particular model. 35 | 4. The collater packages the processed samples together in a batch and the batch is passed into the model during training. 36 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/modules/loss/test_ce_chunked_output_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | from tests.test_utils import assert_expected 10 | from torchtune.modules.loss import CEWithChunkedOutputLoss 11 | from torchtune.training.seed import set_seed 12 | 13 | 14 | @pytest.fixture(autouse=True) 15 | def random(): 16 | set_seed(42) 17 | 18 | 19 | class TestCEWithChunkedOutputLoss: 20 | def test_chunked_cross_entropy_loss(self): 21 | # Create a sample input and label 22 | ignore_index = -100 23 | batch_size = 3 24 | num_tokens = 50 25 | vocab_size = 50 26 | logits = torch.randn(batch_size, num_tokens, vocab_size, dtype=torch.bfloat16) 27 | labels = torch.randint( 28 | 0, vocab_size, (batch_size, num_tokens), dtype=torch.long 29 | ) 30 | 31 | # add random ignore index to random tokens in the label 32 | random_indices = torch.randint(0, num_tokens, (batch_size, num_tokens)) 33 | labels[random_indices < num_tokens // 5] = ignore_index 34 | 35 | # chunked CE 36 | ce_loss = CEWithChunkedOutputLoss( 37 | num_output_chunks=8, ignore_index=ignore_index 38 | ) 39 | logits_chunks = logits.tensor_split(ce_loss.num_output_chunks, dim=1) 40 | chunked_loss = ce_loss(logits_chunks, labels) 41 | 42 | # vanilla CE 43 | logits = logits.reshape(-1, logits.size(-1)) 44 | labels = labels.reshape(-1) 45 | standard_loss = torch.nn.functional.cross_entropy( 46 | logits.float(), labels, reduction="mean", ignore_index=ignore_index 47 | ) 48 | 49 | # Assert 50 | assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2) 51 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/clip/test_clip_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import pytest 7 | 8 | from tests.common import ASSETS 9 | from torchtune.models.clip._model_builders import clip_tokenizer 10 | 11 | 12 | class TestCLIPTokenizer: 13 | @pytest.fixture 14 | def tokenizer(self): 15 | return clip_tokenizer(ASSETS / "tiny_bpe_merges.txt") 16 | 17 | def test_encoding(self, tokenizer): 18 | texts = [ 19 | "a cow jumping over the moon", 20 | "a helpful AI assistant", 21 | ] 22 | correct_tokens = [ 23 | [ 24 | 2416, 25 | 320, 26 | 66, 27 | 78, 28 | 342, 29 | 73, 30 | 669, 31 | 79, 32 | 515, 33 | 326, 34 | 1190, 35 | 337, 36 | 673, 37 | 324, 38 | 76, 39 | 819, 40 | 333, 41 | 2417, 42 | ], 43 | [2416, 320, 516, 75, 79, 69, 84, 331, 64, 328, 813, 667, 540, 339, 2417], 44 | ] 45 | for text, correct in zip(texts, correct_tokens): 46 | tokens = tokenizer.encode(text) 47 | assert tokens == correct 48 | 49 | def test_decoding(self, tokenizer): 50 | text = "this is torchtune" 51 | decoded_text = "<|startoftext|>this is torchtune <|endoftext|>" 52 | assert decoded_text == tokenizer.decode(tokenizer.encode(text)) 53 | 54 | def test_call(self, tokenizer): 55 | sample = {"text": "hello world"} 56 | sample = tokenizer(sample) 57 | assert "text" not in sample 58 | assert "tokens" in sample 59 | -------------------------------------------------------------------------------- /torchtune/torchtune/modules/feed_forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from typing import Optional 9 | 10 | import torch 11 | from torch import nn 12 | 13 | 14 | class FeedForward(nn.Module): 15 | """This class implements the feed-forward network derived from Llama2. 16 | 17 | Args: 18 | gate_proj (nn.Module): Projection from input dim to hidden dim, fed through activation 19 | and multiplied by up_proj. 20 | down_proj (nn.Module): Final projection to output dim. 21 | up_proj (Optional[nn.Module]): Projection from input dim to hidden dim, multiplied by 22 | activation(gate_proj). 23 | activation (nn.Module): Activation function to use. Default is nn.SiLU(). 24 | """ 25 | 26 | def __init__( 27 | self, 28 | *, 29 | gate_proj: nn.Module, 30 | down_proj: nn.Module, 31 | up_proj: Optional[nn.Module] = None, 32 | activation: nn.Module = nn.SiLU(), 33 | ): 34 | super().__init__() 35 | self.w1 = gate_proj 36 | self.w2 = down_proj 37 | self.w3 = up_proj 38 | self.activation = activation 39 | 40 | def forward(self, x: torch.Tensor) -> torch.Tensor: 41 | """ 42 | Args: 43 | x (torch.Tensor): input tensor with shape ``(..., in_dim)``, where ``in_dim`` is the 44 | input dimension of both ``gate_proj`` and ``up_proj``. 45 | 46 | Returns: 47 | torch.Tensor: output tensor with shape ``(..., out_dim)``, where ``out_dim`` is the \ 48 | output dimension of ``down_proj``. 49 | """ 50 | h = self.activation(self.w1(x)) 51 | if self.w3 is not None: 52 | h = h * self.w3(x) 53 | h = self.w2(h) 54 | return h 55 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/models/mistral/test_mistral_prompt_template.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | from tests.test_utils import assert_dialogue_equal, MESSAGE_SAMPLE 9 | from torchtune.data import Message 10 | from torchtune.models.mistral import MistralChatTemplate 11 | 12 | 13 | class TestMistralChatTemplate: 14 | expected_dialogue = [ 15 | Message( 16 | role="user", 17 | content="[INST] Please briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old " 18 | "Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? " 19 | "How about on an icy road? Well one father in Russia did just that, and recorded " 20 | "the entire thing. To her credit, the child seemed to be doing a great job. " 21 | "(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n" 22 | "Summary: [/INST] ", 23 | ), 24 | Message( 25 | role="assistant", 26 | content="A father in Russia allowed his 8-year-old child to drive his car on an " 27 | "icy road and recorded the event. The child appeared to be handling the situation well, " 28 | "showcasing their driving skills despite the challenging conditions.", 29 | ), 30 | ] 31 | 32 | def test_format(self): 33 | no_system_sample = MESSAGE_SAMPLE[1:] 34 | actual = MistralChatTemplate()(no_system_sample) 35 | assert_dialogue_equal(actual, self.expected_dialogue) 36 | 37 | def test_format_with_system_prompt_raises(self): 38 | with pytest.raises( 39 | ValueError, match="System prompts are not supported in MistralChatTemplate" 40 | ): 41 | _ = MistralChatTemplate()(MESSAGE_SAMPLE) 42 | -------------------------------------------------------------------------------- /torchtune/torchtune/modules/rms_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | 10 | 11 | class RMSNorm(nn.Module): 12 | """ 13 | Root Mean Square Normalization in fp32. 14 | 15 | See: https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html 16 | 17 | Args: 18 | dim (int): embedding size 19 | eps (float): small value to avoid division by zero. Default: 1e-6 20 | """ 21 | 22 | def __init__(self, dim: int, eps: float = 1e-6) -> None: 23 | super().__init__() 24 | self.normalized_shape = (dim,) 25 | self.eps = eps 26 | self.scale = nn.Parameter(torch.ones(dim)) 27 | 28 | def forward(self, x: torch.Tensor) -> torch.Tensor: 29 | """ 30 | Args: 31 | x (torch.Tensor): input tensor to normalize 32 | 33 | Returns: 34 | torch.Tensor: The normalized and scaled tensor having the same shape as ``x``. 35 | """ 36 | # computation is in fp32 37 | x_fp32 = x.float() 38 | x_normed = ( 39 | x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps) 40 | ).type_as(x) 41 | return x_normed * self.scale 42 | 43 | 44 | def rms_norm(x: torch.Tensor, eps: float = 1e-6): 45 | """ 46 | This is just a functional RMSNorm without the trainable scale parameter. 47 | 48 | Args: 49 | x (torch.Tensor): input tensor to normalize 50 | eps (float): small value to avoid division by zero. Default: 1e-6 51 | 52 | Returns: 53 | torch.Tensor: The normalized tensor having the same shape as ``x``. 54 | 55 | """ 56 | x_fp32 = x.float() 57 | x_normed = ( 58 | x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + eps) 59 | ).type_as(x) 60 | return x_normed 61 | -------------------------------------------------------------------------------- /torchtune/tests/torchtune/datasets/multimodal/test_vqa_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | 9 | import torch 10 | from tests.common import ASSETS 11 | from tests.test_utils import DummyTokenizer 12 | from torchtune.datasets.multimodal import vqa_dataset 13 | 14 | 15 | class TestMultimodalInstructDataset: 16 | @pytest.fixture 17 | def tokenizer(self): 18 | return DummyTokenizer() 19 | 20 | def test_get_item(self, tokenizer): 21 | system_prompt = "follow this prompt" 22 | 23 | dataset = vqa_dataset( 24 | model_transform=tokenizer, 25 | source="json", 26 | data_files=str(ASSETS / "vqa_tiny.json"), 27 | split="train", 28 | new_system_prompt=system_prompt, 29 | ) 30 | 31 | expected_tokens = [ 32 | [0, 6, 4, 6, -2, 4, 2, 9, 2, 6, 7, 5, -1], 33 | ] 34 | 35 | expected_labels = [ 36 | [-100, -100, -100, -100, -100, -100, -100, -100, -100, 7, 5, -1, -100] 37 | ] 38 | 39 | assert len(dataset) == 1 40 | 41 | for i in range(len(dataset)): 42 | prompt, label, image = ( 43 | dataset[i]["tokens"], 44 | dataset[i]["labels"], 45 | dataset[i]["images"], 46 | ) 47 | assert prompt == expected_tokens[i] 48 | assert label == expected_labels[i] 49 | assert isinstance(image[0], torch.Tensor) 50 | 51 | def test_dataset_fails_with_packed(self, tokenizer): 52 | with pytest.raises( 53 | ValueError, match="Multimodal datasets don't support packing yet." 54 | ): 55 | vqa_dataset( 56 | model_transform=tokenizer, 57 | source="json", 58 | packed=True, 59 | ) 60 | -------------------------------------------------------------------------------- /torchtune/torchtune/models/llama4/_chunked_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | from torchtune.modules.attention_utils import ( 11 | _MaskType, 12 | _SUPPORTS_FLEX_ATTENTION, 13 | causal_mask_flex, 14 | ) 15 | 16 | if _SUPPORTS_FLEX_ATTENTION: 17 | from torch.nn.attention.flex_attention import BlockMask, create_block_mask 18 | 19 | 20 | def get_chunked_attention_mask( 21 | mask: Optional[_MaskType], 22 | chunk_size: int, 23 | bsz: int, 24 | seq_len: int, 25 | ) -> _MaskType: 26 | """ """ 27 | # TODO: check this somewhere that doesn't get called every forward 28 | if not _SUPPORTS_FLEX_ATTENTION: 29 | raise ValueError("Local attention is only supported with flex attention.") 30 | if mask is None: 31 | mask_mod = causal_mask_flex 32 | q_seq_len, kv_seq_len = seq_len, seq_len 33 | elif isinstance(mask, BlockMask): 34 | mask_mod = mask.mask_mod 35 | q_seq_len, kv_seq_len = mask.seq_lengths 36 | else: 37 | raise ValueError("Unsupported mask type") 38 | 39 | def chunked_attention_mask_mod( 40 | b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor 41 | ): 42 | # Get the chunk index of the query and key 43 | q_chunk = q_idx // chunk_size 44 | kv_chunk = kv_idx // chunk_size 45 | # Only allow attention within the same batch 46 | same_chunk = q_chunk == kv_chunk 47 | # Apply the original mask mod 48 | inner_mask = mask_mod(b, h, q_idx % chunk_size, kv_idx % chunk_size) 49 | return same_chunk & inner_mask 50 | 51 | return create_block_mask( 52 | chunked_attention_mask_mod, 53 | bsz, 54 | None, 55 | q_seq_len, 56 | kv_seq_len, 57 | ) 58 | --------------------------------------------------------------------------------