├── version.txt ├── .github ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── documentation.yml │ ├── feature-request.yml │ └── bug-report.yml ├── PULL_REQUEST_TEMPLATE.md ├── packaging │ ├── post_build_script.sh │ ├── pre_build_cpu.sh │ ├── vllm_reqs_12_8.txt │ └── vllm_reqs_12_9.txt └── workflows │ ├── lint.yaml │ ├── gpu_test.yaml │ ├── build_vllm.yaml │ └── docs.yml ├── docs ├── source │ ├── metric_logging.md │ ├── tutorial_sources │ │ ├── README.txt │ │ └── template_tutorial.py │ ├── tutorials.md │ ├── api_service.md │ ├── api_actors.md │ ├── api_generator.md │ ├── _static │ │ ├── logo-icon.svg │ │ └── custom.css │ ├── api_model.md │ ├── zero-to-forge-intro.md │ ├── api.md │ └── api_trainer.md ├── license_header.txt ├── Makefile └── make.bat ├── apps ├── grpo │ ├── wandb_llama8b.png │ ├── __init__.py │ ├── slurm │ │ ├── submit.sh │ │ ├── submit_grpo.sh │ │ ├── qwen3_8b.yaml │ │ ├── qwen3_30b_a3b.yaml │ │ └── qwen3_32b.yaml │ ├── README.md │ ├── data.py │ ├── grading.py │ ├── qwen3_8b.yaml │ ├── llama3_8b.yaml │ └── qwen3_1_7b.yaml └── sft │ ├── qwen3_8b.yaml │ └── llama3_8b.yaml ├── tests ├── __init__.py ├── unit_tests │ ├── __init__.py │ ├── util │ │ └── __init__.py │ ├── datasets │ │ └── __init__.py │ ├── examples │ │ ├── __init__.py │ │ └── gsm8k │ │ │ └── __init__.py │ ├── observability │ │ ├── __init__.py │ │ ├── test_utils.py │ │ └── conftest.py │ ├── test_torchstore_utils.py │ ├── test_coder.py │ └── test_env_constants.py ├── integration_tests │ ├── __init__.py │ ├── conftest.py │ ├── fixtures │ │ ├── qwen3_1_7b_no_tp.yaml │ │ └── qwen3_1_7b_tp.yaml │ └── test_coder.py ├── sandbox │ ├── toy_rl │ │ ├── __init__.py │ │ ├── sumdigits-tp.yaml │ │ ├── sumdigits.yaml │ │ └── toy_metrics │ │ │ └── main.py │ ├── vllm │ │ ├── qwen2_5_32b.yaml │ │ ├── llama3_8b.yaml │ │ ├── deepseek_r1.yaml │ │ └── main.py │ └── weight_sync │ │ └── qwen3_1_7b.yaml ├── assets │ ├── custom_schedule.csv │ └── extend_jobconfig_example.py ├── test_utils.py ├── conftest.py └── README.md ├── src └── forge │ ├── losses │ ├── __init__.py │ ├── grpo_loss.py │ └── reinforce_loss.py │ ├── data_models │ ├── __init__.py │ ├── completion.py │ └── prompt.py │ ├── util │ ├── __init__.py │ ├── distributed.py │ ├── checkpoint.py │ └── logging.py │ ├── rl │ ├── __init__.py │ ├── advantage.py │ ├── collate.py │ ├── grading.py │ └── types.py │ ├── data │ ├── __init__.py │ ├── datasets │ │ └── __init__.py │ └── metric_transform.py │ ├── controller │ ├── __init__.py │ └── service │ │ ├── __init__.py │ │ ├── spawn.py │ │ ├── metrics.py │ │ └── router.py │ ├── __init__.py │ ├── actors │ ├── trainer │ │ └── __init__.py │ ├── __init__.py │ └── _torchstore_utils.py │ ├── api │ └── __init__.py │ ├── observability │ ├── __init__.py │ └── utils.py │ ├── env.py │ ├── interfaces.py │ └── types.py ├── CONTRIBUTING.md ├── assets └── versions.sh ├── .flake8 ├── .pre-commit-config.yaml ├── LICENSE ├── pyproject.toml ├── CODE_OF_CONDUCT.md ├── README.md └── .gitignore /version.txt: -------------------------------------------------------------------------------- 1 | 0.1.0 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: True 2 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | ## Test plan 4 | -------------------------------------------------------------------------------- /docs/source/metric_logging.md: -------------------------------------------------------------------------------- 1 | ```{include} ../../src/forge/observability/README.md 2 | -------------------------------------------------------------------------------- /apps/grpo/wandb_llama8b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/torchforge/HEAD/apps/grpo/wandb_llama8b.png -------------------------------------------------------------------------------- /docs/license_header.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) Meta Platforms, Inc. and affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the BSD-style license found in the 5 | LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /apps/grpo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /docs/source/tutorial_sources/README.txt: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ========= 3 | 4 | This gallery contains tutorials and examples to help you get started with TorchForge. 5 | Each tutorial demonstrates specific features and use cases with practical examples. 6 | -------------------------------------------------------------------------------- /src/forge/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /src/forge/data_models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/integration_tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/sandbox/toy_rl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/unit_tests/util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/unit_tests/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/unit_tests/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/unit_tests/examples/gsm8k/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/unit_tests/observability/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /docs/source/tutorials.md: -------------------------------------------------------------------------------- 1 | # Tutorials 2 | 3 | This section provides step-by-step guides to help you master TorchForge's capabilities, 4 | from basic model fine-tuning to advanced distributed training scenarios. 5 | 6 | ```{toctree} 7 | :maxdepth: 1 8 | 9 | zero-to-forge-intro 10 | metric_logging 11 | ``` 12 | -------------------------------------------------------------------------------- /docs/source/api_service.md: -------------------------------------------------------------------------------- 1 | # Service 2 | 3 | ```{eval-rst} 4 | .. currentmodule:: forge.controller.service.service 5 | ``` 6 | 7 | ```{eval-rst} 8 | .. autoclass:: Service 9 | 10 | :members: call_all, start_session, get_metrics, get_metrics_summary, terminate_session, stop 11 | :show-inheritance: 12 | ``` 13 | -------------------------------------------------------------------------------- /tests/assets/custom_schedule.csv: -------------------------------------------------------------------------------- 1 | 0F0,0F1,0F2,0F3,0F4,0F5,0F6,0F7,2F0,2F1,2F2,2F3,2F4,2F5,2F6,2F7,2I0,2W0,2I1,2W1,0I0,0W0,0I1,0W1,2I2,2W2,2I3,2W3,0I2,0W2,0I3,0W3,2I4,2W4,2I5,2W5,0I4,0W4,0I5,0W5,2I6,2W6,2I7,2W7,0I6,0W6,0I7,0W7 2 | 1F0,1F1,1F2,1F3,1F4,1F5,1F6,1F7,3F0,3F1,3F2,3F3,3F4,3F5,3F6,3F7,3I0,3W0,3I1,3W1,1I0,1W0,1I1,1W1,3I2,3W2,3I3,3W3,1I2,1W2,1I3,1W3,3I4,3W4,3I5,3W5,1I4,1W4,1I5,1W5,3I6,3W6,3I7,3W7,1I6,1W6,1I7,1W7 3 | -------------------------------------------------------------------------------- /.github/packaging/post_build_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euxo pipefail 3 | 4 | FORGE_WHEEL=${GITHUB_WORKSPACE}/${REPOSITORY}/dist/*.whl 5 | WHL_DIR="${GITHUB_WORKSPACE}/wheels/dist" 6 | DIST=dist/ 7 | 8 | ls -l "${WHL_DIR}" 9 | ls ${FORGE_WHEEL} 10 | echo "Copying files from $WHL_DIR to $DIST" 11 | mkdir -p $DIST && find "$WHL_DIR" -maxdepth 1 -type f -exec cp {} "$DIST/" \; 12 | echo "The following wheels will be uploaded to S3" 13 | ls -l "${DIST}" 14 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to forge 2 | We want to make contributing to this project as easy and transparent as possible. 3 | 4 | 5 | 6 | ## Coding Style 7 | `forge` uses pre-commit hooks to ensure style consistency and prevent common mistakes. Enable it by: 8 | 9 | ``` 10 | pre-commit install 11 | ``` 12 | Ater this pre-commit hooks will be run before every commit. 13 | 14 | You can also run this manually on every file using: 15 | 16 | ``` 17 | pre-commit run --all-files 18 | ``` 19 | -------------------------------------------------------------------------------- /src/forge/util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from .distributed import get_world_size_and_rank 7 | from .logging import get_logger, log_once, log_rank_zero 8 | 9 | __all__ = [ 10 | "get_world_size_and_rank", 11 | "get_logger", 12 | "log_once", 13 | "log_rank_zero", 14 | ] 15 | -------------------------------------------------------------------------------- /tests/sandbox/vllm/qwen2_5_32b.yaml: -------------------------------------------------------------------------------- 1 | # >>> python -m tests.sandbox.vllm.main --config tests/sandbox/vllm/qwen2_5_32b.yaml 2 | 3 | policy: 4 | engine_args: 5 | model: "Qwen/Qwen2.5-32B" 6 | tensor_parallel_size: 4 7 | pipeline_parallel_size: 1 8 | enforce_eager: true 9 | sampling_params: 10 | n: 2 11 | max_tokens: 512 12 | 13 | services: 14 | policy: 15 | procs: 4 16 | num_replicas: 1 17 | with_gpus: true 18 | 19 | 20 | # Optional, otherwise argparse fallback kicks in 21 | prompt: "Tell me a joke" 22 | -------------------------------------------------------------------------------- /tests/sandbox/vllm/llama3_8b.yaml: -------------------------------------------------------------------------------- 1 | # >>> python -m tests.sandbox.vllm.main --config tests/sandbox/vllm/llama3_8b.yaml 2 | 3 | policy: 4 | engine_args: 5 | model: "meta-llama/Llama-3.1-8B-Instruct" 6 | tensor_parallel_size: 2 7 | pipeline_parallel_size: 1 8 | enforce_eager: true 9 | sampling_params: 10 | n: 2 11 | max_tokens: 512 12 | 13 | services: 14 | policy: 15 | procs: ${policy.engine_args.tensor_parallel_size} 16 | num_replicas: 4 17 | with_gpus: true 18 | 19 | 20 | # Optional, otherwise argparse fallback kicks in 21 | prompt: "Tell me a joke" 22 | -------------------------------------------------------------------------------- /apps/grpo/slurm/submit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | CONFIG_NAME="${1}" 9 | 10 | sbatch --job-name="${CONFIG_NAME}" \ 11 | --export=ALL,CONFIG_NAME="${CONFIG_NAME}" \ 12 | apps/grpo/slurm/submit_grpo.sh 13 | 14 | 15 | # Usage: 16 | # ./apps/grpo/slurm/submit.sh qwen3_8b 17 | # ./apps/grpo/slurm/submit.sh qwen3_32b 18 | # ./apps/grpo/slurm/submit.sh qwen3_30b_a3b 19 | -------------------------------------------------------------------------------- /src/forge/rl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from forge.rl.advantage import ComputeAdvantages 8 | from forge.rl.collate import collate 9 | from forge.rl.grading import RewardActor 10 | from forge.rl.types import Episode, Group, Policy 11 | 12 | __all__ = [ 13 | "Episode", 14 | "Group", 15 | "Policy", 16 | "collate", 17 | "ComputeAdvantages", 18 | "RewardActor", 19 | ] 20 | -------------------------------------------------------------------------------- /src/forge/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from .collate import collate_packed, collate_padded 7 | from .metric_transform import DefaultDatasetMetricTransform, MetricTransform 8 | from .utils import CROSS_ENTROPY_IGNORE_IDX 9 | 10 | __all__ = [ 11 | "collate_packed", 12 | "collate_padded", 13 | "CROSS_ENTROPY_IGNORE_IDX", 14 | "MetricTransform", 15 | "DefaultDatasetMetricTransform", 16 | ] 17 | -------------------------------------------------------------------------------- /src/forge/controller/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from .actor import ForgeActor 7 | from .provisioner import ( 8 | get_proc_mesh, 9 | host_mesh_from_proc, 10 | init_provisioner, 11 | shutdown, 12 | stop_proc_mesh, 13 | ) 14 | 15 | __all__ = [ 16 | "ForgeActor", 17 | "get_proc_mesh", 18 | "stop_proc_mesh", 19 | "init_provisioner", 20 | "shutdown", 21 | "host_mesh_from_proc", 22 | ] 23 | -------------------------------------------------------------------------------- /docs/source/api_actors.md: -------------------------------------------------------------------------------- 1 | # ForgeActor 2 | 3 | ```{eval-rst} 4 | .. currentmodule:: forge.actors 5 | ``` 6 | 7 | The actors module contains the core components for model training 8 | and inference in TorchForge. These pre-built actors provide essential 9 | functionality for reinforcement learning workflows and can be used 10 | as building blocks for complex distributed training systems. 11 | 12 | ```{eval-rst} 13 | .. currentmodule:: forge.controller.actor 14 | 15 | .. autoclass:: ForgeActor 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | :exclude-members: logger, setup, set_env, __init__, as_service 20 | ``` 21 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | 10 | 11 | def gpu_test(gpu_count: int = 1): 12 | """ 13 | Annotation for GPU tests, skipping the test if the 14 | required amount of GPU is not available 15 | """ 16 | message = f"Not enough GPUs to run the test: requires {gpu_count}" 17 | local_gpu_count: int = torch.cuda.device_count() 18 | return pytest.mark.skipif(local_gpu_count < gpu_count, reason=message) 19 | -------------------------------------------------------------------------------- /src/forge/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | __version__ = "" 8 | 9 | # Enables faster downloading. For more info: https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads 10 | # To disable, run `HF_HUB_ENABLE_HF_TRANSFER=0 tune download ` 11 | try: 12 | import os 13 | 14 | import hf_transfer # noqa 15 | 16 | if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") is None: 17 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" 18 | except ImportError: 19 | pass 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.yml: -------------------------------------------------------------------------------- 1 | name: 📚 Documentation 2 | description: Report an issue or make a request related to documentation 3 | 4 | body: 5 | - type: textarea 6 | attributes: 7 | label: 📚 The doc issue or request 8 | description: > 9 | A clear and concise description of what content is missing or an issue. Please include the URL to the page or pages with the problem if it exists. 10 | validations: 11 | required: true 12 | - type: textarea 13 | attributes: 14 | label: Suggest a potential alternative/fix 15 | description: > 16 | Tell us how we could improve the documentation. 17 | - type: markdown 18 | attributes: 19 | value: > 20 | Thanks for contributing 🎉! 21 | -------------------------------------------------------------------------------- /tests/sandbox/vllm/deepseek_r1.yaml: -------------------------------------------------------------------------------- 1 | # >>> python -m tests.sandbox.vllm.main --config tests/sandbox/vllm/deepseek_r1.yaml 2 | 3 | # NOTE - this won't work until we have proper HostMesh support 4 | policy: 5 | engine_args: 6 | model: "deepseek-ai/DeepSeek-R1-0528" 7 | tensor_parallel_size: 16 8 | pipeline_parallel_size: 1 9 | enable_expert_parallel: true 10 | # enforce_eager: true 11 | sampling_params: 12 | n: 2 13 | max_tokens: 512 14 | 15 | provisioner: 16 | launcher: slurm 17 | 18 | services: 19 | policy: 20 | procs: 8 21 | hosts: 2 22 | num_replicas: 1 23 | with_gpus: true 24 | 25 | 26 | # Optional, otherwise argparse fallback kicks in 27 | prompt: "Tell me a joke" 28 | -------------------------------------------------------------------------------- /assets/versions.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Version Configuration for Forge Wheel Building 8 | # This file contains all pinned versions and commits for dependencies 9 | 10 | # Stable versions of upstream libraries for OSS repo 11 | PYTORCH_VERSION="2.9.0" 12 | VLLM_VERSION="v0.10.0" 13 | MONARCH_NIGHTLY_VERSION="2025.12.17" 14 | TORCHTITAN_VERSION="0.2.0" 15 | TORCHSTORE_BRANCH="no-monarch-2025.12.17" 16 | 17 | # Torchtitan commit hash for launching on MAST 18 | TORCHTITAN_COMMIT_MAST="d0e25450bcac2332359b13fbda430dc701f073d4" 19 | -------------------------------------------------------------------------------- /src/forge/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .dataset import DatasetInfo, InfiniteTuneIterableDataset, InterleavedDataset 8 | from .hf_dataset import HfIterableDataset 9 | from .packed import PackedDataset 10 | from .sft_dataset import sft_iterable_dataset, SFTOutputTransform 11 | 12 | __all__ = [ 13 | "DatasetInfo", 14 | "HfIterableDataset", 15 | "InterleavedDataset", 16 | "InfiniteTuneIterableDataset", 17 | "PackedDataset", 18 | "SFTOutputTransform", 19 | "sft_iterable_dataset", 20 | ] 21 | -------------------------------------------------------------------------------- /src/forge/util/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | 10 | def get_world_size_and_rank() -> tuple[int, int]: 11 | """Function that gets the current world size (aka total number 12 | of ranks) and rank number of the current process in the default process group. 13 | 14 | Returns: 15 | tuple[int, int]: world size, rank 16 | """ 17 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 18 | return torch.distributed.get_world_size(), torch.distributed.get_rank() 19 | else: 20 | return 1, 0 21 | -------------------------------------------------------------------------------- /docs/source/api_generator.md: -------------------------------------------------------------------------------- 1 | # Generator 2 | 3 | ```{eval-rst} 4 | .. currentmodule:: forge.actors.generator 5 | ``` 6 | 7 | The Generator (Policy) is the core inference engine in TorchForge, 8 | built on top of [vLLM](https://docs.vllm.ai/en/latest/). 9 | It manages model serving, text generation, and weight updates for reinforcement learning workflows. 10 | 11 | ## Generator 12 | 13 | ```{eval-rst} 14 | .. autoclass:: Generator 15 | :members: generate, update_weights, get_version, stop 16 | :exclude-members: __init__, launch 17 | :no-inherited-members: 18 | ``` 19 | 20 | ## GeneratorWorker 21 | 22 | ```{eval-rst} 23 | .. autoclass:: GeneratorWorker 24 | :members: execute_model, update, setup_kv_cache 25 | :show-inheritance: 26 | :exclude-members: __init__ 27 | ``` 28 | -------------------------------------------------------------------------------- /src/forge/actors/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import warnings 8 | 9 | from .titan import TitanTrainer 10 | 11 | __all__ = ["TitanTrainer", "RLTrainer"] 12 | 13 | 14 | def __getattr__(name): 15 | if name == "RLTrainer": 16 | warnings.warn( 17 | "RLTrainer is deprecated and will be removed in a future version. " 18 | "Please use TitanTrainer instead.", 19 | FutureWarning, 20 | stacklevel=2, 21 | ) 22 | return TitanTrainer 23 | raise AttributeError(f"module {__name__} has no attribute {name}") 24 | -------------------------------------------------------------------------------- /docs/source/_static/logo-icon.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /src/forge/api/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Forge public API module. 8 | 9 | This module defines the public interfaces that all Forge implementations conform to. 10 | """ 11 | 12 | from forge.api.trainer import Trainer 13 | from forge.api.types import ( 14 | ForwardBackwardResult, 15 | LossFn, 16 | OptimStepResult, 17 | ParallelismConfig, 18 | TextTrainBatch, 19 | TrainerConfig, 20 | TrainerStatus, 21 | ) 22 | 23 | __all__ = [ 24 | "Trainer", 25 | "TextTrainBatch", 26 | "ForwardBackwardResult", 27 | "OptimStepResult", 28 | "TrainerConfig", 29 | "TrainerStatus", 30 | "ParallelismConfig", 31 | "LossFn", 32 | ] 33 | -------------------------------------------------------------------------------- /apps/grpo/slurm/submit_grpo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | #SBATCH --qos=h200_capabilities_shared 9 | #SBATCH --account=agentic-models 10 | #SBATCH --nodes=1 11 | #SBATCH --ntasks-per-node=1 12 | #SBATCH --gpus-per-node=8 13 | #SBATCH --cpus-per-task=128 14 | #SBATCH --mem=500G 15 | #SBATCH --time=72:00:00 16 | 17 | echo "Starting GRPO training job" 18 | 19 | eval "$(conda shell.bash hook)" 20 | 21 | conda activate forge 22 | 23 | export TORCH_COMPILE_DISABLE=1 24 | unset SLURM_MEM_PER_CPU SLURM_MEM_PER_GPU SLURM_MEM_PER_NODE 25 | export TORCHSTORE_RDMA_ENABLED=0 26 | 27 | cd /storage/home/$USER/torchforge 28 | 29 | srun python -m apps.grpo.main --config apps/grpo/slurm/${CONFIG_NAME}.yaml 30 | -------------------------------------------------------------------------------- /tests/assets/extend_jobconfig_example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass, field 8 | 9 | 10 | @dataclass 11 | class CustomArgs: 12 | how_is_your_day: str = "good" 13 | """Just an example helptext""" 14 | 15 | num_days: int = 7 16 | """Number of days in a week""" 17 | 18 | 19 | @dataclass 20 | class Training: 21 | steps: int = 99 22 | my_custom_steps: int = 32 23 | 24 | 25 | @dataclass 26 | class JobConfig: 27 | """ 28 | This is an example of how to extend the tyro parser with custom config classes. 29 | """ 30 | 31 | custom_args: CustomArgs = field(default_factory=CustomArgs) 32 | training: Training = field(default_factory=Training) 33 | -------------------------------------------------------------------------------- /src/forge/controller/service/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .interface import ServiceInterface, Session, SessionContext 8 | from .metrics import ServiceMetrics 9 | from .replica import Replica, ReplicaMetrics, ReplicaState 10 | from .router import LeastLoadedRouter, RoundRobinRouter, SessionRouter 11 | from .service import Service, ServiceActor, ServiceConfig 12 | 13 | __all__ = [ 14 | "Replica", 15 | "ReplicaMetrics", 16 | "ReplicaState", 17 | "Service", 18 | "ServiceConfig", 19 | "ServiceInterface", 20 | "ServiceMetrics", 21 | "Session", 22 | "SessionContext", 23 | "ServiceActor", 24 | "LeastLoadedRouter", 25 | "RoundRobinRouter", 26 | "SessionRouter", 27 | ] 28 | -------------------------------------------------------------------------------- /src/forge/rl/advantage.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass 8 | 9 | import torch 10 | 11 | from forge.controller.actor import ForgeActor 12 | from forge.rl.types import Group 13 | from monarch.actor import endpoint 14 | 15 | 16 | # TODO: this doesn't need to be an actor 17 | @dataclass 18 | class ComputeAdvantages(ForgeActor): 19 | @endpoint 20 | async def compute(self, group: Group) -> list[float]: 21 | # TODO: add batch processing 22 | rewards = torch.tensor([[e.reward for e in group]]) 23 | mean = rewards.mean(1, keepdim=True) 24 | std = rewards.std(1, keepdim=True) 25 | advantages = (rewards - mean) / (std + 1e-4) 26 | return advantages.squeeze(0).tolist() 27 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Minimal makefile for Sphinx documentation 7 | # 8 | 9 | # You can set these variables from the command line, and also 10 | # from the environment for the first two. 11 | SPHINXOPTS ?= 12 | SPHINXBUILD ?= sphinx-build 13 | SOURCEDIR = source 14 | BUILDDIR = build 15 | 16 | # Put it first so that "make" without argument is like "make help". 17 | help: 18 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 19 | 20 | .PHONY: help Makefile 21 | 22 | # Catch-all target: route all unknown targets to Sphinx using the new 23 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 24 | %: Makefile 25 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 26 | -------------------------------------------------------------------------------- /.github/packaging/pre_build_cpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euxo pipefail 3 | 4 | # Builds vLLM 5 | # This script builds vLLM and places its wheel into dist/. 6 | 7 | # Script runs relative to forge root 8 | CURRENT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 9 | echo "current dir is $CURRENT_DIR" 10 | VERSIONS_FILE="$CURRENT_DIR/../../assets/versions.sh" 11 | echo "versions file is $VERSIONS_FILE" 12 | source "$VERSIONS_FILE" 13 | 14 | BUILD_DIR="$HOME/forge-build" 15 | 16 | # Push other files to the dist folder 17 | WHL_DIR="${GITHUB_WORKSPACE}/wheels/dist" 18 | 19 | mkdir -p $BUILD_DIR 20 | mkdir -p $WHL_DIR 21 | echo "build dir is $BUILD_DIR" 22 | echo "wheel dir is $WHL_DIR" 23 | 24 | build_vllm() { 25 | cd "$BUILD_DIR" 26 | 27 | git clone https://github.com/vllm-project/vllm.git --branch $VLLM_VERSION 28 | cd "$BUILD_DIR/vllm" 29 | 30 | python use_existing_torch.py 31 | pip install -r requirements/build.txt 32 | export VERBOSE=1 33 | export CMAKE_VERBOSE_MAKEFILE=1 34 | export FORCE_CMAKE=1 35 | pip wheel -v --no-build-isolation --no-deps . -w "$WHL_DIR" 36 | } 37 | 38 | build_vllm -------------------------------------------------------------------------------- /src/forge/util/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import time 8 | 9 | import torchstore as ts 10 | 11 | from forge.actors._torchstore_utils import ( 12 | get_dcp_whole_state_dict_key, 13 | get_param_prefix, 14 | ) 15 | 16 | 17 | async def drop_weights(version: int): 18 | print(f"Dropping weights @ version {version}") 19 | start_time = time.perf_counter() 20 | prefix = get_param_prefix(version) 21 | matching_keys = await ts.keys(prefix) 22 | # TODO: once we have something like `get_meta()` in torchstore, we can just 23 | # query the type of the object instead of relying on keys. 24 | dcp_key = get_dcp_whole_state_dict_key(version) 25 | if dcp_key in matching_keys: 26 | dcp_handle = await ts.get(dcp_key) 27 | dcp_handle.drop() 28 | for key in matching_keys: 29 | await ts.delete(key) 30 | elapsed = time.perf_counter() - start_time 31 | print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds") 32 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | @ECHO OFF 8 | pushd %~dp0 9 | 10 | REM Command file for Sphinx documentation 11 | 12 | if "%SPHINXBUILD%" == "" ( 13 | set SPHINXBUILD=sphinx-build 14 | ) 15 | set SOURCEDIR=source 16 | set BUILDDIR=build 17 | 18 | %SPHINXBUILD% >NUL 2>NUL 19 | if errorlevel 9009 ( 20 | echo. 21 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 22 | echo.installed, then set the SPHINXBUILD environment variable to point 23 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 24 | echo.may add the Sphinx directory to PATH. 25 | echo. 26 | echo.If you don't have Sphinx installed, grab it from 27 | echo.https://www.sphinx-doc.org/ 28 | exit /b 1 29 | ) 30 | 31 | if "%1" == "" goto help 32 | 33 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 34 | goto end 35 | 36 | :help 37 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 38 | 39 | :end 40 | popd 41 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: ✨ Feature Request 2 | description: Suggest a new feature or enhancement for this project 3 | 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: > 8 | #### Before submitting a feature request, please search through [existing issues](https://github.com/meta-pytorch/forge/issues?q=is%3Aissue+sort%3Acreated-desc+) to see if something similar has already been proposed. 9 | - type: textarea 10 | attributes: 11 | label: Context/Motivation 12 | description: | 13 | Describe the problem you're trying to solve or the use case for this feature. Include any relevant links and context. 14 | validations: 15 | required: true 16 | - type: textarea 17 | attributes: 18 | label: Pseudo-code + acceptance criteria [Optional] 19 | description: | 20 | Provide a rough sketch of what the API or implementation might look like. This helps us understand your vision for how the feature would work. 21 | Also, if possible, include what would need to be true for this feature to be considered complete. 22 | validations: 23 | required: false 24 | - type: markdown 25 | attributes: 26 | value: > 27 | Thanks for contributing 🎉! 28 | -------------------------------------------------------------------------------- /docs/source/api_model.md: -------------------------------------------------------------------------------- 1 | # Model 2 | 3 | ```{eval-rst} 4 | .. currentmodule:: forge.actors.reference_model 5 | ``` 6 | 7 | The {class}`forge.actors.reference_model.ReferenceModel` provides a frozen 8 | copy of the policy model used for computing advantages in reinforcement 9 | learning. It performs inference on input sequences and returns logits or 10 | log probabilities for computing KL divergence and other RL metrics. 11 | 12 | ## ReferenceModel 13 | 14 | ```{eval-rst} 15 | .. autoclass:: forge.actors.reference_model.ReferenceModel 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | ``` 20 | 21 | The ReferenceModel uses a subset of TorchTitan's configuration system: 22 | 23 | - **model**: Model architecture settings (Model dataclass) 24 | - **parallelism**: Parallelism configuration for distributed inference (Parallelism dataclass) 25 | - **checkpoint**: Checkpoint loading settings (Checkpoint dataclass) 26 | - **compile**: Model compilation settings (Compile dataclass) 27 | - **training**: Training configuration for dtype and other settings (Training dataclass) 28 | 29 | For detailed configuration options, refer to the [TorchTitan documentation](https://github.com/pytorch/torchtitan). 30 | -------------------------------------------------------------------------------- /src/forge/losses/grpo_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | 10 | 11 | class SimpleGRPOLoss(nn.Module): 12 | """Simplified GRPO Loss for simplified single step updates 13 | Inspired by the Hugging Face TRL implementation: 14 | https://github.com/huggingface/trl/blob/417915a3e4d3e3bc8d7b196594308b8eabf928be/trl/trainer/grpo_trainer.py#L1624. 15 | """ 16 | 17 | def __init__(self, beta: float = 0.1): 18 | super().__init__() 19 | self.beta = beta 20 | 21 | def forward(self, logprobs, ref_logprobs, advantages, padding_mask): 22 | kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1 23 | per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages 24 | per_token_loss = -(per_token_policy_loss - self.beta * kl) 25 | loss = ( 26 | ((per_token_loss * padding_mask).sum(dim=1)) 27 | / (padding_mask.sum(dim=1).clamp(min=1.0)) 28 | ).mean() 29 | return loss 30 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # Suggested config from pytorch that we can adapt 3 | select = B,C,E,F,N,P,T4,W,B9,TOR0,TOR1,TOR2 4 | max-line-length = 120 5 | # C408 ignored because we like the dict keyword argument syntax 6 | # E501 is not flexible enough, we're using B950 instead 7 | # N812 ignored because import torch.nn.functional as F is PyTorch convention 8 | # N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP) 9 | # E731 allow usage of assigning lambda expressions 10 | # N803,N806 allow caps and mixed case in function params. This is to work with Triton kernel coding style. 11 | # E704 ignored to allow black's formatting of Protocol stub methods (def method(self) -> None: ...) 12 | ignore = 13 | E203,E305,E402,E501,E704,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731,N803,N806 14 | # shebang has extra meaning in fbcode lints, so I think it's not worth trying 15 | # to line this up with executable bit 16 | EXE001, 17 | # these ignores are from flake8-bugbear; please fix! 18 | B007,B008, 19 | optional-ascii-coding = True 20 | exclude = 21 | ./.git, 22 | ./docs 23 | ./build 24 | ./scripts, 25 | ./venv, 26 | *.pyi 27 | .pre-commit-config.yaml 28 | *.md 29 | .flake8 30 | -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | pull_request: 5 | workflow_dispatch: 6 | 7 | 8 | concurrency: 9 | group: lint-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} 10 | cancel-in-progress: true 11 | 12 | defaults: 13 | run: 14 | shell: bash -l -eo pipefail {0} 15 | 16 | jobs: 17 | lint: 18 | if: github.repository_owner == 'meta-pytorch' 19 | runs-on: ubuntu-latest 20 | strategy: 21 | matrix: 22 | python-version: ['3.10'] 23 | steps: 24 | - name: Check out repo 25 | uses: actions/checkout@v4 26 | - name: Setup python 27 | uses: actions/setup-python@v4 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | - name: Update pip 31 | run: python -m pip install --upgrade pip 32 | - name: Install lint utilities 33 | run: | 34 | python -m pip install pre-commit 35 | pre-commit install-hooks 36 | - name: Get changed files 37 | id: changed-files 38 | uses: tj-actions/changed-files@d6e91a2266cdb9d62096cebf1e8546899c6aa18f # v45.0.6 39 | - name: Lint modified files 40 | run: pre-commit run --files ${{ steps.changed-files.outputs.all_changed_files }} 41 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Global test configuration for all tests. 9 | 10 | This file contains pytest fixtures that are automatically applied to all tests 11 | in the forge test suite. 12 | """ 13 | 14 | from unittest.mock import Mock 15 | 16 | import pytest 17 | 18 | from forge.env import FORGE_DISABLE_METRICS 19 | 20 | 21 | @pytest.fixture(autouse=True) 22 | def mock_metrics_globally(monkeypatch): 23 | """ 24 | Automatically disable `forge.observability.metrics.record_metrics` during tests, 25 | which could otherwise introduce flakiness if not properly configured. 26 | 27 | To disable this mock in a specific test, override the fixture: 28 | 29 | @pytest.fixture 30 | def mock_metrics_globally(): 31 | # Return None to disable the mock for this test 32 | return None 33 | 34 | def test_real_metrics(mock_metrics_globally): 35 | # This test will use the real metrics system 36 | pass 37 | """ 38 | 39 | monkeypatch.setenv(FORGE_DISABLE_METRICS.name, "true") 40 | return Mock() 41 | -------------------------------------------------------------------------------- /src/forge/data_models/completion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass 8 | from typing import Any 9 | 10 | import torch 11 | 12 | from forge.data_models.prompt import Prompt 13 | 14 | 15 | @dataclass 16 | class Completion: 17 | """A model-generated completion for a given prompt.""" 18 | 19 | # The original prompt. 20 | prompt: Prompt 21 | 22 | # the decoded text returned by the model 23 | text: str 24 | 25 | # the encoded text (token ids) that were fed into the model 26 | prompt_ids: torch.Tensor 27 | 28 | # the encoded text (token ids) that were generated by the model 29 | token_ids: torch.Tensor 30 | 31 | # the log probabilities of the target tokens 32 | logprobs: torch.Tensor | None = None 33 | 34 | # the reason for stopping the generation 35 | stop_reason: str | None = None 36 | 37 | # the version identifier of the model when the generation was performed 38 | generator_version: int | None = None 39 | 40 | # extra information that might be useful for debugging 41 | metadata: dict[str, Any] | None = None 42 | -------------------------------------------------------------------------------- /tests/integration_tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | 9 | import pytest 10 | 11 | 12 | def str_to_bool(value): 13 | if value.lower() in ("yes", "true", "t", "y", "1"): 14 | return True 15 | elif value.lower() in ("no", "false", "f", "n", "0"): 16 | return False 17 | else: 18 | raise argparse.ArgumentTypeError(f"Boolean value expected, got '{value}'") 19 | 20 | 21 | def pytest_addoption(parser): 22 | """Add custom command line options for pytest.""" 23 | parser.addoption( 24 | "--config", 25 | action="store", 26 | default=None, 27 | help="Path to YAML config file for sanity check tests", 28 | ) 29 | 30 | parser.addoption( 31 | "--use_dcp", 32 | action="store", 33 | type=str_to_bool, 34 | default=None, 35 | help="Overrides the YAML config `trainer.use_dcp` field.", 36 | ) 37 | 38 | 39 | @pytest.fixture 40 | def config_path(request): 41 | """Fixture to provide the config path from command line.""" 42 | return request.config.getoption("--config") 43 | -------------------------------------------------------------------------------- /tests/sandbox/toy_rl/sumdigits-tp.yaml: -------------------------------------------------------------------------------- 1 | # Toy app Training Configuration 2 | 3 | # Global configuration 4 | group_size: 16 5 | batch_size: 64 6 | max_req_tokens: 64 7 | max_res_tokens: 64 8 | model: "Qwen/Qwen2.5-0.5B-Instruct" 9 | 10 | # Dataset configuration 11 | dataset: 12 | model: ${model} 13 | 14 | # Policy configuration 15 | policy: 16 | engine_args: 17 | model: ${model} 18 | tensor_parallel_size: 2 19 | pipeline_parallel_size: 1 20 | enforce_eager: false 21 | sampling_params: 22 | n: ${group_size} 23 | max_tokens: ${max_res_tokens} 24 | temperature: 1.0 25 | top_p: 1.0 26 | 27 | # Trainer configuration 28 | trainer: 29 | model_name: ${model} 30 | learning_rate: 1e-5 31 | 32 | # Reference model configuration 33 | ref_model: 34 | model_name: ${model} 35 | 36 | # Replay buffer configuration 37 | replay_buffer: 38 | batch_size: ${batch_size} 39 | max_policy_age: 1 # Async by 1 40 | dp_size: 1 41 | 42 | services: 43 | policy: 44 | procs: 1 45 | num_replicas: 1 46 | with_gpus: true 47 | reward_actor: 48 | procs: 1 49 | num_replicas: 1 50 | with_gpus: false 51 | ref_model: 52 | procs: 1 53 | num_replicas: 1 54 | with_gpus: true 55 | 56 | actors: 57 | dataset: 58 | procs: 1 59 | with_gpus: false 60 | trainer: 61 | procs: 1 62 | with_gpus: true 63 | replay_buffer: 64 | procs: 1 65 | with_gpus: false 66 | -------------------------------------------------------------------------------- /tests/sandbox/toy_rl/sumdigits.yaml: -------------------------------------------------------------------------------- 1 | # Toy app Training Configuration 2 | 3 | # Global configuration 4 | group_size: 6 5 | batch_size: 12 6 | max_req_tokens: 64 7 | max_res_tokens: 64 8 | model: "Qwen/Qwen2.5-0.5B-Instruct" 9 | 10 | # Dataset configuration 11 | dataset: 12 | model: ${model} 13 | 14 | # Policy configuration 15 | policy: 16 | use_dcp: false 17 | engine_args: 18 | model: ${model} 19 | tensor_parallel_size: 1 20 | pipeline_parallel_size: 1 21 | enforce_eager: false 22 | sampling_params: 23 | n: ${group_size} 24 | max_tokens: ${max_res_tokens} 25 | temperature: 1.0 26 | top_p: 1.0 27 | 28 | 29 | # Trainer configuration 30 | trainer: 31 | model_name: ${model} 32 | learning_rate: 1e-5 33 | 34 | # Reference model configuration 35 | ref_model: 36 | model_name: ${model} 37 | 38 | # Replay buffer configuration 39 | replay_buffer: 40 | batch_size: ${batch_size} 41 | max_policy_age: 1 # Async by 1 42 | dp_size: 1 43 | 44 | services: 45 | policy: 46 | procs: 1 47 | num_replicas: 1 48 | with_gpus: true 49 | reward_actor: 50 | procs: 1 51 | num_replicas: 1 52 | with_gpus: false 53 | ref_model: 54 | procs: 1 55 | num_replicas: 1 56 | with_gpus: true 57 | 58 | actors: 59 | dataset: 60 | procs: 1 61 | with_gpus: false 62 | trainer: 63 | procs: 1 64 | with_gpus: true 65 | replay_buffer: 66 | procs: 1 67 | with_gpus: false 68 | -------------------------------------------------------------------------------- /.github/workflows/gpu_test.yaml: -------------------------------------------------------------------------------- 1 | name: Unit Tests (GPU) 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | workflow_dispatch: 8 | 9 | concurrency: 10 | group: gpu-test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} 11 | cancel-in-progress: true 12 | 13 | permissions: 14 | id-token: write 15 | contents: read 16 | 17 | defaults: 18 | run: 19 | shell: bash -l -eo pipefail {0} 20 | 21 | jobs: 22 | gpu_test: 23 | if: github.repository_owner == 'meta-pytorch' 24 | runs-on: linux.g5.12xlarge.nvidia.gpu 25 | strategy: 26 | matrix: 27 | python-version: ['3.10', '3.11', '3.12'] 28 | steps: 29 | - name: Check out repo 30 | uses: actions/checkout@v4 31 | - name: Setup conda env 32 | uses: conda-incubator/setup-miniconda@v2 33 | with: 34 | auto-update-conda: true 35 | miniconda-version: "latest" 36 | activate-environment: test 37 | python-version: ${{ matrix.python-version }} 38 | - name: Update pip 39 | run: python -m pip install --upgrade pip 40 | - name: Install torchforge 41 | run: pip install uv && uv pip install . && uv pip install .[dev] 42 | - name: Run unit tests with coverage 43 | # TODO add all tests 44 | run: pytest tests/unit_tests --cov=. --cov-report=xml --durations=20 -vv 45 | - name: Upload Coverage to Codecov 46 | uses: codecov/codecov-action@v3 47 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: 'build' 2 | 3 | default_language_version: 4 | python: python3 5 | 6 | repos: 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v5.0.0 9 | hooks: 10 | - id: trailing-whitespace 11 | - id: check-ast 12 | - id: check-merge-conflict 13 | - id: no-commit-to-branch 14 | args: ['--branch=main'] 15 | - id: end-of-file-fixer 16 | exclude: '^(.*\.svg)$' 17 | 18 | - repo: https://github.com/Lucas-C/pre-commit-hooks 19 | rev: v1.5.5 20 | hooks: 21 | - id: insert-license 22 | files: \.py$|\.sh$ 23 | args: 24 | - --license-filepath 25 | - docs/license_header.txt 26 | 27 | - repo: https://github.com/pycqa/flake8 28 | rev: 7.1.1 29 | hooks: 30 | - id: flake8 31 | additional_dependencies: 32 | - flake8-bugbear == 22.4.25 33 | - pep8-naming == 0.12.1 34 | - torchfix 35 | args: ['--config=.flake8'] 36 | 37 | - repo: https://github.com/omnilib/ufmt 38 | rev: v2.8.0 39 | hooks: 40 | - id: ufmt 41 | additional_dependencies: 42 | - black == 24.4.2 43 | - usort == 1.0.8.post1 44 | 45 | - repo: https://github.com/jsh9/pydoclint 46 | rev: 0.5.12 47 | hooks: 48 | - id: pydoclint 49 | args: [--config=pyproject.toml] 50 | 51 | - repo: https://github.com/fastai/nbdev.git 52 | rev: 2.4.5 53 | hooks: 54 | - id: nbdev_clean 55 | args: [--clear_all] 56 | -------------------------------------------------------------------------------- /docs/source/zero-to-forge-intro.md: -------------------------------------------------------------------------------- 1 | # Zero to TorchForge: From RL Theory to Production-Scale Implementation 2 | 3 | A comprehensive guide for ML Engineers building distributed RL systems for language models. 4 | 5 | Some of the examples mentioned below will be conceptual in nature for understanding. 6 | Please refer to [API Docs](./api) for more details. 7 | 8 | Welcome to the Tutorials section! This section is inspired by the A-Z 9 | PyTorch tutorial, shoutout to our PyTorch friends that remember! 10 | 11 | ## Tutorial Structure 12 | 13 | This section currently is structured in 3 detailed parts: 14 | 15 | 1. [Part 1: RL Fundamentals - Using TorchForge Terminology](tutorials/zero-to-forge/1_RL_and_Forge_Fundamentals): This gives a quick refresher of Reinforcement Learning and teaches you TorchForge Fundamentals 16 | 2. [Part 2: Peeling Back the Abstraction - What Are Services?](tutorials/zero-to-forge/2_Forge_Internals): Goes a layer deeper and explains the internals of TorchForge 17 | 3. [Part 3: The TorchForge-Monarch Connection](tutorials/zero-to-forge/3_Monarch_101): It's a 101 to Monarch and how TorchForge Talks to Monarch 18 | 19 | Each part builds upon the next and the entire section can be consumed in roughly an hour - Grab a Chai and Enjoy! 20 | 21 | If you're eager, please checkout our SFT Tutorial too (Coming soon!)! 22 | 23 | ```{toctree} 24 | :maxdepth: 1 25 | :hidden: 26 | tutorials/zero-to-forge/1_RL_and_Forge_Fundamentals 27 | tutorials/zero-to-forge/2_Forge_Internals 28 | tutorials/zero-to-forge/3_Monarch_101 29 | ``` 30 | -------------------------------------------------------------------------------- /docs/source/api.md: -------------------------------------------------------------------------------- 1 | # API Reference 2 | 3 | This section provides comprehensive API documentation for TorchForge. 4 | 5 | ## Overview 6 | 7 | TorchForge is a PyTorch native platform for post-training generative AI models, 8 | designed to streamline reinforcement learning workflows for large language 9 | models. The platform leverages PyTorch's distributed computing capabilities 10 | and is built on top of [Monarch](https://meta-pytorch.org/monarch/), 11 | making extensive use of actors for distributed computation and fault tolerance. 12 | 13 | Key Features of TorchForge include: 14 | 15 | - **Actor-Based Architecture**: TorchForge uses an actor-based system for distributed training, providing excellent scalability and fault tolerance. 16 | - **PyTorch Native**: Built natively on PyTorch, ensuring seamless integration with existing PyTorch workflows. 17 | - **Post-Training Focus**: Specifically designed for post-training techniques like RLVR, SFT, and other alignment methods. 18 | - **Distributed by Design**: Supports multi-GPU and multi-node training out of the box. 19 | 20 | 21 | For most use cases, you'll interact with the high-level service 22 | interfaces, which handle the complexity of actor coordination and 23 | distributed training automatically. 24 | 25 | For advanced users who need fine-grained control, the individual actor 26 | APIs provide direct access to the underlying distributed components. 27 | 28 | ```{toctree} 29 | :maxdepth: 1 30 | api_actors 31 | api_service 32 | api_generator 33 | api_model 34 | api_trainer 35 | ``` 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, 6 | are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice,this list 9 | of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, this 12 | list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its contributors may 16 | be used to endorse or promote products derived from this software without specific 17 | prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY 20 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 21 | OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT 22 | SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 23 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 24 | TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR 25 | BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 27 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH 28 | DAMAGE. 29 | -------------------------------------------------------------------------------- /src/forge/actors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import warnings 8 | 9 | __all__ = [ 10 | "Generator", 11 | "TitanTrainer", 12 | "RLTrainer", # Deprecated, use TitanTrainer 13 | "ReplayBuffer", 14 | "ReferenceModel", 15 | "SandboxedPythonCoder", 16 | ] 17 | 18 | 19 | def __getattr__(name): 20 | if name == "Generator": 21 | from .generator import Generator 22 | 23 | return Generator 24 | elif name == "TitanTrainer": 25 | from .trainer import TitanTrainer 26 | 27 | return TitanTrainer 28 | elif name == "RLTrainer": 29 | warnings.warn( 30 | "RLTrainer is deprecated and will be removed in a future version. " 31 | "Please use TitanTrainer instead.", 32 | FutureWarning, 33 | stacklevel=2, 34 | ) 35 | from .trainer import RLTrainer 36 | 37 | return RLTrainer 38 | elif name == "ReplayBuffer": 39 | from .replay_buffer import ReplayBuffer 40 | 41 | return ReplayBuffer 42 | elif name == "ReferenceModel": 43 | from .reference_model import ReferenceModel 44 | 45 | return ReferenceModel 46 | elif name == "SandboxedPythonCoder": 47 | from .coder import SandboxedPythonCoder 48 | 49 | return SandboxedPythonCoder 50 | else: 51 | raise AttributeError(f"module {__name__} has no attribute {name}") 52 | -------------------------------------------------------------------------------- /docs/source/api_trainer.md: -------------------------------------------------------------------------------- 1 | # Trainer 2 | 3 | ```{eval-rst} 4 | .. currentmodule:: forge.actors.trainer 5 | ``` 6 | 7 | The Trainer manages model training in TorchForge, built on top of TorchTitan. 8 | It handles forward/backward passes, weight updates, and checkpoint management for reinforcement learning workflows. 9 | 10 | ## TitanTrainer 11 | 12 | ```{eval-rst} 13 | .. autoclass:: TitanTrainer 14 | :members: train_step, push_weights, cleanup 15 | :exclude-members: __init__ 16 | ``` 17 | 18 | ## Configuration 19 | 20 | The TitanTrainer uses TorchTitan's configuration system with the following components: 21 | 22 | ### Job Configuration 23 | 24 | ```{eval-rst} 25 | .. autoclass:: torchtitan.config.job_config.Job 26 | :members: 27 | :undoc-members: 28 | ``` 29 | 30 | ### Model Configuration 31 | 32 | ```{eval-rst} 33 | .. autoclass:: torchtitan.config.job_config.Model 34 | :members: 35 | :undoc-members: 36 | ``` 37 | 38 | ### Optimizer Configuration 39 | 40 | ```{eval-rst} 41 | .. autoclass:: torchtitan.config.job_config.Optimizer 42 | :members: 43 | :undoc-members: 44 | ``` 45 | 46 | ### Training Configuration 47 | 48 | ```{eval-rst} 49 | .. autoclass:: torchtitan.config.job_config.Training 50 | :members: 51 | :undoc-members: 52 | ``` 53 | 54 | ### Parallelism Configuration 55 | 56 | ```{eval-rst} 57 | .. autoclass:: torchtitan.config.job_config.Parallelism 58 | :members: 59 | :undoc-members: 60 | ``` 61 | 62 | ### Checkpoint Configuration 63 | 64 | ```{eval-rst} 65 | .. autoclass:: torchtitan.config.job_config.Checkpoint 66 | :members: 67 | :undoc-members: 68 | ``` 69 | -------------------------------------------------------------------------------- /src/forge/rl/collate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any 8 | 9 | import torch 10 | 11 | from forge.rl.types import Group 12 | 13 | 14 | def collate( 15 | batches: list[Group], 16 | ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: 17 | """ 18 | Collates a list of batches into a single batch of inputs and targets. 19 | Each batch is a list of episodes, and each episode is a dict of tensors. 20 | """ 21 | inputs = [] 22 | targets = [] 23 | for batch in batches: 24 | request = [e.request_tensor for e in batch] 25 | request = torch.stack(request) # [b x s] 26 | 27 | response = [e.response_tensor for e in batch] 28 | response = torch.stack(response) # [b x s] 29 | 30 | ref_logprobs = [e.ref_logprobs for e in batch] 31 | ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s] 32 | 33 | advantages = [e.advantage for e in batch] 34 | advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1] 35 | 36 | pad_id = batch[0].pad_id 37 | mask = response != pad_id 38 | 39 | input = {"tokens": torch.cat([request, response], dim=1)} 40 | target = { 41 | "response": response, 42 | "ref_logprobs": ref_logprobs, 43 | "advantages": advantages, 44 | "padding_mask": mask, 45 | } 46 | inputs.append(input) 47 | targets.append(target) 48 | return inputs, targets 49 | -------------------------------------------------------------------------------- /src/forge/data_models/prompt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from collections.abc import Sequence 8 | from dataclasses import dataclass 9 | from enum import Enum 10 | 11 | 12 | class Role(Enum): 13 | SYSTEM = "system" 14 | USER = "user" 15 | ASSISTANT = "assistant" 16 | NONE = "none" 17 | 18 | 19 | @dataclass 20 | class Message: 21 | """A single message in a conversation.""" 22 | 23 | chunks: Sequence[str] 24 | role: Role 25 | 26 | 27 | @dataclass 28 | class Prompt: 29 | """A multi-turn prompt (conversation history).""" 30 | 31 | # Multi-turn messages, each turn is a message. 32 | messages: Sequence[Message] 33 | 34 | @classmethod 35 | def from_prompt( 36 | cls, prompt: str, system_instruction: str | None = None 37 | ) -> "Prompt": 38 | messages = prompt_to_messages(prompt, system_instruction) 39 | return Prompt( 40 | messages=messages, 41 | ) 42 | 43 | 44 | def prompt_to_messages( 45 | prompt: str, system_instruction: str | None = None 46 | ) -> Sequence[Message]: 47 | """Convert a prompt to a sequence of messages.""" 48 | messages = [] 49 | if system_instruction is not None: 50 | messages.append(Message(chunks=[system_instruction], role=Role.SYSTEM)) 51 | messages.append( 52 | Message(chunks=[prompt], role=Role.USER), 53 | ) 54 | return messages 55 | 56 | 57 | def to_prompt(prompt: str, system_instruction: str | None = None) -> Prompt: 58 | """Converts a prompt to a sequence of messages.""" 59 | return Prompt( 60 | messages=prompt_to_messages(prompt, system_instruction), 61 | ) 62 | -------------------------------------------------------------------------------- /src/forge/observability/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .metric_actors import ( 8 | get_or_create_metric_logger, 9 | GlobalLoggingActor, 10 | LocalFetcherActor, 11 | ) 12 | from .metrics import ( 13 | BackendRole, 14 | ConsoleBackend, 15 | get_logger_backend_class, 16 | LoggerBackend, 17 | LoggingMode, 18 | MaxAccumulator, 19 | MeanAccumulator, 20 | Metric, 21 | MetricAccumulator, 22 | MetricCollector, 23 | MinAccumulator, 24 | record_metric, 25 | Reduce, 26 | reduce_metrics_states, 27 | SampleAccumulator, 28 | StdAccumulator, 29 | SumAccumulator, 30 | WandbBackend, 31 | ) 32 | from .perf_tracker import trace, Tracer 33 | from .utils import get_proc_name_with_rank 34 | 35 | __all__ = [ 36 | # Main API functions 37 | "record_metric", 38 | "reduce_metrics_states", 39 | "get_logger_backend_class", 40 | "get_or_create_metric_logger", 41 | # Performance tracking 42 | "Tracer", 43 | "trace", 44 | # Data classes 45 | "Metric", 46 | "BackendRole", 47 | # Enums 48 | "Reduce", 49 | "LoggingMode", 50 | # Utility functions 51 | "get_proc_name_with_rank", 52 | # Actor classes 53 | "GlobalLoggingActor", 54 | "LocalFetcherActor", 55 | # Collector 56 | "MetricCollector", 57 | # Backend classes 58 | "LoggerBackend", 59 | "ConsoleBackend", 60 | "WandbBackend", 61 | # Accumulator classes 62 | "MetricAccumulator", 63 | "MeanAccumulator", 64 | "SumAccumulator", 65 | "MaxAccumulator", 66 | "MinAccumulator", 67 | "StdAccumulator", 68 | "SampleAccumulator", 69 | ] 70 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Tests 2 | 3 | This directory contains tests for the forge project, including unit tests and integration tests. 4 | 5 | ## Test Structure 6 | 7 | - `unit_tests/`: Contains unit tests for individual components 8 | - `integration_tests/`: Contains integration tests that test multiple components together 9 | - `sandbox/`: Contains experimental adhoc scripts used for development and debugging 10 | - `assets/`: Contains test assets and fixtures used by the tests 11 | 12 | ## Running Tests 13 | 14 | ### Prerequisites 15 | 16 | Ensure you have all development dependencies installed (run from forge root): 17 | 18 | ```bash 19 | pip install .[dev] 20 | ``` 21 | 22 | ### Running Integration Tests 23 | 24 | To run all integration tests: 25 | 26 | ```bash 27 | pytest -s tests/integration_tests/ 28 | ``` 29 | 30 | To run a specific integration test file: 31 | 32 | ```bash 33 | pytest -s tests/integration_tests/test_vllm_policy_correctness.py 34 | ``` 35 | 36 | To run a specific integration test function: 37 | 38 | ```bash 39 | pytest -s tests/integration_tests/test_vllm_policy_correctness.py::test_same_output 40 | ``` 41 | 42 | Integration tests support custom options defined in `conftest.py`: 43 | - `--config`: Path to YAML config file for sanity check tests 44 | - `--use_dcp`: Override the YAML config `trainer.use_dcp` field (true/false) 45 | 46 | Example with options: 47 | ```bash 48 | pytest -s tests/integration_tests/ --config ./path/to/config.yaml --use_dcp true 49 | ``` 50 | 51 | ### Running Unit Tests 52 | 53 | To run all unit tests: 54 | 55 | ```bash 56 | pytest -s tests/unit_tests/ 57 | ``` 58 | 59 | To run a specific unit test file: 60 | 61 | ```bash 62 | pytest -s tests/unit_tests/test_config.py 63 | ``` 64 | 65 | To run a specific unit test function: 66 | 67 | ```bash 68 | pytest -s tests/unit_tests/test_config.py::test_cache_hit_scenario 69 | ``` 70 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: 🐛 Bug Report 2 | description: Create a report to help us reproduce and fix the bug 3 | 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: > 8 | #### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/meta-pytorch/forge/issues?q=is%3Aissue+sort%3Acreated-desc+). 9 | - type: textarea 10 | attributes: 11 | label: 🐛 Describe the bug 12 | description: | 13 | Please provide a clear and concise description of what the bug is. 14 | 15 | If relevant, add a minimal example so that we can reproduce the error by running the code. 16 | 17 | If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. 18 | 19 | Please also paste or describe the results you observe along with the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. 20 | placeholder: | 21 | A clear and concise description of what the bug is. 22 | 23 | ```python 24 | # Sample code to reproduce the problem 25 | ``` 26 | 27 | ``` 28 | The error message you got, with the full traceback. 29 | ``` 30 | validations: 31 | required: true 32 | - type: textarea 33 | attributes: 34 | label: Versions 35 | description: | 36 | Please share the relevant package versions/commit hash of [pytorch](https://github.com/pytorch/pytorch), [torchtitan](https://github.com/pytorch/torchtitan), [monarch](https://github.com/meta-pytorch/monarch), and [vllm](https://github.com/vllm-project/vllm) 37 | - type: markdown 38 | attributes: 39 | value: > 40 | Thanks for contributing 🎉! 41 | -------------------------------------------------------------------------------- /src/forge/rl/grading.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass 8 | from typing import Callable 9 | 10 | from forge.controller.actor import ForgeActor 11 | from forge.observability.metrics import record_metric, Reduce 12 | 13 | from monarch.actor import endpoint 14 | 15 | 16 | @dataclass 17 | class RewardActor(ForgeActor): 18 | reward_functions: list[Callable] 19 | 20 | @endpoint 21 | async def evaluate_response( 22 | self, prompt: str, response: str, target: str 23 | ) -> (dict[str, float], float): 24 | total_rewards = 0.0 25 | reward_breakdown = {} # reward breakdown by function 26 | for reward_fn in self.reward_functions: 27 | reward = reward_fn(prompt, response, target) 28 | total_rewards += reward 29 | 30 | # Get a name for the reward function (works for classes, functions, lambdas) 31 | reward_fn_name = getattr( 32 | reward_fn, "__name__", reward_fn.__class__.__name__ 33 | ) 34 | reward_breakdown[reward_fn_name] = reward 35 | 36 | # log per fn reward and avg total 37 | record_metric( 38 | f"reward/evaluate_response/avg_{reward_fn_name}_reward", 39 | reward, 40 | Reduce.MEAN, 41 | ) 42 | record_metric( 43 | f"reward/evaluate_response/std_{reward_fn_name}_reward", 44 | reward, 45 | Reduce.STD, 46 | ) 47 | 48 | record_metric( 49 | "reward/evaluate_response/avg_total_reward", 50 | reward, 51 | Reduce.MEAN, 52 | ) 53 | 54 | avg_reward: float = total_rewards / len(self.reward_functions) 55 | return reward_breakdown, avg_reward 56 | -------------------------------------------------------------------------------- /tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml: -------------------------------------------------------------------------------- 1 | # Global configuration 2 | group_size: 8 3 | batch_size: 16 4 | max_req_tokens: 512 5 | max_res_tokens: 512 6 | model: "Qwen/Qwen3-1.7B" 7 | off_by_n: 1 # Off by one by default 8 | compile: true # Enable torch.compile for trainer, and CUDA graphs for vLLM 9 | 10 | 11 | # Policy configuration 12 | policy: 13 | engine_args: 14 | model: ${model} 15 | tensor_parallel_size: 1 16 | pipeline_parallel_size: 1 17 | enforce_eager: ${not:${compile}} 18 | sampling_params: 19 | n: ${group_size} 20 | max_tokens: ${max_res_tokens} 21 | temperature: 1.0 22 | top_p: 1.0 23 | 24 | # Trainer configuration 25 | trainer: 26 | model: 27 | name: qwen3 28 | flavor: 1.7B 29 | hf_assets_path: hf://${model} 30 | optimizer: 31 | name: AdamW 32 | lr: 1e-5 33 | eps: 1e-8 34 | lr_scheduler: 35 | warmup_steps: 1 36 | training: 37 | local_batch_size: ${batch_size} 38 | seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens 39 | max_norm: 1.0 40 | steps: 1000000 41 | dtype: bfloat16 42 | gc_freq: 1 43 | compile: 44 | enable: ${compile} 45 | parallelism: 46 | data_parallel_replicate_degree: 1 47 | data_parallel_shard_degree: 1 48 | tensor_parallel_degree: 1 49 | pipeline_parallel_degree: 1 50 | context_parallel_degree: 1 51 | expert_parallel_degree: 1 52 | disable_loss_parallel: true 53 | checkpoint: 54 | enable: true 55 | initial_load_path: hf://${model} 56 | initial_load_in_hf: true 57 | last_save_in_hf: true 58 | interval: 500 59 | async_mode: "disabled" 60 | activation_checkpoint: 61 | mode: selective 62 | selective_ac_option: op 63 | 64 | # All resource allocations 65 | services: 66 | policy: 67 | procs: ${policy.engine_args.tensor_parallel_size} 68 | num_replicas: 1 69 | with_gpus: true 70 | 71 | actors: 72 | trainer: 73 | procs: 1 74 | num_replicas: 1 75 | with_gpus: true 76 | -------------------------------------------------------------------------------- /tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml: -------------------------------------------------------------------------------- 1 | # trainer tp = 2, policy tp = 4 2 | 3 | # Global confOiguration 4 | group_size: 8 5 | batch_size: 16 6 | max_req_tokens: 512 7 | max_res_tokens: 512 8 | model: "Qwen/Qwen3-1.7B" 9 | off_by_n: 1 # Off by one by default 10 | compile: true # Enable torch.compile for trainer, and CUDA graphs for vLLM 11 | 12 | 13 | # Policy configuration 14 | policy: 15 | engine_args: 16 | model: ${model} 17 | tensor_parallel_size: 4 18 | pipeline_parallel_size: 1 19 | enforce_eager: ${not:${compile}} 20 | sampling_params: 21 | n: ${group_size} 22 | max_tokens: ${max_res_tokens} 23 | temperature: 1.0 24 | top_p: 1.0 25 | 26 | # Trainer configuration 27 | trainer: 28 | model: 29 | name: qwen3 30 | flavor: 1.7B 31 | hf_assets_path: hf://${model} 32 | optimizer: 33 | name: AdamW 34 | lr: 1e-5 35 | eps: 1e-8 36 | lr_scheduler: 37 | warmup_steps: 1 38 | training: 39 | local_batch_size: ${batch_size} 40 | seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens 41 | max_norm: 1.0 42 | steps: 1000000 43 | dtype: bfloat16 44 | gc_freq: 1 45 | compile: 46 | enable: ${compile} 47 | parallelism: 48 | data_parallel_replicate_degree: 1 49 | data_parallel_shard_degree: 1 50 | tensor_parallel_degree: 2 51 | pipeline_parallel_degree: 1 52 | context_parallel_degree: 1 53 | expert_parallel_degree: 1 54 | disable_loss_parallel: true 55 | checkpoint: 56 | enable: true 57 | initial_load_path: hf://${model} 58 | initial_load_in_hf: true 59 | last_save_in_hf: true 60 | interval: 500 61 | async_mode: "disabled" 62 | activation_checkpoint: 63 | mode: selective 64 | selective_ac_option: op 65 | 66 | # All resource allocations 67 | services: 68 | policy: 69 | procs: ${policy.engine_args.tensor_parallel_size} 70 | num_replicas: 1 71 | with_gpus: true 72 | 73 | actors: 74 | trainer: 75 | procs: 2 76 | num_replicas: 1 77 | with_gpus: true 78 | -------------------------------------------------------------------------------- /src/forge/losses/reinforce_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from forge.util.ops import compute_logprobs 10 | from torch import nn 11 | 12 | 13 | class ReinforceLoss(nn.Module): 14 | """Reinforce loss function with optional importance ratio clipping. 15 | 16 | Reinforce with importance ratio is NOT GRPO. GRPO uses ratio clipping, where 17 | tokens outside trust region don't have gradients. Reinforce with importance 18 | ratio clips a detached importance ratio, where tokens outside trust region 19 | still have gradients. 20 | 21 | This difference is importance when very bad things happens, e.g. SDC or 22 | expert selection mismatch between sampling and policy update due to 23 | numerical noise. GRPO is more resilient in this case. 24 | """ 25 | 26 | def __init__(self): 27 | super().__init__() 28 | 29 | def forward( 30 | self, trainer_logits, target_ids, target_mask, target_weights, target_log_probs 31 | ): 32 | trainer_log_probs = compute_logprobs(trainer_logits, target_ids, align=False) 33 | target_mask = target_mask.detach() 34 | target_weights = target_weights 35 | target_mask_sum = target_mask.sum() 36 | target_mask_sum = torch.maximum( 37 | target_mask_sum, torch.ones_like(target_mask_sum) 38 | ) 39 | sampler_log_probs = target_log_probs 40 | 41 | # Importance sampling ratio 42 | logp_diff = trainer_log_probs - sampler_log_probs.detach() 43 | importance_weights = torch.exp(logp_diff).detach() 44 | importance_weights = torch.clamp(importance_weights, min=0.1, max=10.0) 45 | weighted_advantages = target_weights * importance_weights 46 | 47 | numerator = (-trainer_log_probs * weighted_advantages * target_mask).sum() 48 | 49 | denominator = target_mask_sum 50 | return numerator / denominator 51 | -------------------------------------------------------------------------------- /src/forge/observability/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | from typing import Optional 9 | 10 | from monarch.actor import context, current_rank 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def get_proc_name_with_rank(proc_name: Optional[str] = None) -> str: 16 | """ 17 | Returns a unique identifier for the current rank from Monarch actor context. 18 | 19 | Multiple ranks from the same ProcMesh will share the same ProcMesh hash suffix, 20 | but have different rank numbers. 21 | 22 | Format: "{ProcessName}_{ProcMeshHash}_r{rank}" where: 23 | - ProcessName: The provided proc_name (e.g., "TrainActor") or extracted from actor_name if None. 24 | - ProcMeshHash: Hash suffix identifying the ProcMesh (e.g., "1abc2def") 25 | - rank: Local rank within the ProcMesh (0, 1, 2, ...) 26 | 27 | Note: If called from the main process (e.g. main.py), returns "client_r0". 28 | 29 | Args: 30 | proc_name: Optional override for process name. If None, uses actor_id.actor_name. 31 | 32 | Returns: 33 | str: Unique identifier per rank (e.g., "TrainActor_1abc2def_r0" or "client_r0"). 34 | """ 35 | ctx = context() 36 | actor_id = ctx.actor_instance.actor_id 37 | actor_name = actor_id.actor_name 38 | rank = current_rank().rank 39 | 40 | # If proc_name provided, extract procmesh hash from actor_name and combine 41 | if proc_name is not None: 42 | parts = actor_name.split("_") 43 | if len(parts) > 1: 44 | replica_hash = parts[-1] # (e.g., "MyActor_1abc2def" -> "1abc2def") 45 | return f"{proc_name}_{replica_hash}_r{rank}" 46 | else: 47 | # if a direct process (e.g. called from main), actor_name == "client" -> len(parts) == 1 48 | return f"{proc_name}_r{rank}" 49 | 50 | # No proc_name override - use full actor_name with rank 51 | return f"{actor_name}_r{rank}" 52 | -------------------------------------------------------------------------------- /apps/sft/qwen3_8b.yaml: -------------------------------------------------------------------------------- 1 | # >>> python -m apps.sft.main --config apps/sft/qwen3_8b.yaml 2 | 3 | 4 | # TODO: required by torchtitan 5 | # https://github.com/pytorch/torchtitan/blob/2f1c814da071cc8ad165d00be6f9c1a66f8e1cce/torchtitan/distributed/utils.py#L265 6 | comm: 7 | trace_buf_size: 0 8 | 9 | model_name: "Qwen/Qwen3-8B" 10 | 11 | model: 12 | name: qwen3 13 | flavor: 8B 14 | hf_assets_path: hf://${model_name} 15 | 16 | processes: 17 | procs: 8 18 | with_gpus: true 19 | 20 | optimizer: 21 | name: AdamW 22 | lr: 1e-5 23 | eps: 1e-8 24 | 25 | lr_scheduler: 26 | warmup_steps: 200 27 | 28 | training: 29 | local_batch_size: 8 30 | seq_len: 2048 31 | max_norm: 1.0 32 | steps: 1000 33 | compile: false 34 | datasets: 35 | - path: "yahma/alpaca-cleaned" 36 | split: "train[:95%]" 37 | 38 | eval: 39 | eval_every_n_steps: 50 # null = disabled 40 | max_eval_steps: null # null = run until epoch completes 41 | datasets: 42 | - path: "yahma/alpaca-cleaned" 43 | split: "train[95%:]" 44 | 45 | parallelism: 46 | data_parallel_replicate_degree: 1 47 | data_parallel_shard_degree: -1 48 | tensor_parallel_degree: 1 49 | pipeline_parallel_degree: 1 50 | context_parallel_degree: 1 51 | expert_parallel_degree: 1 52 | disable_loss_parallel: false 53 | 54 | checkpoint: 55 | enable: true 56 | folder: ./checkpoint # The folder to save checkpoints to. 57 | initial_load_path: hf://${model_name} # The path to load the initial checkpoint from. Ignored if `folder` exists. 58 | initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo 59 | last_save_in_hf: true 60 | interval: 500 61 | async_mode: "disabled" 62 | 63 | activation_checkpoint: 64 | mode: selective 65 | selective_ac_option: op 66 | 67 | metric_logging: 68 | wandb: 69 | project: sft-training 70 | group: sft_exp_${oc.env:USER} 71 | logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce 72 | 73 | # profiling: 74 | # enable_profiling: false 75 | 76 | # metrics: 77 | # log_freq: 10 78 | # enable_tensorboard: true 79 | # save_tb_folder: "tb" 80 | -------------------------------------------------------------------------------- /tests/sandbox/weight_sync/qwen3_1_7b.yaml: -------------------------------------------------------------------------------- 1 | # Weight Sync Sandbox Configuration 2 | # >>> python -m tests.sandbox.weight_sync.main --config tests/sandbox/weight_sync/qwen3_1_7b.yaml 3 | 4 | model: "Qwen/Qwen3-1.7B" 5 | local_batch_size: 4 6 | max_req_tokens: 64 7 | max_res_tokens: 64 8 | compile: true # Enable torch.compile for trainer, and CUDA graphs for vLLM 9 | 10 | metric_logging: 11 | console: 12 | logging_mode: global_reduce 13 | 14 | policy: 15 | prefetch_weights_to_shm: false # Disable to avoid shared memory warnings in test 16 | engine_args: 17 | model: ${model} 18 | tensor_parallel_size: 1 19 | pipeline_parallel_size: 1 20 | enforce_eager: ${not:${compile}} 21 | sampling_params: 22 | n: 1 23 | max_tokens: 32 # Just for verification forward pass 24 | temperature: 1.0 25 | top_p: 1.0 26 | 27 | trainer: 28 | model: 29 | name: qwen3 30 | flavor: 1.7B 31 | hf_assets_path: hf://${model} 32 | optimizer: 33 | name: AdamW 34 | lr: 1e-5 35 | eps: 1e-8 36 | lr_scheduler: 37 | warmup_steps: 1 38 | training: 39 | local_batch_size: ${local_batch_size} 40 | seq_len: 128 # max_req_tokens + max_res_tokens 41 | max_norm: 1.0 42 | steps: 1 # We only run 1 step 43 | dtype: bfloat16 44 | gc_freq: 1 45 | compile: 46 | enable: ${compile} 47 | parallelism: 48 | data_parallel_replicate_degree: 1 49 | data_parallel_shard_degree: 1 # Single GPU, no FSDP 50 | tensor_parallel_degree: 1 51 | pipeline_parallel_degree: 1 52 | context_parallel_degree: 1 53 | expert_parallel_degree: 1 54 | disable_loss_parallel: true 55 | checkpoint: 56 | enable: true 57 | folder: ./checkpoint 58 | initial_load_path: hf://${model} 59 | initial_load_in_hf: true 60 | last_save_in_hf: true 61 | async_mode: "disabled" 62 | activation_checkpoint: 63 | mode: selective 64 | selective_ac_option: op 65 | 66 | # Resource allocation - both as actors 67 | actors: 68 | policy: 69 | procs: 1 # Single process for generator 70 | with_gpus: true 71 | mesh_name: policy 72 | trainer: 73 | procs: 1 # Single process for trainer 74 | with_gpus: true 75 | mesh_name: trainer 76 | -------------------------------------------------------------------------------- /tests/unit_tests/observability/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Tests for observability utility functions.""" 8 | 9 | from forge.controller.actor import ForgeActor 10 | 11 | from forge.observability.utils import get_proc_name_with_rank 12 | from monarch.actor import endpoint 13 | 14 | 15 | class UtilActor(ForgeActor): 16 | """Actor for testing get_proc_name_with_rank in spawned context.""" 17 | 18 | @endpoint 19 | async def get_name(self) -> str: 20 | return get_proc_name_with_rank() 21 | 22 | 23 | class TestGetProcNameWithRank: 24 | """Tests for get_proc_name_with_rank utility.""" 25 | 26 | def test_direct_proc(self): 27 | """Direct proc should return 'client_r0'.""" 28 | assert get_proc_name_with_rank() == "client_r0" 29 | 30 | def test_direct_proc_with_override(self): 31 | """Direct proc with override should use provided name.""" 32 | result = get_proc_name_with_rank(proc_name="MyProcess") 33 | assert result == "MyProcess_r0" 34 | 35 | # TODO (felipemello): currently not working with CI wheel, but passes locally 36 | # reactive once wheel is updated with new monarch version 37 | # @pytest.mark.timeout(10) 38 | # @pytest.mark.asyncio 39 | # async def test_replicas(self): 40 | # """Test service with replicas returns unique names and hashes per replica.""" 41 | # actor = await UtilActor.options( 42 | # procs=1, num_replicas=2, with_gpus=False 43 | # ).as_service() 44 | # results = await actor.get_name.fanout() 45 | 46 | # assert len(results) == 2 47 | # assert len(set(results)) == 2 # All names are unique 48 | # for name in results: 49 | # assert name.startswith("UtilActor") 50 | # assert name.endswith("_r0") 51 | 52 | # # Extract hashes from names (format: ActorName_replicaIdx_hash_r0) 53 | # hashes = [name.split("_")[-2] for name in results] 54 | # assert hashes[0] != hashes[1] # Hashes are different between replicas 55 | -------------------------------------------------------------------------------- /apps/sft/llama3_8b.yaml: -------------------------------------------------------------------------------- 1 | # >>> python -m apps.sft.main --config apps/sft/llama3_8b.yaml 2 | 3 | # Config for supervised full finetuning using a Llama3.1 8B Instruct model 4 | 5 | # TODO: required by torchtitan 6 | # https://github.com/pytorch/torchtitan/blob/2f1c814da071cc8ad165d00be6f9c1a66f8e1cce/torchtitan/distributed/utils.py#L265 7 | comm: 8 | trace_buf_size: 0 9 | 10 | model_name: "meta-llama/Meta-Llama-3.1-8B-Instruct" 11 | 12 | model: 13 | name: llama3 14 | flavor: 8B 15 | hf_assets_path: hf://${model_name} 16 | 17 | processes: 18 | procs: 8 19 | with_gpus: true 20 | 21 | optimizer: 22 | name: AdamW 23 | lr: 1e-5 24 | eps: 1e-8 25 | 26 | lr_scheduler: 27 | warmup_steps: 200 28 | 29 | training: 30 | local_batch_size: 8 31 | seq_len: 2048 32 | max_norm: 1.0 33 | steps: 1000 34 | compile: false 35 | datasets: 36 | - path: "yahma/alpaca-cleaned" 37 | split: "train[:95%]" 38 | 39 | eval: 40 | eval_every_n_steps: 50 # null = disabled 41 | max_eval_steps: null # null = run until epoch completes 42 | datasets: 43 | - path: "yahma/alpaca-cleaned" 44 | split: "train[95%:]" 45 | 46 | parallelism: 47 | data_parallel_replicate_degree: 1 48 | data_parallel_shard_degree: -1 49 | tensor_parallel_degree: 1 50 | pipeline_parallel_degree: 1 51 | context_parallel_degree: 1 52 | expert_parallel_degree: 1 53 | disable_loss_parallel: false 54 | 55 | checkpoint: 56 | enable: true 57 | folder: ./checkpoint # The folder to save checkpoints to. 58 | initial_load_path: hf://${model_name} # The path to load the initial checkpoint from. Ignored if `folder` exists. 59 | initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo 60 | last_save_in_hf: true 61 | interval: 500 62 | async_mode: "disabled" 63 | 64 | activation_checkpoint: 65 | mode: selective 66 | selective_ac_option: op 67 | 68 | metric_logging: 69 | wandb: 70 | project: sft-training 71 | group: sft_exp_${oc.env:USER} 72 | logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce 73 | 74 | 75 | # profiling: 76 | # enable_profiling: false 77 | 78 | # metrics: 79 | # log_freq: 10 80 | # enable_tensorboard: true 81 | # save_tb_folder: "tb" 82 | -------------------------------------------------------------------------------- /src/forge/controller/service/spawn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Factory-based service spawning for the Monarch rollout system.""" 7 | 8 | import logging 9 | from typing import Type 10 | 11 | from forge.controller import ForgeActor 12 | from forge.controller.service import ServiceActor, ServiceConfig 13 | 14 | from forge.controller.service.interface import ServiceInterfaceV2 15 | 16 | from monarch.actor import proc_mesh 17 | 18 | logger = logging.getLogger(__name__) 19 | logger.setLevel(logging.INFO) 20 | 21 | 22 | async def spawn_service_v2( 23 | service_cfg: ServiceConfig, actor_def: Type[ForgeActor], **actor_kwargs 24 | ) -> ServiceInterfaceV2: 25 | """Spawns a service based on the actor class. 26 | 27 | Args: 28 | service_cfg: Service configuration 29 | actor_def: Actor class definition 30 | **actor_kwargs: Keyword arguments to pass to actor constructor 31 | 32 | Returns: 33 | A ServiceInterface that provides access to the Service Actor 34 | """ 35 | # Assert that actor_def is a subclass of ForgeActor 36 | if not issubclass(actor_def, ForgeActor): 37 | raise TypeError( 38 | f"actor_def must be a subclass of ForgeActor, got {type(actor_def).__name__}" 39 | ) 40 | 41 | # Create a single-node proc_mesh and actor_mesh for the Service Actor 42 | logger.info("Spawning Service Actor for %s", actor_def.__name__) 43 | m = await proc_mesh(gpus=1) 44 | service_actor = m.spawn( 45 | "service", ServiceActor, service_cfg, actor_def, actor_kwargs 46 | ) 47 | await service_actor.__initialize__.call_one() 48 | 49 | # Return the ServiceInterface that wraps the proc_mesh, actor_mesh, and actor_def 50 | return ServiceInterfaceV2(m, service_actor, actor_def) 51 | 52 | 53 | async def shutdown_service_v2(service: ServiceInterfaceV2) -> None: 54 | """Shuts down the service. 55 | 56 | Implemented in this way to avoid actors overriding stop() unintentionally. 57 | 58 | """ 59 | await service._service.stop.call_one() 60 | await service._proc_mesh.stop() 61 | -------------------------------------------------------------------------------- /tests/unit_tests/observability/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Shared fixtures and mocks for observability unit tests.""" 8 | 9 | from unittest.mock import MagicMock, patch 10 | 11 | import pytest 12 | from forge.observability.metrics import MetricCollector 13 | 14 | 15 | @pytest.fixture(autouse=True) 16 | def clear_metric_collector_singletons(): 17 | """Clear MetricCollector singletons before each test to avoid state leakage.""" 18 | MetricCollector._instances.clear() 19 | yield 20 | MetricCollector._instances.clear() 21 | 22 | 23 | @pytest.fixture(autouse=True) 24 | def clean_metrics_environment(): 25 | """Override the global mock_metrics_globally fixture to allow real metrics testing.""" 26 | import os 27 | 28 | from forge.env import FORGE_DISABLE_METRICS 29 | 30 | # Set default state for tests (metrics enabled) 31 | if FORGE_DISABLE_METRICS.name in os.environ: 32 | del os.environ[FORGE_DISABLE_METRICS.name] 33 | 34 | yield 35 | 36 | 37 | @pytest.fixture 38 | def mock_rank(): 39 | """Mock current_rank function with configurable rank.""" 40 | with patch("forge.observability.metrics.current_rank") as mock: 41 | rank_obj = MagicMock() 42 | rank_obj.rank = 0 43 | mock.return_value = rank_obj 44 | yield mock 45 | 46 | 47 | @pytest.fixture 48 | def mock_actor_context(): 49 | """Mock Monarch actor context for testing actor name generation.""" 50 | with ( 51 | patch("forge.observability.metrics.context") as mock_context, 52 | patch("forge.observability.metrics.current_rank") as mock_rank, 53 | ): 54 | # Setup mock context 55 | ctx = MagicMock() 56 | actor_instance = MagicMock() 57 | actor_instance.actor_id = "_1rjutFUXQrEJ[0].TestActorConfigured[0]" 58 | ctx.actor_instance = actor_instance 59 | mock_context.return_value = ctx 60 | 61 | # Setup mock rank 62 | rank_obj = MagicMock() 63 | rank_obj.rank = 0 64 | mock_rank.return_value = rank_obj 65 | 66 | yield { 67 | "context": mock_context, 68 | "rank": mock_rank, 69 | "expected_name": "TestActor_0XQr_r0", 70 | } 71 | -------------------------------------------------------------------------------- /apps/grpo/README.md: -------------------------------------------------------------------------------- 1 | Last updated: 2025-12-05 2 | 3 | # GRPO on GSM8K 4 | 5 | Training GRPO (Grouped Relative Policy Optimization) on GSM8K grade school math word problems. 6 | 7 | ## Dataset 8 | 9 | GSM8K consists of grade school math word problems that require multi-step arithmetic reasoning. Models generate solutions with chain-of-thought reasoning and provide final answers in `\boxed{}` format. 10 | 11 | **Example Input:** 12 | ``` 13 | <|begin_of_text|><|start_header_id|>system<|end_header_id|> 14 | 15 | Cutting Knowledge Date: December 2023 16 | Today Date: 26 Jul 2024 17 | 18 | Put all your scratchpad work between and tags. 19 | Your final answer should be between and tags otherwise it will not be scored.<|eot_id|><|start_header_id|>user<|end_header_id|> 20 | 21 | Lucille is painting her room. Two of her walls are 3 meters wide and 2 meters tall. The third wall is 5 meters wide and 2 meters tall. The final wall is 4 meters wide and 2 meters tall. If each can of paint covers 2 square meters, how many cans of paint does Lucille need?<|eot_id|><|start_header_id|>assistant<|end_header_id|> 22 | ``` 23 | 24 | **Example Output:** 25 | ``` 26 | 27 | First, let's calculate the area of each wall: 28 | 29 | - Two walls of 3 meters x 2 meters = 3 x 2 = 6 square meters per wall. Since there are two of these, 6 * 2 = 12 square meters. 30 | - The third wall is 5 meters x 2 meters = 10 square meters. 31 | - The final wall is 4 meters x 2 meters = 8 square meters. 32 | Total wall area = 12 + 10 + 8 = 30 square meters. 33 | 34 | Since each can covers 2 square meters, we need to divide the total wall area by the area one can covers: 30 / 2 = 15. 35 | 36 | 37 | 38 | 15 39 | ``` 40 | 41 | ## Quick Start 42 | 43 | **Llama 3.1 8B** (recommended for learning, requires 5 GPUs as is, not optimized): 44 | ```bash 45 | python -m apps.grpo.main --config apps/grpo/llama3_8b.yaml 46 | ``` 47 | 48 | **Qwen3 1.7B** (NOTE: Qwen3 is already saturated on GSM8K, so rewards will **not** increase. Requires 3 GPUs, not optimized): 49 | ```bash 50 | python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml 51 | ``` 52 | 53 | ## Expected Results 54 | 55 | For **Llama 3.1 8B**, training rewards should rise above 0.8 within the first few steps as the model learns the task. 56 | 57 | ![Llama 3.1 8B Training Rewards](wandb_llama8b.png) 58 | 59 | ## Configurations 60 | 61 | - `llama3_8b.yaml` - Meta Llama 3.1 8B Instruct 62 | - `qwen3_1_7b.yaml` - Qwen3 1.7B 63 | - `qwen3_8b.yaml` - Qwen3 8B 64 | - `qwen3_32b.yaml` - Qwen3 32B 65 | -------------------------------------------------------------------------------- /tests/unit_tests/test_torchstore_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import tempfile 9 | import unittest 10 | 11 | from pathlib import Path 12 | 13 | import pytest 14 | 15 | import torch 16 | import torch.distributed.checkpoint as dcp 17 | from forge.actors._torchstore_utils import DcpHandle 18 | 19 | ignore_torch_distributed_unitialized_warning = pytest.mark.filterwarnings( 20 | r"ignore:.*torch.distributed" 21 | ) 22 | 23 | 24 | class TestDcpHandle(unittest.TestCase): 25 | def _prepare_dcp_handle(self, test_dir: str) -> tuple[str, DcpHandle]: 26 | """Returns path to checkpoint and DcpHandle.""" 27 | checkpoint_id = str(Path(test_dir) / "test_checkpoint_id") 28 | state_dict = {"a": torch.rand(1, 1), "b": torch.rand(1, 1)} 29 | metadata = dcp.save(checkpoint_id=checkpoint_id, state_dict=state_dict) 30 | assert os.path.exists(checkpoint_id), "failed to set up test checkpoint" 31 | return checkpoint_id, DcpHandle( 32 | checkpoint_id=checkpoint_id, 33 | metadata=metadata, 34 | param_names=list(state_dict.keys()), 35 | ) 36 | 37 | @ignore_torch_distributed_unitialized_warning 38 | def test_dcp_handle_drop_deletes(self): 39 | with tempfile.TemporaryDirectory() as test_dir: 40 | ckpt_path, handle = self._prepare_dcp_handle(test_dir) 41 | handle.drop() 42 | self.assertFalse(os.path.exists(ckpt_path)) 43 | 44 | @ignore_torch_distributed_unitialized_warning 45 | def test_dcp_handle_drop_sets_none(self): 46 | with tempfile.TemporaryDirectory() as test_dir: 47 | _, handle = self._prepare_dcp_handle(test_dir) 48 | handle.drop() 49 | self.assertEqual(handle.checkpoint_id, None) 50 | self.assertEqual(handle.metadata, None) 51 | self.assertEqual(handle.param_names, None) 52 | 53 | @ignore_torch_distributed_unitialized_warning 54 | def test_dcp_handle_drop_sets_none_for_manifold(self): 55 | with tempfile.TemporaryDirectory() as test_dir: 56 | _, handle = self._prepare_dcp_handle(test_dir) 57 | handle.checkpoint_id = "manifold://test_bucket/tree/test_path" 58 | handle.drop() 59 | self.assertEqual(handle.checkpoint_id, None) 60 | self.assertEqual(handle.metadata, None) 61 | self.assertEqual(handle.param_names, None) 62 | -------------------------------------------------------------------------------- /tests/sandbox/vllm/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """To run: 8 | export HF_HUB_DISABLE_XET=1 9 | python -m tests.sandbox.vllm.main --config tests/sandbox/vllm/llama3_8b.yaml 10 | """ 11 | 12 | import asyncio 13 | 14 | import os 15 | 16 | from forge.actors.generator import Generator 17 | 18 | from forge.controller.provisioner import init_provisioner, shutdown 19 | 20 | from forge.data_models.completion import Completion 21 | from forge.observability.metric_actors import get_or_create_metric_logger 22 | from forge.types import LauncherConfig, ProvisionerConfig 23 | from forge.util.config import parse 24 | from omegaconf import DictConfig 25 | 26 | os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600" 27 | os.environ["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824" 28 | 29 | 30 | async def run(cfg: DictConfig): 31 | if cfg.get("provisioner", None) is not None: 32 | await init_provisioner( 33 | ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) 34 | ) 35 | metric_logging_cfg = cfg.get( 36 | "metric_logging", {"console": {"logging_mode": "global_reduce"}} 37 | ) 38 | mlogger = await get_or_create_metric_logger(process_name="Controller") 39 | await mlogger.init_backends.call_one(metric_logging_cfg) 40 | 41 | if (prompt := cfg.get("prompt")) is None: 42 | prompt = "Tell me a joke" 43 | 44 | print("Spawning service...") 45 | policy = await Generator.options(**cfg.services.policy).as_service(**cfg.policy) 46 | 47 | import time 48 | 49 | print("Requesting generation...") 50 | n = 100 51 | start = time.time() 52 | response_outputs: list[Completion] = await asyncio.gather( 53 | *[policy.generate.route(prompt=prompt) for _ in range(n)] 54 | ) 55 | end = time.time() 56 | 57 | print(f"Generation of {n} requests completed in {end - start:.2f} seconds.") 58 | print( 59 | f"Generation with procs {cfg.services.policy.procs}, replicas {cfg.services.policy.num_replicas}" 60 | ) 61 | 62 | print(f"\nGeneration Results (last one of {n} requests):") 63 | print("=" * 80) 64 | for batch, response in enumerate(response_outputs[-1]): 65 | print(f"Sample {batch + 1}:") 66 | print(f"User: {prompt}") 67 | print(f"Assistant: {response.text}") 68 | print("-" * 80) 69 | 70 | print("\nShutting down...") 71 | await shutdown() 72 | 73 | 74 | @parse 75 | def recipe_main(cfg: DictConfig) -> None: 76 | asyncio.run(run(cfg)) 77 | 78 | 79 | if __name__ == "__main__": 80 | recipe_main() 81 | -------------------------------------------------------------------------------- /docs/source/tutorial_sources/template_tutorial.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Template Tutorial 9 | ================= 10 | 11 | **Author:** `FirstName LastName `_ 12 | 13 | .. grid:: 2 14 | 15 | .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn 16 | :class-card: card-prerequisites 17 | 18 | * Item 1 19 | * Item 2 20 | * Item 3 21 | 22 | .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites 23 | :class-card: card-prerequisites 24 | 25 | * PyTorch v2.0.0 26 | * GPU ??? 27 | * Other items 3 28 | 29 | 30 | To test your tutorial locally, you can do one of the following: 31 | 32 | * You can control specific files that generate the results by using 33 | ``GALLERY_PATTERN`` environment variable. The GALLERY_PATTERN variable 34 | respects regular expressions. 35 | For example to run only ``neural_style_transfer_tutorial.py``, 36 | use the following command: 37 | 38 | .. code-block:: sh 39 | 40 | GALLERY_PATTERN="neural_style_transfer_tutorial.py" make html 41 | 42 | or 43 | 44 | .. code-block:: sh 45 | 46 | GALLERY_PATTERN="neural_style_transfer_tutorial.py" sphinx-build . _build 47 | 48 | * Make a copy of this repository and add only your 49 | tutorial to the `beginner_source` directory removing all other tutorials. 50 | Then run ``make html``. 51 | 52 | Verify that all outputs were generated correctly in the created HTML. 53 | """ 54 | 55 | ######################################################################### 56 | # Overview 57 | # -------- 58 | # 59 | # Describe Why is this topic important? Add Links to relevant research papers. 60 | # 61 | # This tutorial walks you through the process of.... 62 | # 63 | # Steps 64 | # ----- 65 | # 66 | # Example code (the output below is generated automatically): 67 | # 68 | import torch 69 | 70 | x = torch.rand(5, 3) 71 | print(x) 72 | 73 | ###################################################################### 74 | # (Optional) Additional Exercises 75 | # ------------------------------- 76 | # 77 | # Add additional practice exercises for users to test their knowledge. 78 | # Example: `NLP from Scratch `__. 79 | # 80 | 81 | ###################################################################### 82 | # Conclusion 83 | # ---------- 84 | # 85 | # Summarize the steps and concepts covered. Highlight key takeaways. 86 | # 87 | # Further Reading 88 | # --------------- 89 | # 90 | # * Link1 91 | # * Link2 92 | -------------------------------------------------------------------------------- /src/forge/actors/_torchstore_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | import shutil 8 | from dataclasses import dataclass 9 | 10 | import torch 11 | import torch.distributed.checkpoint as dcp 12 | from torch.distributed.checkpoint.metadata import Metadata as DcpMeta 13 | from torchstore.transport.buffers import rdma_available 14 | 15 | logger = logging.getLogger(__name__) 16 | logger.setLevel(logging.DEBUG) 17 | 18 | KEY_DELIM = "." 19 | DCP_WHOLE_STATE_TAG = "dcp_whole_state_dict" 20 | 21 | 22 | @dataclass 23 | class DcpHandle: 24 | checkpoint_id: str | None = None 25 | metadata: DcpMeta | None = None 26 | param_names: list[str] | None = None 27 | 28 | def drop(self) -> None: 29 | if self.checkpoint_id is None: 30 | raise ValueError("Dropping a null DcpHandle") 31 | if self.checkpoint_id.startswith("manifold://"): 32 | # Probably don't need to delete the checkpoint if it's on manifold 33 | logger.warning( 34 | f"Skipping deletion of {self.checkpoint_id} since it's on manifold" 35 | ) 36 | self.checkpoint_id = None 37 | self.metadata = None 38 | self.param_names = None 39 | return 40 | 41 | try: 42 | shutil.rmtree(self.checkpoint_id, ignore_errors=False) 43 | logger.debug(f"Removed old weights at {self.checkpoint_id}") 44 | except OSError as e: 45 | logger.error(f"Error deleting {self.checkpoint_id}: {e}") 46 | finally: 47 | self.checkpoint_id = None 48 | self.metadata = None 49 | self.param_names = None 50 | 51 | 52 | def load_tensor_from_dcp(handle: DcpHandle, param_name) -> torch.Tensor: 53 | tensor_meta = handle.metadata.state_dict_metadata[param_name] 54 | buffer = torch.empty(tensor_meta.size, dtype=tensor_meta.properties.dtype) 55 | dcp.load(checkpoint_id=handle.checkpoint_id, state_dict={param_name: buffer}) 56 | return buffer 57 | 58 | 59 | def get_param_prefix(policy_version: int) -> str: 60 | return f"policy_ver_{policy_version:010d}" 61 | 62 | 63 | def get_param_key(policy_version: int, name: str) -> str: 64 | return f"policy_ver_{policy_version:010d}{KEY_DELIM}{name}" 65 | 66 | 67 | def extract_param_name(key: str) -> str: 68 | return KEY_DELIM.join(key.split(KEY_DELIM)[1:]) 69 | 70 | 71 | def get_dcp_whole_state_dict_key(policy_version: int) -> str: 72 | return f"{get_param_prefix(policy_version)}{KEY_DELIM}{DCP_WHOLE_STATE_TAG}" 73 | 74 | 75 | def rdma_enabled() -> bool: 76 | """Return if TorchStore thinks we're using RDMA""" 77 | return rdma_available() 78 | -------------------------------------------------------------------------------- /tests/integration_tests/test_coder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Integration tests for forge.actors.coder.SandboxedPythonCoder. 9 | 10 | Requires enroot to be installed. 11 | 12 | """ 13 | 14 | import os 15 | import uuid 16 | 17 | import pytest 18 | 19 | from forge.actors.coder import SandboxedPythonCoder 20 | 21 | 22 | @pytest.mark.timeout(30) 23 | @pytest.mark.asyncio 24 | async def test_coder_runs_python(): 25 | """Integration test for SandboxedPythonCoder with real container execution.""" 26 | # Create unique names to avoid test conflicts 27 | unique_id = str(uuid.uuid1()) 28 | container_name = f"test_sandbox_{unique_id}" 29 | image_path = f"/tmp/python_test_{unique_id}.sqsh" 30 | 31 | coder = None 32 | try: 33 | coder = await SandboxedPythonCoder.as_actor( 34 | docker_image="docker://python:3.10", 35 | sqsh_image_path=image_path, 36 | container_name=container_name, 37 | ) 38 | 39 | # Execute code 40 | results, _ = await coder.execute.call_one( 41 | code="print('hello world')", 42 | ) 43 | print("Got results", results) 44 | assert results == "hello world\n" 45 | 46 | finally: 47 | # Clean up resources 48 | if coder: 49 | await SandboxedPythonCoder.shutdown(coder) 50 | 51 | # Clean up the image file 52 | if os.path.exists(image_path): 53 | os.unlink(image_path) 54 | 55 | 56 | @pytest.mark.timeout(30) 57 | @pytest.mark.asyncio 58 | async def test_coder_catches_error(): 59 | """Integration test for SandboxedPythonCoder with real container execution.""" 60 | # Create unique names to avoid test conflicts 61 | unique_id = str(uuid.uuid1()) 62 | container_name = f"test_sandbox_{unique_id}" 63 | image_path = f"/tmp/python_test_{unique_id}.sqsh" 64 | 65 | coder = None 66 | try: 67 | print("starting test") 68 | coder = await SandboxedPythonCoder.as_actor( 69 | docker_image="docker://python:3.10", 70 | sqsh_image_path=image_path, 71 | container_name=container_name, 72 | ) 73 | print("Got coder") 74 | 75 | # Execute code 76 | _, stderr = await coder.execute.call_one( 77 | code="hello world", 78 | ) 79 | print("got stderr", stderr) 80 | assert "SyntaxError" in stderr 81 | 82 | finally: 83 | # Clean up resources 84 | if coder: 85 | await SandboxedPythonCoder.shutdown(coder) 86 | 87 | # Clean up the image file 88 | if os.path.exists(image_path): 89 | os.unlink(image_path) 90 | -------------------------------------------------------------------------------- /src/forge/rl/types.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass 8 | from typing import Any 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | from forge.actors.generator import Generator 14 | from forge.data_models.completion import Completion 15 | 16 | 17 | @dataclass 18 | class Episode: 19 | episode_id: str 20 | pad_id: int 21 | request_len: int 22 | response_len: int 23 | target: Any | None = None 24 | request: str | None = None 25 | response: str | None = None 26 | # Processed data 27 | completion: Completion | None = None 28 | ref_logprobs: torch.Tensor | None = None 29 | reward: float | None = None 30 | reward_breakdown: dict[str, float] | None = None 31 | advantage: float | None = None 32 | 33 | @property 34 | def policy_version(self) -> int | None: 35 | return self.completion.generator_version 36 | 37 | @property 38 | def request_tensor(self) -> torch.Tensor: 39 | tensor: torch.Tensor = self.completion.prompt_ids.to(torch.long) 40 | if tensor.shape[0] < self.request_len: # left pad 41 | diff = self.request_len - tensor.shape[0] 42 | tensor = F.pad(tensor, (diff, 0), value=self.pad_id) 43 | return tensor 44 | 45 | @property 46 | def response_tensor(self) -> torch.Tensor: 47 | tensor: torch.Tensor = self.completion.token_ids.to(torch.long) 48 | if tensor.shape[0] < self.response_len: # right pad 49 | diff = self.response_len - tensor.shape[0] 50 | tensor = F.pad(tensor, (0, diff), value=self.pad_id) 51 | return tensor 52 | 53 | def to_dict(self, exclude: list[str] | None = None) -> dict[str, Any]: 54 | """Convert episode to dict, optionally excluding specified fields.""" 55 | result = { 56 | "episode_id": self.episode_id, 57 | "policy_version": self.policy_version, 58 | "prompt": self.request, 59 | "response": self.response, 60 | "target": str(self.target), 61 | "reward": self.reward, 62 | "advantage": self.advantage, 63 | "request_len": self.request_len, 64 | "response_len": self.response_len, 65 | "pad_id": self.pad_id, 66 | "ref_logprobs": self.ref_logprobs, 67 | "completion": self.completion, 68 | } 69 | 70 | if self.reward_breakdown is not None and "reward_breakdown" not in exclude: 71 | result.update(self.reward_breakdown) 72 | 73 | if exclude: 74 | for key in exclude: 75 | result.pop(key, None) 76 | 77 | return result 78 | 79 | 80 | # Represents the group (G) of episodes in GRPO 81 | Group = list[Episode] 82 | 83 | # Represents the Policy Model to collect data from 84 | Policy = Generator 85 | -------------------------------------------------------------------------------- /src/forge/controller/service/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Service metrics collection and aggregation. 8 | 9 | This module provides comprehensive metrics tracking for distributed services, 10 | including per-replica performance data, service-wide aggregations, and 11 | health status information. 12 | """ 13 | 14 | from dataclasses import dataclass, field 15 | 16 | from forge.controller.service.replica import ReplicaMetrics 17 | 18 | 19 | # TODO - tie this into metrics logger when it exists. 20 | @dataclass 21 | class ServiceMetrics: 22 | """ 23 | Aggregated metrics collection for the entire service. 24 | 25 | Provides service-wide visibility into performance, health, and scaling metrics 26 | by aggregating data from all replica instances. 27 | 28 | Attributes: 29 | replica_metrics: Per-replica metrics indexed by replica ID 30 | total_sessions: Number of active sessions across all replicas 31 | healthy_replicas: Number of currently healthy replicas 32 | total_replicas: Total number of replicas (healthy + unhealthy) 33 | last_scale_event: Timestamp of the last scaling operation 34 | """ 35 | 36 | # Replica metrics 37 | replica_metrics: dict[int, ReplicaMetrics] = field(default_factory=dict) 38 | # Service-level metrics 39 | total_sessions: int = 0 40 | healthy_replicas: int = 0 41 | total_replicas: int = 0 42 | # Time-based metrics 43 | last_scale_event: float = 0.0 44 | 45 | def get_total_request_rate(self, window_seconds: float = 60.0) -> float: 46 | """Get total requests per second across all replicas.""" 47 | return sum( 48 | metrics.get_request_rate(window_seconds) 49 | for metrics in self.replica_metrics.values() 50 | ) 51 | 52 | def get_avg_queue_depth(self, replicas: list) -> float: 53 | """Get average queue depth across all healthy replicas.""" 54 | healthy_replicas = [r for r in replicas if r.healthy] 55 | if not healthy_replicas: 56 | return 0.0 57 | total_queue_depth = sum(r.request_queue.qsize() for r in healthy_replicas) 58 | return total_queue_depth / len(healthy_replicas) 59 | 60 | def get_avg_capacity_utilization(self, replicas: list) -> float: 61 | """Get average capacity utilization across all healthy replicas.""" 62 | healthy_replicas = [r for r in replicas if r.healthy] 63 | if not healthy_replicas: 64 | return 0.0 65 | total_utilization = sum(r.capacity_utilization for r in healthy_replicas) 66 | return total_utilization / len(healthy_replicas) 67 | 68 | def get_sessions_per_replica(self) -> float: 69 | """Get average sessions per replica.""" 70 | if self.total_replicas == 0: 71 | return 0.0 72 | return self.total_sessions / self.total_replicas 73 | -------------------------------------------------------------------------------- /src/forge/controller/service/router.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | from .interface import Router 10 | from .replica import Replica 11 | 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(logging.DEBUG) 14 | 15 | 16 | class RoundRobinRouter(Router): 17 | """Round-robin router for stateless requests.""" 18 | 19 | def __init__(self): 20 | self._next_idx = 0 21 | 22 | def get_replica( 23 | self, 24 | healthy_replicas: list[Replica], 25 | sess_id: str | None = None, 26 | session_map: dict[str, int] | None = None, 27 | ) -> Replica: 28 | if not healthy_replicas: 29 | raise RuntimeError("No healthy replicas available for load balancing") 30 | 31 | self._next_idx = (self._next_idx + 1) % len(healthy_replicas) 32 | replica = healthy_replicas[self._next_idx] 33 | 34 | return replica 35 | 36 | 37 | class LeastLoadedRouter(Router): 38 | """Always routes to the replica with the lowest current load.""" 39 | 40 | def get_replica( 41 | self, 42 | healthy_replicas: list[Replica], 43 | sess_id: str | None = None, 44 | session_map: dict[str, int] | None = None, 45 | ) -> Replica: 46 | if not healthy_replicas: 47 | raise RuntimeError("No healthy replicas available for session assignment") 48 | return min(healthy_replicas, key=lambda r: r.current_load) 49 | 50 | 51 | class SessionRouter(Router): 52 | """Session-based routing: sticky sessions with a fallback router.""" 53 | 54 | def __init__(self, fallback_router: Router): 55 | self.fallback_router = fallback_router 56 | 57 | def get_replica( 58 | self, 59 | healthy_replicas: list[Replica], 60 | sess_id: str | None = None, 61 | session_map: dict[str, int] | None = None, 62 | ) -> Replica: 63 | if sess_id is None: 64 | raise ValueError("SessionRouter requires a session ID") 65 | 66 | if session_map is None: 67 | raise ValueError("Session map must be provided for SessionRouter") 68 | 69 | # Check if session already has a replica 70 | if sess_id in session_map: 71 | replica_idx = session_map[sess_id] 72 | # Find the replica with this index 73 | for r in healthy_replicas: 74 | if r.idx == replica_idx: 75 | return r 76 | # If the replica is no longer healthy, remove from session map and reassign 77 | del session_map[sess_id] 78 | 79 | # Use fallback router to assign a new replica 80 | replica = self.fallback_router.get_replica( 81 | healthy_replicas, sess_id, session_map 82 | ) 83 | session_map[sess_id] = replica.idx 84 | logger.debug( 85 | "Assigning session %s to replica %d", 86 | sess_id, 87 | replica.idx, 88 | ) 89 | return replica 90 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # ---- All project specifications ---- # 2 | [project] 3 | name = "forge" 4 | description = "A PyTorch native platform for post-training generative AI models" 5 | readme = "README.md" 6 | requires-python = ">=3.10, <3.13" 7 | license = {file = "LICENSE"} 8 | authors = [ 9 | { name = "PyTorch Team", email = "packages@pytorch.org" }, 10 | ] 11 | keywords = ["pytorch", "training", "llm"] 12 | dependencies = [ 13 | # PyTorch 14 | "torch==2.9.0", 15 | "torchdata>=0.8.0", 16 | "torchtitan==0.2.0", 17 | "torchmonarch-nightly==2025.12.17", 18 | # Issue 656: switch to ping torchstore nightly 19 | "torchstore", 20 | # vLLM 21 | "vllm", 22 | # Hugging Face integrations 23 | "datasets>=2.21.0", 24 | "tokenizers", 25 | # Miscellaneous 26 | "omegaconf", 27 | "wandb", 28 | "hf_transfer", 29 | "six", 30 | "setuptools<80", 31 | ] 32 | dynamic = ["version"] 33 | 34 | [project.urls] 35 | GitHub = "https://github.com/meta-pytorch/torchforge" 36 | Documentation = "https://meta-pytorch.org/torchforge" 37 | Issues = "https://github.com/meta-pytorch/torchforge/issues" 38 | 39 | [project.optional-dependencies] 40 | dev = [ 41 | "pre-commit", 42 | "pytest", 43 | "pytest-cov", 44 | "pytest-timeout", 45 | "tensorboard", 46 | "expecttest", 47 | "tomli>=1.1.0", 48 | "anyio", 49 | "pytest-asyncio", 50 | "multiprocess", 51 | ] 52 | docs = [ 53 | "sphinx==7.2.6", 54 | "pytorch-sphinx-theme2==0.1.0", 55 | "docutils>=0.18.1,<0.21", 56 | "sphinx-design==0.6.1", 57 | "sphinxcontrib-mermaid==1.0.0", 58 | "sphinx-gallery==0.19.0", 59 | "matplotlib", 60 | "myst-parser", 61 | "sphinx-sitemap==2.7.1", 62 | "sphinx-autodoc-typehints==1.25.3", 63 | ] 64 | 65 | # ---- Explicit project build information ---- # 66 | [build-system] 67 | requires = ["setuptools>=61.0"] 68 | build-backend = "setuptools.build_meta" 69 | 70 | [tool.setuptools.packages.find] 71 | where = ["src"] 72 | include = ["forge*"] 73 | 74 | [tool.pytest.ini_options] 75 | addopts = ["--showlocals"] # show local variables in tracebacks 76 | pythonpath = "." 77 | 78 | [tool.uv.workspace] 79 | members = [ 80 | "forge", 81 | ] 82 | 83 | # pytorch 84 | [[tool.uv.index]] 85 | name = "pytorch-cu128" 86 | url = "https://download.pytorch.org/whl/cu128" 87 | 88 | # vllm 89 | [[tool.uv.index]] 90 | name = "vllm-forge" 91 | url = "https://download.pytorch.org/whl/preview/forge" 92 | 93 | [tool.uv.sources] 94 | torch = { index = "pytorch-cu128" } 95 | vllm = { index = "vllm-forge" } 96 | # Issue 656: switch to ping torchstore nightly 97 | torchstore = { git = "https://github.com/meta-pytorch/torchstore.git", branch = "no-monarch-2025.12.17" } 98 | 99 | [tool.uv] 100 | # TODO: revert to stricter default uv strategy 101 | index-strategy = "unsafe-best-match" 102 | prerelease = "allow" 103 | # TODO: add more backends 104 | environments = [ 105 | "sys_platform == 'linux'", 106 | ] 107 | 108 | [tool.black] 109 | target-version = ["py310"] # match the minium supported python version 110 | 111 | [tool.usort] 112 | first_party_detection = false 113 | -------------------------------------------------------------------------------- /src/forge/util/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # FIXME: remove this once wandb fixed this issue 8 | # https://github.com/wandb/wandb/issues/10890 9 | # Patch importlib.metadata.distributions before wandb imports it 10 | # to filter out packages with None metadata 11 | import importlib.metadata 12 | 13 | # Guard to ensure this runs only once 14 | if not hasattr(importlib.metadata, "_distributions_patched"): 15 | _original_distributions = importlib.metadata.distributions 16 | 17 | def _patched_distributions(): 18 | """Filter out distributions with None metadata""" 19 | for distribution in _original_distributions(): 20 | if distribution.metadata is not None: 21 | yield distribution 22 | 23 | importlib.metadata.distributions = _patched_distributions 24 | importlib.metadata._distributions_patched = True 25 | 26 | import logging 27 | from functools import lru_cache 28 | 29 | from torch import distributed as dist 30 | 31 | 32 | def get_logger(level: str | None = None) -> logging.Logger: 33 | """ 34 | Get a logger with a stream handler. 35 | 36 | Args: 37 | level (str | None): The logging level. See https://docs.python.org/3/library/logging.html#levels for list of levels. 38 | 39 | Example: 40 | >>> logger = get_logger("INFO") 41 | >>> logger.info("Hello world!") 42 | INFO:forge.util.logging: Hello world! 43 | 44 | Returns: 45 | logging.Logger: The logger. 46 | """ 47 | logger = logging.getLogger(__name__) 48 | if not logger.hasHandlers(): 49 | handler = logging.StreamHandler() 50 | formatter = logging.Formatter("%(levelname)s:%(name)s: %(message)s") 51 | handler.setFormatter(formatter) 52 | logger.addHandler(handler) 53 | if level is not None: 54 | level = getattr(logging, level.upper()) 55 | logger.setLevel(level) 56 | return logger 57 | 58 | 59 | def log_rank_zero(logger: logging.Logger, msg: str, level: int = logging.INFO) -> None: 60 | """ 61 | Logs a message only on rank zero. 62 | 63 | Args: 64 | logger (logging.Logger): The logger. 65 | msg (str): The warning message. 66 | level (int): The logging level. See https://docs.python.org/3/library/logging.html#levels for values. 67 | Defaults to ``logging.INFO``. 68 | """ 69 | rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 70 | if rank != 0: 71 | return 72 | logger.log(level, msg, stacklevel=2) 73 | 74 | 75 | @lru_cache(None) 76 | def log_once(logger: logging.Logger, msg: str, level: int = logging.INFO) -> None: 77 | """ 78 | Logs a message only once. LRU cache is used to ensure a specific message is 79 | logged only once, similar to how :func:`~warnings.warn` works when the ``once`` 80 | rule is set via command-line or environment variable. 81 | 82 | Args: 83 | logger (logging.Logger): The logger. 84 | msg (str): The warning message. 85 | level (int): The logging level. See https://docs.python.org/3/library/logging.html#levels for values. 86 | Defaults to ``logging.INFO``. 87 | """ 88 | log_rank_zero(logger=logger, msg=msg, level=level) 89 | -------------------------------------------------------------------------------- /.github/workflows/build_vllm.yaml: -------------------------------------------------------------------------------- 1 | name: Build pinned vLLM against PyTorch nightly and upload 2 | 3 | on: 4 | push: 5 | branches: 6 | - nightly 7 | workflow_dispatch: 8 | 9 | permissions: 10 | id-token: write 11 | contents: read 12 | 13 | jobs: 14 | build: 15 | name: forge-cu128-nightly 16 | uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@vllm-push 17 | strategy: 18 | fail-fast: false 19 | with: 20 | repository: meta-pytorch/forge 21 | ref: "" 22 | test-infra-repository: pytorch/test-infra 23 | test-infra-ref: vllm-push 24 | run-smoke-test: false 25 | wheel-nightly-policy: gha_workflow_preview_build_wheels 26 | wheel-upload-path: whl/preview/forge/ 27 | package-name: forge 28 | channel: test # Hack here to make sure stable pytorch is used 29 | build-matrix: | 30 | { 31 | "include": [ 32 | { 33 | "python_version": "3.10", 34 | "gpu_arch_type": "cpu", 35 | "gpu_arch_version": "12.8", 36 | "desired_cuda": "cu128", 37 | "container_image": "pytorch/manylinux2_28-builder:cuda12.8", 38 | "package_type": "manywheel", 39 | "build_name": "manywheel-py3_10-cuda12_8", 40 | "validation_runner": "linux.12xlarge.memory", 41 | "installation": "pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128", 42 | "channel": "test", 43 | "upload_to_base_bucket": "no", 44 | "stable_version": "2.9.0", 45 | "use_split_build": false 46 | }, 47 | { 48 | "python_version": "3.11", 49 | "gpu_arch_type": "cpu", 50 | "gpu_arch_version": "12.8", 51 | "desired_cuda": "cu128", 52 | "container_image": "pytorch/manylinux2_28-builder:cuda12.8", 53 | "package_type": "manywheel", 54 | "build_name": "manywheel-py3_11-cuda12_8", 55 | "validation_runner": "linux.12xlarge.memory", 56 | "installation": "pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128", 57 | "channel": "test", 58 | "upload_to_base_bucket": "no", 59 | "stable_version": "2.9.0", 60 | "use_split_build": false 61 | }, 62 | { 63 | "python_version": "3.12", 64 | "gpu_arch_type": "cpu", 65 | "gpu_arch_version": "12.8", 66 | "desired_cuda": "cu128", 67 | "container_image": "pytorch/manylinux2_28-builder:cuda12.8", 68 | "package_type": "manywheel", 69 | "build_name": "manywheel-py3_12-cuda12_8", 70 | "validation_runner": "linux.12xlarge.memory", 71 | "installation": "pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128", 72 | "channel": "test", 73 | "upload_to_base_bucket": "no", 74 | "stable_version": "2.9.0", 75 | "use_split_build": false 76 | }, 77 | ] 78 | } 79 | pre-script: .github/packaging/pre_build_cpu.sh 80 | post-script: .github/packaging/post_build_script.sh 81 | trigger-event: ${{ github.event_name }} 82 | build-platform: 'python-build-package' 83 | -------------------------------------------------------------------------------- /.github/packaging/vllm_reqs_12_8.txt: -------------------------------------------------------------------------------- 1 | # This file was generated by running ./scripts/generate_vllm_reqs.sh 2 | aiohappyeyeballs==2.6.1 3 | aiohttp==3.13.1 4 | aiosignal==1.4.0 5 | annotated-types==0.7.0 6 | anyio==4.11.0 7 | astor==0.8.1 8 | async-timeout==5.0.1 9 | attrs==25.4.0 10 | blake3==1.0.8 11 | cachetools==6.2.1 12 | cbor2==5.7.0 13 | certifi==2025.10.5 14 | cffi==2.0.0 15 | charset-normalizer==3.4.4 16 | click==8.2.1 17 | cloudpickle==3.1.1 18 | cmake==4.1.0 19 | compressed-tensors==0.10.2 20 | cupy-cuda12x==13.6.0 21 | depyf==0.19.0 22 | dill==0.4.0 23 | diskcache==5.6.3 24 | distro==1.9.0 25 | dnspython==2.8.0 26 | einops==0.8.1 27 | email-validator==2.3.0 28 | exceptiongroup==1.3.0 29 | fastapi==0.119.1 30 | fastapi-cli==0.0.14 31 | fastapi-cloud-cli==0.3.1 32 | fastrlock==0.8.3 33 | filelock==3.19.1 34 | frozenlist==1.8.0 35 | fsspec==2025.9.0 36 | gguf==0.17.1 37 | h11==0.16.0 38 | hf-xet==1.1.10 39 | httpcore==1.0.9 40 | httptools==0.7.1 41 | httpx==0.28.1 42 | huggingface-hub==0.35.3 43 | idna==3.11 44 | interegular==0.3.3 45 | Jinja2==3.1.6 46 | jiter==0.11.1 47 | jsonschema==4.25.1 48 | jsonschema-specifications==2025.9.1 49 | lark==1.2.2 50 | llguidance==0.7.30 51 | llvmlite==0.44.0 52 | lm-format-enforcer==0.10.12 53 | markdown-it-py==4.0.0 54 | MarkupSafe==2.1.5 55 | mdurl==0.1.2 56 | mistral_common==1.8.5 57 | mpmath==1.3.0 58 | msgpack==1.1.2 59 | msgspec==0.19.0 60 | multidict==6.7.0 61 | networkx==3.3 62 | ninja==1.13.0 63 | numba==0.61.2 64 | numpy==2.2.6 65 | nvidia-cublas-cu12==12.8.4.1 66 | nvidia-cuda-cupti-cu12==12.8.90 67 | nvidia-cuda-nvrtc-cu12==12.8.93 68 | nvidia-cuda-runtime-cu12==12.8.90 69 | nvidia-cudnn-cu12==9.10.2.21 70 | nvidia-cufft-cu12==11.3.3.83 71 | nvidia-cufile-cu12==1.13.1.3 72 | nvidia-curand-cu12==10.3.9.90 73 | nvidia-cusolver-cu12==11.7.3.90 74 | nvidia-cusparse-cu12==12.5.8.93 75 | nvidia-cusparselt-cu12==0.7.1 76 | nvidia-nccl-cu12==2.27.5 77 | nvidia-nvjitlink-cu12==12.8.93 78 | nvidia-nvshmem-cu12==3.3.20 79 | nvidia-nvtx-cu12==12.8.90 80 | openai==1.90.0 81 | opencv-python-headless==4.12.0.88 82 | outlines_core==0.2.10 83 | packaging==25.0 84 | partial-json-parser==0.2.1.1.post6 85 | pillow==12.0.0 86 | prometheus-fastapi-instrumentator==7.1.0 87 | prometheus_client==0.23.1 88 | propcache==0.4.1 89 | protobuf==6.33.0 90 | psutil==7.1.1 91 | py-cpuinfo==9.0.0 92 | pybase64==1.4.2 93 | pycountry==24.6.1 94 | pycparser==2.23 95 | pydantic==2.12.3 96 | pydantic-extra-types==2.10.6 97 | pydantic_core==2.41.4 98 | Pygments==2.19.2 99 | python-dotenv==1.1.1 100 | python-json-logger==4.0.0 101 | python-multipart==0.0.20 102 | PyYAML==6.0.3 103 | pyzmq==27.1.0 104 | ray==2.50.1 105 | referencing==0.37.0 106 | regex==2025.10.23 107 | requests==2.32.5 108 | rich==14.2.0 109 | rich-toolkit==0.15.1 110 | rignore==0.7.1 111 | rpds-py==0.27.1 112 | safetensors==0.6.2 113 | scipy==1.15.3 114 | sentencepiece==0.2.1 115 | sentry-sdk==2.42.1 116 | setuptools-scm==9.2.2 117 | shellingham==1.5.4 118 | sniffio==1.3.1 119 | soundfile==0.13.1 120 | soxr==1.0.0 121 | starlette==0.48.0 122 | sympy==1.14.0 123 | tiktoken==0.12.0 124 | tokenizers==0.22.1 125 | tomli==2.3.0 126 | torch==2.9.0+cu128 127 | tqdm==4.67.1 128 | transformers==4.57.1 129 | triton==3.5.0 130 | typer==0.20.0 131 | typing-inspection==0.4.2 132 | typing_extensions==4.15.0 133 | urllib3==2.5.0 134 | uvicorn==0.38.0 135 | uvloop==0.22.1 136 | watchfiles==1.1.1 137 | websockets==15.0.1 138 | xgrammar==0.1.21 139 | yarl==1.22.0 140 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /apps/grpo/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass 8 | 9 | from datasets import load_dataset 10 | from forge.controller.actor import ForgeActor 11 | from forge.observability.metrics import record_metric, Reduce 12 | from monarch.actor import endpoint 13 | from vllm.transformers_utils.tokenizer import get_tokenizer 14 | 15 | 16 | @dataclass 17 | class DatasetActor(ForgeActor): 18 | """Actor wrapper for HuggingFace dataset to provide async interface.""" 19 | 20 | path: str = "openai/gsm8k" 21 | revision: str = "main" 22 | data_split: str = "train" 23 | streaming: bool = True 24 | model: str = "" 25 | seed: int = 42 26 | 27 | @endpoint 28 | async def setup(self): 29 | self._tokenizer = get_tokenizer(self.model) 30 | self._epoch = 0 31 | 32 | def gsm8k_transform(sample): 33 | system_prompt = ( 34 | "A conversation between User and Assistant. The user asks a question, " 35 | "and the Assistant solves it. The assistant first thinks about the reasoning " 36 | "process and then provides the user with the answer. The reasoning " 37 | "process and answer are enclosed within and " 38 | "tags, respectively, i.e., reasoning process here " 39 | "answer here." 40 | ) 41 | request: str = sample["question"] 42 | as_chat = [ 43 | {"role": "system", "content": system_prompt}, 44 | {"role": "user", "content": request}, 45 | ] 46 | formatted_request = self._tokenizer.apply_chat_template( 47 | as_chat, 48 | tokenize=False, 49 | add_generation_prompt=True, 50 | ) 51 | target: str = sample["answer"] 52 | formatted_target = target.split("#### ")[1] 53 | return {"request": formatted_request, "target": formatted_target} 54 | 55 | self._base_dataset = load_dataset( 56 | self.path, self.revision, split=self.data_split, streaming=self.streaming 57 | ) 58 | self._base_dataset = self._base_dataset.map(gsm8k_transform) 59 | self._base_dataset = self._base_dataset.shuffle(seed=self.seed) 60 | self._base_dataset.set_epoch(self._epoch) 61 | self._iterator = iter(self._base_dataset) 62 | 63 | @endpoint 64 | async def sample(self) -> dict[str, str] | None: 65 | try: 66 | sample = next(self._iterator) 67 | 68 | record_metric("dataset/sample/current_epoch", self._epoch, Reduce.MAX) 69 | 70 | return sample 71 | except StopIteration: 72 | # Restart iterator for next epoch with reshuffling 73 | self._epoch += 1 74 | print( 75 | f"Dataset epoch {self._epoch - 1} completed. Starting epoch {self._epoch}" 76 | ) 77 | self._base_dataset.set_epoch(self._epoch) 78 | self._iterator = iter(self._base_dataset) 79 | return next(self._iterator) 80 | 81 | @endpoint 82 | async def pad_token(self): 83 | # Use pad_token_id if available, otherwise use eos_token_id 84 | # Llama models don't have a pad token by default 85 | if self._tokenizer.pad_token_id is not None: 86 | return self._tokenizer.pad_token_id 87 | else: 88 | return self._tokenizer.eos_token_id 89 | -------------------------------------------------------------------------------- /apps/grpo/grading.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import re 8 | 9 | 10 | class MathReward: 11 | """Reward class for evaluating math correctness.""" 12 | 13 | def __init__(self, tolerance: float = 1e-6, partial_credit: float = 0.1): 14 | self.tolerance = tolerance 15 | self.partial_credit = partial_credit 16 | 17 | def __call__(self, prompt: str, response: str, target: str) -> float: 18 | """Compute math correctness reward.""" 19 | target_number = self._to_float(target) 20 | if target_number is None: 21 | return 0.0 22 | 23 | # Look for answer in tags 24 | answer_match = re.search(r"(.*?)", response, re.DOTALL) 25 | 26 | if answer_match: 27 | model_answer = self._to_float(answer_match.group(1).strip()) 28 | if ( 29 | model_answer is not None 30 | and abs(target_number - model_answer) < self.tolerance 31 | ): 32 | return 1.0 # Correct answer 33 | 34 | # Check for partial credit: target number appears elsewhere in response 35 | response_without_answer_tags = re.sub( 36 | r".*?", "", response, flags=re.DOTALL 37 | ) 38 | # Convert to int if it's a whole number to avoid "117.0" vs "117" mismatch 39 | target_str = ( 40 | str(int(target_number)) 41 | if target_number.is_integer() 42 | else str(target_number) 43 | ) 44 | if target_str in response_without_answer_tags: 45 | return self.partial_credit 46 | 47 | return 0.0 # No match 48 | 49 | def _to_float(self, text: str) -> float | None: 50 | """Convert text to float, return None if invalid.""" 51 | try: 52 | # Remove common non-numeric characters like $, commas, etc. 53 | cleaned_text = re.sub(r"[$,\s]", "", text.strip()) 54 | return float(cleaned_text) 55 | except (ValueError, AttributeError): 56 | return None 57 | 58 | 59 | class ThinkingReward: 60 | """Reward class for evaluating use of thinking tags in reasoning. 61 | 62 | Args: 63 | partial_reward: Reward for partial tag usage (incomplete/malformed) 64 | full_reward: Reward for well-formed thinking blocks with content 65 | tag: Tag name to use (default "think", can use "思考" for Japanese, etc.) 66 | """ 67 | 68 | def __init__( 69 | self, partial_reward: float = 0.2, full_reward: float = 1.0, tag: str = "think" 70 | ): 71 | self.partial_reward = partial_reward 72 | self.full_reward = full_reward 73 | self.tag = tag 74 | # Build regex patterns for the specified tag 75 | self._THINK_BLOCK_RE = re.compile( 76 | rf"<\s*{re.escape(tag)}\s*>(.*?)<\s*/\s*{re.escape(tag)}\s*>", 77 | re.IGNORECASE | re.DOTALL, 78 | ) 79 | self._THINK_TAG_ATTEMPT_RE = re.compile( 80 | rf"<\s*/?\s*{re.escape(tag)}\s*>", re.IGNORECASE 81 | ) 82 | 83 | def __call__(self, prompt: str, response: str, target: str | None = None) -> float: 84 | """Compute thinking reward.""" 85 | if not response: 86 | return 0.0 87 | 88 | matches = self._THINK_BLOCK_RE.findall(response) 89 | has_well_formed = any(len(re.sub(r"\s+", "", m)) >= 1 for m in matches) 90 | has_attempt = bool(self._THINK_TAG_ATTEMPT_RE.search(response)) or bool(matches) 91 | if has_well_formed: 92 | return self.full_reward 93 | elif has_attempt: 94 | return self.partial_reward 95 | return 0.0 96 | -------------------------------------------------------------------------------- /src/forge/env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Centralized constants for environment variable names used in the project.""" 8 | 9 | import os 10 | from dataclasses import dataclass 11 | from typing import Any 12 | 13 | 14 | @dataclass 15 | class EnvVar: 16 | """Configuration for an environment variable.""" 17 | 18 | name: str 19 | default: Any 20 | description: str 21 | 22 | def get_value(self) -> Any: 23 | """Get the value of this environment variable with fallback to default. 24 | 25 | Returns: 26 | The environment variable value, auto-converted to the appropriate type 27 | based on the default value, or the default value if not set. 28 | 29 | Example: 30 | >>> DISABLE_PERF_METRICS.get_value() 31 | False 32 | >>> os.environ["DISABLE_PERF_METRICS"] = "true" 33 | >>> DISABLE_PERF_METRICS.get_value() 34 | True 35 | """ 36 | value = os.environ.get(self.name) 37 | 38 | if value is None: 39 | return self.default 40 | 41 | # Auto-convert based on the default type 42 | if isinstance(self.default, bool): 43 | return value.lower() in ("true", "1", "yes") 44 | elif isinstance(self.default, int): 45 | return int(value) 46 | elif isinstance(self.default, float): 47 | return float(value) 48 | else: 49 | # Return as string for other types 50 | return value 51 | 52 | 53 | # Environment variable definitions 54 | DISABLE_PERF_METRICS = EnvVar( 55 | name="DISABLE_PERF_METRICS", 56 | default=False, 57 | description="Performance metrics in forge.observability.perf_tracker.py becomes no-op", 58 | ) 59 | 60 | METRIC_TIMER_USES_GPU = EnvVar( 61 | name="METRIC_TIMER_USES_GPU", 62 | default=None, 63 | description=( 64 | "Force all timing methods in forge.observability.perf_tracker.py " 65 | "to use CPU timer if False or GPU timer if True. If unset (None), defaults to the timer parameter." 66 | ), 67 | ) 68 | 69 | FORGE_DISABLE_METRICS = EnvVar( 70 | name="FORGE_DISABLE_METRICS", 71 | default=False, 72 | description=( 73 | "Makes forge.observability.metrics.record_metric a no-op and disables spawning LocalFetcherActor" 74 | " in get_or_create_metric_logger" 75 | ), 76 | ) 77 | 78 | MONARCH_STDERR_LEVEL = EnvVar( 79 | name="MONARCH_STDERR_LOG", 80 | default="warning", 81 | description="Sets Monarch's stderr log level, i.e. set to 'info' or 'debug'", 82 | ) 83 | 84 | RUST_BACKTRACE = EnvVar( 85 | name="RUST_BACKTRACE", 86 | default="full", 87 | description="Sets the level for Rust-level failures. I.e. set to full for full stack traces.", 88 | ) 89 | 90 | MONARCH_MESSAGE_DELIVERY_TIMEOUT = EnvVar( 91 | name="HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS", 92 | default=600, 93 | description="Sets the timeout limit for Monarch's actor message delivery in seconds.", 94 | ) 95 | 96 | MONARCH_MAX_FRAME_LENGTH = EnvVar( 97 | name="HYPERACTOR_CODE_MAX_FRAME_LENGTH", 98 | default=1073741824, 99 | description="Sets the maximum frame length for Monarch's actor message delivery in bytes.", 100 | ) 101 | 102 | TORCHSTORE_USE_RDMA = EnvVar( 103 | name="TORCHSTORE_RDMA_ENABLED", 104 | default=1, 105 | description="Whether or not to use RDMA in TorchStore.", 106 | ) 107 | 108 | 109 | def all_env_vars() -> list[EnvVar]: 110 | """Retrieves all registered environment variable names.""" 111 | env_vars = [] 112 | for _, value in globals().items(): 113 | if isinstance(value, EnvVar): 114 | env_vars.append(value) 115 | return env_vars 116 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # image torchforge 2 | 3 | #### A PyTorch-native agentic RL library that lets you focus on algorithms—not infra. 4 | [![GPU Tests](https://github.com/meta-pytorch/forge/actions/workflows/gpu_test.yaml/badge.svg?branch=main)](https://github.com/meta-pytorch/forge/actions/workflows/gpu_test.yaml?query=branch%3Amain) 5 | [![Documentation](https://img.shields.io/badge/Docs-meta--pytorch.org-blue?style=flat&logo=readthedocs&logoColor=white)](https://meta-pytorch.org/torchforge/) 6 | [![Discord](https://img.shields.io/badge/Discord-OpenEnv-7289da?style=flat&logo=discord&logoColor=white)](https://discord.gg/YsTYBh6PD9) 7 | 8 | ## Overview 9 | The primary purpose of the torchforge ecosystem is to separate infra concerns from model concerns thereby making RL experimentation easier. torchforge delivers this by providing clear RL abstractions and one scalable implementation of these abstractions. When you need fine-grained control over placement, fault handling/redirecting training loads during a run, or communication patterns, the primitives are there. When you don’t, you can focus purely on your RL algorithm. 10 | 11 | Key features: 12 | - Usability for rapid research (isolating the RL loop from infrastructure) 13 | - Hackability for power users (all parts of the RL loop can be easily modified without interacting with infrastructure) 14 | - Scalability (ability to shift between async and synchronous training and across thousands of GPUs) 15 | 16 | > ⚠️ **Early Development Warning** torchforge is currently in an experimental 17 | > stage. You should expect bugs, incomplete features, and APIs that may change 18 | > in future versions. The project welcomes bugfixes, but to make sure things are 19 | > well coordinated you should discuss any significant change before starting the 20 | > work. It's recommended that you signal your intention to contribute in the 21 | > issue tracker, either by filing a new issue or by claiming an existing one. 22 | 23 | ## 📖 Documentation 24 | 25 | View torchforge's hosted documentation: https://meta-pytorch.org/torchforge. 26 | 27 | ## Tutorials 28 | 29 | You can also find our notebook tutorials (coming soon) 30 | 31 | ## Installation 32 | 33 | torchforge requires PyTorch 2.9.0 with [Monarch](https://github.com/meta-pytorch/monarch), [vLLM](https://github.com/vllm-project/vllm), and [torchtitan](https://github.com/pytorch/torchtitan). 34 | 35 | Install torchforge with: 36 | 37 | ```bash 38 | conda create -n forge python=3.12 39 | conda activate forge 40 | ./scripts/install.sh 41 | ``` 42 | 43 | The install script installs system dependencies along with torchforge. Note that this install script uses [DNF](https://docs.fedoraproject.org/en-US/quick-docs/dnf/), but could be easily extended to other Linux OS. 44 | 45 | Optional: By default, the packages installation uses conda. If you want to install system packages on the target machine instead of conda, you can pass the `--use-sudo` flag to the installation script: `./scripts/install.sh --use-sudo`. 46 | 47 | > **Note:** We are actively working on enabling pure `uv` installation. Currently, Conda is the recommended approach. `uv` support is not fully working at the moment but is being tracked in [issue #494](https://github.com/meta-pytorch/torchforge/issues/494). 48 | 49 | After install, you can run the following command and should see output confirming GRPO training is running (you need a minimum 3 GPU devices): 50 | 51 | ``` 52 | python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml 53 | ``` 54 | 55 | ## Quick Start 56 | 57 | To run SFT on a Llama3 8B model, run 58 | 59 | ```bash 60 | python -m apps.sft.main --config apps/sft/llama3_8b.yaml 61 | ``` 62 | 63 | ### Citation 64 | 65 | ## License 66 | 67 | Source code is made available under a [BSD 3 license](./LICENSE), however you may have other legal obligations that govern your use of other content linked in this repository, such as the license or terms of service for third-party data and models. 68 | -------------------------------------------------------------------------------- /.github/packaging/vllm_reqs_12_9.txt: -------------------------------------------------------------------------------- 1 | # These requirements were generated by running steps 1-3 of scripts/build_wheels.sh 2 | # then running pip freeze and manually removing the vllm dependency. 3 | # The intention of this file is to use these known requirements for a fixed 4 | # vLLM build to supplement a vLLM install from download.pytorch.org without 5 | # resorting to --extra-index-url https://download.pytorch.org/whl/nightly to find 6 | # vLLM dependencies (as this results in a ResolutionTooDeep error from pip). 7 | # See the file .github/workflows/gpu_test.yaml for an E2E forge installation using this approach. 8 | # TODO: this should be done way less hackily 9 | aiohappyeyeballs==2.6.1 10 | aiohttp==3.13.0 11 | aiosignal==1.4.0 12 | annotated-types==0.7.0 13 | anyio==4.11.0 14 | astor==0.8.1 15 | async-timeout==5.0.1 16 | attrs==25.4.0 17 | blake3==1.0.7 18 | cachetools==6.2.0 19 | cbor2==5.7.0 20 | certifi==2025.10.5 21 | cffi==2.0.0 22 | charset-normalizer==3.4.3 23 | click==8.3.0 24 | cloudpickle==3.1.1 25 | cmake==4.1.0 26 | compressed-tensors==0.10.2 27 | cupy-cuda12x==13.6.0 28 | depyf==0.19.0 29 | dill==0.4.0 30 | diskcache==5.6.3 31 | distro==1.9.0 32 | dnspython==2.8.0 33 | einops==0.8.1 34 | email-validator==2.3.0 35 | exceptiongroup==1.3.0 36 | fastapi==0.118.3 37 | fastapi-cli==0.0.13 38 | fastapi-cloud-cli==0.3.1 39 | fastrlock==0.8.3 40 | filelock==3.19.1 41 | frozenlist==1.8.0 42 | fsspec==2025.9.0 43 | gguf==0.17.1 44 | h11==0.16.0 45 | hf-xet==1.1.10 46 | httpcore==1.0.9 47 | httptools==0.7.1 48 | httpx==0.28.1 49 | huggingface-hub==0.35.3 50 | idna==3.10 51 | interegular==0.3.3 52 | Jinja2==3.1.6 53 | jiter==0.11.0 54 | jsonschema==4.25.1 55 | jsonschema-specifications==2025.9.1 56 | lark==1.2.2 57 | llguidance==0.7.30 58 | llvmlite==0.44.0 59 | lm-format-enforcer==0.10.12 60 | markdown-it-py==4.0.0 61 | MarkupSafe==3.0.2 62 | mdurl==0.1.2 63 | mistral_common==1.8.5 64 | mpmath==1.3.0 65 | msgpack==1.1.2 66 | msgspec==0.19.0 67 | multidict==6.7.0 68 | networkx==3.4.2 69 | ninja==1.13.0 70 | numba==0.61.2 71 | numpy==2.2.6 72 | nvidia-cublas-cu12==12.9.1.4 73 | nvidia-cuda-cupti-cu12==12.9.79 74 | nvidia-cuda-nvrtc-cu12==12.9.86 75 | nvidia-cuda-runtime-cu12==12.9.79 76 | nvidia-cudnn-cu12==9.10.2.21 77 | nvidia-cufft-cu12==11.4.1.4 78 | nvidia-cufile-cu12==1.14.1.1 79 | nvidia-curand-cu12==10.3.10.19 80 | nvidia-cusolver-cu12==11.7.5.82 81 | nvidia-cusparse-cu12==12.5.10.65 82 | nvidia-cusparselt-cu12==0.7.1 83 | nvidia-nccl-cu12==2.27.5 84 | nvidia-nvjitlink-cu12==12.9.86 85 | nvidia-nvshmem-cu12==3.3.20 86 | nvidia-nvtx-cu12==12.9.79 87 | openai==1.90.0 88 | opencv-python-headless==4.12.0.88 89 | outlines_core==0.2.10 90 | packaging==25.0 91 | partial-json-parser==0.2.1.1.post6 92 | pillow==11.3.0 93 | prometheus-fastapi-instrumentator==7.1.0 94 | prometheus_client==0.23.1 95 | propcache==0.4.1 96 | protobuf==6.32.1 97 | psutil==7.1.0 98 | py-cpuinfo==9.0.0 99 | pybase64==1.4.2 100 | pycountry==24.6.1 101 | pycparser==2.23 102 | pydantic==2.12.0 103 | pydantic-extra-types==2.10.6 104 | pydantic_core==2.41.1 105 | Pygments==2.19.2 106 | python-dotenv==1.1.1 107 | python-json-logger==4.0.0 108 | python-multipart==0.0.20 109 | pytorch-triton==3.4.0+gitf7888497 110 | PyYAML==6.0.3 111 | pyzmq==27.1.0 112 | ray==2.49.2 113 | referencing==0.36.2 114 | regex==2025.9.18 115 | requests==2.32.5 116 | rich==14.2.0 117 | rich-toolkit==0.15.1 118 | rignore==0.7.0 119 | rpds-py==0.27.1 120 | safetensors==0.6.2 121 | scipy==1.15.3 122 | sentencepiece==0.2.1 123 | sentry-sdk==2.41.0 124 | setuptools-scm==9.2.0 125 | shellingham==1.5.4 126 | sniffio==1.3.1 127 | soundfile==0.13.1 128 | soxr==1.0.0 129 | starlette==0.48.0 130 | sympy==1.14.0 131 | tiktoken==0.12.0 132 | tokenizers==0.22.1 133 | tomli==2.3.0 134 | torch==2.9.0.dev20250905+cu129 135 | tqdm==4.67.1 136 | transformers==4.57.0 137 | triton==3.4.0 138 | typer==0.19.2 139 | typing-inspection==0.4.2 140 | typing_extensions==4.15.0 141 | urllib3==2.5.0 142 | uvicorn==0.37.0 143 | uvloop==0.21.0 144 | watchfiles==1.1.0 145 | websockets==15.0.1 146 | xgrammar==0.1.21 147 | yarl==1.22.0 148 | -------------------------------------------------------------------------------- /src/forge/interfaces.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from abc import ABC, abstractmethod 8 | from typing import Any, Mapping 9 | 10 | from forge.types import Message, Scalar 11 | 12 | 13 | class BaseTokenizer(ABC): 14 | """ 15 | Abstract token encoding model that implements ``encode`` and ``decode`` methods. 16 | See :class:`forge.data.HuggingFaceModelTokenizer for an example implementation of this protocol. 17 | """ 18 | 19 | @abstractmethod 20 | def encode(self, text: str, **kwargs: dict[str, Any]) -> list[int]: 21 | """ 22 | Given a string, return the encoded list of token ids. 23 | 24 | Args: 25 | text (str): The text to encode. 26 | **kwargs (dict[str, Any]): kwargs. 27 | 28 | Returns: 29 | list[int]: The encoded list of token ids. 30 | """ 31 | pass 32 | 33 | @abstractmethod 34 | def decode(self, token_ids: list[int], **kwargs: dict[str, Any]) -> str: 35 | """ 36 | Given a list of token ids, return the decoded text, optionally including special tokens. 37 | 38 | Args: 39 | token_ids (list[int]): The list of token ids to decode. 40 | **kwargs (dict[str, Any]): kwargs. 41 | 42 | Returns: 43 | str: The decoded text. 44 | """ 45 | pass 46 | 47 | 48 | class ModelTokenizer(ABC): 49 | """ 50 | Abstract tokenizer that implements model-specific special token logic in 51 | the ``tokenize_messages`` method. See :class:`forge.data.HuggingFaceModelTokenizer` 52 | for an example implementation of this protocol. 53 | """ 54 | 55 | special_tokens: dict[str, int] 56 | max_seq_len: int | None 57 | 58 | @abstractmethod 59 | def tokenize_messages( 60 | self, messages: list[Message], **kwargs: dict[str, Any] 61 | ) -> tuple[list[int], list[bool]]: 62 | """ 63 | Given a list of messages, return a list of tokens and list of masks for 64 | the concatenated and formatted messages. 65 | 66 | Args: 67 | messages (list[Message]): The list of messages to tokenize. 68 | **kwargs (dict[str, Any]): kwargs. 69 | 70 | Returns: 71 | tuple[list[int], list[bool]]: The list of token ids and the list of masks. 72 | """ 73 | pass 74 | 75 | 76 | class MetricLogger(ABC): 77 | """Abstract metric logger.""" 78 | 79 | @abstractmethod 80 | def is_log_step(self, name: str, step: int) -> bool: 81 | """Returns true if the current step is a logging step. 82 | 83 | Args: 84 | name (str): metric name (for checking the freq for this metric) 85 | step (int): current step 86 | """ 87 | pass 88 | 89 | @abstractmethod 90 | def log(self, name: str, data: Scalar, step: int) -> None: 91 | """Log scalar data if this is a logging step. 92 | 93 | Args: 94 | name (str): tag name used to group scalars 95 | data (Scalar): scalar data to log 96 | step (int): step value to record 97 | """ 98 | pass 99 | 100 | @abstractmethod 101 | def log_dict(self, metrics: Mapping[str, Scalar], step: int) -> None: 102 | """Log multiple scalar values if this is a logging step. 103 | 104 | Args: 105 | metrics (Mapping[str, Scalar]): dictionary of tag name and scalar value 106 | step (int): step value to record 107 | """ 108 | pass 109 | 110 | def __del__(self) -> None: 111 | self.close() 112 | 113 | def close(self) -> None: 114 | """ 115 | Close log resource, flushing if necessary. 116 | This will automatically be called via __del__ when the instance goes out of scope. 117 | Logs should not be written after `close` is called. 118 | """ 119 | -------------------------------------------------------------------------------- /tests/sandbox/toy_rl/toy_metrics/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import asyncio 8 | 9 | import logging 10 | import time 11 | 12 | from forge.controller.actor import ForgeActor 13 | from forge.controller.provisioner import shutdown 14 | from forge.observability.metric_actors import get_or_create_metric_logger 15 | from forge.observability.metrics import record_metric, Reduce 16 | from forge.observability.perf_tracker import trace, Tracer 17 | 18 | from monarch.actor import current_rank, endpoint 19 | 20 | logging.basicConfig(level=logging.DEBUG) 21 | 22 | 23 | class TrainActor(ForgeActor): 24 | """Example training actor that records loss metrics.""" 25 | 26 | @endpoint 27 | async def train_step(self, step: int): 28 | rank = current_rank().rank 29 | 30 | # Phase 2: Use Tracer for detailed step timing 31 | tracer = Tracer("trainer_perf/step", track_memory=True, timer="gpu") 32 | tracer.start() 33 | 34 | # Simulate forward pass 35 | tracer.step("forward") 36 | 37 | # Simulate backward pass 38 | tracer.step("backward") 39 | 40 | value = rank * 1000 + 100 * step 41 | 42 | # Record training metrics 43 | record_metric("trainer/avg_grpo_loss", value, Reduce.MEAN) 44 | record_metric("trainer/std_grpo_loss", value, Reduce.STD) 45 | record_metric("trainer/count_training_steps", 1, Reduce.SUM) 46 | record_metric("trainer/learning_rate", 0.001, Reduce.MEAN) 47 | 48 | print(f"🔧 Train rank {rank}: Step {step}, loss={value}") 49 | 50 | tracer.stop() 51 | return value 52 | 53 | 54 | class GeneratorActor(ForgeActor): 55 | """Example generation actor that records token count metrics.""" 56 | 57 | @endpoint 58 | async def generate_step(self, step: int, substep: int): 59 | rank = current_rank().rank 60 | 61 | with trace("policy_perf", track_memory=False, timer="gpu") as tracer: 62 | value = rank * 1000 + step * 100 + substep * 10 63 | tracer.step("time_to_value") 64 | # Record generation metrics following the plan 65 | record_metric("policy/count_requests", 1, Reduce.SUM) 66 | record_metric( 67 | "policy/sum_tokens_requested", 50, Reduce.SUM 68 | ) # Simulated max_tokens 69 | record_metric("policy/sum_tokens_generated", value, Reduce.SUM) 70 | record_metric("policy/count_sequences_completed", 1, Reduce.SUM) 71 | record_metric("policy/avg_tokens_per_sample", value, Reduce.MEAN) 72 | 73 | print(f"🎯 Gen rank {rank}: Step {step}.{substep}, tokens={value}") 74 | 75 | return value 76 | 77 | 78 | # Main 79 | async def main(): 80 | """Example demonstrating distributed metric logging with different backends.""" 81 | group = f"grpo_exp_{int(time.time())}" 82 | 83 | # Config format: {backend_name: backend_config_dict} 84 | config = { 85 | "console": {"logging_mode": "global_reduce"}, 86 | "wandb": { 87 | "project": "toy_metrics", 88 | "group": group, 89 | "logging_mode": "per_rank_reduce", # global_reduce, per_rank_reduce, per_rank_no_reduce 90 | "per_rank_share_run": True, 91 | }, 92 | } 93 | 94 | service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False} 95 | mlogger = await get_or_create_metric_logger(process_name="Controller") 96 | await mlogger.init_backends.call_one(config) 97 | 98 | # Spawn services first (triggers registrations via provisioner hook) 99 | trainer = await TrainActor.options( 100 | **service_config, mesh_name="TrainActor" 101 | ).as_service() 102 | generator = await GeneratorActor.options( 103 | **service_config, mesh_name="GeneratorActor" 104 | ).as_service() 105 | 106 | for i in range(3): 107 | print(f"\n=== Global Step {i} ===") 108 | await trainer.train_step.fanout(i) 109 | for sub in range(3): 110 | await generator.generate_step.fanout(i, sub) 111 | await mlogger.flush.call_one(i) 112 | 113 | # shutdown 114 | await shutdown() 115 | 116 | 117 | if __name__ == "__main__": 118 | asyncio.run(main()) 119 | -------------------------------------------------------------------------------- /docs/source/_static/custom.css: -------------------------------------------------------------------------------- 1 | /* Center all Mermaid diagrams */ 2 | .mermaid { 3 | display: block; 4 | margin-left: auto; 5 | margin-right: auto; 6 | text-align: center; 7 | } 8 | 9 | /* Center the pre element that contains mermaid diagrams */ 10 | pre.mermaid { 11 | display: flex; 12 | justify-content: center; 13 | } 14 | 15 | /* Adjust Mermaid line colors based on theme */ 16 | /* Light mode - darker lines for visibility on white background */ 17 | html[data-theme="light"] .mermaid .edgePath .path, 18 | html[data-theme="light"] .mermaid .flowchart-link { 19 | stroke: #555 !important; 20 | stroke-width: 2px !important; 21 | } 22 | 23 | /* Light mode - darker arrow tips */ 24 | html[data-theme="light"] .mermaid .arrowheadPath, 25 | html[data-theme="light"] .mermaid marker path { 26 | fill: #555 !important; 27 | stroke: #555 !important; 28 | } 29 | 30 | html[data-theme="dark"] .mermaid .arrowheadPath, 31 | html[data-theme="dark"] .mermaid marker path { 32 | fill: #aaa !important; 33 | stroke: #aaa !important; 34 | } 35 | 36 | /* Dark mode - lighter lines for visibility on dark background */ 37 | html[data-theme="dark"] .mermaid .edgePath .path, 38 | html[data-theme="dark"] .mermaid .flowchart-link { 39 | stroke: #aaa !important; 40 | stroke-width: 2px !important; 41 | } 42 | 43 | /* Dark mode - lighter arrow tips */ 44 | html[data-theme="dark"] .mermaid .arrowheadPath, 45 | html[data-theme="dark"] .mermaid marker path { 46 | fill: #aaa !important; 47 | stroke: #aaa !important; 48 | } 49 | 50 | /* Adjust edge labels background based on theme */ 51 | html[data-theme="light"] .mermaid .edgeLabel { 52 | background-color: #fff !important; 53 | } 54 | 55 | html[data-theme="dark"] .mermaid .edgeLabel { 56 | background-color: #1a1a1a !important; 57 | color: #fff !important; 58 | } 59 | 60 | /* Custom CSS for collapsible parameter lists */ 61 | 62 | /* Hide parameters in signatures */ 63 | .sig-param-hidden { 64 | display: none !important; 65 | } 66 | 67 | /* Inline toggle button for signatures */ 68 | .params-toggle-btn-inline { 69 | display: inline; 70 | padding: 0.2rem 0.5rem; 71 | margin: 0 0.25rem; 72 | background-color: var(--pst-color-background); 73 | border: 1px solid var(--pst-color-border); 74 | border-radius: 3px; 75 | cursor: pointer; 76 | font-size: 0.85em; 77 | font-family: var(--pst-font-family-base); 78 | color: var(--pst-color-primary); 79 | transition: all 0.2s ease; 80 | vertical-align: middle; 81 | } 82 | 83 | .params-toggle-btn-inline:hover { 84 | background-color: var(--pst-color-background); 85 | border-color: var(--pst-color-border); 86 | } 87 | 88 | .params-toggle-btn-inline:focus { 89 | outline: none; 90 | } 91 | 92 | .toggle-icon { 93 | display: inline-block; 94 | font-size: 0.8em; 95 | transition: transform 0.2s ease; 96 | } 97 | 98 | /* Wrapper for the button */ 99 | .sig-params-wrapper { 100 | display: inline; 101 | } 102 | 103 | /* Old styles for field-list collapsing (kept for backward compatibility) */ 104 | .collapsible-params { 105 | margin: 1rem 0; 106 | } 107 | 108 | .params-toggle-btn { 109 | display: inline-block; 110 | padding: 0.5rem 1rem; 111 | margin-bottom: 0.5rem; 112 | background-color: var(--pst-color-background); 113 | border: 1px solid var(--pst-color-border); 114 | border-radius: 4px; 115 | cursor: pointer; 116 | font-size: 0.9rem; 117 | color: var(--pst-color-primary); 118 | transition: all 0.3s ease; 119 | } 120 | 121 | .params-toggle-btn:hover { 122 | background-color: var(--pst-color-background); 123 | border-color: var(--pst-color-border); 124 | } 125 | 126 | .params-content { 127 | max-height: 10000px; 128 | overflow: hidden; 129 | transition: max-height 0.5s ease, opacity 0.3s ease; 130 | opacity: 1; 131 | } 132 | 133 | .params-content.collapsed { 134 | max-height: 0; 135 | opacity: 0; 136 | } 137 | 138 | /* Ensure the collapsed parameters look good */ 139 | .params-content dl.field-list { 140 | margin-top: 0; 141 | } 142 | 143 | .params-content > dt { 144 | margin-top: 0.5rem; 145 | } 146 | 147 | .params-content > dt:first-child { 148 | margin-top: 0; 149 | } 150 | 151 | /* Responsive adjustments */ 152 | @media (max-width: 768px) { 153 | .params-toggle-btn { 154 | width: 100%; 155 | text-align: left; 156 | } 157 | } 158 | -------------------------------------------------------------------------------- /apps/grpo/qwen3_8b.yaml: -------------------------------------------------------------------------------- 1 | # Grouped Relative Policy Optimization (GRPO) 2 | # >>> python -m apps.grpo.main --config apps/grpo/qwen3_8b.yaml 3 | 4 | # Global configuration 5 | group_size: 16 6 | local_batch_size: 4 # per-device batch size 7 | max_req_tokens: 1024 8 | max_res_tokens: 2048 9 | model: "Qwen/Qwen3-8B" 10 | off_by_n: 1 # Off by one by default 11 | compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM 12 | 13 | # Observability configuration 14 | metric_logging: 15 | wandb: 16 | project: grpo-training 17 | group: grpo_exp_${oc.env:USER} 18 | logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce 19 | console: 20 | logging_mode: global_reduce 21 | 22 | # Dataset configuration 23 | dataset: 24 | path: "openai/gsm8k" 25 | revision: "main" 26 | data_split: "train" 27 | streaming: true 28 | model: ${model} 29 | 30 | # Policy configuration 31 | policy: 32 | engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs 33 | model: ${model} 34 | tensor_parallel_size: 2 35 | pipeline_parallel_size: 1 36 | enforce_eager: ${not:${compile}} 37 | sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams 38 | n: ${group_size} 39 | max_tokens: ${max_res_tokens} 40 | temperature: 1.0 41 | top_p: 1.0 42 | 43 | # Trainer configuration 44 | trainer: 45 | model: 46 | name: qwen3 47 | flavor: 8B 48 | hf_assets_path: hf://${model} 49 | optimizer: 50 | name: AdamW 51 | lr: 1e-5 52 | eps: 1e-8 53 | lr_scheduler: 54 | warmup_steps: 1 55 | training: 56 | local_batch_size: ${local_batch_size} 57 | seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens 58 | max_norm: 1.0 59 | steps: 1000000 60 | dtype: bfloat16 61 | gc_freq: 1 62 | compile: 63 | enable: ${compile} 64 | parallelism: 65 | data_parallel_replicate_degree: 1 66 | data_parallel_shard_degree: -1 67 | tensor_parallel_degree: 1 68 | pipeline_parallel_degree: 1 69 | context_parallel_degree: 1 70 | expert_parallel_degree: 1 71 | disable_loss_parallel: true 72 | checkpoint: 73 | enable: true 74 | folder: ./checkpoint # The folder to save checkpoints to. 75 | initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists. 76 | initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo 77 | last_save_in_hf: true 78 | interval: 500 79 | async_mode: "disabled" 80 | activation_checkpoint: 81 | mode: selective 82 | selective_ac_option: op 83 | 84 | # Replay buffer configuration 85 | replay_buffer: 86 | batch_size: ${local_batch_size} 87 | max_policy_age: ${off_by_n} 88 | # This should match the dp_size of TorchTitan 89 | # Here it's set explicitly to 2, because we've set 90 | # 2 GPUs for the trainer and we're using full FSDP. 91 | dp_size: 2 92 | 93 | # Reference model configuration 94 | ref_model: 95 | model: 96 | name: qwen3 97 | flavor: 8B 98 | hf_assets_path: hf://${model} 99 | training: 100 | seq_len: ${trainer.training.seq_len} 101 | dtype: bfloat16 102 | gc_freq: 1 103 | compile: 104 | enable: ${compile} 105 | parallelism: 106 | data_parallel_replicate_degree: 1 107 | data_parallel_shard_degree: 1 108 | tensor_parallel_degree: 1 109 | pipeline_parallel_degree: 1 110 | context_parallel_degree: 1 111 | expert_parallel_degree: 1 112 | checkpoint: 113 | initial_load_path: hf://${model} 114 | initial_load_in_hf: true 115 | 116 | # All resource allocations 117 | services: 118 | policy: 119 | procs: ${policy.engine_args.tensor_parallel_size} 120 | num_replicas: 1 121 | with_gpus: true 122 | mesh_name: policy 123 | ref_model: 124 | procs: 1 125 | num_replicas: 1 126 | with_gpus: true 127 | mesh_name: ref_model 128 | reward_actor: 129 | procs: 1 130 | num_replicas: 1 131 | with_gpus: false 132 | mesh_name: reward_actor 133 | 134 | actors: 135 | dataset: 136 | procs: 1 137 | with_gpus: false 138 | mesh_name: dataset 139 | trainer: 140 | procs: 2 141 | with_gpus: true 142 | mesh_name: trainer 143 | replay_buffer: 144 | procs: 1 145 | with_gpus: false 146 | mesh_name: replay_buffer 147 | compute_advantages: 148 | procs: 1 149 | with_gpus: false 150 | mesh_name: compute_advantages 151 | -------------------------------------------------------------------------------- /src/forge/data/metric_transform.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any 8 | 9 | from forge.observability.metrics import Metric, Reduce 10 | 11 | 12 | class MetricTransform: 13 | """ 14 | Base class for transforms that collect observability metrics from dataset samples. 15 | 16 | This class provides a foundation for implementing dataset-level metric collection 17 | during data processing pipelines. Subclasses should override the __call__ method 18 | to add specific metrics to each sample that passes through the transform. 19 | 20 | Metrics are collected as `forge.observability.metrics.Metric` objects and made available 21 | in batch["metrics"]. 22 | 23 | Attributes: 24 | source (str, optional): The source name for metrics, typically the dataset name. 25 | This is used as a prefix in metric keys to distinguish metrics from different 26 | data sources. 27 | 28 | Example: 29 | >>> transform = SomeMetricTransform() 30 | >>> transform.set_source("training_data") 31 | >>> processed_sample = transform(sample) 32 | >>> # Metrics are automatically added to sample["metrics"] 33 | """ 34 | 35 | def __init__(self): 36 | self.source = None 37 | 38 | def set_source(self, source: str): 39 | """Set the source name for metrics (typically the dataset name).""" 40 | self.source = source 41 | 42 | def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: 43 | """Transform a sample by adding metrics to it.""" 44 | return sample 45 | 46 | 47 | class DefaultDatasetMetricTransform(MetricTransform): 48 | """ 49 | Collects basic dataset processing metrics during data pipeline execution. 50 | 51 | Metrics collected: 52 | - samples_processed: Total number of samples that have passed through this transform (SUM) 53 | - tokens_processed: Total number of tokens processed across all samples (SUM) 54 | - mean_seq_len: Average sequence length across samples (MEAN) 55 | - max_seq_len: Maximum sequence length observed (MAX) 56 | - min_seq_len: Minimum sequence length observed (MIN) 57 | 58 | Note: Token-related metrics are only collected if the sample contains a 'tokens' field. 59 | Sequence length is measured as the number of tokens in each sample. 60 | 61 | Example: 62 | >>> collector = DefaultDatasetMetricTransform() 63 | >>> collector.set_source("training_data") 64 | >>> sample = {"tokens": ["hello", "world"]} 65 | >>> processed_sample = collector(sample) 66 | >>> # Metrics are automatically added to processed_sample["metrics"] 67 | """ 68 | 69 | def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: 70 | if "metrics" not in sample: 71 | sample["metrics"] = [] 72 | 73 | source_name = self.source or "unnamed_ds" 74 | 75 | # Add samples_processed metric 76 | sample["metrics"].append( 77 | Metric( 78 | key=f"dataset/{source_name}/samples_processed", 79 | value=1, 80 | reduction=Reduce.SUM, 81 | ) 82 | ) 83 | 84 | # Add token-based metrics if tokens are present 85 | if "tokens" in sample: 86 | token_count = len(sample.get("tokens", [])) 87 | 88 | sample["metrics"].extend( 89 | [ 90 | Metric( 91 | key=f"dataset/{source_name}/tokens_processed", 92 | value=token_count, 93 | reduction=Reduce.SUM, 94 | ), 95 | Metric( 96 | key=f"dataset/{source_name}/mean_seq_len", 97 | value=token_count, 98 | reduction=Reduce.MEAN, 99 | ), 100 | Metric( 101 | key=f"dataset/{source_name}/max_seq_len", 102 | value=token_count, 103 | reduction=Reduce.MAX, 104 | ), 105 | Metric( 106 | key=f"dataset/{source_name}/min_seq_len", 107 | value=token_count, 108 | reduction=Reduce.MIN, 109 | ), 110 | ] 111 | ) 112 | 113 | return sample 114 | -------------------------------------------------------------------------------- /tests/unit_tests/test_coder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Unit tests for forge.actors.coder.SandboxedPythonCoder. 9 | """ 10 | 11 | import os 12 | import tempfile 13 | import uuid 14 | from unittest.mock import Mock, patch 15 | 16 | import pytest 17 | 18 | from forge.actors.coder import _SandboxedPythonCoder 19 | 20 | 21 | @pytest.mark.asyncio 22 | async def test_coder_success(): 23 | """Test successful execution.""" 24 | unique_id = str(uuid.uuid4())[:8] 25 | container_name = f"test_sandbox_{unique_id}" 26 | 27 | with tempfile.NamedTemporaryFile(suffix=".sqsh", delete=False) as temp_image: 28 | image_path = temp_image.name 29 | 30 | def mock_subprocess_run(*args, **kwargs): 31 | """Mock subprocess.run for testing.""" 32 | cmd = args[0] if args else kwargs.get("args", []) 33 | cmd_str = " ".join(cmd) if isinstance(cmd, list) else str(cmd) 34 | 35 | if "import" in cmd_str: 36 | result = Mock() 37 | result.returncode = 0 38 | result.stderr = "" 39 | return result 40 | elif "remove" in cmd_str: 41 | result = Mock() 42 | result.returncode = 0 43 | return result 44 | elif "create" in cmd_str: 45 | result = Mock() 46 | result.returncode = 0 47 | result.stderr = "" 48 | return result 49 | elif "start" in cmd_str: 50 | result = Mock() 51 | result.returncode = 0 52 | result.stdout = "Hello World\n" 53 | result.stderr = "" 54 | return result 55 | else: 56 | raise ValueError(f"Unexpected subprocess call: {cmd_str}") 57 | 58 | try: 59 | with patch( 60 | "forge.actors.coder.subprocess.run", side_effect=mock_subprocess_run 61 | ): 62 | coder = _SandboxedPythonCoder( 63 | docker_image="docker://python:3.10", 64 | sqsh_image_path=image_path, 65 | container_name=container_name, 66 | ) 67 | 68 | await coder.setup() 69 | result, _ = await coder.execute(code="print('Hello World')") 70 | assert result == "Hello World\n" 71 | finally: 72 | if os.path.exists(image_path): 73 | os.unlink(image_path) 74 | 75 | 76 | @pytest.mark.asyncio 77 | async def test_coder_execution_failure(): 78 | """Test execution failure.""" 79 | unique_id = str(uuid.uuid4())[:8] 80 | container_name = f"test_sandbox_{unique_id}" 81 | 82 | with tempfile.NamedTemporaryFile(suffix=".sqsh", delete=False) as temp_image: 83 | image_path = temp_image.name 84 | 85 | def mock_subprocess_run(*args, **kwargs): 86 | """Mock subprocess.run for testing.""" 87 | cmd = args[0] if args else kwargs.get("args", []) 88 | cmd_str = " ".join(cmd) if isinstance(cmd, list) else str(cmd) 89 | 90 | if "import" in cmd_str: 91 | result = Mock() 92 | result.returncode = 0 93 | result.stderr = "" 94 | return result 95 | elif "remove" in cmd_str: 96 | result = Mock() 97 | result.returncode = 0 98 | return result 99 | elif "create" in cmd_str: 100 | result = Mock() 101 | result.returncode = 0 102 | result.stderr = "" 103 | return result 104 | elif "start" in cmd_str: 105 | result = Mock() 106 | result.returncode = 1 107 | result.stdout = "" 108 | result.stderr = "SyntaxError: invalid syntax" 109 | return result 110 | else: 111 | raise ValueError(f"Unexpected subprocess call: {cmd_str}") 112 | 113 | try: 114 | with patch( 115 | "forge.actors.coder.subprocess.run", side_effect=mock_subprocess_run 116 | ): 117 | coder = _SandboxedPythonCoder( 118 | docker_image="docker://python:3.10", 119 | sqsh_image_path=image_path, 120 | container_name=container_name, 121 | ) 122 | 123 | await coder.setup() 124 | output, err = await coder.execute(code="invalid syntax") 125 | assert "SyntaxError" in err 126 | finally: 127 | if os.path.exists(image_path): 128 | os.unlink(image_path) 129 | -------------------------------------------------------------------------------- /apps/grpo/llama3_8b.yaml: -------------------------------------------------------------------------------- 1 | # Grouped Relative Policy Optimization (GRPO) 2 | # >>> python -m apps.grpo.main --config apps/grpo/llama3_8b.yaml 3 | 4 | # Global configuration 5 | group_size: 4 6 | local_batch_size: 4 # per-device batch size 7 | max_req_tokens: 1024 8 | max_res_tokens: 2048 9 | model: "meta-llama/Meta-Llama-3.1-8B-Instruct" 10 | off_by_n: 1 # Off by one by default 11 | compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM 12 | 13 | # Observability configuration 14 | metric_logging: 15 | wandb: 16 | project: grpo-training 17 | group: grpo_exp_${oc.env:USER} 18 | logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce 19 | console: 20 | logging_mode: global_reduce 21 | 22 | # Dataset configuration 23 | dataset: 24 | path: "openai/gsm8k" 25 | revision: "main" 26 | data_split: "train" 27 | streaming: true 28 | model: ${model} 29 | 30 | # Policy configuration 31 | policy: 32 | engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs 33 | model: ${model} 34 | tensor_parallel_size: 2 35 | pipeline_parallel_size: 1 36 | enforce_eager: ${not:${compile}} 37 | sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams 38 | n: ${group_size} 39 | max_tokens: ${max_res_tokens} 40 | temperature: 1.0 41 | top_p: 1.0 42 | 43 | # Trainer configuration 44 | trainer: 45 | model: 46 | name: llama3 47 | flavor: 8B 48 | hf_assets_path: hf://${model} 49 | optimizer: 50 | name: AdamW 51 | lr: 1e-5 52 | eps: 1e-8 53 | lr_scheduler: 54 | warmup_steps: 1 55 | training: 56 | local_batch_size: ${local_batch_size} 57 | seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens 58 | max_norm: 1.0 59 | steps: 1000000 60 | dtype: bfloat16 61 | gc_freq: 1 62 | compile: 63 | enable: ${compile} 64 | parallelism: 65 | data_parallel_replicate_degree: 1 66 | data_parallel_shard_degree: -1 67 | tensor_parallel_degree: 1 68 | pipeline_parallel_degree: 1 69 | context_parallel_degree: 1 70 | expert_parallel_degree: 1 71 | disable_loss_parallel: true 72 | checkpoint: 73 | enable: true 74 | folder: ./checkpoint # The folder to save checkpoints to. 75 | initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists. 76 | initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo 77 | last_save_in_hf: true 78 | interval: 500 79 | async_mode: "disabled" 80 | activation_checkpoint: 81 | mode: selective 82 | selective_ac_option: op 83 | 84 | # Replay buffer configuration 85 | replay_buffer: 86 | batch_size: ${local_batch_size} 87 | max_policy_age: ${off_by_n} 88 | # This should match the dp_size of TorchTitan 89 | # Here it's set explicitly to 2, because we've set 90 | # 2 GPUs for the trainer and we're using full FSDP. 91 | dp_size: 2 92 | 93 | # Reference model configuration 94 | ref_model: 95 | model: 96 | name: llama3 97 | flavor: 8B 98 | hf_assets_path: hf://${model} 99 | training: 100 | seq_len: ${trainer.training.seq_len} 101 | dtype: bfloat16 102 | gc_freq: 1 103 | compile: 104 | enable: ${compile} 105 | parallelism: 106 | data_parallel_replicate_degree: 1 107 | data_parallel_shard_degree: 1 108 | tensor_parallel_degree: 1 109 | pipeline_parallel_degree: 1 110 | context_parallel_degree: 1 111 | expert_parallel_degree: 1 112 | checkpoint: 113 | initial_load_path: hf://${model} 114 | initial_load_in_hf: true 115 | 116 | # All resource allocations 117 | services: 118 | policy: 119 | procs: ${policy.engine_args.tensor_parallel_size} 120 | num_replicas: 1 121 | with_gpus: true 122 | mesh_name: policy 123 | ref_model: 124 | procs: 1 125 | num_replicas: 1 126 | with_gpus: true 127 | mesh_name: ref_model 128 | reward_actor: 129 | procs: 1 130 | num_replicas: 1 131 | with_gpus: false 132 | mesh_name: reward_actor 133 | 134 | actors: 135 | dataset: 136 | procs: 1 137 | with_gpus: false 138 | mesh_name: dataset 139 | trainer: 140 | procs: 2 141 | with_gpus: true 142 | mesh_name: trainer 143 | replay_buffer: 144 | procs: 1 145 | with_gpus: false 146 | mesh_name: replay_buffer 147 | compute_advantages: 148 | procs: 1 149 | with_gpus: false 150 | mesh_name: compute_advantages 151 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Docs 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | workflow_dispatch: 9 | 10 | jobs: 11 | build-docs: 12 | if: github.repository_owner == 'meta-pytorch' 13 | name: Build Documentation 14 | runs-on: linux.g5.4xlarge.nvidia.gpu 15 | timeout-minutes: 30 16 | steps: 17 | - name: Checkout 18 | uses: actions/checkout@v4 19 | with: 20 | fetch-depth: 0 21 | - name: Setup conda env 22 | uses: conda-incubator/setup-miniconda@v2 23 | with: 24 | auto-update-conda: true 25 | miniconda-version: "latest" 26 | activate-environment: test 27 | python-version: '3.10' 28 | auto-activate: false 29 | - name: Update pip 30 | shell: bash -l {0} 31 | run: python -m pip install --upgrade pip 32 | - name: Install torchforge 33 | shell: bash -l {0} 34 | run: pip install uv && uv pip install . && uv pip install .[docs] 35 | - name: Build docs 36 | shell: bash -l {0} 37 | working-directory: docs 38 | run: make html 39 | - name: Upload docs artifact 40 | uses: actions/upload-artifact@v4 41 | with: 42 | name: docs 43 | path: docs/build/html/ 44 | 45 | doc-preview: 46 | runs-on: linux.large 47 | needs: build-docs 48 | if: ${{ github.event_name == 'pull_request' }} 49 | steps: 50 | - name: Checkout 51 | uses: actions/checkout@v4 52 | - name: Download artifact 53 | uses: actions/download-artifact@v4 54 | with: 55 | name: docs 56 | path: docs 57 | - name: Add noindex to preview docs 58 | run: | 59 | echo "Adding noindex meta tag to prevent search engine indexing of preview docs" 60 | find docs -name "*.html" -print0 | xargs -0 sed -i 's//\n /' 61 | - name: Upload docs preview 62 | uses: seemethere/upload-artifact-s3@v5 63 | if: ${{ github.event_name == 'pull_request' }} 64 | with: 65 | retention-days: 14 66 | s3-bucket: doc-previews 67 | if-no-files-found: error 68 | path: docs 69 | s3-prefix: meta-pytorch/torchforge/${{ github.event.pull_request.number }} 70 | 71 | upload: 72 | runs-on: ubuntu-latest 73 | permissions: 74 | # Grant write permission here so that the doc can be pushed to gh-pages branch 75 | contents: write 76 | needs: build-docs 77 | if: github.repository == 'meta-pytorch/torchforge' && github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/v') || github.event_name == 'workflow_dispatch') 78 | steps: 79 | - name: Checkout 80 | uses: actions/checkout@v4 81 | with: 82 | ref: gh-pages 83 | persist-credentials: true 84 | - name: Download artifact 85 | uses: actions/download-artifact@v4 86 | with: 87 | name: docs 88 | path: docs 89 | #- name: Add no-index tag 90 | # run: | 91 | # REF_NAME=$(echo "${{ github.ref }}") 92 | # echo "Ref name: ${REF_NAME}" 93 | # if [[ "${{ github.ref }}" == 'refs/heads/main' ]]; then 94 | # find docs -name "*.html" -print0 | xargs -0 sed -i '//a \ \ '; 95 | # fi 96 | - name: Move and commit changes 97 | run: | 98 | set -euo pipefail 99 | # Get github.ref for the output doc folder. By default "main" 100 | # If matches a tag like refs/tags/v1.12.0-rc3 or 101 | # refs/tags/v1.12.0 convert to 1.12 102 | GITHUB_REF=${{ github.ref }} 103 | 104 | # Convert refs/tags/v1.12.0rc3 into 1.12. 105 | # Adopted from https://github.com/pytorch/pytorch/blob/main/.github/workflows/_docs.yml#L150C11-L155C13 106 | if [[ "${GITHUB_REF}" =~ ^refs/tags/v([0-9]+\.[0-9]+)\.* ]]; then 107 | TARGET_FOLDER="${BASH_REMATCH[1]}" 108 | else 109 | TARGET_FOLDER="main" 110 | fi 111 | echo "Target Folder: ${TARGET_FOLDER}" 112 | 113 | mkdir -p "${TARGET_FOLDER}" 114 | rm -rf "${TARGET_FOLDER}"/* 115 | mv docs/* "${TARGET_FOLDER}" 116 | 117 | git config user.name 'pytorchbot' 118 | git config user.email 'soumith+bot@pytorch.org' 119 | git add "${TARGET_FOLDER}" || true 120 | git commit -m "auto-generating sphinx docs" || true 121 | git push -f 122 | -------------------------------------------------------------------------------- /apps/grpo/qwen3_1_7b.yaml: -------------------------------------------------------------------------------- 1 | # Grouped Relative Policy Optimization (GRPO) 2 | # >>> python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml 3 | 4 | # Global configuration 5 | group_size: 8 6 | local_batch_size: 16 # per-device batch size 7 | max_req_tokens: 1024 8 | max_res_tokens: 2048 9 | model: "Qwen/Qwen3-1.7B" 10 | off_by_n: 1 # Off by one by default 11 | compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM 12 | 13 | # Main loop configuration 14 | rollout_threads: 1 # Recommended to set equal to policy.num_replicas 15 | 16 | 17 | # Observability configuration 18 | metric_logging: 19 | wandb: 20 | project: grpo-training 21 | group: grpo_exp_${oc.env:USER} 22 | logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce 23 | console: 24 | logging_mode: global_reduce 25 | 26 | # Dataset configuration 27 | dataset: 28 | path: "openai/gsm8k" 29 | revision: "main" 30 | data_split: "train" 31 | streaming: true 32 | model: ${model} 33 | 34 | # Policy configuration 35 | policy: 36 | engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs 37 | model: ${model} 38 | tensor_parallel_size: 1 39 | pipeline_parallel_size: 1 40 | enforce_eager: ${not:${compile}} 41 | sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams 42 | n: ${group_size} 43 | max_tokens: ${max_res_tokens} 44 | temperature: 1.0 45 | top_p: 1.0 46 | 47 | # Trainer configuration 48 | trainer: 49 | model: 50 | name: qwen3 51 | flavor: 1.7B 52 | hf_assets_path: hf://${model} 53 | optimizer: 54 | name: AdamW 55 | lr: 1e-5 56 | eps: 1e-8 57 | lr_scheduler: 58 | warmup_steps: 1 59 | training: 60 | local_batch_size: ${local_batch_size} 61 | seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens 62 | max_norm: 1.0 63 | steps: 1000000 64 | dtype: bfloat16 65 | gc_freq: 1 66 | compile: 67 | enable: ${compile} 68 | parallelism: 69 | data_parallel_replicate_degree: 1 70 | data_parallel_shard_degree: 1 71 | tensor_parallel_degree: 1 72 | pipeline_parallel_degree: 1 73 | context_parallel_degree: 1 74 | expert_parallel_degree: 1 75 | disable_loss_parallel: true 76 | checkpoint: 77 | enable: true 78 | folder: ./checkpoint # The folder to save checkpoints to. 79 | initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists. 80 | initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo 81 | last_save_in_hf: true 82 | interval: 500 83 | async_mode: "disabled" 84 | activation_checkpoint: 85 | mode: selective 86 | selective_ac_option: op 87 | 88 | # Replay buffer configuration 89 | replay_buffer: 90 | batch_size: ${local_batch_size} 91 | max_policy_age: ${off_by_n} 92 | dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree 93 | 94 | # Reference model configuration 95 | ref_model: 96 | model: 97 | name: qwen3 98 | flavor: 1.7B 99 | hf_assets_path: hf://${model} 100 | training: 101 | seq_len: ${trainer.training.seq_len} 102 | dtype: bfloat16 103 | gc_freq: 1 104 | compile: 105 | enable: ${compile} 106 | parallelism: 107 | data_parallel_replicate_degree: 1 108 | data_parallel_shard_degree: 1 109 | tensor_parallel_degree: 1 110 | pipeline_parallel_degree: 1 111 | context_parallel_degree: 1 112 | expert_parallel_degree: 1 113 | checkpoint: 114 | enable: true 115 | initial_load_path: hf://${model} 116 | initial_load_in_hf: true 117 | 118 | # All resource allocations 119 | services: 120 | policy: 121 | procs: ${policy.engine_args.tensor_parallel_size} 122 | num_replicas: 1 123 | mesh_name: policy 124 | with_gpus: true 125 | ref_model: 126 | procs: 1 127 | num_replicas: 1 128 | mesh_name: ref_model 129 | with_gpus: true 130 | reward_actor: 131 | procs: 1 132 | num_replicas: 1 133 | mesh_name: reward_actor 134 | with_gpus: false 135 | 136 | actors: 137 | dataset: 138 | procs: 1 139 | with_gpus: false 140 | mesh_name: dataset 141 | trainer: 142 | procs: 1 143 | with_gpus: true 144 | mesh_name: trainer 145 | replay_buffer: 146 | procs: 1 147 | with_gpus: false 148 | mesh_name: replay_buffer 149 | compute_advantages: 150 | procs: 1 151 | with_gpus: false 152 | mesh_name: compute_advantages 153 | -------------------------------------------------------------------------------- /apps/grpo/slurm/qwen3_8b.yaml: -------------------------------------------------------------------------------- 1 | # Grouped Relative Policy Optimization (GRPO) 2 | # ./apps/grpo/slurm/submit.sh qwen3_8b 3 | 4 | # Global configuration 5 | group_size: 16 6 | local_batch_size: 4 # per-device batch size 7 | max_req_tokens: 1024 8 | max_res_tokens: 2048 9 | model: "Qwen/Qwen3-8B" 10 | off_by_n: 1 # Off by one by default 11 | compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM 12 | 13 | 14 | provisioner: 15 | launcher: slurm 16 | memMB: 2047962 17 | cpu: 192 18 | account: agentic-models 19 | qos: h200_capabilities_shared 20 | 21 | # Observability configuration 22 | metric_logging: 23 | wandb: 24 | entity: agentic-models 25 | project: grpo-training 26 | group: grpo_exp_${oc.env:USER} 27 | logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce 28 | console: 29 | logging_mode: global_reduce 30 | 31 | # Dataset configuration 32 | dataset: 33 | path: "openai/gsm8k" 34 | revision: "main" 35 | data_split: "train" 36 | streaming: true 37 | model: ${model} 38 | 39 | # Policy configuration 40 | policy: 41 | engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs 42 | model: ${model} 43 | tensor_parallel_size: 2 44 | pipeline_parallel_size: 1 45 | enforce_eager: ${not:${compile}} 46 | sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams 47 | n: ${group_size} 48 | max_tokens: ${max_res_tokens} 49 | temperature: 1.0 50 | top_p: 1.0 51 | 52 | # Trainer configuration 53 | trainer: 54 | model: 55 | name: qwen3 56 | flavor: 8B 57 | hf_assets_path: hf://${model} 58 | optimizer: 59 | name: AdamW 60 | lr: 1e-5 61 | eps: 1e-8 62 | lr_scheduler: 63 | warmup_steps: 1 64 | training: 65 | local_batch_size: ${local_batch_size} 66 | seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens 67 | max_norm: 1.0 68 | steps: 1000000 69 | dtype: bfloat16 70 | gc_freq: 1 71 | compile: 72 | enable: ${compile} 73 | parallelism: 74 | data_parallel_replicate_degree: 1 75 | data_parallel_shard_degree: -1 76 | tensor_parallel_degree: 1 77 | pipeline_parallel_degree: 1 78 | context_parallel_degree: 1 79 | expert_parallel_degree: 1 80 | disable_loss_parallel: true 81 | checkpoint: 82 | enable: true 83 | folder: ./checkpoint # The folder to save checkpoints to. 84 | initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists. 85 | initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo 86 | last_save_in_hf: true 87 | interval: 500 88 | async_mode: "disabled" 89 | activation_checkpoint: 90 | mode: selective 91 | selective_ac_option: op 92 | 93 | # Replay buffer configuration 94 | replay_buffer: 95 | batch_size: ${local_batch_size} 96 | max_policy_age: ${off_by_n} 97 | # This should match the dp_size of TorchTitan 98 | # Here it's set explicitly to 2, because we've set 99 | # 2 GPUs for the trainer and we're using full FSDP. 100 | dp_size: 2 101 | 102 | # Reference model configuration 103 | ref_model: 104 | model: 105 | name: qwen3 106 | flavor: 8B 107 | hf_assets_path: hf://${model} 108 | training: 109 | seq_len: ${trainer.training.seq_len} 110 | dtype: bfloat16 111 | gc_freq: 1 112 | compile: 113 | enable: ${compile} 114 | parallelism: 115 | data_parallel_replicate_degree: 1 116 | data_parallel_shard_degree: 1 117 | tensor_parallel_degree: 1 118 | pipeline_parallel_degree: 1 119 | context_parallel_degree: 1 120 | expert_parallel_degree: 1 121 | checkpoint: 122 | initial_load_path: hf://${model} 123 | initial_load_in_hf: true 124 | 125 | # All resource allocations 126 | services: 127 | policy: 128 | procs: ${policy.engine_args.tensor_parallel_size} 129 | num_replicas: 1 130 | hosts: 1 131 | with_gpus: true 132 | mesh_name: policy 133 | ref_model: 134 | procs: 1 135 | num_replicas: 1 136 | with_gpus: true 137 | mesh_name: ref_model 138 | reward_actor: 139 | procs: 1 140 | num_replicas: 1 141 | with_gpus: false 142 | mesh_name: reward_actor 143 | 144 | actors: 145 | dataset: 146 | procs: 1 147 | with_gpus: false 148 | mesh_name: dataset 149 | trainer: 150 | procs: 2 151 | with_gpus: true 152 | mesh_name: trainer 153 | replay_buffer: 154 | procs: 1 155 | with_gpus: false 156 | mesh_name: replay_buffer 157 | compute_advantages: 158 | procs: 1 159 | with_gpus: false 160 | mesh_name: compute_advantages 161 | -------------------------------------------------------------------------------- /tests/unit_tests/test_env_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Unit tests for env_constants module.""" 8 | 9 | import os 10 | 11 | from forge.env import all_env_vars, DISABLE_PERF_METRICS, EnvVar, FORGE_DISABLE_METRICS 12 | 13 | 14 | class TestEnvVarGetValue: 15 | """Test the EnvVar.get_value() method.""" 16 | 17 | def test_get_value_returns_default_when_unset(self): 18 | """Test get_value returns default when env var is not set.""" 19 | if "DISABLE_PERF_METRICS" in os.environ: 20 | del os.environ["DISABLE_PERF_METRICS"] 21 | 22 | value = DISABLE_PERF_METRICS.get_value() 23 | assert value is False 24 | 25 | def test_get_value_returns_env_value_when_set(self): 26 | """Test get_value returns env var value when set.""" 27 | from forge.env import MONARCH_STDERR_LEVEL 28 | 29 | os.environ["MONARCH_STDERR_LOG"] = "debug" 30 | 31 | try: 32 | value = MONARCH_STDERR_LEVEL.get_value() 33 | assert value == "debug" 34 | finally: 35 | del os.environ["MONARCH_STDERR_LOG"] 36 | 37 | def test_get_value_bool_auto_cast_with_true(self): 38 | """Test get_value auto-casts 'true' to bool.""" 39 | os.environ["DISABLE_PERF_METRICS"] = "true" 40 | try: 41 | assert DISABLE_PERF_METRICS.get_value() is True 42 | finally: 43 | del os.environ["DISABLE_PERF_METRICS"] 44 | 45 | def test_get_value_bool_auto_cast_with_one(self): 46 | """Test get_value auto-casts '1' to bool.""" 47 | os.environ["DISABLE_PERF_METRICS"] = "1" 48 | try: 49 | assert DISABLE_PERF_METRICS.get_value() is True 50 | finally: 51 | del os.environ["DISABLE_PERF_METRICS"] 52 | 53 | def test_get_value_bool_auto_cast_with_false(self): 54 | """Test get_value auto-casts 'false' to bool.""" 55 | os.environ["DISABLE_PERF_METRICS"] = "false" 56 | try: 57 | assert DISABLE_PERF_METRICS.get_value() is False 58 | finally: 59 | del os.environ["DISABLE_PERF_METRICS"] 60 | 61 | 62 | class TestPredefinedConstants: 63 | """Test the predefined environment variable constants.""" 64 | 65 | def test_predefined_constants_structure(self): 66 | """Test that predefined constants are properly defined.""" 67 | assert isinstance(DISABLE_PERF_METRICS, EnvVar) 68 | assert DISABLE_PERF_METRICS.name == "DISABLE_PERF_METRICS" 69 | assert DISABLE_PERF_METRICS.default is False 70 | 71 | assert isinstance(FORGE_DISABLE_METRICS, EnvVar) 72 | assert FORGE_DISABLE_METRICS.name == "FORGE_DISABLE_METRICS" 73 | assert FORGE_DISABLE_METRICS.default is False 74 | 75 | def test_predefined_constants_work_with_get_value(self): 76 | """Test that predefined constants work with get_value method.""" 77 | if DISABLE_PERF_METRICS.name in os.environ: 78 | del os.environ[DISABLE_PERF_METRICS.name] 79 | 80 | assert DISABLE_PERF_METRICS.get_value() is False 81 | 82 | os.environ[DISABLE_PERF_METRICS.name] = "true" 83 | try: 84 | assert DISABLE_PERF_METRICS.get_value() is True 85 | finally: 86 | del os.environ[DISABLE_PERF_METRICS.name] 87 | 88 | 89 | class TestAllEnvVars: 90 | """Test the all_env_vars() function.""" 91 | 92 | def test_all_env_vars_returns_list(self): 93 | """Test that all_env_vars returns a list.""" 94 | env_vars = all_env_vars() 95 | assert isinstance(env_vars, list) 96 | 97 | def test_all_env_vars_contains_only_env_var_instances(self): 98 | """Test that all_env_vars returns only EnvVar instances.""" 99 | env_vars = all_env_vars() 100 | assert len(env_vars) > 0 101 | assert all(isinstance(var, EnvVar) for var in env_vars) 102 | 103 | def test_all_env_vars_contains_expected_constants(self): 104 | """Test that all_env_vars includes known constants.""" 105 | env_vars = all_env_vars() 106 | env_var_names = {var.name for var in env_vars} 107 | 108 | assert "DISABLE_PERF_METRICS" in env_var_names 109 | assert "FORGE_DISABLE_METRICS" in env_var_names 110 | assert "MONARCH_STDERR_LOG" in env_var_names 111 | 112 | def test_all_env_vars_can_iterate_and_get_values(self): 113 | """Test that all_env_vars can be used to iterate and get values.""" 114 | for env_var in all_env_vars(): 115 | value = env_var.get_value() 116 | assert value is not None or env_var.default is None 117 | -------------------------------------------------------------------------------- /src/forge/types.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass, field 8 | from enum import Enum 9 | from typing import Any, TypedDict, Union 10 | 11 | 12 | class Message(TypedDict): 13 | role: str 14 | content: str | dict[str, Any] 15 | tools: dict[str, Any] | None 16 | 17 | 18 | @dataclass(kw_only=True) 19 | class Observation: 20 | """Base class for environment observations. 21 | 22 | Contract: 23 | - Should contain all information needed by an agent to make decisions 24 | - Should be serializable/deserializable 25 | - Should be immutable (or treated as such) 26 | Args: 27 | done: Whether the episode/conversation is complete 28 | reward: Optional reward signal (can be boolean, int, or float) 29 | metadata: Additional data that doesn't affect agent decisions but may be useful 30 | for transforms, logging, evaluation, etc. 31 | """ 32 | 33 | done: bool = False 34 | reward: bool | int | float | None = None 35 | metadata: dict[str, Any] = field(default_factory=dict) 36 | 37 | 38 | class Launcher(Enum): 39 | MAST = "mast" 40 | SLURM = "slurm" 41 | 42 | 43 | @dataclass 44 | class ProcessConfig: 45 | """A configuration for allocating Monarch ProcMeshes. 46 | 47 | Args: 48 | procs (int): Number of processes to launch for each replica of the service. 49 | with_gpus (bool, optional): Whether to allocate GPUs for the service processes. 50 | hosts (int | None, optional): Number of hosts to allocate for each replica. 51 | If this is set to None, it will use the local host. 52 | If this is set to a positive integer, it will run on a remote host. 53 | mesh_name (str | None, optional): Name of the mesh to use for the proc_mesh. 54 | 55 | """ 56 | 57 | procs: int = 1 58 | with_gpus: bool = False 59 | hosts: int | None = None 60 | mesh_name: str | None = None 61 | 62 | 63 | @dataclass 64 | class ServiceConfig: 65 | """The configuration for a Forge service. 66 | 67 | Args: 68 | procs (int): Number of processes to launch for each replica of the service. 69 | num_replicas (int): Number of replicas to launch for the service. 70 | with_gpus (bool, optional): Whether to allocate GPUs for the service processes. 71 | hosts (int | None, optional): Number of hosts to allocate for each replica. 72 | If this is set to None, it will use the local host. 73 | If this is set to a positive integer, it will run on a remote host. 74 | health_poll_rate (float, optional): Frequency (in seconds) to poll for health status. 75 | replica_max_concurrent_requests (int, optional): Maximum number of concurrent requests per replica. 76 | return_first_rank_result (bool, optional): Whether to auto-unwrap ValueMesh to the first rank's result. 77 | """ 78 | 79 | procs: int 80 | num_replicas: int 81 | with_gpus: bool = False 82 | hosts: int | None = None 83 | health_poll_rate: float = 0.2 84 | replica_max_concurrent_requests: int = 10 85 | return_first_rank_result: bool = True 86 | mesh_name: str | None = None 87 | 88 | def to_process_config(self) -> ProcessConfig: 89 | """Extract ProcessConfig from this ServiceConfig. 90 | 91 | Maps procs to procs for ProcessConfig. 92 | """ 93 | return ProcessConfig( 94 | procs=self.procs, 95 | with_gpus=self.with_gpus, 96 | hosts=self.hosts, 97 | mesh_name=self.mesh_name, 98 | ) 99 | 100 | 101 | Scalar = Union[int, float] 102 | 103 | 104 | @dataclass 105 | class LauncherConfig: 106 | """A launcher config for the scheduler.""" 107 | 108 | launcher: Launcher 109 | job_name: str = "" 110 | services: dict[str, ServiceConfig] = field(default_factory=dict) 111 | actors: dict[str, ProcessConfig] = field(default_factory=dict) 112 | cpu: int | None = None # CPUs per node (required for SLURM, can get with sinfo) 113 | memMB: int | None = ( # noqa: N815 114 | None # Memory in MB per node (required for SLURM, can get with sinfo) 115 | ) 116 | gpu: int = 8 # GPUs per node (required for SLURM, can get with sinfo) 117 | account: str = "" 118 | qos: str = "" 119 | 120 | def __post_init__(self): 121 | if isinstance(self.launcher, str): 122 | self.launcher = Launcher(self.launcher) 123 | 124 | 125 | @dataclass 126 | class ProvisionerConfig: 127 | """A config for the forge provisioner.""" 128 | 129 | launcher_config: LauncherConfig 130 | -------------------------------------------------------------------------------- /apps/grpo/slurm/qwen3_30b_a3b.yaml: -------------------------------------------------------------------------------- 1 | # Grouped Relative Policy Optimization (GRPO) 2 | # NOTE - This has not been tested for correctness yet! All testing so far has been only for infrastructure stability 3 | # ./apps/grpo/slurm/submit.sh qwen3_30b_a3b 4 | 5 | # Global configuration 6 | group_size: 4 7 | local_batch_size: 1 # per-device batch size 8 | max_req_tokens: 1024 9 | max_res_tokens: 1024 10 | model: "Qwen/Qwen3-30B-A3B" 11 | off_by_n: 1 # Off by one by default 12 | 13 | provisioner: 14 | launcher: slurm 15 | memMB: 2047962 16 | cpu: 192 17 | account: agentic-models 18 | qos: h200_capabilities_shared 19 | 20 | # Main loop configuration 21 | rollout_threads: 32 # make this 4x the number of policy replicas seems to work well 22 | 23 | # Observability configuration 24 | metric_logging: 25 | wandb: 26 | entity: agentic-models 27 | project: grpo-training 28 | group: grpo_exp_${oc.env:USER} 29 | logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce 30 | console: 31 | logging_mode: global_reduce 32 | 33 | # Dataset configuration 34 | dataset: 35 | path: "openai/gsm8k" 36 | revision: "main" 37 | data_split: "train" 38 | streaming: true 39 | model: ${model} 40 | 41 | # Policy configuration 42 | policy: 43 | engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs 44 | model: ${model} 45 | tensor_parallel_size: 4 46 | pipeline_parallel_size: 1 47 | enforce_eager: false 48 | sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams 49 | n: ${group_size} 50 | max_tokens: ${max_res_tokens} 51 | temperature: 1.0 52 | top_p: 1.0 53 | 54 | # Trainer configuration 55 | trainer: 56 | model: 57 | name: qwen3 58 | flavor: 30B-A3B 59 | hf_assets_path: hf://${model} 60 | optimizer: 61 | name: AdamW 62 | lr: 1e-5 63 | eps: 1e-8 64 | lr_scheduler: 65 | warmup_steps: 1 66 | training: 67 | local_batch_size: ${local_batch_size} 68 | seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens 69 | max_norm: 1.0 70 | steps: 1000000 71 | dtype: bfloat16 72 | gc_freq: 1 73 | compile: 74 | enable: false 75 | parallelism: 76 | data_parallel_replicate_degree: 1 77 | data_parallel_shard_degree: -1 78 | tensor_parallel_degree: 1 79 | pipeline_parallel_degree: 1 80 | context_parallel_degree: 1 81 | expert_parallel_degree: 1 82 | expert_tensor_parallel_degree: 1 83 | disable_loss_parallel: true 84 | checkpoint: 85 | enable: true 86 | folder: ./checkpoint # The folder to save checkpoints to. 87 | initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists. 88 | initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo 89 | last_save_in_hf: true 90 | interval: 500 91 | async_mode: "disabled" 92 | activation_checkpoint: 93 | mode: full 94 | 95 | # Replay buffer configuration 96 | replay_buffer: 97 | batch_size: ${local_batch_size} 98 | max_policy_age: ${off_by_n} 99 | # dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree 100 | dp_size: 4 101 | 102 | # Reference model configuration 103 | ref_model: 104 | model: 105 | name: qwen3 106 | flavor: 30B-A3B 107 | hf_assets_path: hf://${model} 108 | training: 109 | seq_len: ${trainer.training.seq_len} 110 | dtype: bfloat16 111 | gc_freq: 1 112 | compile: 113 | enable: false 114 | parallelism: 115 | data_parallel_replicate_degree: 1 116 | data_parallel_shard_degree: -1 117 | tensor_parallel_degree: 1 118 | pipeline_parallel_degree: 1 119 | context_parallel_degree: 1 120 | expert_parallel_degree: 1 121 | checkpoint: 122 | enable: true 123 | initial_load_path: hf://${model} 124 | initial_load_in_hf: true 125 | 126 | # All resource allocations 127 | services: 128 | policy: 129 | procs: ${policy.engine_args.tensor_parallel_size} 130 | num_replicas: 1 131 | hosts: 1 132 | with_gpus: true 133 | mesh_name: policy 134 | ref_model: 135 | procs: 4 136 | num_replicas: 1 137 | with_gpus: true 138 | mesh_name: ref_model 139 | reward_actor: 140 | procs: 1 141 | num_replicas: 1 142 | with_gpus: false 143 | mesh_name: reward_actor 144 | 145 | actors: 146 | dataset: 147 | procs: 1 148 | with_gpus: false 149 | mesh_name: dataset 150 | trainer: 151 | procs: 4 152 | hosts: 1 153 | with_gpus: true 154 | mesh_name: trainer 155 | replay_buffer: 156 | procs: 1 157 | with_gpus: false 158 | mesh_name: replay_buffer 159 | compute_advantages: 160 | procs: 1 161 | with_gpus: false 162 | mesh_name: compute_advantages 163 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Derived from basic .gitignore template for python projects: 2 | # https://github.com/github/gitignore/blob/main/Python.gitignore 3 | # Please maintain the alphabetic order of the section titles 4 | # To debug why a file is being ignored, use the command: 5 | # git check-ignore -v $my_ignored_file 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Cython debug symbols 16 | cython_debug/ 17 | 18 | # SLURM logs 19 | slogs/ 20 | slurm-* 21 | 22 | # Celery stuff 23 | celerybeat-schedule 24 | celerybeat.pid 25 | 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | share/python-wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | MANIFEST 44 | 45 | # Django stuff 46 | *.log 47 | local_settings.py 48 | db.sqlite3 49 | db.sqlite3-journal 50 | 51 | # Environments 52 | .env 53 | .venv 54 | uv.lock 55 | env/ 56 | venv/ 57 | ENV/ 58 | env.bak/ 59 | venv.bak/ 60 | 61 | # Flask stuff 62 | instance/ 63 | .webassets-cache 64 | 65 | # Installer logs 66 | pip-log.txt 67 | pip-delete-this-directory.txt 68 | 69 | # IPython 70 | profile_default/ 71 | ipython_config.py 72 | 73 | # Jupyter Notebook 74 | *.ipynb_checkpoints 75 | 76 | # mkdocs documentation 77 | /site 78 | 79 | # Model saving / checkpointing 80 | *.pt 81 | *.pth 82 | *.ckpt 83 | *.distcp 84 | .metadata 85 | 86 | # mypy 87 | .mypy_cache/ 88 | .dmypy.json 89 | dmypy.json 90 | 91 | # PyBuilder 92 | .pybuilder/ 93 | target/ 94 | 95 | # PyCharm 96 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 97 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 98 | # and can be added to the global gitignore or merged into this file. For a more nuclear 99 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 100 | #.idea/ 101 | 102 | # PyInstaller 103 | # Usually these files are written by a python script from a template 104 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 105 | *.manifest 106 | *.spec 107 | 108 | # pyenv 109 | # For a library or package, you might want to ignore these files since the code is 110 | # intended to run in multiple environments; otherwise, check them in: 111 | # .python-version 112 | 113 | # pipenv 114 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 115 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 116 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 117 | # install all needed dependencies. 118 | # Pipfile.lock 119 | 120 | # poetry 121 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 122 | # This is especially recommended for binary packages to ensure reproducibility, and is more 123 | # commonly ignored for libraries. 124 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 125 | # poetry.lock 126 | 127 | # PEP 582: https://peps.python.org/pep-0582/ 128 | # This PEP proposes to add to Python a mechanism to automatically recognize a __pypackages__ 129 | # directory and prefer importing packages installed in this location over user or global site-packages. 130 | # This will avoid the steps to create, activate or deactivate virtual environments. Python will use 131 | # the __pypackages__ from the base directory of the script when present. 132 | __pypackages__/ 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | # pytype static type analyzer 138 | .pytype/ 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # SageMath parsed files 144 | *.sage.py 145 | 146 | # Scrapy stuff: 147 | .scrapy 148 | 149 | # Sphinx documentation 150 | docs/build 151 | # sphinx-gallery 152 | docs/source/generated_examples/ 153 | docs/source/gen_modules/ 154 | docs/source/generated/ 155 | docs/source/sg_execution_times.rst 156 | docs/source/tutorials/* 157 | # pytorch-sphinx-theme gets installed here 158 | docs/src 159 | 160 | # Spyder project settings 161 | .spyderproject 162 | .spyproject 163 | 164 | # System / program generated files 165 | *.err 166 | *.log 167 | *.swp 168 | .DS_Store 169 | 170 | # Translations 171 | *.mo 172 | *.pot 173 | 174 | # TorchX 175 | *.torchxconfig 176 | 177 | # Unit test / coverage reports 178 | htmlcov/ 179 | .tox/ 180 | .nox/ 181 | .coverage 182 | .coverage.* 183 | .cache 184 | nosetests.xml 185 | coverage.xml 186 | *.cover 187 | *.py,cover 188 | .hypothesis/ 189 | .pytest_cache/ 190 | cover/ 191 | 192 | # VSCode 193 | .vscode/ 194 | 195 | # wandb 196 | wandb/ 197 | 198 | assets/wheels/vllm*.whl 199 | 200 | # DCP artifacts 201 | forge_dcp_tmp/ 202 | demo_top_down.md 203 | 204 | 205 | # enroot / sqsh 206 | *.sqsh 207 | -------------------------------------------------------------------------------- /apps/grpo/slurm/qwen3_32b.yaml: -------------------------------------------------------------------------------- 1 | # Grouped Relative Policy Optimization (GRPO) 2 | # NOTE - This has not been tested for correctness yet! All testing so far has been only for infrastructure stability 3 | # ./apps/grpo/slurm/submit.sh qwen3_32b 4 | 5 | # Global configuration 6 | group_size: 16 7 | local_batch_size: 2 # per-device batch size 8 | max_req_tokens: 1024 9 | max_res_tokens: 1024 10 | model: "Qwen/Qwen3-32B" 11 | off_by_n: 1 # Off by one by default 12 | compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM 13 | 14 | provisioner: 15 | launcher: slurm 16 | memMB: 2047962 17 | cpu: 192 18 | account: agentic-models 19 | qos: h200_capabilities_shared 20 | 21 | # Main loop configuration 22 | rollout_threads: 32 # make this 4x the number of policy replicas seems to work well 23 | 24 | # Observability configuration 25 | metric_logging: 26 | wandb: 27 | entity: agentic-models 28 | project: grpo-training 29 | group: grpo_exp_${oc.env:USER} 30 | logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce 31 | console: 32 | logging_mode: global_reduce 33 | 34 | # Dataset configuration 35 | dataset: 36 | path: "openai/gsm8k" 37 | revision: "main" 38 | data_split: "train" 39 | streaming: true 40 | model: ${model} 41 | 42 | # Policy configuration 43 | policy: 44 | engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs 45 | model: ${model} 46 | tensor_parallel_size: 4 47 | pipeline_parallel_size: 1 48 | enforce_eager: ${not:${compile}} 49 | sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams 50 | n: ${group_size} 51 | max_tokens: ${max_res_tokens} 52 | temperature: 1.0 53 | top_p: 1.0 54 | 55 | # Trainer configuration 56 | trainer: 57 | model: 58 | name: qwen3 59 | flavor: 32B 60 | hf_assets_path: hf://${model} 61 | optimizer: 62 | name: AdamW 63 | lr: 1e-5 64 | eps: 1e-8 65 | lr_scheduler: 66 | warmup_steps: 1 67 | training: 68 | local_batch_size: ${local_batch_size} 69 | seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens 70 | max_norm: 1.0 71 | steps: 1000000 72 | dtype: bfloat16 73 | gc_freq: 1 74 | compile: 75 | enable: ${compile} 76 | parallelism: 77 | data_parallel_replicate_degree: 1 78 | data_parallel_shard_degree: 1 79 | tensor_parallel_degree: 8 80 | pipeline_parallel_degree: 1 81 | context_parallel_degree: 1 82 | expert_parallel_degree: 1 83 | disable_loss_parallel: true 84 | checkpoint: 85 | enable: true 86 | folder: ./checkpoint # The folder to save checkpoints to. 87 | initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists. 88 | initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo 89 | last_save_in_hf: true 90 | interval: 500 91 | async_mode: "disabled" 92 | activation_checkpoint: 93 | mode: full 94 | 95 | # Replay buffer configuration 96 | replay_buffer: 97 | batch_size: ${local_batch_size} 98 | max_policy_age: ${off_by_n} 99 | # dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree 100 | dp_size: 1 101 | 102 | # Reference model configuration 103 | ref_model: 104 | model: 105 | name: qwen3 106 | flavor: 32B 107 | hf_assets_path: hf://${model} 108 | training: 109 | seq_len: ${trainer.training.seq_len} 110 | dtype: bfloat16 111 | gc_freq: 1 112 | compile: 113 | enable: ${compile} 114 | parallelism: 115 | data_parallel_replicate_degree: 1 116 | data_parallel_shard_degree: 1 117 | tensor_parallel_degree: 4 118 | pipeline_parallel_degree: 1 119 | context_parallel_degree: 1 120 | expert_parallel_degree: 1 121 | checkpoint: 122 | enable: true 123 | initial_load_path: hf://${model} 124 | initial_load_in_hf: true 125 | 126 | # All resource allocations 127 | services: 128 | policy: 129 | procs: ${policy.engine_args.tensor_parallel_size} 130 | num_replicas: 4 131 | hosts: 1 132 | with_gpus: true 133 | mesh_name: policy 134 | ref_model: 135 | procs: ${ref_model.parallelism.tensor_parallel_degree} 136 | num_replicas: 1 137 | with_gpus: true 138 | mesh_name: ref_model 139 | reward_actor: 140 | procs: 1 141 | num_replicas: 1 142 | with_gpus: false 143 | mesh_name: reward_actor 144 | 145 | actors: 146 | dataset: 147 | procs: 1 148 | with_gpus: false 149 | mesh_name: dataset 150 | trainer: 151 | procs: 8 152 | hosts: 1 153 | with_gpus: true 154 | mesh_name: trainer 155 | replay_buffer: 156 | procs: 1 157 | with_gpus: false 158 | mesh_name: replay_buffer 159 | compute_advantages: 160 | procs: 1 161 | with_gpus: false 162 | mesh_name: compute_advantages 163 | --------------------------------------------------------------------------------