├── version.txt
├── torchtune
├── version.txt
├── docs
│ ├── source
│ │ ├── deep_dives
│ │ │ └── README.txt
│ │ ├── tutorials
│ │ │ └── README.txt
│ │ ├── _static
│ │ │ └── img
│ │ │ │ ├── qlora_exp.png
│ │ │ │ ├── kd-qwen2-res.png
│ │ │ │ ├── kd-simplified.png
│ │ │ │ ├── lora_diagram.png
│ │ │ │ ├── qat_diagram.png
│ │ │ │ ├── kd-hyperparam-lr.png
│ │ │ │ ├── pytorch-logo-dark.png
│ │ │ │ ├── pytorch-logo-flame.png
│ │ │ │ ├── generic-pytorch-logo.png
│ │ │ │ ├── kd-finetune-student.png
│ │ │ │ ├── kd-finetune-teacher.png
│ │ │ │ ├── torchtune_workspace.png
│ │ │ │ ├── comet_torchtune_project.png
│ │ │ │ ├── kd-hyperparam-kd-ratio.png
│ │ │ │ ├── lora_experiment_loss_curves.png
│ │ │ │ ├── card-background.svg
│ │ │ │ ├── pytorch-logo-flame.svg
│ │ │ │ └── pytorch-logo-dark.svg
│ │ ├── _templates
│ │ │ ├── autosummary
│ │ │ │ ├── function.rst
│ │ │ │ └── class.rst
│ │ │ └── layout.html
│ │ ├── api_ref_config.rst
│ │ ├── api_ref_generation.rst
│ │ ├── api_ref_utilities.rst
│ │ ├── api_ref_rlhf.rst
│ │ ├── api_ref_datasets.rst
│ │ ├── api_ref_data.rst
│ │ └── basics
│ │ │ └── datasets_overview.rst
│ ├── requirements.txt
│ ├── license_header.txt
│ └── Makefile
├── MANIFEST.in
├── tests
│ ├── assets
│ │ ├── generation_config.json
│ │ ├── m.model
│ │ ├── valid_dummy_config.yaml
│ │ ├── rgb_pytorch.png
│ │ ├── sentencepiece.model
│ │ ├── dog_on_skateboard.jpg
│ │ ├── vqa_tiny.json
│ │ ├── invalid_dummy_config.yaml
│ │ ├── hh_rlhf_tiny.json
│ │ ├── tokenizer_config.json
│ │ ├── instruct_tiny.json
│ │ ├── README.md
│ │ └── chat_tiny.json
│ ├── recipes
│ │ ├── __init__.py
│ │ ├── common.py
│ │ └── test_configs.py
│ ├── torchtune
│ │ ├── __init__.py
│ │ ├── _cli
│ │ │ ├── __init__.py
│ │ │ ├── test_tune.py
│ │ │ ├── test_ls.py
│ │ │ └── test_validate.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── clip
│ │ │ │ ├── __init__.py
│ │ │ │ ├── test_clip_text_encoder.py
│ │ │ │ └── test_clip_tokenizer.py
│ │ │ ├── flux
│ │ │ │ └── __init__.py
│ │ │ ├── phi3
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_phi3.py
│ │ │ ├── phi4
│ │ │ │ └── __init__.py
│ │ │ ├── qwen2
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_qwen2.py
│ │ │ ├── t5
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_t5_tokenizer.py
│ │ │ ├── qwen2_5
│ │ │ │ └── __init__.py
│ │ │ ├── llama2
│ │ │ │ ├── scripts
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── README.md
│ │ │ │ └── test_llama2_prompt_template.py
│ │ │ ├── mistral
│ │ │ │ ├── scripts
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── mistral_test_config.py
│ │ │ │ │ └── README.md
│ │ │ │ ├── test_mistral.py
│ │ │ │ ├── test_mistral_classifier.py
│ │ │ │ └── test_mistral_prompt_template.py
│ │ │ ├── llama3
│ │ │ │ └── test_llama3.py
│ │ │ └── llama4
│ │ │ │ └── test_llama4_transform.py
│ │ ├── modules
│ │ │ ├── __init__.py
│ │ │ ├── peft
│ │ │ │ └── __init__.py
│ │ │ ├── low_precision
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_nf4_dispatch_registration.py
│ │ │ ├── model_fusion
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_fusion_utils.py
│ │ │ ├── transforms
│ │ │ │ ├── test_pad_dim_to_size.py
│ │ │ │ └── tokenizers
│ │ │ │ │ └── test_utils.py
│ │ │ ├── test_layernorm.py
│ │ │ ├── test_feed_forward.py
│ │ │ └── loss
│ │ │ │ └── test_ce_chunked_output_loss.py
│ │ ├── rlhf
│ │ │ ├── __init__.py
│ │ │ └── loss
│ │ │ │ └── __init__.py
│ │ ├── utils
│ │ │ └── __init__.py
│ │ ├── datasets
│ │ │ ├── __init__.py
│ │ │ ├── multimodal
│ │ │ │ ├── test_multimodal_chat_dataset.py
│ │ │ │ └── test_vqa_dataset.py
│ │ │ └── test_wikitext_dataset.py
│ │ ├── generation
│ │ │ └── __init__.py
│ │ ├── dev
│ │ │ └── rl
│ │ │ │ └── rewards
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_rewards.py
│ │ └── config
│ │ │ └── test_validate.py
│ ├── common.py
│ ├── test_import_recipes.py
│ └── __init__.py
├── torchtune
│ ├── dev
│ │ ├── __init__.py
│ │ ├── grpo
│ │ │ └── __init__.py
│ │ ├── rl
│ │ │ ├── __init__.py
│ │ │ ├── workers
│ │ │ │ ├── weight_updaters
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── trainers
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── parameter_servers
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── datacollectors
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── __init__.py
│ │ │ │ └── metric_logger.py
│ │ │ ├── utils
│ │ │ │ ├── __init__.py
│ │ │ │ └── dist.py
│ │ │ └── datatypes
│ │ │ │ ├── vllm_completion_output.py
│ │ │ │ ├── __init__.py
│ │ │ │ └── trajectory.py
│ │ └── README.md
│ ├── _cli
│ │ ├── __init__.py
│ │ ├── subcommand.py
│ │ └── tune.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── flux
│ │ │ └── __init__.py
│ │ ├── llama3_3
│ │ │ ├── __init__.py
│ │ │ └── _model_builders.py
│ │ ├── t5
│ │ │ ├── __init__.py
│ │ │ └── _model_builders.py
│ │ ├── phi4
│ │ │ └── __init__.py
│ │ ├── llama3_2
│ │ │ └── __init__.py
│ │ ├── gemma
│ │ │ ├── __init__.py
│ │ │ └── rms_norm.py
│ │ ├── code_llama2
│ │ │ └── __init__.py
│ │ ├── clip
│ │ │ └── __init__.py
│ │ ├── llama3
│ │ │ ├── __init__.py
│ │ │ └── _model_utils.py
│ │ ├── phi3
│ │ │ └── __init__.py
│ │ ├── llama2
│ │ │ ├── _model_utils.py
│ │ │ └── __init__.py
│ │ ├── llama3_1
│ │ │ └── __init__.py
│ │ ├── gemma2
│ │ │ └── __init__.py
│ │ ├── qwen2
│ │ │ └── __init__.py
│ │ ├── mistral
│ │ │ └── __init__.py
│ │ ├── llama3_2_vision
│ │ │ └── __init__.py
│ │ ├── llama4
│ │ │ ├── __init__.py
│ │ │ └── _chunked_attention.py
│ │ └── qwen2_5
│ │ │ └── __init__.py
│ ├── modules
│ │ ├── transforms
│ │ │ ├── vision_utils
│ │ │ │ ├── __init__.py
│ │ │ │ └── pad_dim_to_size.py
│ │ │ ├── __init__.py
│ │ │ └── tokenizers
│ │ │ │ └── __init__.py
│ │ ├── low_precision
│ │ │ ├── __init__.py
│ │ │ └── _register_nf4_dispatch_ops.py
│ │ ├── _export
│ │ │ ├── install_requirements.sh
│ │ │ └── README.md
│ │ ├── moe
│ │ │ └── __init__.py
│ │ ├── model_fusion
│ │ │ └── __init__.py
│ │ ├── tanh_gate.py
│ │ ├── loss
│ │ │ └── __init__.py
│ │ ├── peft
│ │ │ └── __init__.py
│ │ ├── tokenizers
│ │ │ └── __init__.py
│ │ ├── layer_norm.py
│ │ ├── feed_forward.py
│ │ └── rms_norm.py
│ ├── data
│ │ ├── _common.py
│ │ ├── _torchdata.py
│ │ └── __init__.py
│ ├── rlhf
│ │ ├── loss
│ │ │ └── __init__.py
│ │ ├── utils
│ │ │ └── __init__.py
│ │ └── __init__.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── _errors.py
│ │ └── _validate.py
│ ├── generation
│ │ └── __init__.py
│ ├── datasets
│ │ ├── multimodal
│ │ │ └── __init__.py
│ │ └── __init__.py
│ ├── utils
│ │ ├── _import_guard.py
│ │ ├── __init__.py
│ │ └── _version.py
│ ├── training
│ │ ├── _model_util.py
│ │ ├── pooling.py
│ │ └── checkpointing
│ │ │ └── __init__.py
│ └── __init__.py
├── CITATION.cff
├── recipes
│ ├── dev
│ │ ├── gsm8k_sft.sbatch
│ │ └── multinode_grpo.sbatch
│ ├── configs
│ │ ├── quantization.yaml
│ │ ├── gemma
│ │ │ └── evaluation.yaml
│ │ ├── qwen3
│ │ │ └── evaluation.yaml
│ │ ├── qwen2_5
│ │ │ └── evaluation.yaml
│ │ ├── mistral
│ │ │ └── evaluation.yaml
│ │ ├── eleuther_evaluation.yaml
│ │ ├── phi3
│ │ │ └── evaluation.yaml
│ │ ├── llama3_2
│ │ │ └── evaluation.yaml
│ │ ├── code_llama2
│ │ │ └── evaluation.yaml
│ │ ├── qwen2
│ │ │ └── evaluation.yaml
│ │ ├── phi4
│ │ │ └── evaluation.yaml
│ │ ├── llama2
│ │ │ └── generation_v2.yaml
│ │ ├── generation.yaml
│ │ ├── llama3
│ │ │ └── 70B_generation_distributed.yaml
│ │ ├── llama3_3
│ │ │ └── 70B_generation_distributed.yaml
│ │ ├── llama3_1
│ │ │ └── 70B_generation_distributed.yaml
│ │ ├── llama3_2_vision
│ │ │ └── 11B_generation_v2.yaml
│ │ └── llama4
│ │ │ └── scout_17B_16E_generation_distributed.yaml
│ ├── __init__.py
│ └── full_finetune_multinode.slurm
└── LICENSE
├── MANIFEST.in
├── tests
└── models
│ └── qwen3
│ └── __init__.py
├── mirotrain
├── data
│ └── __init__.py
├── __init__.py
├── modules
│ ├── loss
│ │ └── __init__.py
│ ├── moe
│ │ ├── __init__.py
│ │ ├── grouped_gemm_util.py
│ │ └── expert_parallel.py
│ └── __init__.py
├── models
│ ├── qwen3_moe
│ │ └── __init__.py
│ ├── convert_weights.py
│ └── qwen3
│ │ ├── _checkpointing_utils.py
│ │ └── _validate.py
├── datasets
│ └── __init__.py
├── training
│ ├── checkpointing
│ │ ├── _utils.py
│ │ └── __init__.py
│ ├── __init__.py
│ └── lr_schedulers.py
└── monkey
│ ├── __init__.py
│ └── _errors.py
├── .flake8
└── .pre-commit-config.yaml
/version.txt:
--------------------------------------------------------------------------------
1 | 0.1.0
2 |
--------------------------------------------------------------------------------
/torchtune/version.txt:
--------------------------------------------------------------------------------
1 | 0.7.0
2 |
--------------------------------------------------------------------------------
/torchtune/docs/source/deep_dives/README.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/torchtune/docs/source/tutorials/README.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | prune tests # Remove all testing files from final dist/
2 |
--------------------------------------------------------------------------------
/torchtune/MANIFEST.in:
--------------------------------------------------------------------------------
1 | prune tests # Remove all testing files from final dist/
2 |
--------------------------------------------------------------------------------
/torchtune/tests/assets/generation_config.json:
--------------------------------------------------------------------------------
1 | {"bos_token_id": 0, "eos_token_id": -1}
2 |
--------------------------------------------------------------------------------
/torchtune/tests/assets/m.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/tests/assets/m.model
--------------------------------------------------------------------------------
/torchtune/tests/assets/valid_dummy_config.yaml:
--------------------------------------------------------------------------------
1 | test:
2 | _component_: torchtune.training.get_dtype
3 | dtype: fp32
4 |
--------------------------------------------------------------------------------
/tests/models/qwen3/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2025 MiromindAI
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
--------------------------------------------------------------------------------
/torchtune/tests/assets/rgb_pytorch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/tests/assets/rgb_pytorch.png
--------------------------------------------------------------------------------
/torchtune/tests/assets/sentencepiece.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/tests/assets/sentencepiece.model
--------------------------------------------------------------------------------
/torchtune/tests/assets/dog_on_skateboard.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/tests/assets/dog_on_skateboard.jpg
--------------------------------------------------------------------------------
/torchtune/docs/source/_static/img/qlora_exp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/qlora_exp.png
--------------------------------------------------------------------------------
/torchtune/docs/source/_static/img/kd-qwen2-res.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/kd-qwen2-res.png
--------------------------------------------------------------------------------
/torchtune/docs/source/_static/img/kd-simplified.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/kd-simplified.png
--------------------------------------------------------------------------------
/torchtune/docs/source/_static/img/lora_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/lora_diagram.png
--------------------------------------------------------------------------------
/torchtune/docs/source/_static/img/qat_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/qat_diagram.png
--------------------------------------------------------------------------------
/torchtune/docs/source/_static/img/kd-hyperparam-lr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/kd-hyperparam-lr.png
--------------------------------------------------------------------------------
/torchtune/docs/source/_static/img/pytorch-logo-dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/pytorch-logo-dark.png
--------------------------------------------------------------------------------
/torchtune/docs/source/_static/img/pytorch-logo-flame.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/pytorch-logo-flame.png
--------------------------------------------------------------------------------
/torchtune/docs/source/_static/img/generic-pytorch-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/generic-pytorch-logo.png
--------------------------------------------------------------------------------
/torchtune/docs/source/_static/img/kd-finetune-student.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/kd-finetune-student.png
--------------------------------------------------------------------------------
/torchtune/docs/source/_static/img/kd-finetune-teacher.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/kd-finetune-teacher.png
--------------------------------------------------------------------------------
/torchtune/docs/source/_static/img/torchtune_workspace.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/torchtune_workspace.png
--------------------------------------------------------------------------------
/torchtune/docs/source/_static/img/comet_torchtune_project.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/comet_torchtune_project.png
--------------------------------------------------------------------------------
/torchtune/docs/source/_static/img/kd-hyperparam-kd-ratio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/kd-hyperparam-kd-ratio.png
--------------------------------------------------------------------------------
/torchtune/docs/source/_static/img/lora_experiment_loss_curves.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiroMindAI/MiroTrain/HEAD/torchtune/docs/source/_static/img/lora_experiment_loss_curves.png
--------------------------------------------------------------------------------
/mirotrain/data/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2025 MiromindAI
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from ._messages import TracesToMessages
6 |
7 | __all__ = ["TracesToMessages"]
8 |
--------------------------------------------------------------------------------
/torchtune/tests/assets/vqa_tiny.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "input": "What is presented on image?",
4 | "output": "PyTorch logo.",
5 | "image": "tests/assets/rgb_pytorch.png"
6 | }
7 | ]
8 |
--------------------------------------------------------------------------------
/torchtune/docs/source/_templates/autosummary/function.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 | .. currentmodule:: {{ module }}
4 |
5 |
6 | {{ name | underline}}
7 |
8 | .. autofunction:: {{ name }}
9 |
--------------------------------------------------------------------------------
/torchtune/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx-gallery>0.11
2 | sphinx==5.0.0
3 | sphinx_design
4 | sphinx_copybutton
5 | sphinx-tabs
6 | matplotlib
7 | -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
8 |
--------------------------------------------------------------------------------
/torchtune/tests/assets/invalid_dummy_config.yaml:
--------------------------------------------------------------------------------
1 | test1:
2 | _component_: torchtune.training.get_dtype
3 | dtype: fp32
4 | dummy: 3
5 | test2:
6 | _component_: torchtune.training.get_dtype
7 | dtype: fp32
8 | dummy: 3
9 |
--------------------------------------------------------------------------------
/torchtune/docs/source/_templates/autosummary/class.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 | .. currentmodule:: {{ module }}
4 |
5 |
6 | {{ name | underline}}
7 |
8 | .. autoclass:: {{ name }}
9 | :members:
10 |
--------------------------------------------------------------------------------
/torchtune/docs/license_header.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) Meta Platforms, Inc. and affiliates.
2 | All rights reserved.
3 |
4 | This source code is licensed under the BSD-style license found in the
5 | LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/recipes/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/torchtune/dev/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/torchtune/_cli/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/torchtune/dev/grpo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/torchtune/dev/rl/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/_cli/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/rlhf/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/generation/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/clip/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/flux/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/phi3/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/phi4/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/qwen2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/t5/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/modules/peft/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/rlhf/loss/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/dev/rl/rewards/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/qwen2_5/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/llama2/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/mistral/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/modules/low_precision/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/modules/model_fusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/torchtune/modules/transforms/vision_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchtune/tests/assets/hh_rlhf_tiny.json:
--------------------------------------------------------------------------------
1 | [{"chosen":[{"content":"What do I do when I have a hole in my trousers?","role":"user"},{"content":"Fix the hole.","role":"assistant"}],"rejected":[{"content":"What do I do when I have a hole in my trousers?","role":"user"},{"content":"Take them off.","role":"assistant"}]}]
2 |
--------------------------------------------------------------------------------
/torchtune/docs/source/api_ref_config.rst:
--------------------------------------------------------------------------------
1 | .. _config:
2 |
3 | ================
4 | torchtune.config
5 | ================
6 |
7 | .. currentmodule:: torchtune.config
8 |
9 | .. autosummary::
10 | :toctree: generated/
11 | :nosignatures:
12 |
13 | instantiate
14 | parse
15 | validate
16 | log_config
17 |
--------------------------------------------------------------------------------
/torchtune/tests/assets/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {"bos_token": {"__type": "AddedToken", "content": "<|begin_of_sentence|>", "lstrip": false, "normalized": true, "rstrip": false, "single_word": false}, "eos_token": {"__type": "AddedToken", "content": "<|end_of_sentence|>", "lstrip": false, "normalized": true, "rstrip": false, "single_word": false}}
2 |
--------------------------------------------------------------------------------
/torchtune/tests/recipes/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from pathlib import Path
8 |
9 | RECIPE_TESTS_DIR = Path(__file__).parent
10 |
--------------------------------------------------------------------------------
/torchtune/torchtune/dev/rl/workers/weight_updaters/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .weight_updater import VLLMHFWeightUpdateReceiver # noqa: F401
8 |
--------------------------------------------------------------------------------
/torchtune/tests/assets/instruct_tiny.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "instruction": "What time is it in London?",
4 | "response": "It is 10:00 AM in London"
5 | },
6 | {
7 | "instruction": "Is is Istanbul or Constantinople?",
8 | "response": "Istanbul was Constantinople. Now it's Istanbul, not Constantinople."
9 | }
10 | ]
11 |
--------------------------------------------------------------------------------
/torchtune/torchtune/dev/README.md:
--------------------------------------------------------------------------------
1 | ## Torchtune dev components
2 |
3 | This directory houses experimental components.
4 | The purpose of `torchtune/dev` is to support bleeding-edge APIs
5 | with less stringent requirements on testing, documentation, or stability.
6 | The APIs in this folder are not public and are not guaranteed to adhere to backwards compatibility.
7 |
--------------------------------------------------------------------------------
/mirotrain/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2025 MiromindAI
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 |
6 | __version__ = "0.1.0"
7 |
8 | # Import main modules (lazy import to avoid dependency issues)
9 | __all__ = [
10 | "__version__",
11 | "data",
12 | "datasets",
13 | "models",
14 | "modules",
15 | "training",
16 | ]
17 |
--------------------------------------------------------------------------------
/torchtune/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | title: "torchtune: PyTorch's post-training library"
3 | message: "If you use this software, please cite it as below."
4 | type: software
5 | authors:
6 | - given-names: "torchtune maintainers and contributors"
7 | url: "https//github.com/pytorch/torchtune"
8 | license: "BSD-3-Clause"
9 | date-released: "2024-04-14"
10 |
--------------------------------------------------------------------------------
/torchtune/torchtune/dev/rl/workers/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .training import TrainingWorker
8 |
9 | __all__ = ["TrainingWorker"]
10 |
--------------------------------------------------------------------------------
/torchtune/tests/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | from pathlib import Path
7 |
8 | TUNE_PATH = "torchtune/_cli/tune.py"
9 |
10 | ASSETS = Path(__file__).parent / "assets"
11 |
--------------------------------------------------------------------------------
/torchtune/torchtune/dev/rl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .dist import stateless_init_process_group
8 |
9 | __all__ = ["stateless_init_process_group"]
10 |
--------------------------------------------------------------------------------
/torchtune/torchtune/dev/rl/workers/parameter_servers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .vllm import VLLMParameterServer
8 |
9 | __all__ = ["VLLMParameterServer"]
10 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/flux/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | from ._model_builders import flux_1_autoencoder
7 |
8 | __all__ = [
9 | "flux_1_autoencoder",
10 | ]
11 |
--------------------------------------------------------------------------------
/mirotrain/modules/loss/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2025 MiromindAI
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from .cross_entropy_loss import LigerLinearCrossEntropyLoss, LinearSqrtCrossEntropyLoss
6 | from .dpo_loss import DPOLoss
7 |
8 | __all__ = [
9 | "LigerLinearCrossEntropyLoss",
10 | "LinearSqrtCrossEntropyLoss",
11 | "DPOLoss",
12 | ]
13 |
--------------------------------------------------------------------------------
/torchtune/torchtune/modules/low_precision/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .nf4_linear import FrozenNF4Linear
8 |
9 | __all__ = [
10 | "FrozenNF4Linear",
11 | ]
12 |
--------------------------------------------------------------------------------
/torchtune/torchtune/dev/rl/workers/datacollectors/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .sync import SyncLLMCollector, VLLMWorkerWrapper
8 |
9 | __all__ = ["SyncLLMCollector", "VLLMWorkerWrapper"]
10 |
--------------------------------------------------------------------------------
/mirotrain/models/qwen3_moe/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2025 MiromindAI
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from ._convert_weights import qwen3_moe_hf_to_tune, qwen3_moe_tune_to_hf
6 | from ._model_builders import qwen3_235b_a22b, qwen3_30b_a3b
7 |
8 | __all__ = [
9 | "qwen3_moe_hf_to_tune",
10 | "qwen3_moe_tune_to_hf",
11 | "qwen3_30b_a3b",
12 | "qwen3_235b_a22b",
13 | ]
14 |
--------------------------------------------------------------------------------
/torchtune/torchtune/data/_common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | from typing import Dict, List, Union
7 |
8 | import torch
9 |
10 | CROSS_ENTROPY_IGNORE_IDX = -100
11 | PACK_TYPE = Dict[str, Union[torch.Tensor, List[int]]]
12 |
--------------------------------------------------------------------------------
/torchtune/docs/source/api_ref_generation.rst:
--------------------------------------------------------------------------------
1 | .. _generation:
2 |
3 | ====================
4 | torchtune.generation
5 | ====================
6 |
7 | .. currentmodule:: torchtune.generation
8 |
9 | .. autosummary::
10 | :toctree: generated/
11 | :nosignatures:
12 |
13 | generate
14 | generate_next_token
15 | sample
16 | get_causal_mask_from_padding_mask
17 | get_position_ids_from_padding_mask
18 |
--------------------------------------------------------------------------------
/torchtune/docs/source/api_ref_utilities.rst:
--------------------------------------------------------------------------------
1 | ===============
2 | torchtune.utils
3 | ===============
4 |
5 | .. currentmodule:: torchtune.utils
6 |
7 |
8 | .. _gen_label:
9 |
10 | Miscellaneous
11 | -------------
12 |
13 | .. autosummary::
14 | :toctree: generated/
15 | :nosignatures:
16 |
17 | batch_to_device
18 | get_device
19 | get_logger
20 | torch_version_ge
21 | get_world_size_and_rank
22 |
--------------------------------------------------------------------------------
/torchtune/torchtune/dev/rl/datatypes/vllm_completion_output.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import vllm
8 | from tensordict import from_dataclass
9 |
10 | VllmCompletionOutput = from_dataclass(vllm.outputs.CompletionOutput)
11 |
--------------------------------------------------------------------------------
/torchtune/torchtune/rlhf/loss/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | from .dpo import DPOLoss, RSOLoss
9 | from .ppo import PPOLoss
10 |
11 | __all__ = [
12 | "DPOLoss",
13 | "RSOLoss",
14 | "PPOLoss",
15 | ]
16 |
--------------------------------------------------------------------------------
/torchtune/torchtune/modules/_export/install_requirements.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # Install pytorch
9 | pip install torch==2.6.0 torchvision==0.21.0 --extra-index-url https://download.pytorch.org/whl/nightly/cpu
10 |
--------------------------------------------------------------------------------
/torchtune/torchtune/rlhf/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._convert_weights import reward_hf_to_tune, reward_tune_to_hf # noqa
8 |
9 | __all__ = [
10 | "reward_hf_to_tune",
11 | "reward_tune_to_hf",
12 | ]
13 |
--------------------------------------------------------------------------------
/torchtune/docs/source/_static/img/card-background.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/torchtune/docs/source/api_ref_rlhf.rst:
--------------------------------------------------------------------------------
1 | ===============
2 | torchtune.rlhf
3 | ===============
4 |
5 | .. currentmodule:: torchtune.rlhf
6 |
7 | Components and losses for RLHF algorithms like PPO and DPO.
8 |
9 | .. autosummary::
10 | :toctree: generated/
11 | :nosignatures:
12 |
13 | estimate_advantages
14 | get_rewards_ppo
15 | truncate_sequence_at_first_stop_token
16 | loss.PPOLoss
17 | loss.DPOLoss
18 | loss.RSOLoss
19 |
--------------------------------------------------------------------------------
/torchtune/torchtune/modules/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from torchtune.modules.transforms._transforms import Transform, VisionCrossAttentionMask
8 |
9 |
10 | __all__ = [
11 | "Transform",
12 | "VisionCrossAttentionMask",
13 | ]
14 |
--------------------------------------------------------------------------------
/mirotrain/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2025 MiromindAI
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from ._chat import odr_chat_dataset
6 | from ._collate import padded_collate_dpo, padded_collate_packed
7 | from ._packed import StatefulDistributedStreamingPackedDataset
8 |
9 | __all__ = [
10 | "odr_chat_dataset",
11 | "padded_collate_packed",
12 | "StatefulDistributedStreamingPackedDataset",
13 | "padded_collate_dpo",
14 | ]
15 |
--------------------------------------------------------------------------------
/torchtune/tests/assets/README.md:
--------------------------------------------------------------------------------
1 | # Details on the assets in this folder
2 |
3 | ## `m.model`
4 |
5 | **Description**:
6 | **Creation**:
7 | **Usage**:
8 |
9 |
10 | ## `tiny_fair_checkpoint.pt`
11 |
12 | **Description**:
13 | **Creation**:
14 | **Usage**:
15 |
16 | ## `tiny_llama2_checkpoint.pt`
17 |
18 | **Description**:
19 | **Creation**:
20 | **Usage**:
21 |
22 | ## `tiny_state_dict_with_one_key.pt`
23 |
24 | **Description**:
25 | **Creation**:
26 | **Usage**:
27 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/llama3_3/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._model_builders import llama3_3_70b, lora_llama3_3_70b, qlora_llama3_3_70b # noqa
8 |
9 | __all__ = [
10 | "llama3_3_70b",
11 | "lora_llama3_3_70b",
12 | "qlora_llama3_3_70b",
13 | ]
14 |
--------------------------------------------------------------------------------
/torchtune/tests/test_import_recipes.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pytest
8 |
9 |
10 | def test_import_recipes():
11 | with pytest.raises(
12 | ModuleNotFoundError, match="The torchtune recipes directory isn't a package"
13 | ):
14 | import recipes # noqa
15 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/t5/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._component_builders import t5_encoder
8 | from ._model_builders import t5_tokenizer, t5_v1_1_xxl_encoder
9 |
10 | __all__ = [
11 | "t5_encoder",
12 | "t5_tokenizer",
13 | "t5_v1_1_xxl_encoder",
14 | ]
15 |
--------------------------------------------------------------------------------
/mirotrain/training/checkpointing/_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2025 MiromindAI
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from enum import Enum
6 |
7 | from torchtune.training.checkpointing._utils import ModelType as original_ModelType
8 |
9 | STEP_KEY = "global_step"
10 |
11 | ModelType = Enum(
12 | "ModelType",
13 | {
14 | **{member.name: member.value for member in original_ModelType},
15 | "QWEN3": "qwen3",
16 | "QWEN3_MoE": "qwen3_moe",
17 | },
18 | )
19 |
--------------------------------------------------------------------------------
/torchtune/torchtune/modules/moe/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .experts import GroupedExperts, LoRAGroupedExperts
8 | from .moe import MoE, TokenChoiceTopKRouter
9 |
10 | __all__ = [
11 | "MoE",
12 | "GroupedExperts",
13 | "LoRAGroupedExperts",
14 | "TokenChoiceTopKRouter",
15 | ]
16 |
--------------------------------------------------------------------------------
/torchtune/torchtune/_cli/subcommand.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | class Subcommand:
9 | def __init__(self, *args, **kwargs):
10 | pass
11 |
12 | @classmethod
13 | def create(cls, *args, **kwargs):
14 | return cls(*args, **kwargs)
15 |
16 | def _add_arguments(self):
17 | pass
18 |
--------------------------------------------------------------------------------
/torchtune/torchtune/config/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._instantiate import instantiate
8 | from ._parse import parse
9 | from ._utils import log_config
10 | from ._validate import validate
11 |
12 | __all__ = [
13 | "instantiate",
14 | "parse",
15 | "log_config",
16 | "validate",
17 | ]
18 |
--------------------------------------------------------------------------------
/torchtune/torchtune/dev/rl/datatypes/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .request_output import RequestOutput
8 | from .trajectory import Trajectory
9 | from .vllm_completion_output import VllmCompletionOutput
10 |
11 | __all__ = [
12 | "RequestOutput",
13 | "Trajectory",
14 | "VllmCompletionOutput",
15 | ]
16 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/phi4/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._model_builders import ( # noqa
8 | lora_phi4_14b,
9 | phi4_14b,
10 | phi4_tokenizer,
11 | qlora_phi4_14b,
12 | )
13 |
14 | __all__ = [
15 | "phi4_14b",
16 | "phi4_tokenizer",
17 | "lora_phi4_14b",
18 | "qlora_phi4_14b",
19 | ]
20 |
--------------------------------------------------------------------------------
/mirotrain/training/checkpointing/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2025 MiromindAI
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from ._checkpoint_client import (
6 | SDCheckpointClient,
7 | StepCheckpointClient,
8 | StepTrainingProgress,
9 | )
10 | from ._checkpointer import FullModelHFCheckpointer
11 | from ._utils import STEP_KEY
12 |
13 | __all__ = [
14 | "SDCheckpointClient",
15 | "FullModelHFCheckpointer",
16 | "STEP_KEY",
17 | "StepCheckpointClient",
18 | "StepTrainingProgress",
19 | ]
20 |
--------------------------------------------------------------------------------
/mirotrain/modules/moe/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2025 MiromindAI
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from .dropless_layer import DroplessMoELayer
6 | from .expert_parallel import (
7 | get_expert_parallel_group,
8 | get_expert_parallel_rank,
9 | get_expert_parallel_world_size,
10 | set_expert_parallel_group,
11 | )
12 |
13 | __all__ = [
14 | "DroplessMoELayer",
15 | "get_expert_parallel_group",
16 | "get_expert_parallel_rank",
17 | "get_expert_parallel_world_size",
18 | "set_expert_parallel_group",
19 | ]
20 |
--------------------------------------------------------------------------------
/mirotrain/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2025 MiromindAI
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from .attention import MultiHeadAttentionWithUlysses
6 | from .transformer import (
7 | MoETransformerDecoder,
8 | MoETransformerSelfAttentionLayer,
9 | SDTransformerDecoder,
10 | SDTransformerSelfAttentionLayer,
11 | )
12 |
13 | __all__ = [
14 | "MultiHeadAttentionWithUlysses",
15 | "MoETransformerDecoder",
16 | "MoETransformerSelfAttentionLayer",
17 | "SDTransformerDecoder",
18 | "SDTransformerSelfAttentionLayer",
19 | ]
20 |
--------------------------------------------------------------------------------
/torchtune/torchtune/generation/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._generation import (
8 | generate,
9 | generate_next_token,
10 | get_causal_mask_from_padding_mask,
11 | get_position_ids_from_padding_mask,
12 | sample,
13 | )
14 |
15 | __all__ = [
16 | "generate",
17 | "generate_next_token",
18 | "get_causal_mask_from_padding_mask",
19 | "get_position_ids_from_padding_mask",
20 | "sample",
21 | ]
22 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/llama2/scripts/README.md:
--------------------------------------------------------------------------------
1 | # Verifying Correctness against Reference Implementations
2 |
3 | This repository puts a high bar on correctness and testing. To make sure our model and
4 | module implementations are correct, we compare our implementation against reference implementations
5 | where possible. This folder contains scripts used for these comparisons.
6 |
7 |
8 | ## Running the scripts
9 |
10 | You can run the scripts using the following command as an example.
11 | Each script should print out the value being used in the associated unit tests.
12 |
13 | ```
14 | python3 -m tests.llm.llama2.scripts.compare_attention
15 | ```
16 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/mistral/scripts/mistral_test_config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from dataclasses import dataclass
8 |
9 |
10 | @dataclass
11 | class MistralTestConfig:
12 | BSZ = 2
13 | SEQ_LEN = 128
14 | EMBED_DIM = 64
15 | VOCAB_SIZE = 512
16 | NUM_LAYERS = 4
17 | NUM_HEADS = 4
18 | NUM_KV_HEADS = 2
19 | INTERMEDIATE_DIM = 512
20 | MAX_SEQ_LEN = 256
21 | ROPE_BASE = 10000
22 | NORM_EPS = 1e-5
23 | SEED = 16
24 |
--------------------------------------------------------------------------------
/torchtune/torchtune/datasets/multimodal/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._llava_instruct import llava_instruct_dataset
8 | from ._multimodal import multimodal_chat_dataset
9 | from ._the_cauldron import the_cauldron_dataset, the_cauldron_transform
10 | from ._vqa import vqa_dataset
11 |
12 | __all__ = [
13 | "the_cauldron_dataset",
14 | "the_cauldron_transform",
15 | "llava_instruct_dataset",
16 | "multimodal_chat_dataset",
17 | "vqa_dataset",
18 | ]
19 |
--------------------------------------------------------------------------------
/torchtune/recipes/dev/gsm8k_sft.sbatch:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --time=01:00:00
3 | #SBATCH --constraint=volta32gb
4 | #SBATCH --nodes=1
5 | #SBATCH --ntasks-per-node=1
6 | #SBATCH --gpus-per-node=8
7 | #SBATCH --no-requeue
8 | #SBATCH --exclusive
9 |
10 | #SBATCH --job-name=torchtune
11 | #SBATCH --output=slurm_logs/%j.out
12 | #SBATCH --error=slurm_logs/%j.err
13 |
14 | # /\ Customize SBATCH directives to custommize your hardware
15 |
16 | # \/ Customize the virtual env/module load - this assumes a virtual env in root of torchtune
17 | source ../../.venv/bin/activate
18 |
19 | srun tune run \
20 | --nnodes 1 \
21 | --nproc_per_node 8 \
22 | full_finetune_distributed --config dev/3B_sft_for_grpo "$@"
23 |
--------------------------------------------------------------------------------
/torchtune/torchtune/modules/model_fusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._deep_fusion import DeepFusionModel
8 | from ._early_fusion import EarlyFusionModel
9 | from ._fusion_layers import FusionEmbedding, FusionLayer
10 | from ._fusion_utils import get_fusion_params, register_fusion_module
11 |
12 | __all__ = [
13 | "DeepFusionModel",
14 | "FusionLayer",
15 | "FusionEmbedding",
16 | "register_fusion_module",
17 | "get_fusion_params",
18 | "EarlyFusionModel",
19 | ]
20 |
--------------------------------------------------------------------------------
/torchtune/tests/assets/chat_tiny.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "conversations": [
4 | {
5 | "from": "system",
6 | "value": "You are an AI assistant."
7 | },
8 | {
9 | "from": "human",
10 | "value": "What is the meaning of life?"
11 | },
12 | {
13 | "from": "gpt",
14 | "value": "The meaning of life is 42."
15 | },
16 | {
17 | "from": "human",
18 | "value": "That's ridiculous."
19 | },
20 | {
21 | "from": "gpt",
22 | "value": "I agree."
23 | }
24 | ]
25 | }
26 | ]
27 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/llama3_2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._component_builders import llama3_2, lora_llama3_2
8 |
9 | from ._model_builders import ( # noqa
10 | llama3_2_1b,
11 | llama3_2_3b,
12 | lora_llama3_2_1b,
13 | lora_llama3_2_3b,
14 | qlora_llama3_2_1b,
15 | qlora_llama3_2_3b,
16 | )
17 |
18 | __all__ = [
19 | "llama3_2",
20 | "llama3_2_1b",
21 | "llama3_2_3b",
22 | "lora_llama3_2",
23 | "lora_llama3_2_1b",
24 | "lora_llama3_2_3b",
25 | "qlora_llama3_2_1b",
26 | "qlora_llama3_2_3b",
27 | ]
28 |
--------------------------------------------------------------------------------
/torchtune/torchtune/dev/rl/workers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .datacollectors import SyncLLMCollector
8 | from .metric_logger import MetricLoggerWorker
9 | from .parameter_servers import VLLMParameterServer
10 | from .postprocessing import PostProcessingWorker
11 | from .trainers import TrainingWorker
12 | from .weight_updaters import VLLMHFWeightUpdateReceiver
13 |
14 | __all__ = [
15 | "SyncLLMCollector",
16 | "MetricLoggerWorker",
17 | "VLLMParameterServer",
18 | "PostProcessingWorker",
19 | "TrainingWorker",
20 | "VLLMHFWeightUpdateReceiver",
21 | ]
22 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/_cli/test_tune.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import runpy
9 | import sys
10 |
11 | from tests.common import TUNE_PATH
12 |
13 |
14 | class TestTuneCLI:
15 | def test_tune_without_args_returns_help(self, capsys, monkeypatch):
16 | testargs = ["tune"]
17 |
18 | monkeypatch.setattr(sys, "argv", testargs)
19 | runpy.run_path(TUNE_PATH, run_name="__main__")
20 |
21 | captured = capsys.readouterr()
22 | output = captured.out.rstrip("\n")
23 |
24 | assert "Welcome to the torchtune CLI!" in output
25 |
--------------------------------------------------------------------------------
/torchtune/torchtune/utils/_import_guard.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import importlib
8 |
9 | import torch
10 |
11 | # We can only use flex attention / BlockMask if torch version >= 2.5.0 and GPU is Turing / SM75 and above
12 | _SUPPORTS_FLEX_ATTENTION = (
13 | torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 5)
14 | )
15 |
16 | _TORCHDATA_MIN_VERSION = "0.10.0"
17 | if (
18 | importlib.util.find_spec("torchdata") is not None
19 | and importlib.util.find_spec("torchdata.nodes") is not None
20 | ):
21 | _TORCHDATA_INSTALLED = True
22 | else:
23 | _TORCHDATA_INSTALLED = False
24 |
--------------------------------------------------------------------------------
/torchtune/torchtune/modules/_export/README.md:
--------------------------------------------------------------------------------
1 | # Export
2 |
3 | This directory provides [exportable](https://pytorch.org/docs/stable/export.html) variants of torchtune modules.
4 |
5 | Modules in this directory:
6 |
7 | * Take the same arguments to `__init__()` and `forward()` as the corresponding reference modules in torchtune.
8 | * Give the output as the reference module in torchtune (unless stated otherwise in the docstring).
9 | * Are guaranteed to work out of the box with torch.export.export().
10 | * Should work out of the box with torch.aot_compile().
11 |
12 | All modules should be covered by unit tests (under `tests/torchtune/modules/_export/`) that runs daily and on PRs touching this directory.
13 |
14 | These modules are subject to change so proceed with caution.
15 |
16 | Contributors: @larryliu0820, @Jack-Khuu, @dvorjackz
17 |
--------------------------------------------------------------------------------
/torchtune/torchtune/modules/low_precision/_register_nf4_dispatch_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torchao.dtypes.nf4tensor import implements as nf4_tensor_impl, to_nf4
9 |
10 |
11 | @nf4_tensor_impl([torch.ops.aten.clone.default])
12 | def clone(func, *args, **kwargs):
13 | """
14 | __torch_dispatch__ override that is called when cloning an NF4Tensor.
15 | This is implemented by creating a new NF4Tensor with the unquantized weight
16 | of the input tensor. Note that this is not an exact "clone" due to the loss
17 | in precision.
18 | """
19 | return to_nf4(args[0][0].get_original_weight())
20 |
--------------------------------------------------------------------------------
/torchtune/torchtune/training/_model_util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import warnings
7 |
8 | import torch
9 |
10 |
11 | def disable_dropout(model: torch.nn.Module) -> None:
12 | """
13 | Disables dropout layers in the given model.
14 |
15 | Args:
16 | model (torch.nn.Module): The model in which dropout layers should be disabled.
17 | """
18 | for module in model.modules():
19 | if isinstance(module, torch.nn.Dropout) and module.p != 0:
20 | warnings.warn(
21 | f"Found Dropout with value {module.p} in module {module}. Setting to zero."
22 | )
23 | module.p = 0
24 |
--------------------------------------------------------------------------------
/torchtune/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # Check at the top-level that torchao is installed.
8 | # This is better than doing it at every import site.
9 | # We have to do this because it is not currently possible to
10 | # properly support both nightly and stable installs of PyTorch + torchao
11 | # in pyproject.toml.
12 | try:
13 | import torchao # noqa
14 | except ImportError as e:
15 | raise ImportError(
16 | """
17 | torchao not installed.
18 | Please follow the instructions at https://pytorch.org/torchtune/main/install.html#pre-requisites
19 | to install torchao.
20 | """
21 | ) from e
22 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/gemma/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._component_builders import gemma, lora_gemma # noqa
8 | from ._model_builders import ( # noqa
9 | gemma_2b,
10 | gemma_7b,
11 | gemma_tokenizer,
12 | lora_gemma_2b,
13 | lora_gemma_7b,
14 | qlora_gemma_2b,
15 | qlora_gemma_7b,
16 | )
17 | from ._tokenizer import GemmaTokenizer # noqa
18 |
19 | __all__ = [
20 | "GemmaTokenizer",
21 | "gemma",
22 | "gemma_2b",
23 | "gemma_7b",
24 | "gemma_tokenizer",
25 | "lora_gemma",
26 | "lora_gemma_2b",
27 | "lora_gemma_7b",
28 | "qlora_gemma_2b",
29 | "qlora_gemma_7b",
30 | ]
31 |
--------------------------------------------------------------------------------
/torchtune/torchtune/dev/rl/datatypes/trajectory.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import List
8 |
9 | import torch
10 | from tensordict import TensorClass
11 | from torchtune.dev.rl.rewards import RewardOutput
12 |
13 |
14 | class Trajectory(TensorClass["nocast"]):
15 | query_responses: torch.Tensor
16 | responses: torch.Tensor
17 | logprobs: torch.Tensor
18 | ref_logprobs: torch.Tensor
19 | query_response_padding_masks: torch.Tensor
20 | seq_lens: torch.Tensor
21 | answers: torch.Tensor
22 | policy_version: int
23 | advantages: torch.Tensor
24 | reward_outputs: List[RewardOutput]
25 | sequence_ids: List[str]
26 |
--------------------------------------------------------------------------------
/mirotrain/modules/moe/grouped_gemm_util.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2025 MiromindAI
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | # Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/grouped_gemm_util.py
6 |
7 | try:
8 | import grouped_gemm
9 | except ImportError:
10 | grouped_gemm = None
11 |
12 |
13 | def grouped_gemm_is_available():
14 | """Check if grouped_gemm is available."""
15 | return grouped_gemm is not None
16 |
17 |
18 | def assert_grouped_gemm_is_available():
19 | """Assert that grouped_gemm is available."""
20 | assert grouped_gemm_is_available(), (
21 | "Grouped GEMM is not available. Please run "
22 | "`pip install git+https://github.com/fanshiqing/grouped_gemm@v1.1.4`."
23 | )
24 |
25 |
26 | ops = grouped_gemm.ops if grouped_gemm_is_available() else None
27 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/modules/transforms/test_pad_dim_to_size.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pytest
8 |
9 | import torch
10 |
11 | from torchtune.modules.transforms.vision_utils.pad_dim_to_size import pad_dim_to_size
12 |
13 |
14 | def test_pad_dim_to_size():
15 | image = torch.ones(2, 2, 2, 2, dtype=torch.float16)
16 | image = pad_dim_to_size(image, 4, 1)
17 | assert image.shape == (2, 4, 2, 2)
18 | assert image.mean() == 0.5, "Expected mean to be 0.5 after padding"
19 | assert image.dtype == torch.float16, "Expected dtype to be float16 after padding"
20 |
21 | with pytest.raises(Exception):
22 | pad_dim_to_size(image, 2, 1)
23 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/mistral/scripts/README.md:
--------------------------------------------------------------------------------
1 | ## Verifying correctness
2 | This directory compares the current implementation of `mistral` to the reference implementation at https://github.com/mistralai/mistral-src/blob/main/one_file_ref.py. Additionally, `torchtune.models.mistral._component_builders.mistral_mlp` is compared in `tests.torchtune.models.mistral.scripts.compare_feed_forward.py`
3 |
4 | Since `torchtune.models.mistral` shares nearly all components with `torchtune.models.llama2`, please see `tests.torchtune.models.llama2.scripts` for comparison scripts for individual components.
5 |
6 | ## Running the scripts
7 |
8 | You can run the scripts using the following command as an example.
9 | Each script should print out the value being used in the associated unit tests.
10 |
11 | ```
12 | python3 -m tests.torchtune.models.mistral.scripts.compare_mistral
13 | ```
14 |
--------------------------------------------------------------------------------
/mirotrain/training/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2025 MiromindAI
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from ._distributed import gather_cpu_state_dict, ParallelDims, shard_model
6 | from ._grad_scaler import scale_grads_
7 | from .checkpointing import (
8 | FullModelHFCheckpointer,
9 | SDCheckpointClient,
10 | STEP_KEY,
11 | StepCheckpointClient,
12 | StepTrainingProgress,
13 | )
14 | from .clip_grad import clip_grad_norm_
15 | from .lr_schedulers import get_cosine_schedule_with_warmup
16 |
17 | __all__ = [
18 | "FullModelHFCheckpointer",
19 | "ParallelDims",
20 | "gather_cpu_state_dict",
21 | "shard_model",
22 | "scale_grads_",
23 | "clip_grad_norm_",
24 | "STEP_KEY",
25 | "StepCheckpointClient",
26 | "StepTrainingProgress",
27 | "get_cosine_schedule_with_warmup",
28 | "SDCheckpointClient",
29 | ]
30 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/code_llama2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._model_builders import ( # noqa
8 | code_llama2_13b,
9 | code_llama2_70b,
10 | code_llama2_7b,
11 | lora_code_llama2_13b,
12 | lora_code_llama2_70b,
13 | lora_code_llama2_7b,
14 | qlora_code_llama2_13b,
15 | qlora_code_llama2_70b,
16 | qlora_code_llama2_7b,
17 | )
18 |
19 | __all__ = [
20 | "code_llama2_13b",
21 | "code_llama2_70b",
22 | "code_llama2_7b",
23 | "lora_code_llama2_13b",
24 | "lora_code_llama2_70b",
25 | "lora_code_llama2_7b",
26 | "qlora_code_llama2_13b",
27 | "qlora_code_llama2_70b",
28 | "qlora_code_llama2_7b",
29 | ]
30 |
--------------------------------------------------------------------------------
/torchtune/torchtune/modules/tanh_gate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 |
9 | from torch import nn
10 |
11 |
12 | class TanhGate(nn.Module):
13 | """Implements a basic learnable gate to scale layer outputs"""
14 |
15 | def __init__(self) -> None:
16 | super().__init__()
17 | self.scale = nn.Parameter(torch.zeros(1))
18 |
19 | def forward(self, x: torch.Tensor) -> torch.Tensor:
20 | """
21 | Args:
22 | x (torch.Tensor): input tensor to gate
23 |
24 | Returns:
25 | torch.Tensor: The output tensor after gating. Has the same shape as ``x``.
26 | """
27 | return x * self.scale.tanh()
28 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/datasets/multimodal/test_multimodal_chat_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pytest
8 | from tests.test_utils import DummyTokenizer
9 |
10 | from torchtune.datasets.multimodal import multimodal_chat_dataset
11 |
12 |
13 | class TestMultimodalChatDataset:
14 | @pytest.fixture
15 | def tokenizer(self):
16 | return DummyTokenizer()
17 |
18 | def test_dataset_fails_with_packed(self, tokenizer):
19 | with pytest.raises(
20 | ValueError, match="Multimodal datasets don't support packing yet."
21 | ):
22 | multimodal_chat_dataset(
23 | model_transform=tokenizer, source="json", packed=True
24 | )
25 |
--------------------------------------------------------------------------------
/torchtune/torchtune/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._device import (
8 | batch_to_device,
9 | DeviceSupport,
10 | get_device,
11 | get_device_support,
12 | get_torch_device_namespace,
13 | get_world_size_and_rank,
14 | )
15 | from ._logging import deprecated, get_logger, log_once, log_rank_zero
16 |
17 | from ._version import torch_version_ge
18 |
19 | __all__ = [
20 | "get_world_size_and_rank",
21 | "batch_to_device",
22 | "get_device",
23 | "get_logger",
24 | "torch_version_ge",
25 | "get_device_support",
26 | "get_torch_device_namespace",
27 | "DeviceSupport",
28 | "log_rank_zero",
29 | "deprecated",
30 | "log_once",
31 | ]
32 |
--------------------------------------------------------------------------------
/torchtune/recipes/configs/quantization.yaml:
--------------------------------------------------------------------------------
1 | # Config for QuantizationRecipe in quantize.py
2 | #
3 | # To launch, run the following command from root torchtune directory:
4 | # tune run quantize --config quantization
5 |
6 | output_dir: /tmp/torchtune/llama2_7B/quantized # /tmp may be deleted by your system. Change it to your preference.
7 |
8 | #
9 | # Model arguments
10 | model:
11 | _component_: torchtune.models.llama2.llama2_7b
12 |
13 | checkpointer:
14 | _component_: torchtune.training.FullModelHFCheckpointer
15 | checkpoint_dir: /tmp/Llama-2-7b-hf
16 | checkpoint_files: [
17 | pytorch_model-00001-of-00002.bin,
18 | pytorch_model-00002-of-00002.bin,
19 | ]
20 | recipe_checkpoint: null
21 | output_dir: ${output_dir}
22 | model_type: LLAMA2
23 |
24 | device: cuda
25 | dtype: bf16
26 | seed: 1234
27 |
28 | quantizer:
29 | _component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
30 | groupsize: 256
31 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/clip/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._component_builders import clip_mlp, clip_text_encoder, clip_vision_encoder
8 | from ._model_builders import clip_text_vit_large_patch14, clip_tokenizer
9 | from ._position_embeddings import (
10 | TiledTokenPositionalEmbedding,
11 | TilePositionalEmbedding,
12 | TokenPositionalEmbedding,
13 | )
14 | from ._transform import CLIPImageTransform
15 |
16 | __all__ = [
17 | "clip_mlp",
18 | "clip_text_encoder",
19 | "clip_vision_encoder",
20 | "clip_text_vit_large_patch14",
21 | "clip_tokenizer",
22 | "CLIPImageTransform",
23 | "TokenPositionalEmbedding",
24 | "TiledTokenPositionalEmbedding",
25 | "TilePositionalEmbedding",
26 | ]
27 |
--------------------------------------------------------------------------------
/torchtune/torchtune/modules/transforms/tokenizers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._gpt2 import GPT2BaseTokenizer
8 | from ._hf_tokenizer import HuggingFaceBaseTokenizer
9 | from ._sentencepiece import SentencePieceBaseTokenizer
10 | from ._tiktoken import TikTokenBaseTokenizer
11 | from ._utils import (
12 | BaseTokenizer,
13 | ModelTokenizer,
14 | parse_hf_tokenizer_json,
15 | tokenize_messages_no_special_tokens,
16 | )
17 |
18 | __all__ = [
19 | "SentencePieceBaseTokenizer",
20 | "TikTokenBaseTokenizer",
21 | "ModelTokenizer",
22 | "GPT2BaseTokenizer",
23 | "BaseTokenizer",
24 | "tokenize_messages_no_special_tokens",
25 | "parse_hf_tokenizer_json",
26 | "HuggingFaceBaseTokenizer",
27 | ]
28 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/llama3/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._component_builders import llama3, lora_llama3
8 |
9 | from ._model_builders import ( # noqa
10 | llama3_70b,
11 | llama3_8b,
12 | llama3_tokenizer,
13 | lora_llama3_70b,
14 | lora_llama3_8b,
15 | qlora_llama3_70b,
16 | qlora_llama3_8b,
17 | )
18 | from ._parallelism import base_llama_tp_plan
19 | from ._tokenizer import Llama3Tokenizer
20 |
21 | __all__ = [
22 | "Llama3Tokenizer",
23 | "llama3",
24 | "llama3_8b",
25 | "llama3_70b",
26 | "llama3_tokenizer",
27 | "lora_llama3",
28 | "lora_llama3_8b",
29 | "lora_llama3_70b",
30 | "qlora_llama3_8b",
31 | "qlora_llama3_70b",
32 | "base_llama_tp_plan",
33 | ]
34 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/phi3/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._component_builders import lora_phi3, phi3 # noqa
8 | from ._convert_weights import phi3_hf_to_tune, phi3_tune_to_hf # noqa
9 | from ._model_builders import ( # noqa
10 | lora_phi3_mini,
11 | phi3_mini,
12 | phi3_mini_tokenizer,
13 | qlora_phi3_mini,
14 | )
15 | from ._position_embeddings import Phi3RotaryPositionalEmbeddings # noqa
16 | from ._tokenizer import Phi3MiniTokenizer # noqa
17 |
18 | __all__ = [
19 | "phi3_mini",
20 | "phi3_mini_tokenizer",
21 | "lora_phi3_mini",
22 | "qlora_phi3_mini",
23 | "Phi3RotaryPositionalEmbeddings",
24 | "Phi3MiniTokenizer",
25 | "phi3_hf_to_tune",
26 | "phi3_tune_to_hf",
27 | "phi3",
28 | "lora_phi3",
29 | ]
30 |
--------------------------------------------------------------------------------
/torchtune/docs/source/_templates/layout.html:
--------------------------------------------------------------------------------
1 | {% extends "!layout.html" %}
2 |
3 | {% block sidebartitle %}
4 |
7 | {% include "searchbox.html" %}
8 | {% endblock %}
9 |
10 |
11 | {% block footer %}
12 |
13 |
17 |
18 |
21 | {{ super() }}
22 |
27 | {% endblock %}
28 |
--------------------------------------------------------------------------------
/torchtune/torchtune/modules/loss/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .ce_chunked_output_loss import CEWithChunkedOutputLoss
8 |
9 | from .cross_entropy_loss import LinearCrossEntropyLoss
10 | from .kd_losses import (
11 | ForwardKLLoss,
12 | ForwardKLWithChunkedOutputLoss,
13 | ReverseKLLoss,
14 | ReverseKLWithChunkedOutputLoss,
15 | SymmetricKLLoss,
16 | SymmetricKLWithChunkedOutputLoss,
17 | )
18 | from .loss_types import RLLoss, SFTLoss
19 |
20 | __all__ = [
21 | "CEWithChunkedOutputLoss",
22 | "ForwardKLLoss",
23 | "ForwardKLWithChunkedOutputLoss",
24 | "ReverseKLLoss",
25 | "ReverseKLWithChunkedOutputLoss",
26 | "SymmetricKLLoss",
27 | "SymmetricKLWithChunkedOutputLoss",
28 | "LinearCrossEntropyLoss",
29 | "SFTLoss",
30 | "RLLoss",
31 | ]
32 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/_cli/test_ls.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import runpy
7 | import sys
8 |
9 | from tests.common import TUNE_PATH
10 |
11 | from torchtune._recipe_registry import get_all_recipes
12 |
13 |
14 | class TestTuneListCommand:
15 | """This class tests the `tune ls` command."""
16 |
17 | def test_ls_lists_all_recipes_and_configs(self, capsys, monkeypatch):
18 | testargs = "tune ls".split()
19 |
20 | monkeypatch.setattr(sys, "argv", testargs)
21 | runpy.run_path(TUNE_PATH, run_name="__main__")
22 |
23 | captured = capsys.readouterr()
24 | output = captured.out.rstrip("\n")
25 |
26 | for recipe in get_all_recipes():
27 | assert recipe.name in output
28 | for config in recipe.configs:
29 | assert config.name in output
30 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/llama2/_model_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | def scale_hidden_dim_for_mlp(dim: int, multiple_of: int = 256) -> int:
9 | """Scale hidden dimension for MLP to keep number of parameters and computation constant.
10 |
11 | Args:
12 | dim (int): Input dimension.
13 | multiple_of (int): Round scaled dimension to nearest multiple of `multiple_of` for clean computation.
14 |
15 | Returns:
16 | Scaled hidden dimension.
17 | """
18 | # Scale hidden dimension by (2/3)4d for SwiGLU to keep number of
19 | # parameters and computation constant
20 | hidden_dim = 4 * int(2 * dim / 3)
21 | # Round hidden dimension to nearest multiple of `multiple_of`
22 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
23 | return hidden_dim
24 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/llama3/_model_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | def scale_hidden_dim_for_mlp(dim: int, multiple_of: int = 256) -> int:
9 | """Scale hidden dimension for MLP to keep number of parameters and computation constant.
10 |
11 | Args:
12 | dim (int): Input dimension.
13 | multiple_of (int): Round scaled dimension to nearest multiple of `multiple_of` for clean computation.
14 |
15 | Returns:
16 | Scaled hidden dimension.
17 | """
18 | # Scale hidden dimension by (2/3)4d for SwiGLU to keep number of
19 | # parameters and computation constant
20 | hidden_dim = 4 * int(2 * dim / 3)
21 | # Round hidden dimension to nearest multiple of `multiple_of`
22 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
23 | return hidden_dim
24 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/llama3_1/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._component_builders import llama3_1, lora_llama3_1
8 |
9 | from ._model_builders import ( # noqa
10 | llama3_1_405b,
11 | llama3_1_70b,
12 | llama3_1_8b,
13 | lora_llama3_1_405b,
14 | lora_llama3_1_70b,
15 | lora_llama3_1_8b,
16 | qlora_llama3_1_405b,
17 | qlora_llama3_1_70b,
18 | qlora_llama3_1_8b,
19 | )
20 | from ._position_embeddings import Llama3ScaledRoPE
21 |
22 | __all__ = [
23 | "llama3_1",
24 | "llama3_1_8b",
25 | "llama3_1_70b",
26 | "llama3_1_405b",
27 | "lora_llama3_1",
28 | "lora_llama3_1_8b",
29 | "lora_llama3_1_70b",
30 | "lora_llama3_1_405b",
31 | "qlora_llama3_1_8b",
32 | "qlora_llama3_1_70b",
33 | "qlora_llama3_1_405b",
34 | "Llama3ScaledRoPE",
35 | ]
36 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/gemma/rms_norm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import nn
9 |
10 |
11 | class GemmaRMSNorm(nn.Module):
12 | # Copied from https://github.com/google/gemma_pytorch/blob/main/gemma/model.py
13 | def __init__(self, dim: int, eps: float = 1e-6):
14 | super().__init__()
15 | self.eps = eps
16 | self.scale = nn.Parameter(torch.zeros(dim))
17 |
18 | def _norm(self, x):
19 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
20 |
21 | def forward(self, x):
22 | output = self._norm(x.float())
23 | # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
24 | # See https://github.com/huggingface/transformers/pull/29402
25 | output = output * (1.0 + self.scale.float())
26 | return output.type_as(x)
27 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/gemma2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ..gemma._model_builders import gemma_tokenizer
8 | from ..gemma._tokenizer import GemmaTokenizer # noqa
9 | from ._component_builders import gemma2, lora_gemma2 # noqa
10 | from ._model_builders import ( # noqa
11 | gemma2_27b,
12 | gemma2_2b,
13 | gemma2_9b,
14 | lora_gemma2_27b,
15 | lora_gemma2_2b,
16 | lora_gemma2_9b,
17 | qlora_gemma2_27b,
18 | qlora_gemma2_2b,
19 | qlora_gemma2_9b,
20 | )
21 |
22 | __all__ = [
23 | "GemmaTokenizer",
24 | "gemma2",
25 | "gemma2_2b",
26 | "gemma2_9b",
27 | "gemma2_27b",
28 | "gemma_tokenizer",
29 | "lora_gemma2",
30 | "lora_gemma2_2b",
31 | "lora_gemma2_9b",
32 | "lora_gemma2_27b",
33 | "qlora_gemma2_2b",
34 | "qlora_gemma2_9b",
35 | "qlora_gemma2_27b",
36 | ]
37 |
--------------------------------------------------------------------------------
/torchtune/torchtune/modules/peft/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._utils import ( # noqa
8 | AdapterModule,
9 | disable_adapter,
10 | get_adapter_params,
11 | get_adapter_state_dict,
12 | get_lora_module_names,
13 | get_merged_lora_ckpt,
14 | LORA_ATTN_MODULES,
15 | set_trainable_params,
16 | validate_missing_and_unexpected_for_lora,
17 | )
18 | from .dora import DoRALinear
19 | from .lora import LoRALinear, QATLoRALinear, TrainableParams
20 |
21 |
22 | __all__ = [
23 | "AdapterModule",
24 | "DoRALinear",
25 | "LoRALinear",
26 | "QATLoRALinear",
27 | "get_adapter_params",
28 | "set_trainable_params",
29 | "validate_missing_and_unexpected_for_lora",
30 | "disable_adapter",
31 | "get_adapter_state_dict",
32 | "get_merged_lora_ckpt",
33 | "get_lora_module_names",
34 | "TrainableParams",
35 | ]
36 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/qwen2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._component_builders import lora_qwen2, qwen2 # noqa
8 | from ._convert_weights import qwen2_hf_to_tune, qwen2_tune_to_hf # noqa
9 | from ._model_builders import (
10 | lora_qwen2_0_5b,
11 | lora_qwen2_1_5b,
12 | lora_qwen2_7b,
13 | qwen2_0_5b,
14 | qwen2_1_5b,
15 | qwen2_7b,
16 | qwen2_tokenizer,
17 | )
18 | from ._positional_embeddings import Qwen2RotaryPositionalEmbeddings
19 | from ._tokenizer import Qwen2Tokenizer
20 |
21 | __all__ = [
22 | "lora_qwen2",
23 | "qwen2",
24 | "qwen2_hf_to_tune",
25 | "qwen2_tune_to_hf",
26 | "lora_qwen2_0_5b",
27 | "lora_qwen2_1_5b",
28 | "lora_qwen2_7b",
29 | "qwen2_0_5b",
30 | "qwen2_1_5b",
31 | "qwen2_7b",
32 | "qwen2_tokenizer",
33 | "Qwen2RotaryPositionalEmbeddings",
34 | "Qwen2Tokenizer",
35 | ]
36 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | # Suggested config from pytorch that we can adapt
3 | select = B,C,E,F,N,P,T4,W,B9,TOR0,TOR1,TOR2
4 | max-line-length = 120
5 | # C408 ignored because we like the dict keyword argument syntax
6 | # E501 is not flexible enough, we're using B950 instead
7 | # N812 ignored because import torch.nn.functional as F is PyTorch convention
8 | # N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP)
9 | # E731 allow usage of assigning lambda expressions
10 | ignore =
11 | E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731
12 | # shebang has extra meaning in fbcode lints, so I think it's not worth trying
13 | # to line this up with executable bit
14 | EXE001,
15 | # these ignores are from flake8-bugbear; please fix!
16 | B007,B008,
17 | optional-ascii-coding = True
18 | exclude =
19 | ./.git,
20 | ./docs
21 | ./build
22 | ./scripts,
23 | ./venv,
24 | *.pyi
25 | .pre-commit-config.yaml
26 | *.md
27 | .flake8
28 | tests/torchtune/models/llama2/scripts/*.py
29 |
--------------------------------------------------------------------------------
/torchtune/torchtune/dev/rl/workers/metric_logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import ray
8 | from torchtune import config
9 |
10 |
11 | @ray.remote(num_cpus=1, num_gpus=0)
12 | class MetricLoggerWorker:
13 | def __init__(self, cfg):
14 | self.logger = config.instantiate(cfg.metric_logger)
15 | self.logger.log_config(cfg)
16 |
17 | def log_dict(self, log_dict, step=None):
18 | # allowing actors to use their own step counters
19 | self.logger.log_dict(log_dict, step=step)
20 |
21 | def log_table(self, table_data, columns, table_name, step=None):
22 | """Log a table to WandB."""
23 | import wandb
24 |
25 | table = wandb.Table(columns=columns, data=table_data)
26 | self.logger.log_dict({table_name: table}, step=step)
27 |
28 | def close(self):
29 | if hasattr(self.logger, "close"):
30 | self.logger.close()
31 |
--------------------------------------------------------------------------------
/torchtune/tests/recipes/test_configs.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import os
7 | from pathlib import Path
8 |
9 | import torchtune
10 |
11 | from omegaconf import OmegaConf
12 | from torchao.utils import TORCH_VERSION_AFTER_2_4
13 | from torchtune import config
14 |
15 | CONFIG_DIR = Path(torchtune.__file__).parent.parent / "recipes" / "configs"
16 |
17 |
18 | class TestConfigs:
19 | def test_instantiate(self) -> None:
20 | all_configs = [
21 | os.path.join(CONFIG_DIR, f)
22 | for f in os.listdir(CONFIG_DIR)
23 | if f.endswith(".yaml")
24 | ]
25 | for config_path in all_configs:
26 | # QAT config is only compatible with PyTorch 2.4+
27 | if config_path.endswith("qat_full.yaml") and not TORCH_VERSION_AFTER_2_4:
28 | continue
29 | cfg = OmegaConf.load(config_path)
30 | config.validate(cfg)
31 |
--------------------------------------------------------------------------------
/mirotrain/models/convert_weights.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2025 MiromindAI
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | # Adapted from torchtune.models.convert_weights
6 |
7 | import re
8 |
9 | from typing import Dict
10 |
11 |
12 | def get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str:
13 | try:
14 | # Checks if there is a layer # in the key
15 | if any(k.isdigit() for k in key.split(".")):
16 | # Replace layer number with "{}" to create key for lookup
17 | abstract_key = re.sub(r"(\.\d+)", ".{}", key)
18 | # search all layer number
19 | layer_nums = re.findall(r"\d+", key)
20 | new_key = mapping_dict[abstract_key]
21 | new_key = new_key.format(*layer_nums)
22 | else:
23 | new_key = mapping_dict[key]
24 | except KeyError as e:
25 | raise Exception(
26 | f'Error converting the state dict. Found unexpected key: "{key}". '
27 | "Please make sure you're loading a checkpoint with the right format. "
28 | ) from e
29 |
30 | return new_key
31 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/modules/model_fusion/test_fusion_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from torch import nn
8 | from torchtune.modules.model_fusion import get_fusion_params, register_fusion_module
9 |
10 |
11 | def test_register_fusion_module():
12 | """
13 | Test that all parameters are returned as fusion_params.
14 | """
15 | model = nn.Linear(1, 1)
16 | register_fusion_module(model)
17 |
18 | fusion_params = set(model.fusion_params())
19 | assert fusion_params == {"weight", "bias"}
20 |
21 |
22 | def test_get_fusion_params():
23 | """
24 | Test that the correct parameters are returned as fusion_params.
25 | """
26 | layer1 = nn.Linear(1, 1)
27 | layer2 = nn.Linear(1, 1)
28 | register_fusion_module(layer2)
29 | model = nn.Sequential(layer1, layer2)
30 |
31 | fusion_params = set(get_fusion_params(model))
32 | assert fusion_params == {"1.weight", "1.bias"}
33 |
--------------------------------------------------------------------------------
/torchtune/torchtune/config/_errors.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import List
8 |
9 |
10 | class InstantiationError(Exception):
11 | """
12 | Raised when a `_component_` field in a config is unable to be instantiated.
13 | """
14 |
15 | pass
16 |
17 |
18 | class ConfigError(Exception):
19 | """
20 | Raised when the yaml config is not well-formed. Prints all the collected
21 | errors at once.
22 |
23 | Args:
24 | errors (List[Exception]): exceptions found when validating `_component_`
25 | fields in the config
26 | """
27 |
28 | def __init__(self, errors: List[Exception]):
29 | self.errors = errors
30 |
31 | def __str__(self):
32 | error_messages = [f"{type(e).__name__}: {str(e)}" for e in self.errors]
33 | return "Config is not well-formed, found the following errors: \n" + "\n".join(
34 | error_messages
35 | )
36 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/mistral/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._component_builders import (
8 | lora_mistral,
9 | lora_mistral_classifier,
10 | mistral,
11 | mistral_classifier,
12 | )
13 | from ._model_builders import (
14 | lora_mistral_7b,
15 | lora_mistral_reward_7b,
16 | mistral_7b,
17 | mistral_reward_7b,
18 | mistral_tokenizer,
19 | qlora_mistral_7b,
20 | qlora_mistral_reward_7b,
21 | )
22 | from ._prompt_template import MistralChatTemplate
23 | from ._tokenizer import MistralTokenizer
24 |
25 | __all__ = [
26 | "MistralTokenizer",
27 | "MistralChatTemplate",
28 | "lora_mistral",
29 | "lora_mistral_classifier",
30 | "mistral",
31 | "mistral_classifier",
32 | "lora_mistral_7b",
33 | "lora_mistral_reward_7b",
34 | "mistral_7b",
35 | "mistral_reward_7b",
36 | "mistral_tokenizer",
37 | "qlora_mistral_7b",
38 | "qlora_mistral_reward_7b",
39 | ]
40 |
--------------------------------------------------------------------------------
/torchtune/recipes/configs/gemma/evaluation.yaml:
--------------------------------------------------------------------------------
1 | # Config for EleutherEvalRecipe in eleuther_eval.py
2 | #
3 | # To launch, run the following command:
4 | # tune run eleuther_eval --config gemma/evaluation
5 |
6 | output_dir: ./ # Not needed
7 |
8 | # Model Arguments
9 | model:
10 | _component_: torchtune.models.gemma.gemma_2b
11 |
12 | # Checkpointer
13 | checkpointer:
14 | _component_: torchtune.training.FullModelHFCheckpointer
15 | checkpoint_dir: /tmp/gemma-2b
16 | checkpoint_files: [
17 | model-00001-of-00002.safetensors,
18 | model-00002-of-00002.safetensors,
19 | ]
20 | output_dir: ${output_dir}
21 | model_type: GEMMA
22 |
23 | # Tokenizer
24 | tokenizer:
25 | _component_: torchtune.models.gemma.gemma_tokenizer
26 | path: /tmp/gemma-2b/tokenizer.model
27 |
28 | # Environment
29 | device: cuda
30 | dtype: bf16
31 | seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed
32 |
33 | # EleutherAI specific eval args
34 | tasks: ["truthfulqa_mc2"]
35 | limit: null
36 | max_seq_length: 4096
37 | batch_size: 8
38 | enable_kv_cache: True
39 |
40 | # Quantization specific args
41 | quantizer: null
42 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | exclude: 'build'
2 |
3 | default_language_version:
4 | python: python3
5 |
6 | repos:
7 | - repo: https://github.com/pre-commit/pre-commit-hooks
8 | rev: v5.0.0
9 | hooks:
10 | - id: trailing-whitespace
11 | - id: check-ast
12 | - id: check-merge-conflict
13 | - id: no-commit-to-branch
14 | args: ['--branch=main']
15 | - id: check-added-large-files
16 | args: ['--maxkb=1000']
17 | - id: end-of-file-fixer
18 | exclude: '^(.*\.svg)$'
19 |
20 | - repo: https://github.com/pycqa/flake8
21 | rev: 7.1.1
22 | hooks:
23 | - id: flake8
24 | additional_dependencies:
25 | - flake8-bugbear == 22.4.25
26 | - pep8-naming == 0.12.1
27 | - torchfix
28 | args: ['--config=.flake8']
29 |
30 | - repo: https://github.com/omnilib/ufmt
31 | rev: v2.8.0
32 | hooks:
33 | - id: ufmt
34 | additional_dependencies:
35 | - black == 22.12.0
36 | - usort == 1.0.5
37 |
38 | - repo: https://github.com/jsh9/pydoclint
39 | rev: 0.5.12
40 | hooks:
41 | - id: pydoclint
42 | args: [--config=pyproject.toml]
43 |
--------------------------------------------------------------------------------
/mirotrain/monkey/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2025 MiromindAI
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | # -*- coding: utf-8 -*-
6 | """
7 | author: lei.lei@shanda.com
8 | time: 2025/05/29 10:38
9 | description: gevent style monkey patching:
10 | 1. explicit: user calls `mirotrain.monkey.patch_common()`.
11 | 2. declarative: module under `mirotrain` uses `__targets__` and `__implements__` to inform WHAT to patch.
12 | 3. silent error: when patching failed, gives a warning and continue. no exception will be raised.
13 | """
14 | import logging
15 |
16 | from ._state import is_anything_patched, is_module_patched, is_object_patched
17 | from .api import patch_by_source_module_fqn
18 |
19 | __all__ = [
20 | "patch_common",
21 | "is_object_patched",
22 | "is_module_patched",
23 | "is_anything_patched",
24 | ]
25 |
26 | logger = logging.getLogger("mirotrain")
27 |
28 |
29 | def patch_common():
30 | """TODO: add other patchese here"""
31 | patch_by_source_module_fqn("mirotrain.data._messages")
32 | patch_by_source_module_fqn("mirotrain.datasets._packed")
33 | patch_by_source_module_fqn("mirotrain.training.checkpointing._checkpointer")
34 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/config/test_validate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pytest
8 | from omegaconf import OmegaConf
9 | from torchtune import config
10 | from torchtune.config._errors import ConfigError
11 |
12 | VALID_CONFIG_PATH = "tests/assets/valid_dummy_config.yaml"
13 | INVALID_CONFIG_PATH = "tests/assets/invalid_dummy_config.yaml"
14 |
15 |
16 | class TestValidate:
17 | def test_validate(self):
18 | conf = OmegaConf.load(VALID_CONFIG_PATH)
19 | # Test a valid component
20 | config.validate(conf)
21 | # Test an invalid component
22 | conf = OmegaConf.load(INVALID_CONFIG_PATH)
23 | with pytest.raises(ConfigError) as excinfo:
24 | config.validate(conf)
25 | exc_config = excinfo.value
26 | assert len(exc_config.errors) == 2
27 | for e in exc_config.errors:
28 | assert isinstance(e, TypeError)
29 | assert str(e) == "get_dtype got an unexpected keyword argument 'dummy'"
30 |
--------------------------------------------------------------------------------
/torchtune/recipes/configs/qwen3/evaluation.yaml:
--------------------------------------------------------------------------------
1 | # Config for EleutherEvalRecipe in eleuther_eval.py
2 | #
3 | # To launch, run the following command from root torchtune directory:
4 | # tune run eleuther_eval --config qwen3/evaluation
5 |
6 | output_dir: ./ # Not needed
7 |
8 | # Model Arguments
9 | model:
10 | _component_: torchtune.models.qwen3.qwen3_0_6b_instruct
11 |
12 | checkpointer:
13 | _component_: torchtune.training.FullModelHFCheckpointer
14 | checkpoint_dir: /tmp/Qwen3-0.6B
15 | checkpoint_files: [
16 | model.safetensors,
17 | ]
18 | output_dir: ${output_dir}
19 | model_type: QWEN2
20 |
21 | # Tokenizer
22 | tokenizer:
23 | _component_: torchtune.models.qwen3.qwen3_tokenizer
24 | path: /tmp/Qwen3-0.6B/vocab.json
25 | merges_file: /tmp/Qwen3-0.6B/merges.txt
26 | max_seq_len: null
27 |
28 | # Environment
29 | device: cuda
30 | dtype: bf16
31 | seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed
32 |
33 | # EleutherAI specific eval args
34 | tasks: ["truthfulqa_mc2"]
35 | limit: null
36 | max_seq_length: 4096
37 | batch_size: 8
38 | enable_kv_cache: True
39 |
40 | # Quantization specific args
41 | quantizer: null
42 |
--------------------------------------------------------------------------------
/torchtune/torchtune/dev/rl/utils/dist.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 |
9 |
10 | def stateless_init_process_group(
11 | master_address: str,
12 | master_port: int,
13 | rank: int,
14 | world_size: int,
15 | device: torch.device,
16 | ):
17 | """
18 | vLLM provides `StatelessProcessGroup` to create a process group
19 | without considering the global process group in torch.distributed.
20 | It is recommended to create `StatelessProcessGroup`, and then initialize
21 | the data-plane communication (NCCL) between external (train processes)
22 | and vLLM workers.
23 | """
24 | from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
25 | from vllm.distributed.utils import StatelessProcessGroup
26 |
27 | pg = StatelessProcessGroup.create(
28 | host=master_address, port=master_port, rank=rank, world_size=world_size
29 | )
30 |
31 | pynccl = PyNcclCommunicator(pg, device=device)
32 | return pynccl
33 |
--------------------------------------------------------------------------------
/torchtune/torchtune/modules/tokenizers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # flake8: noqa: F401
8 |
9 | # NOTE: This file is maintained for backward compatibility purposes.
10 | # The imports below point to the new location in `torchtune.modules.transforms.tokenizers`.
11 | # The import paths will be removed in v0.7. Please update your code to use the new path
12 | # (torchtune.modules.transforms.tokenizers) to avoid breaking changes in future releases.
13 |
14 |
15 | import warnings
16 |
17 | from torchtune.modules.transforms.tokenizers import (
18 | BaseTokenizer,
19 | ModelTokenizer,
20 | parse_hf_tokenizer_json,
21 | SentencePieceBaseTokenizer,
22 | TikTokenBaseTokenizer,
23 | tokenize_messages_no_special_tokens,
24 | )
25 |
26 | warnings.warn(
27 | "The import path 'torchtune.modules.tokenizers' is deprecated and will be removed in v0.7. "
28 | "Please update your imports to 'torchtune.modules.transforms.tokenizers'.",
29 | DeprecationWarning,
30 | stacklevel=2,
31 | )
32 |
--------------------------------------------------------------------------------
/torchtune/recipes/configs/qwen2_5/evaluation.yaml:
--------------------------------------------------------------------------------
1 | # Config for EleutherEvalRecipe in eleuther_eval.py
2 | #
3 | # To launch, run the following command from root torchtune directory:
4 | # tune run eleuther_eval --config qwen2_5/evaluation
5 |
6 | output_dir: ./ # Not needed
7 |
8 | # Model Arguments
9 | model:
10 | _component_: torchtune.models.qwen2_5.qwen2_5_0_5b
11 |
12 | checkpointer:
13 | _component_: torchtune.training.FullModelHFCheckpointer
14 | checkpoint_dir: /tmp/Qwen2.5-0.5B-Instruct
15 | checkpoint_files: [
16 | model.safetensors,
17 | ]
18 | output_dir: ${output_dir}
19 | model_type: QWEN2
20 |
21 | # Tokenizer
22 | tokenizer:
23 | _component_: torchtune.models.qwen2_5.qwen2_5_tokenizer
24 | path: /tmp/Qwen2.5-0.5B-Instruct/vocab.json
25 | merges_file: /tmp/Qwen2.5-0.5B-Instruct/merges.txt
26 | max_seq_len: null
27 |
28 | # Environment
29 | device: cuda
30 | dtype: bf16
31 | seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed
32 |
33 | # EleutherAI specific eval args
34 | tasks: ["truthfulqa_mc2"]
35 | limit: null
36 | max_seq_length: 4096
37 | batch_size: 8
38 | enable_kv_cache: True
39 |
40 | # Quantization specific args
41 | quantizer: null
42 |
--------------------------------------------------------------------------------
/mirotrain/modules/moe/expert_parallel.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2025 MiromindAI
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from typing import Optional
6 |
7 | import torch.distributed as dist
8 |
9 | _EXPERT_PARALLEL_GROUP = None
10 |
11 |
12 | def set_expert_parallel_group(group: dist.ProcessGroup):
13 | """
14 | Set expert parallel process group.
15 | """
16 | global _EXPERT_PARALLEL_GROUP
17 | _EXPERT_PARALLEL_GROUP = group
18 |
19 |
20 | def get_expert_parallel_group() -> Optional[dist.ProcessGroup]:
21 | """
22 | Get expert parallel process group.
23 | """
24 | global _EXPERT_PARALLEL_GROUP
25 | return _EXPERT_PARALLEL_GROUP
26 |
27 |
28 | def get_expert_parallel_world_size(group: dist.ProcessGroup = None) -> int:
29 | """
30 | Get expert parallel world size.
31 | """
32 | group = get_expert_parallel_group() if group is None else group
33 | return dist.get_world_size(group) if group else 1
34 |
35 |
36 | def get_expert_parallel_rank(group: dist.ProcessGroup = None) -> int:
37 | """
38 | Get expert parallel rank.
39 | """
40 | group = get_expert_parallel_group() if group is None else group
41 | return dist.get_rank(group) if group else 0
42 |
--------------------------------------------------------------------------------
/torchtune/torchtune/modules/layer_norm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | from typing import Any
9 |
10 | import torch
11 | from torch import nn
12 |
13 |
14 | class Fp32LayerNorm(nn.LayerNorm):
15 | """
16 | Wrapper around :class:`~torch.nn.LayerNorm` to support mixed-precision training.
17 | """
18 |
19 | def __init__(self, *args: Any, **kwargs: Any) -> None:
20 | super().__init__(*args, **kwargs)
21 |
22 | def forward(self, x: torch.Tensor) -> torch.Tensor:
23 | """
24 | Args:
25 | x (torch.Tensor): Input tensor.
26 |
27 | Returns:
28 | torch.Tensor: The normalized output tensor having the same shape as ``x``.
29 | """
30 | output = nn.functional.layer_norm(
31 | x.float(),
32 | self.normalized_shape,
33 | self.weight.float() if self.weight is not None else None,
34 | self.bias.float() if self.bias is not None else None,
35 | self.eps,
36 | )
37 | return output.type_as(x)
38 |
--------------------------------------------------------------------------------
/torchtune/torchtune/rlhf/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | from ._types import ChosenRejectedOutputs, PPOStats, Trajectory
9 |
10 | from .rewards import (
11 | estimate_advantages,
12 | get_reward_penalty_mask,
13 | get_rewards_ppo,
14 | masked_mean,
15 | masked_sum,
16 | masked_var,
17 | whiten,
18 | )
19 | from .sequence_processing import (
20 | batched_logits_to_logprobs,
21 | get_batch_log_probs,
22 | logits_to_logprobs,
23 | truncate_sequence_at_first_stop_token,
24 | truncate_sequence_for_logprobs,
25 | )
26 |
27 | __all__ = [
28 | "truncate_sequence_at_first_stop_token",
29 | "logits_to_logprobs",
30 | "batched_logits_to_logprobs",
31 | "truncate_sequence_for_logprobs",
32 | "get_reward_penalty_mask",
33 | "estimate_advantages",
34 | "get_rewards_ppo",
35 | "whiten",
36 | "masked_mean",
37 | "masked_sum",
38 | "masked_var",
39 | "PPOStats",
40 | "get_batch_log_probs",
41 | "Trajectory",
42 | "ChosenRejectedOutputs",
43 | ]
44 |
--------------------------------------------------------------------------------
/torchtune/recipes/configs/mistral/evaluation.yaml:
--------------------------------------------------------------------------------
1 | # Config for EleutherEvalRecipe in eleuther_eval.py
2 | #
3 | # To launch, run the following command:
4 | # tune run eleuther_eval --config mistral/evaluation
5 |
6 | output_dir: ./ # Not needed
7 |
8 | # Model Arguments
9 | model:
10 | _component_: torchtune.models.mistral.mistral_7b
11 |
12 | # Checkpointer
13 | checkpointer:
14 | _component_: torchtune.training.FullModelHFCheckpointer
15 | checkpoint_dir: /tmp/Mistral-7B-v0.1/
16 | checkpoint_files: [
17 | pytorch_model-00001-of-00002.bin,
18 | pytorch_model-00002-of-00002.bin
19 | ]
20 | output_dir: ${output_dir}
21 | model_type: MISTRAL
22 | resume_from_checkpoint: False
23 |
24 | # Tokenizer
25 | tokenizer:
26 | _component_: torchtune.models.mistral.mistral_tokenizer
27 | path: /tmp/Mistral-7B-v0.1/tokenizer.model
28 | max_seq_len: null
29 |
30 | # Environment
31 | device: cuda
32 | dtype: bf16
33 | seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed
34 |
35 | # EleutherAI specific eval args
36 | tasks: ["truthfulqa_mc2"]
37 | limit: null
38 | max_seq_length: 4096
39 | batch_size: 8
40 | enable_kv_cache: True
41 |
42 | # Quantization specific args
43 | quantizer: null
44 |
--------------------------------------------------------------------------------
/torchtune/recipes/configs/eleuther_evaluation.yaml:
--------------------------------------------------------------------------------
1 | # Config for EleutherEvalRecipe in eleuther_eval.py
2 | #
3 | # To launch, run the following command from root torchtune directory:
4 | # tune run eleuther_eval --config eleuther_evaluation tasks=["truthfulqa_mc2","hellaswag"]
5 |
6 | output_dir: ./ # Not needed
7 |
8 | # Model Arguments
9 | model:
10 | _component_: torchtune.models.llama2.llama2_7b
11 |
12 | checkpointer:
13 | _component_: torchtune.training.FullModelHFCheckpointer
14 | checkpoint_dir: /tmp/Llama-2-7b-hf
15 | checkpoint_files: [
16 | pytorch_model-00001-of-00002.bin,
17 | pytorch_model-00002-of-00002.bin,
18 | ]
19 | output_dir: ${output_dir}
20 | model_type: LLAMA2
21 |
22 | # Tokenizer
23 | tokenizer:
24 | _component_: torchtune.models.llama2.llama2_tokenizer
25 | path: /tmp/Llama-2-7b-hf/tokenizer.model
26 | max_seq_len: null
27 |
28 | # Environment
29 | device: cuda
30 | dtype: bf16
31 | seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed
32 |
33 | # EleutherAI specific eval args
34 | tasks: ["truthfulqa_mc2"]
35 | limit: null
36 | max_seq_length: 4096
37 | batch_size: 8
38 | enable_kv_cache: True
39 |
40 | # Quantization specific args
41 | quantizer: null
42 |
--------------------------------------------------------------------------------
/torchtune/recipes/configs/phi3/evaluation.yaml:
--------------------------------------------------------------------------------
1 | # Config for EleutherEvalRecipe in eleuther_eval.py
2 | #
3 | # To launch, run the following command:
4 | # tune run eleuther_eval --config phi3/evaluation
5 |
6 | output_dir: ./ # Not needed
7 |
8 | # Model Arguments
9 | model:
10 | _component_: torchtune.models.phi3.phi3_mini
11 |
12 | # Checkpointer
13 | checkpointer:
14 | _component_: torchtune.training.FullModelHFCheckpointer
15 | checkpoint_dir: /tmp/phi-3
16 | checkpoint_files: [
17 | model-00001-of-00002.safetensors,
18 | model-00002-of-00002.safetensors
19 | ]
20 | recipe_checkpoint: null
21 | output_dir: ${output_dir}
22 | model_type: PHI3_MINI
23 | resume_from_checkpoint: False
24 |
25 | # Tokenizer
26 | tokenizer:
27 | _component_: torchtune.models.phi3.phi3_mini_tokenizer
28 | path: /tmp/phi-3/tokenizer.model
29 | max_seq_len: null
30 |
31 | # Environment
32 | device: cuda
33 | dtype: bf16
34 | seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed
35 |
36 | # EleutherAI specific eval args
37 | tasks: ["truthfulqa_mc2"]
38 | limit: null
39 | max_seq_length: 4096
40 | batch_size: 8
41 | enable_kv_cache: True
42 |
43 | # Quantization specific args
44 | quantizer: null
45 |
--------------------------------------------------------------------------------
/torchtune/recipes/configs/llama3_2/evaluation.yaml:
--------------------------------------------------------------------------------
1 | # Config for EleutherEvalRecipe in eleuther_eval.py
2 | #
3 | # To launch, run the following command:
4 | # tune run eleuther_eval --config llama3_2/evaluation
5 |
6 | output_dir: ./ # Not needed
7 |
8 | # Model Arguments
9 | model:
10 | _component_: torchtune.models.llama3_2.llama3_2_3b
11 |
12 | # Checkpointer
13 | checkpointer:
14 | _component_: torchtune.training.FullModelHFCheckpointer
15 | checkpoint_dir: /tmp/Llama-3.2-3B-Instruct
16 | checkpoint_files: [
17 | model-00001-of-00002.safetensors,
18 | model-00002-of-00002.safetensors,
19 | ]
20 | recipe_checkpoint: null
21 | output_dir: ${output_dir}
22 | model_type: LLAMA3_2
23 | resume_from_checkpoint: False
24 |
25 | # Tokenizer
26 | tokenizer:
27 | _component_: torchtune.models.llama3.llama3_tokenizer
28 | path: /tmp/Llama-3.2-3B-Instruct/original/tokenizer.model
29 | max_seq_len: null
30 |
31 | # Environment
32 | device: cpu
33 | dtype: bf16
34 | seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed
35 |
36 | # EleutherAI specific eval args
37 | tasks: ["truthfulqa_mc2"]
38 | limit: null
39 | max_seq_length: 4096
40 | batch_size: 8
41 | enable_kv_cache: True
42 |
43 | # Quantization specific args
44 | quantizer: null
45 |
--------------------------------------------------------------------------------
/torchtune/recipes/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # This file mainly exists because we want to ensure that `recipes` aren't
8 | # importable *from the tests*.
9 | # We're using the `prepend` pytest import mode which adds the root dir (i.e. the
10 | # parent of torchtune/, tests/, recipes/) to the pythonpath during pytest
11 | # sessions
12 | # (https://docs.pytest.org/en/7.1.x/explanation/pythonpath.html#import-modes).
13 | # This has the positive effect that the `tests` folder becomes importable when
14 | # testing (we need that, considering how tests are currently set up) but ALSO
15 | # has the negative effect of making the `recipes/` importable when testing.
16 | # Since we don't want the tests to to incorrectly assume that recipes are
17 | # importable, we have to explicitly raise an error here.
18 |
19 | raise ModuleNotFoundError(
20 | "The torchtune recipes directory isn't a package and you should not import anything from here. "
21 | "Refer to our docs for detailed instructions on how to use recipes: "
22 | "https://pytorch.org/torchtune/main/deep_dives/recipe_deepdive.html"
23 | )
24 |
--------------------------------------------------------------------------------
/torchtune/recipes/configs/code_llama2/evaluation.yaml:
--------------------------------------------------------------------------------
1 | # Config for EleutherEvalRecipe in eleuther_eval.py
2 | #
3 | # To launch, run the following command:
4 | # tune run eleuther_eval --config code_llama2/evaluation
5 |
6 | output_dir: ./ # Not needed
7 |
8 | # Model arguments
9 | model:
10 | _component_: torchtune.models.code_llama2.code_llama2_7b
11 |
12 | # Tokenizer
13 | tokenizer:
14 | _component_: torchtune.models.llama2.llama2_tokenizer
15 | path: /tmp/CodeLlama-7b-hf/tokenizer.model
16 | max_seq_len: null
17 |
18 | # Checkpointer
19 | checkpointer:
20 | _component_: torchtune.training.FullModelHFCheckpointer
21 | checkpoint_dir: /tmp/CodeLlama-7b-hf
22 | checkpoint_files: [
23 | pytorch_model-00001-of-00003.bin,
24 | pytorch_model-00002-of-00003.bin,
25 | pytorch_model-00003-of-00003.bin
26 | ]
27 | recipe_checkpoint: null
28 | output_dir: ${output_dir}
29 | model_type: LLAMA2
30 | resume_from_checkpoint: False
31 |
32 | # Environment
33 | device: cpu
34 | dtype: bf16
35 | seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed
36 |
37 | # EleutherAI specific eval args
38 | tasks: ["truthfulqa_mc2"]
39 | limit: null
40 | max_seq_length: 4096
41 | batch_size: 8
42 | enable_kv_cache: True
43 |
44 | # Quantization specific args
45 | quantizer: null
46 |
--------------------------------------------------------------------------------
/torchtune/recipes/configs/qwen2/evaluation.yaml:
--------------------------------------------------------------------------------
1 | # Config for EleutherEvalRecipe in eleuther_eval.py
2 | #
3 | # To launch, run the following command:
4 | # tune run eleuther_eval --config qwen2/evaluation
5 |
6 | output_dir: ./ # Not needed
7 |
8 | # Model Arguments
9 | model:
10 | _component_: torchtune.models.qwen2.qwen2_7b
11 |
12 | # Checkpointer
13 | checkpointer:
14 | _component_: torchtune.training.FullModelHFCheckpointer
15 | checkpoint_dir: /tmp/Qwen2-7B-Instruct
16 | checkpoint_files: [
17 | model-00001-of-00004.safetensors,
18 | model-00002-of-00004.safetensors,
19 | model-00003-of-00004.safetensors,
20 | model-00004-of-00004.safetensors
21 | ]
22 | output_dir: ${output_dir}
23 | model_type: QWEN2
24 |
25 | # Tokenizer
26 | tokenizer:
27 | _component_: torchtune.models.qwen2.qwen2_tokenizer
28 | path: /tmp/Qwen2-7B-Instruct/vocab.json
29 | merges_file: /tmp/Qwen2-7B-Instruct/merges.txt
30 | max_seq_len: null
31 |
32 | # Environment
33 | device: cuda
34 | dtype: bf16
35 | seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed
36 |
37 | # EleutherAI specific eval args
38 | tasks: ["truthfulqa_mc2"]
39 | limit: null
40 | max_seq_length: 4096
41 | batch_size: 8
42 | enable_kv_cache: True
43 |
44 | # Quantization specific args
45 | quantizer: null
46 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/modules/low_precision/test_nf4_dispatch_registration.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torchao.dtypes import to_nf4
9 |
10 |
11 | class TestNF4DispatchRegistration:
12 | """
13 | Class for testing NF4Tensor dispatch ops.
14 | """
15 |
16 | def test_inplace_copy_copies_expected_attributes(self):
17 | """
18 | This test ensures that we're copying over all relevant attributes when implementing
19 | torch.ops.aten.copy_.default. If this test fails, we would need to update our implementation
20 | in _register_nf4_dispatch_ops to cover the newly added attributes.
21 | """
22 | expected_inplace_copy_attrs = [
23 | "block_size",
24 | "n_blocks",
25 | "scaler_block_size",
26 | "quantized_scalers",
27 | "quantization_factor",
28 | "scaler_mean",
29 | "quantized_data",
30 | "nf4",
31 | ]
32 |
33 | z = to_nf4(torch.rand(512, 512, dtype=torch.bfloat16))
34 | inplace_copy_attr_set = set(z.__dict__.keys())
35 | assert set(expected_inplace_copy_attrs) == inplace_copy_attr_set
36 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/llama3_2_vision/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._component_builders import ( # noqa
8 | llama3_2_vision_decoder,
9 | llama3_2_vision_encoder,
10 | lora_llama3_2_vision_decoder,
11 | lora_llama3_2_vision_encoder,
12 | )
13 | from ._encoder import Llama3VisionEncoder, Llama3VisionProjectionHead
14 |
15 | from ._model_builders import ( # noqa
16 | llama3_2_vision_11b,
17 | llama3_2_vision_90b,
18 | llama3_2_vision_transform,
19 | lora_llama3_2_vision_11b,
20 | lora_llama3_2_vision_90b,
21 | qlora_llama3_2_vision_11b,
22 | qlora_llama3_2_vision_90b,
23 | )
24 | from ._transform import Llama3VisionTransform
25 |
26 | __all__ = [
27 | "llama3_2_vision_11b",
28 | "llama3_2_vision_transform",
29 | "lora_llama3_2_vision_11b",
30 | "qlora_llama3_2_vision_11b",
31 | "llama3_2_vision_90b",
32 | "lora_llama3_2_vision_90b",
33 | "qlora_llama3_2_vision_90b",
34 | "llama3_2_vision_decoder",
35 | "llama3_2_vision_encoder",
36 | "lora_llama3_2_vision_decoder",
37 | "lora_llama3_2_vision_encoder",
38 | "Llama3VisionEncoder",
39 | "Llama3VisionProjectionHead",
40 | "Llama3VisionTransform",
41 | ]
42 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/t5/test_t5_tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import pytest
7 |
8 | from tests.common import ASSETS
9 | from torchtune.models.t5._model_builders import t5_tokenizer
10 |
11 |
12 | class TestT5Tokenizer:
13 | @pytest.fixture
14 | def tokenizer(self):
15 | return t5_tokenizer(str(ASSETS / "sentencepiece.model"))
16 |
17 | def test_encoding(self, tokenizer):
18 | texts = [
19 | "a cow jumping over the moon",
20 | "a helpful AI assistant",
21 | ]
22 | correct_tokens = [
23 | [3, 9, 9321, 15539, 147, 8, 8114, 1],
24 | [3, 9, 2690, 7833, 6165, 1],
25 | ]
26 | for text, correct in zip(texts, correct_tokens):
27 | tokens = tokenizer.encode(text)
28 | print(tokens)
29 | assert tokens == correct
30 |
31 | def test_decoding(self, tokenizer):
32 | text = "this is torchtune"
33 | assert text == tokenizer.decode(tokenizer.encode(text))
34 |
35 | def test_call(self, tokenizer):
36 | sample = {"text": "hello world"}
37 | sample = tokenizer(sample)
38 | assert "text" not in sample
39 | assert "tokens" in sample
40 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/llama3_3/_model_builders.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | from torchtune.models.llama3_1._model_builders import (
7 | llama3_1_70b,
8 | lora_llama3_1_70b,
9 | qlora_llama3_1_70b,
10 | )
11 |
12 | """
13 | Model builders build specific instantiations using component builders. The Llama3.3 model
14 | builders all call the Llama3.1 models as they're identical models apart from the checkpoints.
15 | """
16 |
17 | llama3_3_70b = llama3_1_70b
18 |
19 | llama3_3_70b.__doc__ = """
20 | Builder for creating a Llama3.3 model initialized w/ the default 70B parameter values.
21 | Please see `llama3_1_70b` for full API arguments.
22 | """
23 |
24 | lora_llama3_3_70b = lora_llama3_1_70b
25 |
26 | lora_llama3_3_70b.__doc__ = """
27 | Builder for creating a Llama3.3 70B model with LoRA enabled.
28 | Please see `lora_llama3_1_70b` for full API arguments.
29 | """
30 |
31 | qlora_llama3_3_70b = qlora_llama3_1_70b
32 |
33 | qlora_llama3_1_70b.__doc__ = """
34 | Builder for creating a Llama3.3 70B model with QLoRA enabled. Base model weights in linear layers
35 | that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314.
36 | Please see `lora_llama3_1_70b` for full API arguments.
37 | """
38 |
--------------------------------------------------------------------------------
/torchtune/recipes/configs/phi4/evaluation.yaml:
--------------------------------------------------------------------------------
1 | # Config for EleutherEvalRecipe in eleuther_eval.py
2 | #
3 | # To launch, run the following command:
4 | # tune run eleuther_eval --config phi4/evaluation
5 |
6 | output_dir: ./ # Not needed
7 |
8 | # Model Arguments
9 | model:
10 | _component_: torchtune.models.phi4.phi4_14b
11 |
12 | # Checkpointer
13 | checkpointer:
14 | _component_: torchtune.training.FullModelHFCheckpointer
15 | checkpoint_dir: /tmp/phi-4
16 | checkpoint_files: [
17 | model-00001-of-00006.safetensors,
18 | model-00002-of-00006.safetensors,
19 | model-00003-of-00006.safetensors,
20 | model-00004-of-00006.safetensors,
21 | model-00005-of-00006.safetensors,
22 | model-00006-of-00006.safetensors,
23 | ]
24 | recipe_checkpoint: null
25 | output_dir: ${output_dir}
26 | model_type: PHI4
27 | resume_from_checkpoint: False
28 |
29 | # Tokenizer
30 | tokenizer:
31 | _component_: torchtune.models.phi4.phi4_tokenizer
32 | vocab_path: /tmp/phi-4/vocab.json
33 | merges_path: /tmp/phi-4/merges.txt
34 | max_seq_len: null
35 |
36 | # Environment
37 | device: cuda
38 | dtype: bf16
39 | seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed
40 |
41 | # EleutherAI specific eval args
42 | tasks: ["truthfulqa_mc2"]
43 | limit: null
44 | max_seq_length: 4096
45 | batch_size: 8
46 | enable_kv_cache: True
47 |
48 | # Quantization specific args
49 | quantizer: null
50 |
--------------------------------------------------------------------------------
/torchtune/torchtune/utils/_version.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from datetime import datetime
8 |
9 | import torch
10 |
11 |
12 | def torch_version_ge(version: str) -> bool:
13 | """
14 | Check if torch version is greater than or equal to the given version.
15 |
16 | Args:
17 | version (str): The torch version to compare against
18 |
19 | Returns:
20 | bool: True if torch version is greater than or equal to the given version.
21 |
22 | Example:
23 | >>> print(torch.__version__)
24 | 2.4.0
25 | >>> torch_version_ge("2.0")
26 | True
27 | """
28 | return version in torch.__version__ or torch.__version__ >= version
29 |
30 |
31 | def _is_fbcode():
32 | return not hasattr(torch.version, "git_version")
33 |
34 |
35 | def _nightly_version_ge(ao_version_str: str, date: str) -> bool:
36 | """
37 | Compare a torchao nightly version to a date of the form
38 | %Y-%m-%d.
39 |
40 | Returns True if the nightly version is greater than or equal to
41 | the date, False otherwise
42 | """
43 | ao_datetime = datetime.strptime(
44 | ao_version_str.split("+")[0].split("dev")[1], "%Y%m%d"
45 | )
46 | return ao_datetime >= datetime.strptime(date, "%Y-%m-%d")
47 |
--------------------------------------------------------------------------------
/torchtune/recipes/configs/llama2/generation_v2.yaml:
--------------------------------------------------------------------------------
1 | # Config for running the InferenceRecipe in generate_V2.py to generate output from an LLM
2 | #
3 | # This config assumes that you've run the following command before launching:
4 | # tune download meta-llama/Llama-2-7b-chat-hf --output-dir /tmp/Llama-2-7b-chat-hf --ignore-patterns "*.bin" --hf-token
5 | #
6 | # To launch, run the following command:
7 | # tune run dev/generate_v2 --config llama2/generation_v2
8 |
9 | output_dir: ./ # Not needed
10 |
11 | # Model arguments
12 | model:
13 | _component_: torchtune.models.llama2.llama2_7b
14 |
15 | # Transform arguments
16 | tokenizer:
17 | _component_: torchtune.models.llama2.llama2_tokenizer
18 | path: /tmp/Llama-2-7b-chat-hf/tokenizer.model
19 | max_seq_len: 2048
20 |
21 | # Checkpointer
22 | checkpointer:
23 | _component_: torchtune.training.FullModelHFCheckpointer
24 | checkpoint_dir: /tmp/Llama-2-7b-chat-hf
25 | checkpoint_files: [
26 | model-00001-of-00002.safetensors,
27 | model-00002-of-00002.safetensors
28 | ]
29 | output_dir: ${output_dir}
30 | model_type: LLAMA2
31 |
32 | # Device
33 | device: cuda
34 | dtype: bf16
35 | seed: 1234
36 | log_level: INFO # DEBUG, WARN, etc.
37 |
38 | # Generation arguments
39 | prompt:
40 | system: You are a helpful and creative AI assistant.
41 | user: What is the capital of France?
42 | max_new_tokens: 200
43 | temperature: 0.6 # 0.8 and 0.6 are popular values to try
44 | top_k: 300
45 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/llama3/test_llama3.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pytest
8 | import torch
9 | from tests.test_utils import fixed_init_model
10 | from torchtune.models.llama3 import llama3
11 | from torchtune.training.seed import set_seed
12 |
13 | EMBED_DIM = 128
14 | NUM_LAYERS = 4
15 | NUM_HEADS = 16
16 | NUM_KV_HEADS = 8
17 | VOCAB_SIZE = 32000
18 | MAX_SEQ_LEN = 2048
19 | BSZ = 2
20 | SEQ_LEN = 100
21 |
22 |
23 | @pytest.fixture(autouse=True)
24 | def random():
25 | set_seed(16)
26 |
27 |
28 | class TestLlama3:
29 | @pytest.fixture
30 | def inputs(self):
31 | return torch.randint(0, VOCAB_SIZE, (BSZ, SEQ_LEN))
32 |
33 | def test_forward(self, inputs):
34 | model = llama3(
35 | vocab_size=VOCAB_SIZE,
36 | num_layers=NUM_LAYERS,
37 | num_heads=NUM_HEADS,
38 | num_kv_heads=NUM_KV_HEADS,
39 | embed_dim=EMBED_DIM,
40 | max_seq_len=MAX_SEQ_LEN,
41 | )
42 | fixed_init_model(model, min_val=-0.25, max_val=0.5)
43 | actual = model(inputs)
44 | expected = torch.tensor(3.9763)
45 | assert actual.shape == (BSZ, SEQ_LEN, VOCAB_SIZE)
46 | torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-4)
47 |
--------------------------------------------------------------------------------
/mirotrain/monkey/_errors.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2025 MiromindAI
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | # -*- coding: utf-8 -*-
6 |
7 | # This file is adapted from
8 | # https://github.com/gevent/gevent/blob/73025a8837b3bff19c106e877fa2374889c59dd3/src/gevent/monkey/_errors.py
9 |
10 | """
11 | Exception classes and errors that this package may raise.
12 | """
13 |
14 |
15 | class _BadModule(ImportError):
16 | """
17 | Raised when a module is not importable.
18 | """
19 |
20 | def __init__(self, name):
21 | ImportError.__init__(self, "Module %r is not importable" % (name,))
22 |
23 |
24 | class _BadTargets(AttributeError):
25 | """
26 | Raised when ``__targets__`` is incorrect.
27 | """
28 |
29 | def __init__(self, module, target):
30 | AttributeError.__init__(
31 | self,
32 | "Module %r has a bad or missing value %r for __targets__"
33 | % (
34 | target,
35 | module,
36 | ),
37 | )
38 |
39 |
40 | class _BadImplements(AttributeError):
41 | """
42 | Raised when ``__implements__`` is incorrect.
43 | """
44 |
45 | def __init__(self, module, implements):
46 | AttributeError.__init__(
47 | self,
48 | "Module %r has a bad or missing value %r for __implements__"
49 | % (
50 | implements,
51 | module,
52 | ),
53 | )
54 |
--------------------------------------------------------------------------------
/torchtune/torchtune/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | __version__ = ""
8 |
9 |
10 | # Check at the top-level that torchao is installed.
11 | # This is better than doing it at every import site.
12 | # We have to do this because it is not currently possible to
13 | # properly support both nightly and stable installs of PyTorch + torchao
14 | # in pyproject.toml.
15 | try:
16 | import torchao # noqa
17 | except ImportError as e:
18 | raise ImportError(
19 | """
20 | torchao not installed.
21 | Please follow the instructions at https://pytorch.org/torchtune/main/install.html#pre-requisites
22 | to install torchao.
23 | """
24 | ) from e
25 |
26 | # Enables faster downloading. For more info: https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
27 | # To disable, run `HF_HUB_ENABLE_HF_TRANSFER=0 tune download `
28 | try:
29 | import os
30 |
31 | import hf_transfer # noqa
32 |
33 | if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") is None:
34 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
35 | except ImportError:
36 | pass
37 |
38 | from torchtune import datasets, generation, models, modules, utils
39 |
40 | __all__ = [datasets, models, modules, utils, generation]
41 |
--------------------------------------------------------------------------------
/torchtune/recipes/configs/generation.yaml:
--------------------------------------------------------------------------------
1 | # Config for running the InferenceRecipe in generate.py to generate output
2 | # from Llama2 7B model
3 | #
4 | # This config assumes that you've run the following command before launching
5 | # this run:
6 | # tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --ignore-patterns "*.safetensors" --hf-token
7 | #
8 | # To launch, run the following command from root torchtune directory:
9 | # tune run generate --config generation
10 |
11 | output_dir: ./ # Not needed
12 |
13 | # Model arguments
14 | model:
15 | _component_: torchtune.models.llama2.llama2_7b
16 |
17 | checkpointer:
18 | _component_: torchtune.training.FullModelHFCheckpointer
19 | checkpoint_dir: /tmp/Llama-2-7b-hf/
20 | checkpoint_files: [
21 | pytorch_model-00001-of-00002.bin,
22 | pytorch_model-00002-of-00002.bin,
23 | ]
24 | output_dir: ${output_dir}
25 | model_type: LLAMA2
26 |
27 | device: cuda
28 | dtype: bf16
29 |
30 | seed: 1234
31 |
32 | # Tokenizer arguments
33 | tokenizer:
34 | _component_: torchtune.models.llama2.llama2_tokenizer
35 | path: /tmp/Llama-2-7b-hf/tokenizer.model
36 | max_seq_len: null
37 | prompt_template: null
38 |
39 | # Generation arguments; defaults taken from gpt-fast
40 | prompt:
41 | system: null
42 | user: "Tell me a joke."
43 | max_new_tokens: 300
44 | temperature: 0.6 # 0.8 and 0.6 are popular values to try
45 | top_k: 300
46 |
47 | enable_kv_cache: True
48 |
49 | quantizer: null
50 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/_cli/test_validate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import runpy
8 | import sys
9 |
10 | import pytest
11 | from tests.common import ASSETS, TUNE_PATH
12 |
13 |
14 | class TestTuneValidateCommand:
15 | """This class tests the `tune validate` command."""
16 |
17 | VALID_CONFIG_PATH = ASSETS / "valid_dummy_config.yaml"
18 | INVALID_CONFIG_PATH = ASSETS / "invalid_dummy_config.yaml"
19 |
20 | def test_validate_good_config(self, capsys, monkeypatch):
21 | args = f"tune validate {self.VALID_CONFIG_PATH}".split()
22 |
23 | monkeypatch.setattr(sys, "argv", args)
24 | runpy.run_path(TUNE_PATH, run_name="__main__")
25 |
26 | captured = capsys.readouterr()
27 | out = captured.out.rstrip("\n")
28 |
29 | assert out == "Config is well-formed!"
30 |
31 | def test_validate_bad_config(self, monkeypatch, capsys):
32 | args = f"tune validate {self.INVALID_CONFIG_PATH}".split()
33 |
34 | monkeypatch.setattr(sys, "argv", args)
35 | with pytest.raises(SystemExit):
36 | runpy.run_path(TUNE_PATH, run_name="__main__")
37 |
38 | captured = capsys.readouterr()
39 | err = captured.err.rstrip("\n")
40 |
41 | assert "got an unexpected keyword argument 'dummy'" in err
42 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/llama2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._component_builders import (
8 | llama2,
9 | llama2_classifier,
10 | lora_llama2,
11 | lora_llama2_classifier,
12 | )
13 |
14 | from ._model_builders import ( # noqa
15 | llama2_13b,
16 | llama2_70b,
17 | llama2_7b,
18 | llama2_reward_7b,
19 | llama2_tokenizer,
20 | lora_llama2_13b,
21 | lora_llama2_70b,
22 | lora_llama2_7b,
23 | lora_llama2_reward_7b,
24 | qlora_llama2_13b,
25 | qlora_llama2_70b,
26 | qlora_llama2_7b,
27 | qlora_llama2_reward_7b,
28 | )
29 | from ._prompt_template import Llama2ChatTemplate
30 | from ._tokenizer import Llama2Tokenizer
31 |
32 | __all__ = [
33 | "Llama2Tokenizer",
34 | "Llama2ChatTemplate",
35 | "llama2",
36 | "llama2_classifier",
37 | "lora_llama2_classifier",
38 | "llama2_reward_7b",
39 | "lora_llama2_reward_7b",
40 | "qlora_llama2_reward_7b",
41 | "lora_llama2",
42 | "llama2_13b",
43 | "llama2_70b",
44 | "llama2_7b",
45 | "llama2_tokenizer",
46 | "lora_llama2",
47 | "llama2_classifier",
48 | "lora_llama2_13b",
49 | "lora_llama2_70b",
50 | "lora_llama2_7b",
51 | "qlora_llama2_13b",
52 | "qlora_llama2_70b",
53 | "qlora_llama2_7b",
54 | ]
55 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/phi3/test_phi3.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pytest
8 | import torch
9 | from tests.test_utils import fixed_init_model
10 | from torchtune.models.phi3 import phi3
11 | from torchtune.training.seed import set_seed
12 |
13 | EMBED_DIM = 128
14 | INTER_DIM = 256
15 | NUM_LAYERS = 4
16 | NUM_HEADS = 16
17 | NUM_KV_HEADS = 8
18 | VOCAB_SIZE = 32000
19 | MAX_SEQ_LEN = 2048
20 | BSZ = 2
21 | SEQ_LEN = 100
22 |
23 |
24 | @pytest.fixture(autouse=True)
25 | def random():
26 | set_seed(16)
27 |
28 |
29 | class TestPhi3:
30 | @pytest.fixture
31 | def inputs(self):
32 | return torch.randint(0, VOCAB_SIZE, (BSZ, SEQ_LEN))
33 |
34 | def test_forward(self, inputs):
35 | model = phi3(
36 | vocab_size=VOCAB_SIZE,
37 | num_layers=NUM_LAYERS,
38 | num_heads=NUM_HEADS,
39 | num_kv_heads=NUM_KV_HEADS,
40 | embed_dim=EMBED_DIM,
41 | intermediate_dim=INTER_DIM,
42 | max_seq_len=MAX_SEQ_LEN,
43 | )
44 | fixed_init_model(model, min_val=-0.25, max_val=0.5)
45 | actual = model(inputs)
46 | expected = torch.tensor(3.9763)
47 | assert actual.shape == (BSZ, SEQ_LEN, VOCAB_SIZE)
48 | torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-4)
49 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/llama4/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._component_builders import (
8 | llama4_decoder,
9 | llama4_vision_encoder,
10 | llama4_vision_projection_head,
11 | lora_llama4_decoder,
12 | lora_llama4_vision_encoder,
13 | lora_llama4_vision_projection_head,
14 | )
15 | from ._encoder import Llama4VisionEncoder, Llama4VisionProjectionHead
16 |
17 | from ._model_builders import (
18 | llama4_maverick_17b_128e,
19 | llama4_scout_17b_16e,
20 | llama4_transform,
21 | lora_llama4_scout_17b_16e,
22 | )
23 | from ._parallelism import decoder_only_tp_plan
24 | from ._position_embeddings import Llama4ScaledRoPE
25 | from ._tokenizer import Llama4Tokenizer
26 | from ._transform import Llama4Transform
27 |
28 | __all__ = [
29 | "llama4_vision_encoder",
30 | "decoder_only_tp_plan",
31 | "llama4_vision_projection_head",
32 | "Llama4VisionEncoder",
33 | "Llama4VisionProjectionHead",
34 | "llama4_decoder",
35 | "Llama4Tokenizer",
36 | "llama4_scout_17b_16e",
37 | "llama4_maverick_17b_128e",
38 | "lora_llama4_vision_encoder",
39 | "lora_llama4_vision_projection_head",
40 | "lora_llama4_decoder",
41 | "lora_llama4_scout_17b_16e",
42 | "Llama4Transform",
43 | "llama4_transform",
44 | "Llama4ScaledRoPE",
45 | ]
46 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/qwen2/test_qwen2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pytest
8 | import torch
9 | from tests.test_utils import fixed_init_model
10 | from torchtune.models.qwen2 import qwen2
11 | from torchtune.training.seed import set_seed
12 |
13 | EMBED_DIM = 128
14 | INTER_DIM = 256
15 | NUM_LAYERS = 4
16 | NUM_HEADS = 16
17 | NUM_KV_HEADS = 8
18 | VOCAB_SIZE = 32000
19 | MAX_SEQ_LEN = 2048
20 | BSZ = 2
21 | SEQ_LEN = 100
22 |
23 |
24 | @pytest.fixture(autouse=True)
25 | def random():
26 | set_seed(16)
27 |
28 |
29 | class TestQwen2:
30 | @pytest.fixture
31 | def inputs(self):
32 | return torch.randint(0, VOCAB_SIZE, (BSZ, SEQ_LEN))
33 |
34 | def test_forward(self, inputs):
35 | model = qwen2(
36 | vocab_size=VOCAB_SIZE,
37 | num_layers=NUM_LAYERS,
38 | num_heads=NUM_HEADS,
39 | num_kv_heads=NUM_KV_HEADS,
40 | embed_dim=EMBED_DIM,
41 | intermediate_dim=INTER_DIM,
42 | max_seq_len=MAX_SEQ_LEN,
43 | )
44 | fixed_init_model(model, min_val=-0.25, max_val=0.5)
45 | actual = model(inputs)
46 | expected = torch.tensor(3.9763)
47 | assert actual.shape == (BSZ, SEQ_LEN, VOCAB_SIZE)
48 | torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-4)
49 |
--------------------------------------------------------------------------------
/torchtune/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright 2024 Meta
4 |
5 | Redistribution and use in source and binary forms, with or without modification,
6 | are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice,this list
9 | of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice, this
12 | list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its contributors may
16 | be used to endorse or promote products derived from this software without specific
17 | prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY
20 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
21 | OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
22 | SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
23 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
24 | TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
25 | BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
27 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
28 | DAMAGE.
29 |
--------------------------------------------------------------------------------
/torchtune/torchtune/training/pooling.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import torch
7 |
8 |
9 | def get_unmasked_sequence_lengths(mask: torch.Tensor) -> torch.Tensor:
10 | """
11 | Returns the sequence lengths (0-indexed) for each batch element, excluding masked tokens.
12 |
13 | Args:
14 | mask (torch.Tensor): Boolean mask with shape [b x s], where True indicates a value to be masked out
15 | This is usually a mask for padding tokens, where True indicates a padding token.
16 |
17 | Returns:
18 | Tensor: Sequence indices logits with shape [b]
19 |
20 | Shape notation:
21 | - b = batch size
22 | - s = sequence length
23 |
24 | Example:
25 | >>> input_ids = torch.tensor([
26 | ... [2, 4, 0, 0],
27 | ... [2, 4, 6, 0],
28 | ... [2, 4, 6, 9]
29 | ... ])
30 | >>> mask = input_ids == 0
31 | >>> mask
32 | tensor([[False, False, True, True],
33 | [False, False, False, True],
34 | [False, False, False, False]])
35 | >>> get_unmasked_sequence_lengths(mask)
36 | tensor([1, 2, 3])
37 |
38 | """
39 | # calculate per-batch-element sequence lengths by finding last valid tokens
40 | sequence_lengths = (~mask).cumsum(dim=-1).argmax(dim=-1).to(dtype=torch.long)
41 |
42 | return sequence_lengths.clip(0, mask.shape[1] - 1)
43 |
--------------------------------------------------------------------------------
/mirotrain/training/lr_schedulers.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2025 MiromindAI
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import math
6 |
7 | import torch
8 | from torch.optim.lr_scheduler import LambdaLR
9 |
10 |
11 | def get_cosine_schedule_with_warmup(
12 | optimizer: torch.optim.Optimizer,
13 | num_training_steps: int,
14 | warmup_ratio: float = 0.1,
15 | num_cycles: float = 0.5,
16 | last_epoch: int = -1,
17 | ) -> LambdaLR:
18 | """
19 | Create a learning rate schedule that linearly increases the learning rate from
20 | 0.0 to lr over `` warmup_ratio * num_training_steps``, then decreases to 0.0 on a cosine schedule over
21 | the remaining ``num_training_steps - warmup_ratio * num_training_steps`` (assuming ``num_cycles`` = 0.5).
22 |
23 | This is based on the Hugging Face implementation
24 | https://github.com/huggingface/transformers/blob/v4.23.1/src/transformers/optimization.py#L104.
25 |
26 | """
27 | num_warmup_steps = int(warmup_ratio * num_training_steps)
28 |
29 | def lr_lambda(current_step: int) -> float:
30 | # linear warmup phase
31 | if current_step < num_warmup_steps:
32 | return current_step / max(1, num_warmup_steps)
33 |
34 | # cosine
35 | progress = (current_step - num_warmup_steps) / max(
36 | 1, num_training_steps - num_warmup_steps
37 | )
38 |
39 | cosine_lr_multiple = 0.5 * (
40 | 1.0 + math.cos(math.pi * num_cycles * 2.0 * progress)
41 | )
42 | return max(0.0, cosine_lr_multiple)
43 |
44 | return LambdaLR(optimizer, lr_lambda, last_epoch)
45 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/llama4/test_llama4_transform.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pytest
8 | import torch
9 | from tests.common import ASSETS
10 | from tests.test_utils import assert_expected
11 | from torchtune.models.llama4._transform import Llama4Transform
12 | from torchtune.training.seed import set_seed
13 | from torchvision.transforms.v2 import functional as F
14 |
15 | EMBED_DIM = 128
16 | BSZ = 2
17 | N_IMG = 1
18 | N_TILES = 4
19 | N_PATCHES = 17 # 16 + 1 for CLS token
20 |
21 |
22 | @pytest.fixture(autouse=True)
23 | def random():
24 | set_seed(16)
25 |
26 |
27 | def random_image():
28 | tensor = torch.randn(3, 4, 4)
29 | tensor = torch.clamp(tensor, 0, 1)
30 | pil_image = F.to_pil_image(tensor)
31 | return pil_image
32 |
33 |
34 | class TestImageThumbnailTransform:
35 | @pytest.fixture
36 | def transform(self):
37 | return Llama4Transform(
38 | path=str(ASSETS / "tiktoken_small_llama4.model"),
39 | tile_size=16,
40 | patch_size=4,
41 | max_num_tiles=4,
42 | image_mean=[0.5, 0.5, 0.5],
43 | image_std=[0.5, 0.5, 0.5],
44 | dtype=torch.float32,
45 | )
46 |
47 | def test_call(self, transform):
48 | image = random_image()
49 | actual = transform.thumbnail_transform(image)
50 | assert actual.shape == (3, 16, 16)
51 | assert_expected(actual.sum(), torch.tensor(-123.3412))
52 |
--------------------------------------------------------------------------------
/torchtune/docs/source/_static/img/pytorch-logo-flame.svg:
--------------------------------------------------------------------------------
1 |
2 |
34 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/clip/test_clip_text_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pytest
8 | import torch
9 |
10 | from torchtune.models.clip._component_builders import clip_text_encoder
11 | from torchtune.training.seed import set_seed
12 |
13 | VOCAB_SIZE = 512
14 | MAX_SEQ_LEN = 77
15 | BSZ = 2
16 | EMBED_DIM = 4
17 |
18 |
19 | @pytest.fixture(autouse=True)
20 | def random():
21 | set_seed(0)
22 |
23 |
24 | class TestClipTextEncoder:
25 | @pytest.fixture
26 | def model(self):
27 | model = clip_text_encoder(
28 | vocab_size=VOCAB_SIZE,
29 | max_seq_len=MAX_SEQ_LEN,
30 | embed_dim=EMBED_DIM,
31 | num_heads=2,
32 | num_layers=2,
33 | )
34 |
35 | for param in model.parameters():
36 | param.data.uniform_(0, 1)
37 |
38 | return model
39 |
40 | @pytest.fixture
41 | def inputs(self):
42 | return torch.randint(0, VOCAB_SIZE, (BSZ, MAX_SEQ_LEN))
43 |
44 | def test_forward(self, model, inputs):
45 | actual = model(inputs)
46 | expected = torch.tensor(
47 | [[0.1915, 1.3982, 0.6298, -0.0966], [0.2276, 1.3785, 0.6309, -0.1066]]
48 | )
49 | assert actual.shape == (BSZ, EMBED_DIM)
50 | torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)
51 |
52 | def test_backward(self, model, inputs):
53 | y = model(inputs)
54 | loss = y.mean()
55 | loss.backward()
56 |
--------------------------------------------------------------------------------
/torchtune/torchtune/training/checkpointing/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | from typing import Union
7 |
8 | from torchtune.training.checkpointing._checkpointer import (
9 | DistributedCheckpointer,
10 | FullModelHFCheckpointer,
11 | FullModelMetaCheckpointer,
12 | FullModelTorchTuneCheckpointer,
13 | )
14 | from torchtune.training.checkpointing._utils import (
15 | ADAPTER_CONFIG,
16 | ADAPTER_KEY,
17 | DATALOADER_KEY,
18 | EPOCHS_KEY,
19 | FormattedCheckpointFiles,
20 | get_largest_iter_folder,
21 | MAX_STEPS_KEY,
22 | MODEL_KEY,
23 | ModelType,
24 | OPT_KEY,
25 | RNG_KEY,
26 | SEED_KEY,
27 | STEPS_KEY,
28 | TOTAL_EPOCHS_KEY,
29 | update_state_dict_for_classifier,
30 | )
31 |
32 | Checkpointer = Union[
33 | DistributedCheckpointer,
34 | FullModelHFCheckpointer,
35 | FullModelMetaCheckpointer,
36 | FullModelTorchTuneCheckpointer,
37 | ]
38 |
39 | __all__ = [
40 | "FullModelHFCheckpointer",
41 | "FullModelMetaCheckpointer",
42 | "FullModelTorchTuneCheckpointer",
43 | "DistributedCheckpointer",
44 | "ModelType",
45 | "Checkpointer",
46 | "update_state_dict_for_classifier",
47 | "ADAPTER_CONFIG",
48 | "get_largest_iter_folder",
49 | "ADAPTER_KEY",
50 | "EPOCHS_KEY",
51 | "MAX_STEPS_KEY",
52 | "MODEL_KEY",
53 | "OPT_KEY",
54 | "RNG_KEY",
55 | "SEED_KEY",
56 | "STEPS_KEY",
57 | "TOTAL_EPOCHS_KEY",
58 | "FormattedCheckpointFiles",
59 | "DATALOADER_KEY",
60 | ]
61 |
--------------------------------------------------------------------------------
/torchtune/recipes/configs/llama3/70B_generation_distributed.yaml:
--------------------------------------------------------------------------------
1 | # Config for running the InferenceRecipe in dev/generate_v2.py to generate output
2 | # using a Llama3 70B Instruct model
3 | #
4 | # This config assumes that you've run the following command before launching:
5 | # tune download meta-llama/Meta-Llama-3-70B-Instruct --output-dir /tmp/Meta-Llama-3-70B-Instruct --ignore-patterns "original/consolidated*" --hf-token
6 | #
7 | # To launch, run the following command from root torchtune directory:
8 | # tune run --nproc_per_node 8 dev/generate_v2_distributed --config llama3/70B_generation_distributed
9 |
10 | output_dir: ./
11 |
12 | # Model arguments
13 | model:
14 | _component_: torchtune.models.llama3.llama3_70b
15 |
16 | tensor_parallel_plan:
17 | _component_: torchtune.models.llama3.base_llama_tp_plan
18 |
19 | # Transform arguments
20 | tokenizer:
21 | _component_: torchtune.models.llama3.llama3_tokenizer
22 | path: /tmp/Meta-Llama-3-70B-Instruct/original/tokenizer.model
23 | prompt_template: null
24 | max_seq_len: 8192
25 |
26 | # Checkpointer
27 | checkpointer:
28 | _component_: torchtune.training.FullModelHFCheckpointer
29 | checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct
30 | checkpoint_files:
31 | filename_format: model-{}-of-{}.safetensors
32 | max_filename: "00030"
33 | recipe_checkpoint: null
34 | output_dir: ${output_dir}
35 | model_type: LLAMA3
36 |
37 | # Device
38 | device: cuda
39 | dtype: bf16
40 | seed: 1234
41 | log_level: INFO # DEBUG, WARN, etc.
42 |
43 | # Generation arguments
44 | prompt:
45 | system: null
46 | user:
47 | text: Tell a joke.
48 | max_new_tokens: 200
49 | temperature: 0.6 # 0.8 and 0.6 are popular values to try
50 | top_k: 300
51 |
--------------------------------------------------------------------------------
/torchtune/recipes/configs/llama3_3/70B_generation_distributed.yaml:
--------------------------------------------------------------------------------
1 | # Config for running the InferenceRecipe in dev/generate_v2.py to generate output
2 | # using a Llama3.1 70B Instruct model
3 | #
4 | # This config assumes that you've run the following command before launching:
5 | # tune download meta-llama/Llama-3.3-70B-Instruct --ignore-patterns "original/consolidated*" --output-dir /tmp/Llama-3.3-70B-Instruct --hf-token
6 | #
7 | # To launch, run the following command from root torchtune directory:
8 | # tune run --nproc_per_node 8 dev/generate_v2_distributed --config llama3_3/70B_generation_distributed
9 |
10 | output_dir: ./
11 |
12 | # Model arguments
13 | model:
14 | _component_: torchtune.models.llama3_3.llama3_3_70b
15 |
16 | tensor_parallel_plan:
17 | _component_: torchtune.models.llama3.base_llama_tp_plan
18 |
19 | # Transform arguments
20 | tokenizer:
21 | _component_: torchtune.models.llama3.llama3_tokenizer
22 | path: /tmp/Llama-3.3-70B-Instruct/original/tokenizer.model
23 | prompt_template: null
24 | max_seq_len: 8192
25 |
26 | # Checkpointer
27 | checkpointer:
28 | _component_: torchtune.training.FullModelHFCheckpointer
29 | checkpoint_dir: /tmp/Llama-3.3-70B-Instruct/
30 | checkpoint_files:
31 | filename_format: model-{}-of-{}.safetensors
32 | max_filename: "00030"
33 | recipe_checkpoint: null
34 | output_dir: ${output_dir}
35 | model_type: LLAMA3
36 |
37 | # Device
38 | device: cuda
39 | dtype: bf16
40 | seed: 1234
41 | log_level: INFO # DEBUG, WARN, etc.
42 |
43 | # Generation arguments
44 | prompt:
45 | system: null
46 | user:
47 | text: Tell a joke.
48 | max_new_tokens: 200
49 | temperature: 0.6 # 0.8 and 0.6 are popular values to try
50 | top_k: 300
51 |
--------------------------------------------------------------------------------
/torchtune/torchtune/data/_torchdata.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import functools
8 | from typing import Any, Callable, Iterable, Iterator, Mapping, TypeVar
9 |
10 | from torchtune.utils._import_guard import _TORCHDATA_INSTALLED, _TORCHDATA_MIN_VERSION
11 |
12 | from typing_extensions import TypeAlias
13 |
14 |
15 | if _TORCHDATA_INSTALLED:
16 | from torchdata.nodes import BaseNode, Loader # noqa
17 | else:
18 | # If we fail to import torchdata, define stubs to make typechecker happy
19 | T = TypeVar("T")
20 |
21 | class BaseNode(Iterator[T]):
22 | def __init__(self, *args, **kwargs):
23 | pass
24 |
25 | class Loader(Iterable):
26 | def __init__(self, *args, **kwargs):
27 | assert_torchdata_installed()
28 |
29 |
30 | DatasetType: TypeAlias = BaseNode[Mapping[str, Any]] # type: ignore
31 |
32 |
33 | def assert_torchdata_installed():
34 | if not _TORCHDATA_INSTALLED:
35 | raise ImportError(
36 | f"torchdata is not installed, or the current version is too old. "
37 | f"Please (re-)install it with `pip install torchdata>={_TORCHDATA_MIN_VERSION}`. "
38 | )
39 |
40 |
41 | def requires_torchdata(func: Callable) -> Callable:
42 | """
43 | Decorator to check if torchdata is installed and raise an ImportError if not.
44 | """
45 |
46 | @functools.wraps(func)
47 | def wrapper(*args, **kwargs):
48 | assert_torchdata_installed()
49 | return func(*args, **kwargs)
50 |
51 | return wrapper
52 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/datasets/test_wikitext_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | from unittest.mock import patch
7 |
8 | import pytest
9 |
10 | from tests.test_utils import DummyTokenizer
11 |
12 | from torchtune.datasets import wikitext_dataset
13 |
14 |
15 | class TestWikiTextDataset:
16 | @pytest.fixture
17 | def tokenizer(self):
18 | return DummyTokenizer()
19 |
20 | @patch("torchtune.datasets._text_completion.load_dataset")
21 | @pytest.mark.parametrize("max_seq_len", [128, 512, 1024, 4096])
22 | def test_dataset_get_item(self, load_dataset, tokenizer, max_seq_len):
23 | # Sample data from wikitext dataset
24 | load_dataset.return_value = [
25 | {
26 | "page": "Bart , like the rest of his family , has yellow skin . "
27 | "Bart usually wears a red T @-@ shirt , blue shorts and blue trainers . "
28 | "When the Simpson family goes to church in the episodes , or to school "
29 | "events or shows , Bart wears a blue suit with a white shirt , a purple "
30 | "tie , blue shorts and a blue jacket .",
31 | }
32 | ]
33 | ds = wikitext_dataset(
34 | tokenizer=tokenizer,
35 | max_seq_len=max_seq_len,
36 | )
37 | input, label = ds[0]["tokens"], ds[0]["labels"]
38 | assert len(input) <= max_seq_len
39 | assert len(label) <= max_seq_len
40 | assert len(input) == len(label)
41 | assert input[0] == tokenizer.bos_id
42 |
--------------------------------------------------------------------------------
/torchtune/torchtune/modules/transforms/vision_utils/pad_dim_to_size.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn.functional as F
9 |
10 |
11 | def pad_dim_to_size(
12 | input: torch.Tensor, size: int, dim: int, *, fill: float = 0.0
13 | ) -> torch.Tensor:
14 | """
15 | Pads the given dimension of the input to the given size.
16 |
17 | Example:
18 | >>> image = torch.rand(1, 4, 4)
19 | >>> padded_image = pad_to_size(image, 3, dim=0)
20 | >>> padded_image.shape
21 | torch.Size([3, 4, 4])
22 |
23 | Args:
24 | input (torch.Tensor): Tensor to pad.
25 | size (int): Size to pad to.
26 | dim (int): Dimension to pad.
27 | fill (float): Value to fill the padded region with. Default: 0.0
28 |
29 | Returns:
30 | torch.Tensor: Padded input.
31 | """
32 | pad_size = size - input.shape[dim]
33 | assert (
34 | pad_size >= 0
35 | ), f"Tensor input shape {input.shape[dim]} is larger than given size {size}"
36 |
37 | # Set up 0 padding for the entire tensor.
38 | # Padding is in order W*H*C*N, with front and back for each dim.
39 | # https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
40 | padding = [0] * 2 * input.dim()
41 | # Find the pad_index: convert NCHW to WHCN, and only pad the back
42 | # (not both sides).
43 | pad_index = (input.dim() - dim) * 2 - 1
44 | padding[pad_index] = pad_size
45 | # Pad dim to size.
46 | return F.pad(input, padding, value=fill)
47 |
--------------------------------------------------------------------------------
/torchtune/recipes/configs/llama3_1/70B_generation_distributed.yaml:
--------------------------------------------------------------------------------
1 | # Config for running the InferenceRecipe in dev/generate_v2.py to generate output
2 | # using a Llama3.1 70B Instruct model
3 | #
4 | # This config assumes that you've run the following command before launching:
5 | # tune download meta-llama/Meta-Llama-3.1-70B-Instruct --output-dir /tmp/Meta-Llama-3.1-70B-Instruct --ignore-patterns "original/consolidated*" --hf-token
6 | #
7 | # To launch, run the following command from root torchtune directory:
8 | # tune run --nproc_per_node 8 dev/generate_v2_distributed --config llama3_1/70B_generation_distributed
9 |
10 | output_dir: ./
11 |
12 | # Model arguments
13 | model:
14 | _component_: torchtune.models.llama3_1.llama3_1_70b
15 |
16 | tensor_parallel_plan:
17 | _component_: torchtune.models.llama3.base_llama_tp_plan
18 |
19 | # Transform arguments
20 | tokenizer:
21 | _component_: torchtune.models.llama3.llama3_tokenizer
22 | path: /tmp/Meta-Llama-3.1-70B-Instruct/original/tokenizer.model
23 | prompt_template: null
24 | max_seq_len: 8192
25 |
26 | # Checkpointer
27 | checkpointer:
28 | _component_: torchtune.training.FullModelHFCheckpointer
29 | checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/
30 | checkpoint_files:
31 | filename_format: model-{}-of-{}.safetensors
32 | max_filename: "00030"
33 | recipe_checkpoint: null
34 | output_dir: ${output_dir}
35 | model_type: LLAMA3
36 |
37 | # Device
38 | device: cuda
39 | dtype: bf16
40 | seed: 1234
41 | log_level: INFO # DEBUG, WARN, etc.
42 |
43 | # Generation arguments
44 | prompt:
45 | system: null
46 | user:
47 | text: Tell a joke.
48 | max_new_tokens: 200
49 | temperature: 0.6 # 0.8 and 0.6 are popular values to try
50 | top_k: 300
51 |
--------------------------------------------------------------------------------
/mirotrain/models/qwen3/_checkpointing_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2025 MiromindAI
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import json
6 | from functools import wraps
7 | from pathlib import Path
8 |
9 | from ._positional_embeddings import RopeScaling
10 |
11 |
12 | def patch_copy_files(max_position_embeddings: int, rope_scaling: RopeScaling):
13 | import mirotrain.training.checkpointing._checkpointer as _checkpointer
14 |
15 | original_copy_files = _checkpointer.copy_files
16 |
17 | @wraps(original_copy_files)
18 | def patched_copy_files(
19 | input_dir: str | Path, output_dir: str | Path, *args, **kwargs
20 | ):
21 | """
22 | Find the config.json in output_dir and update it with custom model_extra parameters.
23 |
24 | Modifications applied to config.json:
25 | - max_position_embeddings: Updated if specified in model_extra
26 | - rope_scaling: Added/replaced if specified in model_extra
27 | """
28 |
29 | original_copy_files(input_dir, output_dir, *args, **kwargs)
30 |
31 | if isinstance(output_dir, str):
32 | output_dir = Path(output_dir)
33 | config_path = output_dir / "config.json"
34 | if not config_path.exists():
35 | return
36 |
37 | with open(config_path, "r") as f:
38 | config = json.load(f)
39 |
40 | config["max_position_embeddings"] = max_position_embeddings
41 | config["rope_scaling"] = rope_scaling
42 |
43 | with open(config_path, "w") as f:
44 | json.dump(config, f, indent=2, ensure_ascii=False)
45 | f.write("\n")
46 |
47 | return patched_copy_files
48 |
49 | _checkpointer.copy_files = patched_copy_files
50 |
--------------------------------------------------------------------------------
/mirotrain/models/qwen3/_validate.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2025 MiromindAI
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from ._positional_embeddings import RopeScaling
6 |
7 |
8 | def validate_yarn_cfg(max_position_embeddings: int, rope_scaling: RopeScaling):
9 | """
10 | Validate YARN-specific configs
11 | """
12 |
13 | rope_type = rope_scaling.get("rope_type")
14 | if rope_type is None:
15 | raise KeyError("'rope_type' must be specified in 'rope_scaling' configuration.")
16 | if rope_type != "yarn":
17 | raise ValueError(
18 | f"Unsupported 'rope_type': {rope_type}. "
19 | "Currently, only 'yarn' is supported for Qwen patching."
20 | )
21 |
22 | factor = rope_scaling.get("factor")
23 | if factor is None:
24 | raise KeyError("'factor' must be specified in 'rope_scaling' for YARN.")
25 |
26 | original_max_position_embeddings = rope_scaling.get(
27 | "original_max_position_embeddings"
28 | )
29 | if original_max_position_embeddings is None:
30 | raise KeyError(
31 | "'original_max_position_embeddings' must be specified in 'rope_scaling' for YARN."
32 | )
33 |
34 | if max_position_embeddings is None:
35 | raise KeyError(
36 | "'max_position_embeddings' must be specified in model_extra when rope_scaling is enabled."
37 | )
38 |
39 | if max_position_embeddings != int(original_max_position_embeddings * factor):
40 | raise ValueError(
41 | f"Configured 'max_position_embeddings' ({max_position_embeddings}) is not equal to "
42 | f"'original_max_position_embeddings' ({original_max_position_embeddings}) * "
43 | f"'rope_scaling.factor' ({factor})."
44 | )
45 |
--------------------------------------------------------------------------------
/torchtune/recipes/configs/llama3_2_vision/11B_generation_v2.yaml:
--------------------------------------------------------------------------------
1 | # Config for running the InferenceRecipe in dev/generate_v2.py to generate output
2 | # from a Llama3.2 11B Vision Instruct model
3 | #
4 | # This config assumes that you've run the following command before launching:
5 | # tune download meta-llama/Llama-3.2-11B-Vision-Instruct --output-dir /tmp/Llama-3.2-11B-Vision-Instruct --ignore-patterns "original/consolidated*"
6 | #
7 | # To launch, run the following command from root torchtune directory:
8 | # tune run dev/generate_v2 --config llama3_2_vision/11B_generation_v2
9 |
10 | output_dir: ./
11 |
12 | # Model arguments
13 | model:
14 | _component_: torchtune.models.llama3_2_vision.llama3_2_vision_11b
15 |
16 | # Transform arguments
17 | tokenizer:
18 | _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
19 | path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model
20 | prompt_template: null
21 | max_seq_len: 8192
22 |
23 | # Checkpointer
24 | checkpointer:
25 | _component_: torchtune.training.FullModelHFCheckpointer
26 | checkpoint_dir: /tmp/Llama-3.2-11B-Vision-Instruct/
27 | checkpoint_files:
28 | filename_format: model-{}-of-{}.safetensors
29 | max_filename: "00005"
30 | output_dir: ${output_dir}
31 | model_type: LLAMA3_VISION
32 |
33 | # Device
34 | device: cuda
35 | dtype: bf16
36 | seed: 1234
37 | log_level: INFO # DEBUG, WARN, etc.
38 |
39 | # Generation arguments
40 | prompt:
41 | system: You are a helpful assistant who responds like the author Shakespeare.
42 | user:
43 | image: https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg
44 | text: What is in this image?
45 | max_new_tokens: 200
46 | temperature: 0.6 # 0.8 and 0.6 are popular values to try
47 | top_k: 300
48 |
--------------------------------------------------------------------------------
/torchtune/torchtune/_cli/tune.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 |
9 | from torchtune._cli.cat import Cat
10 |
11 | from torchtune._cli.cp import Copy
12 | from torchtune._cli.download import Download
13 | from torchtune._cli.ls import List
14 | from torchtune._cli.run import Run
15 | from torchtune._cli.validate import Validate
16 |
17 |
18 | class TuneCLIParser:
19 | """Holds all information related to running the CLI"""
20 |
21 | def __init__(self):
22 | # Initialize the top-level parser
23 | self._parser = argparse.ArgumentParser(
24 | prog="tune",
25 | description="Welcome to the torchtune CLI!",
26 | add_help=True,
27 | )
28 | # Default command is to print help
29 | self._parser.set_defaults(func=lambda args: self._parser.print_help())
30 |
31 | # Add subcommands
32 | subparsers = self._parser.add_subparsers(title="subcommands")
33 | Download.create(subparsers)
34 | List.create(subparsers)
35 | Copy.create(subparsers)
36 | Run.create(subparsers)
37 | Validate.create(subparsers)
38 | Cat.create(subparsers)
39 |
40 | def parse_args(self) -> argparse.Namespace:
41 | """Parse CLI arguments"""
42 | return self._parser.parse_args()
43 |
44 | def run(self, args: argparse.Namespace) -> None:
45 | """Execute CLI"""
46 | args.func(args)
47 |
48 |
49 | def main():
50 | parser = TuneCLIParser()
51 | args = parser.parse_args()
52 | parser.run(args)
53 |
54 |
55 | if __name__ == "__main__":
56 | main()
57 |
--------------------------------------------------------------------------------
/torchtune/docs/source/api_ref_datasets.rst:
--------------------------------------------------------------------------------
1 | .. _datasets:
2 |
3 | ==================
4 | torchtune.datasets
5 | ==================
6 |
7 | .. currentmodule:: torchtune.datasets
8 |
9 | For a detailed general usage guide, please see :ref:`datasets_overview`.
10 |
11 |
12 | Text datasets
13 | -------------
14 |
15 | torchtune supports several widely used text-only datasets to help quickly bootstrap your fine-tuning.
16 |
17 | .. autosummary::
18 | :toctree: generated/
19 | :nosignatures:
20 |
21 | alpaca_dataset
22 | alpaca_cleaned_dataset
23 | grammar_dataset
24 | hh_rlhf_helpful_dataset
25 | samsum_dataset
26 | slimorca_dataset
27 | stack_exchange_paired_dataset
28 | cnn_dailymail_articles_dataset
29 | wikitext_dataset
30 |
31 | Image + Text datasets
32 | ---------------------
33 |
34 | .. autosummary::
35 | :toctree: generated/
36 | :nosignatures:
37 |
38 | multimodal.llava_instruct_dataset
39 | multimodal.the_cauldron_dataset
40 | multimodal.vqa_dataset
41 |
42 | .. _dataset_builders:
43 |
44 | Generic dataset builders
45 | ------------------------
46 |
47 | torchtune also supports generic dataset builders for common formats like chat models and instruct models.
48 | These are especially useful for specifying from a YAML config.
49 |
50 | .. autosummary::
51 | :toctree: generated/
52 | :nosignatures:
53 |
54 | instruct_dataset
55 | chat_dataset
56 | preference_dataset
57 | text_completion_dataset
58 |
59 | Generic dataset classes
60 | -----------------------
61 |
62 | Class representations for the above dataset builders.
63 |
64 | .. autosummary::
65 | :toctree: generated/
66 | :nosignatures:
67 |
68 | TextCompletionDataset
69 | ConcatDataset
70 | PackedDataset
71 | PreferenceDataset
72 | SFTDataset
73 |
--------------------------------------------------------------------------------
/torchtune/torchtune/config/_validate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import inspect
8 |
9 | from omegaconf import DictConfig
10 | from torchtune.config._errors import ConfigError
11 | from torchtune.config._utils import _get_component_from_path, _has_component
12 |
13 |
14 | def validate(cfg: DictConfig) -> None:
15 | """
16 | Ensure that all components in the config can be instantiated correctly
17 |
18 | Args:
19 | cfg (DictConfig): The config to validate
20 |
21 | Raises:
22 | ConfigError: If any component cannot be instantiated
23 | """
24 |
25 | errors = []
26 | for node, nodedict in cfg.items():
27 | if _has_component(nodedict):
28 | try:
29 | _component_ = _get_component_from_path(nodedict.get("_component_"))
30 | kwargs = {k: v for k, v in nodedict.items() if k != "_component_"}
31 | sig = inspect.signature(_component_)
32 | sig.bind(**kwargs)
33 | # Some objects require other objects as arguments, like optimizers,
34 | # lr_schedulers, datasets, etc. Try doing partial instantiation
35 | except TypeError as e:
36 | if "missing a required argument" in str(e):
37 | sig.bind_partial(**kwargs)
38 | else:
39 | # inspect.signature does not retain the function name in the
40 | # exception, so we manually add it back in
41 | e = TypeError(f"{_component_.__name__} {str(e)}")
42 | errors.append(e)
43 |
44 | if errors:
45 | raise ConfigError(errors)
46 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/modules/test_layernorm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pytest
8 |
9 | import torch
10 |
11 | from tests.test_utils import assert_expected
12 |
13 | from torchtune.modules.layer_norm import Fp32LayerNorm
14 | from torchtune.training.seed import set_seed
15 |
16 |
17 | @pytest.fixture(autouse=True)
18 | def random():
19 | set_seed(0)
20 |
21 |
22 | class TestLayerNorm:
23 | """
24 | Class for testing our LayerNorm, which is just a wrapper around torch.nn.LayerNorm
25 | to support fp16 training.
26 | """
27 |
28 | @pytest.fixture
29 | def dim(self) -> int:
30 | return 8
31 |
32 | @pytest.fixture
33 | def eps(self) -> float:
34 | return 1e-6
35 |
36 | @pytest.fixture
37 | def input_random_fp16(self, dim) -> torch.Tensor:
38 | return torch.randn(dim, dtype=torch.float16)
39 |
40 | @pytest.fixture
41 | def layer_norm(self, dim, eps) -> Fp32LayerNorm:
42 | return Fp32LayerNorm(dim, eps=eps)
43 |
44 | def test_forward_fp16(self, layer_norm, input_random_fp16, eps, dim) -> None:
45 | output_fp16 = layer_norm(input_random_fp16)
46 |
47 | # assert dtype as fp16
48 | assert (
49 | output_fp16.dtype == torch.float16
50 | ), "Expected output to be fp16, but got {output_fp16.dtype=}"
51 |
52 | # assert value as fp32
53 | expected_output = torch.nn.LayerNorm(dim, eps=eps)(input_random_fp16.float())
54 | output_fp32 = layer_norm(input_random_fp16.float())
55 | assert_expected(
56 | output_fp32.mean(), expected_output.mean(), atol=1e-8, rtol=1e-8
57 | )
58 |
--------------------------------------------------------------------------------
/torchtune/docs/source/_static/img/pytorch-logo-dark.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
25 |
--------------------------------------------------------------------------------
/torchtune/docs/source/api_ref_data.rst:
--------------------------------------------------------------------------------
1 | .. _data:
2 |
3 | ==============
4 | torchtune.data
5 | ==============
6 |
7 | .. currentmodule:: torchtune.data
8 |
9 | Text templates
10 | --------------
11 |
12 | Templates for instruct prompts and chat prompts. Includes some specific formatting for difference datasets
13 | and models.
14 |
15 | .. autosummary::
16 | :toctree: generated/
17 | :nosignatures:
18 |
19 | GrammarErrorCorrectionTemplate
20 | SummarizeTemplate
21 | QuestionAnswerTemplate
22 | PromptTemplate
23 | PromptTemplateInterface
24 | ChatMLTemplate
25 |
26 | Types
27 | -----
28 |
29 | .. autosummary::
30 | :toctree: generated/
31 | :nosignatures:
32 |
33 | Message
34 | Role
35 |
36 | .. _message_transforms_ref:
37 |
38 | Message transforms
39 | ------------------
40 |
41 | Converts data from common schema and conversation JSON formats into a list of torchtune :class:`Message`.
42 |
43 | .. autosummary::
44 | :toctree: generated/
45 | :nosignatures:
46 |
47 | InputOutputToMessages
48 | ShareGPTToMessages
49 | OpenAIToMessages
50 | ChosenRejectedToMessages
51 | AlpacaToMessages
52 |
53 | Collaters
54 | ---------
55 |
56 | Collaters used to collect samples into batches and handle any padding.
57 |
58 | .. autosummary::
59 | :toctree: generated/
60 | :nosignatures:
61 |
62 | padded_collate
63 | padded_collate_dpo
64 | padded_collate_packed
65 | padded_collate_sft
66 | padded_collate_tiled_images_and_mask
67 | left_pad_sequence
68 |
69 | Helper functions
70 | ----------------
71 |
72 | Miscellaneous helper functions used in modifying data.
73 |
74 | .. autosummary::
75 | :toctree: generated/
76 | :nosignatures:
77 |
78 | validate_messages
79 | truncate
80 | load_image
81 | format_content_with_images
82 | mask_messages
83 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/qwen2_5/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._model_builders import (
8 | lora_qwen2_5_0_5b,
9 | lora_qwen2_5_14b_base,
10 | lora_qwen2_5_14b_instruct,
11 | lora_qwen2_5_1_5b_base,
12 | lora_qwen2_5_1_5b_instruct,
13 | lora_qwen2_5_32b_base,
14 | lora_qwen2_5_32b_instruct,
15 | lora_qwen2_5_3b,
16 | lora_qwen2_5_72b_base,
17 | lora_qwen2_5_72b_instruct,
18 | lora_qwen2_5_7b_base,
19 | lora_qwen2_5_7b_instruct,
20 | qwen2_5_0_5b,
21 | qwen2_5_14b_base,
22 | qwen2_5_14b_instruct,
23 | qwen2_5_1_5b_base,
24 | qwen2_5_1_5b_instruct,
25 | qwen2_5_32b_base,
26 | qwen2_5_32b_instruct,
27 | qwen2_5_3b,
28 | qwen2_5_72b_base,
29 | qwen2_5_72b_instruct,
30 | qwen2_5_7b_base,
31 | qwen2_5_7b_instruct,
32 | qwen2_5_tokenizer,
33 | )
34 |
35 | __all__ = [
36 | "lora_qwen2_5_0_5b",
37 | "lora_qwen2_5_14b_base",
38 | "lora_qwen2_5_14b_instruct",
39 | "lora_qwen2_5_1_5b_base",
40 | "lora_qwen2_5_1_5b_instruct",
41 | "lora_qwen2_5_32b_base",
42 | "lora_qwen2_5_32b_instruct",
43 | "lora_qwen2_5_3b",
44 | "lora_qwen2_5_72b_base",
45 | "lora_qwen2_5_72b_instruct",
46 | "lora_qwen2_5_7b_base",
47 | "lora_qwen2_5_7b_instruct",
48 | "qwen2_5_0_5b",
49 | "qwen2_5_14b_base",
50 | "qwen2_5_14b_instruct",
51 | "qwen2_5_1_5b_base",
52 | "qwen2_5_1_5b_instruct",
53 | "qwen2_5_32b_base",
54 | "qwen2_5_32b_instruct",
55 | "qwen2_5_3b",
56 | "qwen2_5_72b_base",
57 | "qwen2_5_72b_instruct",
58 | "qwen2_5_7b_base",
59 | "qwen2_5_7b_instruct",
60 | "qwen2_5_tokenizer",
61 | ]
62 |
--------------------------------------------------------------------------------
/torchtune/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | ifneq ($(EXAMPLES_PATTERN),)
5 | EXAMPLES_PATTERN_OPTS := -D sphinx_gallery_conf.filename_pattern="$(EXAMPLES_PATTERN)"
6 | endif
7 |
8 | # You can set these variables from the command line.
9 | SPHINXOPTS = -W -j auto -T -v $(EXAMPLES_PATTERN_OPTS)
10 | SPHINXBUILD = sphinx-build
11 | SPHINXPROJ = torchtune
12 | SOURCEDIR = source
13 | BUILDDIR = build
14 |
15 | # Put it first so that "make" without argument is like "make help".
16 | help:
17 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
18 |
19 | docset: html
20 | doc2dash --name $(SPHINXPROJ) --icon $(SOURCEDIR)/_static/img/pytorch-logo-flame.png --enable-js --online-redirect-url http://pytorch.org/vision/ --force $(BUILDDIR)/html/
21 |
22 | # Manually fix because Zeal doesn't deal well with `icon.png`-only at 2x resolution.
23 | cp $(SPHINXPROJ).docset/icon.png $(SPHINXPROJ).docset/icon@2x.png
24 | convert $(SPHINXPROJ).docset/icon@2x.png -resize 16x16 $(SPHINXPROJ).docset/icon.png
25 |
26 | html-noplot: # Avoids running the gallery examples, which may take time
27 | $(SPHINXBUILD) -D plot_gallery=0 -b html "${SOURCEDIR}" "$(BUILDDIR)"/html
28 | @echo
29 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
30 |
31 | clean:
32 | rm -rf $(BUILDDIR)/*
33 | rm -rf $(SOURCEDIR)/generated_examples/ # sphinx-gallery
34 | rm -rf $(SOURCEDIR)/gen_modules/ # sphinx-gallery
35 | rm -rf $(SOURCEDIR)/sg_execution_times.rst # sphinx-gallery
36 | rm -rf $(SOURCEDIR)/generated/ # autosummary
37 |
38 | .PHONY: help Makefile docset
39 |
40 | # Catch-all target: route all unknown targets to Sphinx using the new
41 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
42 | %: Makefile
43 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
44 |
--------------------------------------------------------------------------------
/torchtune/recipes/full_finetune_multinode.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # ---------- SBATCH commands ---------- #
9 | #SBATCH --job-name=torchtune-multi-node
10 | #SBATCH --ntasks=2
11 | #SBATCH --nodes=2
12 | #SBATCH --gpus-per-task=8
13 | #SBATCH --cpus-per-task=96
14 | #SBATCH --partition=train
15 |
16 | # ---------- Set env variables ---------- #
17 | # Grab the IP for head node:
18 | # You may need to set this to the fully qualified domain name of your head node
19 | nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
20 | nodes_array=($nodes)
21 | head_node=${nodes_array[0]}
22 | head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
23 | echo Node IP: $head_node_ip
24 |
25 | # You might need to explicitly set the network interface for distributed backends:
26 | # export NCCL_SOCKET_IFNAME=...
27 | # export GLOO_SOCKET_IFNAME=...
28 |
29 | export TORCH_DIST_INIT_BARRIER=1
30 | export LOGLEVEL=INFO
31 |
32 | # ---------- Launch training ---------- #
33 | # You probably want to load in a virtual env w/ conda...
34 | # module load conda
35 | # conda activate torchtune
36 | # ...or venv
37 | # source torchtune/bin/activate
38 |
39 | SHARED_FS=/mnt/slurm # <-- Replace w/ your filesystem
40 | CHECKPOINT_DIR="$SHARED_FS/Llama-3.3-70B-Instruct"
41 | OUTPUT_DIR="$SHARED_FS/Llama3.3-70B-fft-output"
42 |
43 | # Adjust sbatch --ntasks and sbatch --nodes above and --nnodes below to your specific node count
44 | srun tune run --nnodes 2 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" \
45 | full_finetune_distributed --config llama3_3/70B_full_multinode checkpoint_dir=$CHECKPOINT_DIR output_dir=$OUTPUT_DIR
46 |
--------------------------------------------------------------------------------
/torchtune/recipes/dev/multinode_grpo.sbatch:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # ---------- SBATCH commands ---------- #
9 | #SBATCH --time=72:00:00
10 | #SBATCH --job-name=torchtune-multi-node
11 | #SBATCH --constraint=volta32gb
12 | #SBATCH --ntasks-per-node=1
13 | #SBATCH --nodes=2
14 | #SBATCH --exclusive
15 | #SBATCH --gpus-per-task=8
16 | #SBATCH --cpus-per-task=80
17 | #SBATCH --output=slurm_logs/%j/%N.out
18 | #SBATCH --error=slurm_logs/%j/%N.err
19 |
20 |
21 | # ---------- Set env variables ---------- #
22 | # Grab the IP for head node:
23 | # You may need to set this to the fully qualified domain name of your head node
24 | nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
25 | nodes_array=($nodes)
26 | head_node=${nodes_array[0]}
27 | head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
28 | echo Node IP: $head_node_ip
29 |
30 | # You might need to explicitly set the network interface for distributed backends:
31 | # export NCCL_SOCKET_IFNAME=...
32 | # export GLOO_SOCKET_IFNAME=...
33 |
34 | export TORCH_DIST_INIT_BARRIER=1
35 | export LOGLEVEL=INFO
36 |
37 | # ---------- Launch training ---------- #
38 | # You probably want to load in a virtual env w/ conda...
39 | # module load conda
40 | # conda activate torchtune
41 | # ...or venv
42 | # source torchtune/bin/activate
43 |
44 | source ../../.venv/bin/activate
45 |
46 | # Adjust sbatch --ntasks and sbatch --nodes above and --nnodes below to your specific node count
47 | srun --export=ALL,OMP_NUM_THREADS=8 tune run --nnodes 2 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" \
48 | dev/grpo_full_finetune_distributed --config dev/3B_full_grpo "$@"
49 |
--------------------------------------------------------------------------------
/torchtune/recipes/configs/llama4/scout_17B_16E_generation_distributed.yaml:
--------------------------------------------------------------------------------
1 | # Config for running the InferenceRecipe in dev/generate_v2.py to generate output
2 | # from a Llama4 17Bx16E MoE model
3 | #
4 | # This config assumes that you've run the following command before launching
5 | # tune download meta-llama/Llama-4-Scout-17B-16E-Instruct
6 | #
7 | # To launch, run the following command:
8 | # tune run --nproc_per_node 4 dev/generate_v2_distributed --config llama4/scout_17B_16E_generation_distributed
9 |
10 | # Model arguments
11 | model:
12 | _component_: torchtune.models.llama4.llama4_scout_17b_16e
13 |
14 | tensor_parallel_plan:
15 | _component_: torchtune.models.llama4.decoder_only_tp_plan
16 |
17 | tokenizer:
18 | _component_: torchtune.models.llama4.llama4_transform
19 | path: /tmp/Llama-4-Scout-17B-16E-Instruct/tokenizer.model
20 | max_seq_len: null
21 | max_num_tiles: 16
22 |
23 | checkpointer:
24 | _component_: torchtune.training.FullModelHFCheckpointer
25 | checkpoint_dir: /tmp/Llama-4-Scout-17B-16E-Instruct # You can also point this to your finetuned model!
26 | checkpoint_files:
27 | filename_format: model-{}-of-{}.safetensors
28 | max_filename: "00050"
29 | output_dir: ./ # No need for an output dir
30 | model_type: LLAMA4
31 |
32 | use_distributed_state_dict: True
33 | use_flex: True # Use PyTorch's FlexAttention for construction of attention masks
34 |
35 | # Environment
36 | device: cuda
37 | dtype: bf16
38 | seed: 1234
39 | log_level: INFO # DEBUG, WARN, etc.
40 |
41 | # Generation arguments
42 | prompt:
43 | system: You are a helpful assistant who responds like the author Shakespeare.
44 | user:
45 | image: https://upload.wikimedia.org/wikipedia/commons/thumb/4/4c/Olive_baboon.jpg/330px-Olive_baboon.jpg
46 | text: What is in this image?
47 | max_new_tokens: 200
48 | temperature: 0.6 # 0.8 and 0.6 are popular values to try
49 | top_k: 300
50 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/mistral/test_mistral.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pytest
8 | import torch
9 | from tests.test_utils import fixed_init_model
10 | from tests.torchtune.models.mistral.scripts.mistral_test_config import MistralTestConfig
11 | from torchtune.models.mistral import mistral
12 | from torchtune.training.seed import set_seed
13 |
14 |
15 | @pytest.fixture(autouse=True)
16 | def random():
17 | set_seed(MistralTestConfig.SEED)
18 |
19 |
20 | class TestMistral:
21 | @pytest.fixture
22 | def inputs(self):
23 | return torch.randint(
24 | 0,
25 | MistralTestConfig.VOCAB_SIZE,
26 | (MistralTestConfig.BSZ, MistralTestConfig.SEQ_LEN),
27 | )
28 |
29 | def test_forward(self, inputs):
30 | model = mistral(
31 | vocab_size=MistralTestConfig.VOCAB_SIZE,
32 | embed_dim=MistralTestConfig.EMBED_DIM,
33 | num_heads=MistralTestConfig.NUM_HEADS,
34 | num_layers=MistralTestConfig.NUM_LAYERS,
35 | num_kv_heads=MistralTestConfig.NUM_KV_HEADS,
36 | max_seq_len=MistralTestConfig.MAX_SEQ_LEN,
37 | intermediate_dim=MistralTestConfig.INTERMEDIATE_DIM,
38 | norm_eps=MistralTestConfig.NORM_EPS,
39 | rope_base=MistralTestConfig.ROPE_BASE,
40 | )
41 | fixed_init_model(model)
42 | actual = model(inputs)
43 | expected = torch.tensor(18.2749)
44 | assert actual.shape == (
45 | MistralTestConfig.BSZ,
46 | MistralTestConfig.SEQ_LEN,
47 | MistralTestConfig.VOCAB_SIZE,
48 | )
49 | torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-4)
50 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/dev/rl/rewards/test_rewards.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pytest
8 | import torch
9 | from torchtune.dev.rl.rewards import RewardOutput
10 |
11 |
12 | class TestRewardOutput:
13 | @pytest.fixture
14 | def sample_reward_output(self):
15 | return RewardOutput(
16 | reward_base_name="test_reward",
17 | total_reward=torch.tensor([1.0, 2.0, 3.0]),
18 | successes=torch.tensor([1.0, 0.0, 1.0]),
19 | rewards={
20 | "sub_reward_1": torch.tensor([0.5, 1.5, 2.5]),
21 | "sub_reward_2": torch.tensor([10.0, 20.0, 30.0]),
22 | },
23 | )
24 |
25 | def test_log(self, sample_reward_output):
26 | log_dict = sample_reward_output.log(prefix="train")
27 | expected_log = {
28 | "train/test_reward/sub_reward_1": 1.5,
29 | "train/test_reward/sub_reward_2": 20.0,
30 | "train/test_reward": 2.0,
31 | "train/test_reward/successes": 2.0 / 3.0,
32 | }
33 | assert log_dict.keys() == expected_log.keys()
34 | for key in expected_log:
35 | assert log_dict[key] == pytest.approx(expected_log[key])
36 |
37 | def test_log_no_prefix(self, sample_reward_output):
38 | log_dict = sample_reward_output.log()
39 | expected_log = {
40 | "test_reward/sub_reward_1": 1.5,
41 | "test_reward/sub_reward_2": 20.0,
42 | "test_reward": 2.0,
43 | "test_reward/successes": 2.0 / 3.0,
44 | }
45 | assert log_dict.keys() == expected_log.keys()
46 | for key in expected_log:
47 | assert log_dict[key] == pytest.approx(expected_log[key])
48 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/t5/_model_builders.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from torchtune.models.t5._component_builders import t5_encoder
8 | from torchtune.models.t5._encoder import T5Encoder
9 | from torchtune.models.t5._tokenizer import T5Tokenizer
10 |
11 |
12 | def t5_v1_1_xxl_encoder(max_seq_len: int = 512) -> T5Encoder:
13 | """
14 | Builder for the T5 v1.1 XXL (11B parameters) encoder.
15 |
16 | T5 paper: https://arxiv.org/abs/1910.10683
17 |
18 | 1.1 release:
19 | https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511
20 |
21 | Args:
22 | max_seq_len (int): The maximum sequence length (context length) of the model.
23 | Default: 512
24 |
25 | Returns:
26 | T5Encoder: Instantiation of the T5 encoder
27 | """
28 | return t5_encoder(
29 | embed_dim=4096,
30 | mlp_dim=10240,
31 | num_heads=64,
32 | head_dim=64,
33 | num_layers=24,
34 | rel_pos_num_buckets=32,
35 | rel_pos_max_dist=128,
36 | vocab_size=32128,
37 | norm_eps=1e-6,
38 | max_seq_len=max_seq_len,
39 | )
40 |
41 |
42 | def t5_tokenizer(path: str, max_seq_len: int = 512, truncate: bool = True):
43 | """
44 | Builder for the T5 tokenizer.
45 |
46 | Args:
47 | path (str): the path to the T5 sentencepiece tokenizer file
48 | max_seq_len (int): the context length
49 | truncate (bool): whether to truncate the token sequence when longer than max_seq_len
50 |
51 | Returns:
52 | T5Tokenizer: Instantiation of the T5 tokenizer
53 | """
54 | return T5Tokenizer(path, max_seq_len=max_seq_len, truncate=truncate)
55 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/llama2/test_llama2_prompt_template.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from tests.test_utils import assert_dialogue_equal, MESSAGE_SAMPLE
8 | from torchtune.data import Message
9 | from torchtune.models.llama2 import Llama2ChatTemplate
10 |
11 |
12 | class TestLlama2ChatTemplate:
13 | expected_dialogue = [
14 | Message(
15 | role="user",
16 | content="[INST] <>\nYou are an AI assistant. User will you give you a task. "
17 | "Your goal is to complete the task as faithfully as you can. While performing "
18 | "the task think step-by-step and justify your steps.\n<>\n\nPlease "
19 | "briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old "
20 | "Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? "
21 | "How about on an icy road? Well one father in Russia did just that, and recorded "
22 | "the entire thing. To her credit, the child seemed to be doing a great job. "
23 | "(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n"
24 | "Summary: [/INST] ",
25 | ),
26 | Message(
27 | role="assistant",
28 | content="A father in Russia allowed his 8-year-old child to drive his car on an "
29 | "icy road and recorded the event. The child appeared to be handling the situation well, "
30 | "showcasing their driving skills despite the challenging conditions.",
31 | ),
32 | ]
33 |
34 | def test_call(self):
35 | actual = Llama2ChatTemplate()(MESSAGE_SAMPLE)
36 | assert_dialogue_equal(actual, self.expected_dialogue)
37 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/modules/transforms/tokenizers/test_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pytest
8 |
9 | from tests.test_utils import DummyTokenizer
10 | from torchtune.data import Message
11 |
12 | from torchtune.modules.transforms.tokenizers import tokenize_messages_no_special_tokens
13 |
14 |
15 | class TestTokenizerUtils:
16 | @pytest.fixture
17 | def tokenizer(self):
18 | return DummyTokenizer(max_seq_len=100)
19 |
20 | @pytest.fixture
21 | def messages(self):
22 | return [
23 | Message(role="user", content="hello world!", masked=True),
24 | Message(role="assistant", content="hello back!"),
25 | ]
26 |
27 | @pytest.mark.parametrize(
28 | "add_bos, add_eos",
29 | [
30 | (True, True),
31 | (False, False),
32 | ],
33 | )
34 | def test_tokenize_no_special_tokens(self, tokenizer, messages, add_bos, add_eos):
35 | tokens, mask = tokenize_messages_no_special_tokens(
36 | tokenizer,
37 | messages,
38 | bos_id=tokenizer.bos_id if add_bos else None,
39 | eos_id=tokenizer.eos_id if add_eos else None,
40 | )
41 |
42 | assert len(tokens) == len(mask)
43 |
44 | # User message should be masked
45 | assert mask[0] is True
46 | # Assistant message should not be masked
47 | assert mask[-1] is False
48 |
49 | if add_bos:
50 | assert tokens[0] == tokenizer.bos_id
51 | else:
52 | assert tokens[0] != tokenizer.bos_id
53 |
54 | if add_eos:
55 | assert tokens[-1] == tokenizer.eos_id
56 | else:
57 | assert tokens[-1] != tokenizer.eos_id
58 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/mistral/test_mistral_classifier.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pytest
8 | import torch
9 | from tests.test_utils import fixed_init_model
10 | from torchtune.models.mistral import mistral_classifier
11 | from torchtune.training.seed import set_seed
12 |
13 | NUM_LAYERS = 4
14 | NUM_HEADS = 16
15 | NUM_KV_HEADS = 8
16 | VOCAB_SIZE = 32000
17 | MAX_SEQ_LEN = 2048
18 | INTERMEDIATE_DIM = 512
19 |
20 |
21 | @pytest.fixture(autouse=True)
22 | def random():
23 | set_seed(16)
24 |
25 |
26 | class TestMistralClassifier:
27 | # expected values are calculated using
28 | # `tests.torchtune.models.scripts.compare_mistral_classifier`
29 | @pytest.mark.parametrize(
30 | "bsz, embed_dim, seq_len, n_classes, expected",
31 | [
32 | (2, 64, 64, 2, 22.6879),
33 | (1, 256, 256, 1, 110.2561),
34 | ],
35 | )
36 | def test_forward(
37 | self, bsz: int, embed_dim: int, seq_len: int, n_classes: int, expected: float
38 | ):
39 | inputs = torch.randint(low=0, high=VOCAB_SIZE, size=(bsz, seq_len))
40 | model = mistral_classifier(
41 | num_classes=n_classes,
42 | vocab_size=VOCAB_SIZE,
43 | num_layers=n_classes,
44 | num_heads=NUM_HEADS,
45 | num_kv_heads=NUM_KV_HEADS,
46 | embed_dim=embed_dim,
47 | intermediate_dim=INTERMEDIATE_DIM,
48 | max_seq_len=MAX_SEQ_LEN,
49 | )
50 | fixed_init_model(model)
51 | actual = model(inputs)
52 | expected = torch.tensor(expected)
53 | assert actual.shape == (bsz, seq_len, n_classes)
54 | torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-4)
55 |
--------------------------------------------------------------------------------
/torchtune/torchtune/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from torchtune.data._collate import (
8 | left_pad_sequence,
9 | padded_collate,
10 | padded_collate_dpo,
11 | padded_collate_packed,
12 | padded_collate_sft,
13 | padded_collate_tiled_images_and_mask,
14 | )
15 | from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
16 | from torchtune.data._messages import (
17 | AlpacaToMessages,
18 | ChosenRejectedToMessages,
19 | InputOutputToMessages,
20 | mask_messages,
21 | Message,
22 | OpenAIToMessages,
23 | Role,
24 | ShareGPTToMessages,
25 | validate_messages,
26 | )
27 | from torchtune.data._prompt_templates import (
28 | ChatMLTemplate,
29 | GrammarErrorCorrectionTemplate,
30 | PromptTemplate,
31 | PromptTemplateInterface,
32 | QuestionAnswerTemplate,
33 | SummarizeTemplate,
34 | )
35 | from torchtune.data._utils import format_content_with_images, load_image, truncate
36 |
37 | __all__ = [
38 | "CROSS_ENTROPY_IGNORE_IDX",
39 | "GrammarErrorCorrectionTemplate",
40 | "SummarizeTemplate",
41 | "OpenAIToMessages",
42 | "ShareGPTToMessages",
43 | "AlpacaToMessages",
44 | "truncate",
45 | "Message",
46 | "validate_messages",
47 | "mask_messages",
48 | "Role",
49 | "format_content_with_images",
50 | "PromptTemplateInterface",
51 | "PromptTemplate",
52 | "InputOutputToMessages",
53 | "ChosenRejectedToMessages",
54 | "QuestionAnswerTemplate",
55 | "ChatMLTemplate",
56 | "padded_collate_sft",
57 | "padded_collate_dpo",
58 | "left_pad_sequence",
59 | "padded_collate",
60 | "padded_collate_tiled_images_and_mask",
61 | "padded_collate_packed",
62 | "load_image",
63 | ]
64 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/modules/test_feed_forward.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Tuple
8 |
9 | import pytest
10 |
11 | import torch
12 |
13 | from tests.test_utils import assert_expected, fixed_init_model
14 | from torch import nn
15 |
16 | from torchtune.modules import FeedForward
17 | from torchtune.training.seed import set_seed
18 |
19 |
20 | @pytest.fixture(autouse=True)
21 | def random():
22 | set_seed(0)
23 |
24 |
25 | class TestFeedForward:
26 | """Class for testing FFN implementation."""
27 |
28 | @pytest.fixture
29 | def input_params(self) -> Tuple[int, int]:
30 | dim = 4096
31 | hidden_dim = 11008 # Scaled for SwiGLU
32 | return dim, hidden_dim
33 |
34 | @pytest.fixture
35 | def input(self, input_params: Tuple[int, int]) -> torch.Tensor:
36 | dim, _ = input_params
37 | return torch.randn(1, dim)
38 |
39 | @pytest.fixture
40 | def ffn(self, input_params: Tuple[int, int]) -> FeedForward:
41 | dim, hidden_dim = input_params
42 | gate_proj = nn.Linear(dim, hidden_dim, bias=False)
43 | down_proj = nn.Linear(hidden_dim, dim, bias=False)
44 | up_proj = nn.Linear(dim, hidden_dim, bias=False)
45 | ff = FeedForward(
46 | gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj
47 | ).eval()
48 | fixed_init_model(ff)
49 | ff.eval()
50 | return ff
51 |
52 | def test_forward(self, input: torch.Tensor, ffn: FeedForward) -> None:
53 | with torch.no_grad():
54 | x_out = ffn(input)
55 | assert_expected(x_out.mean(), torch.tensor(251.5356), atol=1e-7, rtol=1e-3)
56 | assert_expected(x_out.max(), torch.tensor(503.0614), atol=1e-7, rtol=1e-3)
57 |
--------------------------------------------------------------------------------
/torchtune/torchtune/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from torchtune.datasets import multimodal
8 | from torchtune.datasets._alpaca import alpaca_cleaned_dataset, alpaca_dataset
9 | from torchtune.datasets._chat import chat_dataset
10 | from torchtune.datasets._cnn_dailymail import cnn_dailymail_articles_dataset
11 | from torchtune.datasets._concat import ConcatDataset
12 | from torchtune.datasets._grammar import grammar_dataset
13 | from torchtune.datasets._hh_rlhf_helpful import hh_rlhf_helpful_dataset
14 | from torchtune.datasets._instruct import instruct_dataset
15 | from torchtune.datasets._packed import PackedDataset
16 | from torchtune.datasets._preference import preference_dataset, PreferenceDataset
17 | from torchtune.datasets._samsum import samsum_dataset
18 | from torchtune.datasets._sft import SFTDataset
19 | from torchtune.datasets._slimorca import slimorca_dataset
20 | from torchtune.datasets._stack_exchange_paired import stack_exchange_paired_dataset
21 | from torchtune.datasets._text_completion import (
22 | text_completion_dataset,
23 | TextCompletionDataset,
24 | )
25 | from torchtune.datasets._wikitext import wikitext_dataset
26 |
27 | __all__ = [
28 | "alpaca_dataset",
29 | "alpaca_cleaned_dataset",
30 | "grammar_dataset",
31 | "samsum_dataset",
32 | "stack_exchange_paired_dataset",
33 | "slimorca_dataset",
34 | "instruct_dataset",
35 | "preference_dataset",
36 | "chat_dataset",
37 | "text_completion_dataset",
38 | "TextCompletionDataset",
39 | "cnn_dailymail_articles_dataset",
40 | "PackedDataset",
41 | "ConcatDataset",
42 | "wikitext_dataset",
43 | "PreferenceDataset",
44 | "SFTDataset",
45 | "hh_rlhf_helpful_dataset",
46 | "multimodal",
47 | ]
48 |
--------------------------------------------------------------------------------
/torchtune/docs/source/basics/datasets_overview.rst:
--------------------------------------------------------------------------------
1 | .. _datasets_overview:
2 |
3 | =================
4 | Datasets Overview
5 | =================
6 | torchtune lets you fine-tune LLMs and VLMs using any dataset found on Hugging Face Hub, downloaded locally,
7 | or on a remote url. We provide built-in dataset builders to help you quickly bootstrap your fine-tuning project
8 | for workflows including instruct tuning, preference alignment, continued pretraining, and more. Beyond those, torchtune
9 | enables full customizability on your dataset pipeline, letting you train on any data format or schema.
10 |
11 | The following tasks are supported:
12 |
13 | - Text supervised fine-tuning
14 | - :ref:`instruct_dataset_usage_label`
15 | - :ref:`chat_dataset_usage_label`
16 | - Multimodal supervised fine-tuning
17 | - :ref:`multimodal_dataset_usage_label`
18 | - RLHF
19 | - :ref:`preference_dataset_usage_label`
20 | - Continued pre-training
21 | - :ref:`text_completion_dataset_usage_label`
22 |
23 | Data pipeline
24 | -------------
25 | .. image:: /_static/img/torchtune_datasets.svg
26 |
27 | From raw data samples to the model inputs in the training recipe, all torchtune datasets follow
28 | the same pipeline:
29 |
30 | 1. Raw data is queried one sample at a time from a Hugging Face dataset, local file, or remote file
31 | 2. :ref:`message_transform_usage_label` convert the raw sample which can take any format into a list of torchtune
32 | :ref:`messages_usage_label`. Images are contained in the message object they are associated with.
33 | 3. :ref:`model_transform_usage_label` applies model-specific transforms to the messages, including tokenization (see :ref:`tokenizers_usage_label`),
34 | prompt templating (see :ref:`prompt_templates_usage_label`), image transforms, and anything else required for that particular model.
35 | 4. The collater packages the processed samples together in a batch and the batch is passed into the model during training.
36 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/modules/loss/test_ce_chunked_output_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pytest
8 | import torch
9 | from tests.test_utils import assert_expected
10 | from torchtune.modules.loss import CEWithChunkedOutputLoss
11 | from torchtune.training.seed import set_seed
12 |
13 |
14 | @pytest.fixture(autouse=True)
15 | def random():
16 | set_seed(42)
17 |
18 |
19 | class TestCEWithChunkedOutputLoss:
20 | def test_chunked_cross_entropy_loss(self):
21 | # Create a sample input and label
22 | ignore_index = -100
23 | batch_size = 3
24 | num_tokens = 50
25 | vocab_size = 50
26 | logits = torch.randn(batch_size, num_tokens, vocab_size, dtype=torch.bfloat16)
27 | labels = torch.randint(
28 | 0, vocab_size, (batch_size, num_tokens), dtype=torch.long
29 | )
30 |
31 | # add random ignore index to random tokens in the label
32 | random_indices = torch.randint(0, num_tokens, (batch_size, num_tokens))
33 | labels[random_indices < num_tokens // 5] = ignore_index
34 |
35 | # chunked CE
36 | ce_loss = CEWithChunkedOutputLoss(
37 | num_output_chunks=8, ignore_index=ignore_index
38 | )
39 | logits_chunks = logits.tensor_split(ce_loss.num_output_chunks, dim=1)
40 | chunked_loss = ce_loss(logits_chunks, labels)
41 |
42 | # vanilla CE
43 | logits = logits.reshape(-1, logits.size(-1))
44 | labels = labels.reshape(-1)
45 | standard_loss = torch.nn.functional.cross_entropy(
46 | logits.float(), labels, reduction="mean", ignore_index=ignore_index
47 | )
48 |
49 | # Assert
50 | assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2)
51 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/clip/test_clip_tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import pytest
7 |
8 | from tests.common import ASSETS
9 | from torchtune.models.clip._model_builders import clip_tokenizer
10 |
11 |
12 | class TestCLIPTokenizer:
13 | @pytest.fixture
14 | def tokenizer(self):
15 | return clip_tokenizer(ASSETS / "tiny_bpe_merges.txt")
16 |
17 | def test_encoding(self, tokenizer):
18 | texts = [
19 | "a cow jumping over the moon",
20 | "a helpful AI assistant",
21 | ]
22 | correct_tokens = [
23 | [
24 | 2416,
25 | 320,
26 | 66,
27 | 78,
28 | 342,
29 | 73,
30 | 669,
31 | 79,
32 | 515,
33 | 326,
34 | 1190,
35 | 337,
36 | 673,
37 | 324,
38 | 76,
39 | 819,
40 | 333,
41 | 2417,
42 | ],
43 | [2416, 320, 516, 75, 79, 69, 84, 331, 64, 328, 813, 667, 540, 339, 2417],
44 | ]
45 | for text, correct in zip(texts, correct_tokens):
46 | tokens = tokenizer.encode(text)
47 | assert tokens == correct
48 |
49 | def test_decoding(self, tokenizer):
50 | text = "this is torchtune"
51 | decoded_text = "<|startoftext|>this is torchtune <|endoftext|>"
52 | assert decoded_text == tokenizer.decode(tokenizer.encode(text))
53 |
54 | def test_call(self, tokenizer):
55 | sample = {"text": "hello world"}
56 | sample = tokenizer(sample)
57 | assert "text" not in sample
58 | assert "tokens" in sample
59 |
--------------------------------------------------------------------------------
/torchtune/torchtune/modules/feed_forward.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | from typing import Optional
9 |
10 | import torch
11 | from torch import nn
12 |
13 |
14 | class FeedForward(nn.Module):
15 | """This class implements the feed-forward network derived from Llama2.
16 |
17 | Args:
18 | gate_proj (nn.Module): Projection from input dim to hidden dim, fed through activation
19 | and multiplied by up_proj.
20 | down_proj (nn.Module): Final projection to output dim.
21 | up_proj (Optional[nn.Module]): Projection from input dim to hidden dim, multiplied by
22 | activation(gate_proj).
23 | activation (nn.Module): Activation function to use. Default is nn.SiLU().
24 | """
25 |
26 | def __init__(
27 | self,
28 | *,
29 | gate_proj: nn.Module,
30 | down_proj: nn.Module,
31 | up_proj: Optional[nn.Module] = None,
32 | activation: nn.Module = nn.SiLU(),
33 | ):
34 | super().__init__()
35 | self.w1 = gate_proj
36 | self.w2 = down_proj
37 | self.w3 = up_proj
38 | self.activation = activation
39 |
40 | def forward(self, x: torch.Tensor) -> torch.Tensor:
41 | """
42 | Args:
43 | x (torch.Tensor): input tensor with shape ``(..., in_dim)``, where ``in_dim`` is the
44 | input dimension of both ``gate_proj`` and ``up_proj``.
45 |
46 | Returns:
47 | torch.Tensor: output tensor with shape ``(..., out_dim)``, where ``out_dim`` is the \
48 | output dimension of ``down_proj``.
49 | """
50 | h = self.activation(self.w1(x))
51 | if self.w3 is not None:
52 | h = h * self.w3(x)
53 | h = self.w2(h)
54 | return h
55 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/models/mistral/test_mistral_prompt_template.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pytest
8 | from tests.test_utils import assert_dialogue_equal, MESSAGE_SAMPLE
9 | from torchtune.data import Message
10 | from torchtune.models.mistral import MistralChatTemplate
11 |
12 |
13 | class TestMistralChatTemplate:
14 | expected_dialogue = [
15 | Message(
16 | role="user",
17 | content="[INST] Please briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old "
18 | "Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? "
19 | "How about on an icy road? Well one father in Russia did just that, and recorded "
20 | "the entire thing. To her credit, the child seemed to be doing a great job. "
21 | "(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n"
22 | "Summary: [/INST] ",
23 | ),
24 | Message(
25 | role="assistant",
26 | content="A father in Russia allowed his 8-year-old child to drive his car on an "
27 | "icy road and recorded the event. The child appeared to be handling the situation well, "
28 | "showcasing their driving skills despite the challenging conditions.",
29 | ),
30 | ]
31 |
32 | def test_format(self):
33 | no_system_sample = MESSAGE_SAMPLE[1:]
34 | actual = MistralChatTemplate()(no_system_sample)
35 | assert_dialogue_equal(actual, self.expected_dialogue)
36 |
37 | def test_format_with_system_prompt_raises(self):
38 | with pytest.raises(
39 | ValueError, match="System prompts are not supported in MistralChatTemplate"
40 | ):
41 | _ = MistralChatTemplate()(MESSAGE_SAMPLE)
42 |
--------------------------------------------------------------------------------
/torchtune/torchtune/modules/rms_norm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import nn
9 |
10 |
11 | class RMSNorm(nn.Module):
12 | """
13 | Root Mean Square Normalization in fp32.
14 |
15 | See: https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html
16 |
17 | Args:
18 | dim (int): embedding size
19 | eps (float): small value to avoid division by zero. Default: 1e-6
20 | """
21 |
22 | def __init__(self, dim: int, eps: float = 1e-6) -> None:
23 | super().__init__()
24 | self.normalized_shape = (dim,)
25 | self.eps = eps
26 | self.scale = nn.Parameter(torch.ones(dim))
27 |
28 | def forward(self, x: torch.Tensor) -> torch.Tensor:
29 | """
30 | Args:
31 | x (torch.Tensor): input tensor to normalize
32 |
33 | Returns:
34 | torch.Tensor: The normalized and scaled tensor having the same shape as ``x``.
35 | """
36 | # computation is in fp32
37 | x_fp32 = x.float()
38 | x_normed = (
39 | x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps)
40 | ).type_as(x)
41 | return x_normed * self.scale
42 |
43 |
44 | def rms_norm(x: torch.Tensor, eps: float = 1e-6):
45 | """
46 | This is just a functional RMSNorm without the trainable scale parameter.
47 |
48 | Args:
49 | x (torch.Tensor): input tensor to normalize
50 | eps (float): small value to avoid division by zero. Default: 1e-6
51 |
52 | Returns:
53 | torch.Tensor: The normalized tensor having the same shape as ``x``.
54 |
55 | """
56 | x_fp32 = x.float()
57 | x_normed = (
58 | x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + eps)
59 | ).type_as(x)
60 | return x_normed
61 |
--------------------------------------------------------------------------------
/torchtune/tests/torchtune/datasets/multimodal/test_vqa_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pytest
8 |
9 | import torch
10 | from tests.common import ASSETS
11 | from tests.test_utils import DummyTokenizer
12 | from torchtune.datasets.multimodal import vqa_dataset
13 |
14 |
15 | class TestMultimodalInstructDataset:
16 | @pytest.fixture
17 | def tokenizer(self):
18 | return DummyTokenizer()
19 |
20 | def test_get_item(self, tokenizer):
21 | system_prompt = "follow this prompt"
22 |
23 | dataset = vqa_dataset(
24 | model_transform=tokenizer,
25 | source="json",
26 | data_files=str(ASSETS / "vqa_tiny.json"),
27 | split="train",
28 | new_system_prompt=system_prompt,
29 | )
30 |
31 | expected_tokens = [
32 | [0, 6, 4, 6, -2, 4, 2, 9, 2, 6, 7, 5, -1],
33 | ]
34 |
35 | expected_labels = [
36 | [-100, -100, -100, -100, -100, -100, -100, -100, -100, 7, 5, -1, -100]
37 | ]
38 |
39 | assert len(dataset) == 1
40 |
41 | for i in range(len(dataset)):
42 | prompt, label, image = (
43 | dataset[i]["tokens"],
44 | dataset[i]["labels"],
45 | dataset[i]["images"],
46 | )
47 | assert prompt == expected_tokens[i]
48 | assert label == expected_labels[i]
49 | assert isinstance(image[0], torch.Tensor)
50 |
51 | def test_dataset_fails_with_packed(self, tokenizer):
52 | with pytest.raises(
53 | ValueError, match="Multimodal datasets don't support packing yet."
54 | ):
55 | vqa_dataset(
56 | model_transform=tokenizer,
57 | source="json",
58 | packed=True,
59 | )
60 |
--------------------------------------------------------------------------------
/torchtune/torchtune/models/llama4/_chunked_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Optional
8 |
9 | import torch
10 | from torchtune.modules.attention_utils import (
11 | _MaskType,
12 | _SUPPORTS_FLEX_ATTENTION,
13 | causal_mask_flex,
14 | )
15 |
16 | if _SUPPORTS_FLEX_ATTENTION:
17 | from torch.nn.attention.flex_attention import BlockMask, create_block_mask
18 |
19 |
20 | def get_chunked_attention_mask(
21 | mask: Optional[_MaskType],
22 | chunk_size: int,
23 | bsz: int,
24 | seq_len: int,
25 | ) -> _MaskType:
26 | """ """
27 | # TODO: check this somewhere that doesn't get called every forward
28 | if not _SUPPORTS_FLEX_ATTENTION:
29 | raise ValueError("Local attention is only supported with flex attention.")
30 | if mask is None:
31 | mask_mod = causal_mask_flex
32 | q_seq_len, kv_seq_len = seq_len, seq_len
33 | elif isinstance(mask, BlockMask):
34 | mask_mod = mask.mask_mod
35 | q_seq_len, kv_seq_len = mask.seq_lengths
36 | else:
37 | raise ValueError("Unsupported mask type")
38 |
39 | def chunked_attention_mask_mod(
40 | b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor
41 | ):
42 | # Get the chunk index of the query and key
43 | q_chunk = q_idx // chunk_size
44 | kv_chunk = kv_idx // chunk_size
45 | # Only allow attention within the same batch
46 | same_chunk = q_chunk == kv_chunk
47 | # Apply the original mask mod
48 | inner_mask = mask_mod(b, h, q_idx % chunk_size, kv_idx % chunk_size)
49 | return same_chunk & inner_mask
50 |
51 | return create_block_mask(
52 | chunked_attention_mask_mod,
53 | bsz,
54 | None,
55 | q_seq_len,
56 | kv_seq_len,
57 | )
58 |
--------------------------------------------------------------------------------