├── .env ├── .coveragerc ├── docs ├── source │ ├── images │ │ ├── joining.png │ │ ├── data_flow.png │ │ └── energon3_darkbg_border.png │ ├── api │ │ ├── modules.md │ │ ├── modules_data.md │ │ └── cli.md │ ├── internals │ │ ├── code_structure.md │ │ └── contrib_guidelines.md │ ├── intro │ │ ├── introduction.md │ │ └── installation.md │ ├── basic │ │ ├── basics_flow.md │ │ └── save_restore.md │ ├── _static │ │ └── css │ │ │ └── custom.css │ ├── advanced │ │ ├── repro_scaling.md │ │ ├── epochized_blending.md │ │ ├── custom_blending.md │ │ ├── remote_dataset.md │ │ ├── grouping.md │ │ ├── packing.md │ │ └── joining_datasets.md │ └── index.md ├── Makefile └── README.md ├── src └── megatron │ └── energon │ ├── cli │ ├── __init__.py │ └── main.py │ ├── tools │ ├── __init__.py │ ├── info.py │ └── lint.py │ ├── epathlib │ ├── __init__.py │ └── rclone_config.py │ ├── wrappers │ ├── skip.py │ ├── _log_exception.py │ ├── __init__.py │ ├── concat_dataset.py │ ├── filter_dataset.py │ ├── shuffle_buffer_dataset.py │ ├── repeat_dataset.py │ ├── log_sample_dataset.py │ ├── gc_dataset.py │ ├── mix_batch_dataset.py │ ├── epochize_dataset.py │ └── limit_dataset.py │ ├── flavors │ ├── webdataset │ │ ├── empty_dataset_error.py │ │ ├── config.py │ │ ├── video_data.py │ │ ├── __init__.py │ │ ├── error_handler.py │ │ ├── structs.py │ │ ├── metadata.py │ │ ├── field_access.py │ │ ├── standard_webdataset.py │ │ ├── default_decoder_webdataset.py │ │ ├── itar.py │ │ └── default_generic_webdataset.py │ ├── text.py │ ├── image.py │ ├── captioning.py │ ├── interleaved.py │ ├── image_classification.py │ ├── multichoice_vqa.py │ ├── vqa.py │ ├── vid_qa.py │ ├── similarity_interleaved.py │ ├── vqa_and_ocr.py │ ├── crude.py │ ├── ocr.py │ └── __init__.py │ ├── transforms │ ├── common.py │ ├── custom.py │ └── __init__.py │ ├── metadataset │ ├── __init__.py │ ├── loader.py │ ├── loader_interface.py │ ├── dataset_loader.py │ └── join_dataset_loader.py │ ├── task_encoder │ ├── __init__.py │ └── cooking.py │ ├── module_loader.py │ ├── retry_stream.py │ ├── errors.py │ ├── loader.py │ ├── __init__.py │ └── dataset_config.py ├── docker └── energon-ci.Dockerfile ├── .github └── workflows │ ├── black.yml │ ├── isort.yml │ ├── tests.yml │ ├── license_headers.yml │ ├── release.yml │ └── documentation.yml ├── .gitignore ├── pyproject.toml ├── tests └── test_epathlib.py ├── README.md └── scripts └── license_headers.py /.env: -------------------------------------------------------------------------------- 1 | PYTHONPATH=src 2 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | include = ./src/megatron/energon/** 3 | 4 | [xml] 5 | output = ./coverage.xml -------------------------------------------------------------------------------- /docs/source/images/joining.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/benjaminegger/Megatron-Energon/HEAD/docs/source/images/joining.png -------------------------------------------------------------------------------- /src/megatron/energon/cli/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | -------------------------------------------------------------------------------- /src/megatron/energon/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | -------------------------------------------------------------------------------- /docs/source/images/data_flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/benjaminegger/Megatron-Energon/HEAD/docs/source/images/data_flow.png -------------------------------------------------------------------------------- /docs/source/images/energon3_darkbg_border.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/benjaminegger/Megatron-Energon/HEAD/docs/source/images/energon3_darkbg_border.png -------------------------------------------------------------------------------- /src/megatron/energon/epathlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from megatron.energon.epathlib.epath import EPath 5 | 6 | __all__ = ["EPath"] 7 | -------------------------------------------------------------------------------- /src/megatron/energon/wrappers/skip.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | 5 | class SkipSample(Exception): 6 | """Exception to raise in the map_fn to skip a sample.""" 7 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/webdataset/empty_dataset_error.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | 5 | class EmptyDatasetError(Exception): 6 | """Raised when a dataset is empty.""" 7 | -------------------------------------------------------------------------------- /docs/source/api/modules.md: -------------------------------------------------------------------------------- 1 | 3 | 4 | # Packages and Modules 5 | 6 | ```{toctree} 7 | --- 8 | maxdepth: 2 9 | --- 10 | 11 | modules_data 12 | ``` 13 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/webdataset/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import re 5 | 6 | split_name_re = re.compile(r"^((?:.*/|)[^.]+)[.]([^/]*)$") 7 | skip_meta_re = re.compile(r"__[^/]*__($|/)") 8 | MAIN_FOLDER_NAME = ".nv-meta" 9 | -------------------------------------------------------------------------------- /docs/source/internals/code_structure.md: -------------------------------------------------------------------------------- 1 | 3 | 4 | # Code Structure 5 | 6 | This section is meant to provide an introduction to Megatron Energon for developers who want to cotribute to energon itself. 7 | 8 | For now, this is still a placeholder and we encourage you to get in touch with us for an introduction. 9 | -------------------------------------------------------------------------------- /src/megatron/energon/transforms/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from enum import Enum 5 | from typing import TypeVar 6 | 7 | import PIL.Image 8 | import torch 9 | 10 | TTransform = TypeVar("TTransform", bound=torch.nn.Module) 11 | 12 | 13 | class Sampling(Enum): 14 | BILINEAR = PIL.Image.BILINEAR 15 | NEAREST = PIL.Image.NEAREST 16 | -------------------------------------------------------------------------------- /docker/energon-ci.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:24.02-py3 2 | 3 | ENV TORCH_CUDA_ARCH_LIST "7.0;7.5;8.0;8.6;8.7;8.9;9.0" 4 | ENV MMCV_WITH_OPS 1 5 | ENV FORCE_CUDA 1 6 | 7 | RUN python3 -m pip install --upgrade pip 8 | 9 | # Install, then uninstall to get only the deps 10 | COPY . ./megatron-energon 11 | RUN pip install -e ./megatron-energon && pip uninstall -y megatron-energon && rm -rf ./megatron-energon 12 | -------------------------------------------------------------------------------- /docs/source/api/modules_data.md: -------------------------------------------------------------------------------- 1 | 3 | 4 | # megatron.energon 5 | 6 | ```{eval-rst} 7 | .. automodule:: megatron.energon 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | .. automodule:: megatron.energon.task_encoder.cooking 13 | :members: 14 | :undoc-members: 15 | :show-inheritance: 16 | ``` 17 | -------------------------------------------------------------------------------- /src/megatron/energon/transforms/custom.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import torch 5 | 6 | from megatron.energon.transforms.mappers import TransformMapper 7 | 8 | 9 | class CustomTransform(TransformMapper[torch.nn.Module]): 10 | """Abstract class for additional custom transforms. 11 | Inherit from this and override at least `apply_transform`. 12 | """ 13 | -------------------------------------------------------------------------------- /src/megatron/energon/wrappers/_log_exception.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import traceback 5 | 6 | 7 | def log_exception(_e: Exception, sample, /): 8 | traceback.print_exc() 9 | print("-" * 10) 10 | 11 | sample_str = str(sample) 12 | if len(sample_str) > 400: 13 | sample_str = sample_str[:200] + "..." + sample_str[-200:] 14 | 15 | print(sample_str) 16 | 17 | print("-" * 10) 18 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/webdataset/video_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import dataclasses 5 | from typing import Dict, Union 6 | 7 | import torch 8 | 9 | 10 | @dataclasses.dataclass 11 | class VideoData: 12 | #: The input video tensor in the shape (frames, channel, h, w) 13 | frames: torch.Tensor 14 | #: The input audio frames in the shape (number of channels, number of points) 15 | aframes: torch.Tensor 16 | #: Metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int) 17 | info: Dict[str, Union[bool, int, float, str]] 18 | -------------------------------------------------------------------------------- /src/megatron/energon/metadataset/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from megatron.energon.metadataset.dataset_loader import DatasetLoader 5 | from megatron.energon.metadataset.loader import load_dataset 6 | from megatron.energon.metadataset.loader_interface import DatasetLoaderInterface 7 | from megatron.energon.metadataset.metadataset import Metadataset 8 | from megatron.energon.metadataset.metadataset_v2 import MetadatasetV2 9 | 10 | __all__ = [ 11 | "DatasetLoader", 12 | "load_dataset", 13 | "DatasetLoaderInterface", 14 | "Metadataset", 15 | "MetadatasetV2", 16 | ] 17 | -------------------------------------------------------------------------------- /.github/workflows/black.yml: -------------------------------------------------------------------------------- 1 | name: black formatting 2 | 3 | on: 4 | push: 5 | branches: 6 | - develop 7 | workflow_dispatch: 8 | pull_request: 9 | branches: 10 | - develop 11 | 12 | jobs: 13 | black: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - name: Checkout code 18 | uses: actions/checkout@v4 19 | 20 | - name: Set up Python 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: '3.10' 24 | 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install black 29 | 30 | - name: Run Black 31 | run: black --check . 32 | -------------------------------------------------------------------------------- /.github/workflows/isort.yml: -------------------------------------------------------------------------------- 1 | name: isort formatting 2 | 3 | on: 4 | push: 5 | branches: 6 | - develop 7 | workflow_dispatch: 8 | pull_request: 9 | branches: 10 | - develop 11 | 12 | jobs: 13 | isort: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - name: Checkout code 18 | uses: actions/checkout@v4 19 | 20 | - name: Set up Python 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: '3.10' 24 | 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install isort 29 | 30 | - name: Run isort 31 | run: isort --check-only . -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python3 -m sphinx.cmd.build 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - develop 7 | workflow_dispatch: 8 | pull_request: 9 | branches: 10 | - develop 11 | 12 | jobs: 13 | unittest: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Set up Python 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: '3.8' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | python -m pip install -e .[transforms] 25 | - name: Run unit tests 26 | run: | 27 | python -m unittest discover -s tests 28 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | 3 | 4 | # Building the documentation 5 | 6 | To build the documentation, you need sphinx and additional packages: 7 | 8 | - sphinx-rtd-theme 9 | - sphinx 10 | - sphinxcontrib-napoleon 11 | - myst-parser 12 | 13 | You can install these like 14 | 15 | `pip install sphinx-rtd-theme sphinx sphinxcontrib-napoleon myst-parser sphinx-click` 16 | 17 | Use `make html` to build it. 18 | 19 | Or use PyCharm by adding a configuration: 20 | 21 | `Run -> Edit Configurations -> Add new Configuration -> Python docs -> Sphinx task` 22 | 23 | Use the `src/docs/source` folder as input folder and the `src/docs/build` as output. 24 | -------------------------------------------------------------------------------- /.github/workflows/license_headers.yml: -------------------------------------------------------------------------------- 1 | name: verify license headers 2 | 3 | on: 4 | push: 5 | branches: 6 | - develop 7 | workflow_dispatch: 8 | pull_request: 9 | branches: 10 | - develop 11 | 12 | jobs: 13 | license-check: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - name: Checkout Repository 18 | uses: actions/checkout@v3 19 | - name: Set up Python 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: 3.9 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install click==8.1.7 27 | - name: Run License Header Check 28 | run: python scripts/license_headers.py . 29 | -------------------------------------------------------------------------------- /docs/source/intro/introduction.md: -------------------------------------------------------------------------------- 1 | 3 | 4 | # General 5 | 6 | Megatron-Energon is a data loader that works best with your [Megatron](https://github.com/NVIDIA/Megatron-LM) project. 7 | However, you can use it in any of your PyTorch-based deep learning projects. 8 | 9 | What can it offer compared to other data loaders? 10 | 11 | The most important features are: 12 | 13 | * Comes with a standardized WebDataset-based format on disk 14 | * Optimized for high-speed multi-rank training 15 | * Can easily mix and blend multiple datasets 16 | * Its state is savable and restorable 17 | * Handles various kinds of multi-modal data even in one training 18 | 19 | Energon also comes with a command line tool that you can use to prepare your datasets. 20 | -------------------------------------------------------------------------------- /docs/source/basic/basics_flow.md: -------------------------------------------------------------------------------- 1 | 3 | 4 | # Data Flow 5 | 6 | ![energon data flow](../images/data_flow.png) 7 | 8 | 9 | (flavors_general)= 10 | ## Dataset Flavors 11 | 12 | The datasets are organized in "flavors", i.e. each modality returned by the dataset is a "flavor". 13 | A modality can for example be a {py:class}`CaptioningSample ` or an 14 | {py:class}`VQASample `. The dataset class combines the source data format 15 | and the iterated sample format. For example, the {py:class}`CaptioningWebdataset ` 16 | combines the webdataset loader with the {py:class}`CaptioningSample `. 17 | 18 | For all types, see [](sect-sample-types) 19 | -------------------------------------------------------------------------------- /docs/source/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /*.sig-param:nth-child(1 of .sig-param):nth-last-child(n + 3 of .sig-param)::before, 2 | .sig-param:nth-child(1 of .sig-param):nth-last-child(n + 3 of .sig-param) ~ .sig-param ::before { 3 | content: "\a\20\20\20\20\20\20\20\20\20\20\20\20\20\20\20\20"; 4 | white-space: pre; 5 | }*/ 6 | 7 | /* Newlines (\a) and spaces (\20) before each parameter */ 8 | .sig-param::before { 9 | content: "\a\20\20\20\20\20\20\20\20\20\20\20\20\20\20\20\20"; 10 | white-space: pre; 11 | } 12 | 13 | /* Newline after the last parameter (so the closing bracket is on a new line) */ 14 | dt em.sig-param:last-of-type::after { 15 | content: "\a"; 16 | white-space: pre; 17 | } 18 | 19 | /* To have blue background of width of the block (instead of width of content) */ 20 | dl.class > dt:first-of-type { 21 | display: block !important; 22 | } 23 | -------------------------------------------------------------------------------- /src/megatron/energon/task_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from megatron.energon.task_encoder.base import ( 5 | AugmentTaskEncoder, 6 | Batch, 7 | DefaultTaskEncoder, 8 | TaskEncoder, 9 | batch_list, 10 | batch_pad_stack, 11 | batch_stack, 12 | generic_batch, 13 | stateless, 14 | ) 15 | from megatron.energon.task_encoder.loader import ( 16 | get_train_dataset, 17 | get_val_dataset, 18 | get_val_datasets, 19 | ) 20 | 21 | __all__ = [ 22 | "AugmentTaskEncoder", 23 | "Batch", 24 | "DefaultTaskEncoder", 25 | "TaskEncoder", 26 | "batch_list", 27 | "batch_pad_stack", 28 | "batch_stack", 29 | "generic_batch", 30 | "get_train_dataset", 31 | "get_val_dataset", 32 | "get_val_datasets", 33 | "stateless", 34 | ] 35 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | deploy: 9 | name: Upload release to PyPI 10 | runs-on: ubuntu-latest 11 | environment: 12 | name: release 13 | url: https://pypi.org/project/megatron-energon 14 | permissions: 15 | id-token: write # This permission is mandatory for trusted publishing 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: '3.8' 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install build 26 | - name: Build package 27 | run: | 28 | python -m build -w 29 | python -m build -s 30 | - name: Publish package 31 | uses: pypa/gh-action-pypi-publish@release/v1 32 | -------------------------------------------------------------------------------- /src/megatron/energon/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from megatron.energon.transforms.custom import CustomTransform 5 | from megatron.energon.transforms.mappers import ( 6 | CenterCropMapper, 7 | RandomCropMapper, 8 | RandomHorizontalFlipMapper, 9 | RandomPerspectiveMapper, 10 | RandomResizedCropMapper, 11 | RandomRotationMapper, 12 | RandomVerticalFlipMapper, 13 | ResizeMapper, 14 | TransformMapper, 15 | ) 16 | from megatron.energon.transforms.merge import MergeTransform 17 | 18 | __all__ = [ 19 | "TransformMapper", 20 | "ResizeMapper", 21 | "RandomResizedCropMapper", 22 | "CenterCropMapper", 23 | "RandomCropMapper", 24 | "RandomHorizontalFlipMapper", 25 | "RandomVerticalFlipMapper", 26 | "RandomRotationMapper", 27 | "RandomPerspectiveMapper", 28 | "MergeTransform", 29 | "CustomTransform", 30 | ] 31 | -------------------------------------------------------------------------------- /src/megatron/energon/metadataset/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from pathlib import Path 5 | from typing import Union 6 | 7 | from megatron.energon.dataset_config import load_config 8 | from megatron.energon.epathlib import EPath 9 | from megatron.energon.metadataset.dataset_loader import DatasetLoader 10 | from megatron.energon.metadataset.loader_interface import DatasetLoaderInterface 11 | from megatron.energon.metadataset.metadataset import Metadataset 12 | 13 | 14 | def load_dataset( 15 | path: Union[str, EPath, Path], 16 | **kwargs, 17 | ) -> DatasetLoaderInterface: 18 | """Loads a (meta)dataset.""" 19 | 20 | if isinstance(path, dict): 21 | return load_config( 22 | path, 23 | default_type=Metadataset, 24 | strict=True, 25 | default_kwargs=dict(parent_path=EPath("/"), **kwargs), 26 | ) 27 | path = EPath(path).absolute() 28 | if path.is_file(): 29 | return load_config( 30 | path, 31 | default_type=Metadataset, 32 | strict=True, 33 | default_kwargs=dict(parent_path=path.parent, **kwargs), 34 | ) 35 | else: 36 | return DatasetLoader(path=path, **kwargs) 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | 7 | # Distribution / packaging 8 | .Python 9 | build/ 10 | develop-eggs/ 11 | dist/ 12 | downloads/ 13 | eggs/ 14 | .eggs/ 15 | lib/ 16 | lib64/ 17 | parts/ 18 | sdist/ 19 | var/ 20 | wheels/ 21 | pip-wheel-metadata/ 22 | share/python-wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # C extensions 29 | *.so 30 | 31 | # C objects 32 | *.o 33 | 34 | # IDE 35 | /.idea 36 | /.vscode 37 | .DS_Store 38 | .spyderproject 39 | .spyproject 40 | .ropeproject 41 | *.code-workspace 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Jupyter Notebook 58 | .ipynb_checkpoints 59 | 60 | # IPython 61 | profile_default/ 62 | ipython_config.py 63 | 64 | # pyenv 65 | .python-version 66 | 67 | # Environments 68 | #.env 69 | .venv 70 | env/ 71 | venv/ 72 | ENV/ 73 | env.bak/ 74 | venv.bak/ 75 | 76 | # Temporary Data 77 | /test_dataset 78 | 79 | # drawio temp files 80 | .$* 81 | 82 | # ninja build 83 | .ninja* 84 | 85 | # Other files 86 | *.log 87 | *.csv 88 | /*.png 89 | /public 90 | /.run 91 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/text.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import warnings 5 | from dataclasses import dataclass 6 | 7 | from megatron.energon.epathlib.epath import EPath 8 | from megatron.energon.errors import warn_deprecated 9 | from megatron.energon.flavors.base_dataset import Sample 10 | from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory 11 | 12 | 13 | @dataclass 14 | class TextSample(Sample): 15 | """Sample type for simple text.""" 16 | 17 | #: The text of the sample 18 | text: str 19 | 20 | 21 | class TextWebdataset(DefaultDecoderWebdatasetFactory[TextSample]): 22 | __sample_type__ = TextSample 23 | 24 | def __init__(self, path: EPath, **kwargs): 25 | warn_deprecated( 26 | f"{type(self)} is deprecated, use the default instead and set the sample_type:\n" 27 | f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n" 28 | f"# remove top-level __module__ and __class__\n" 29 | f"sample_type:\n" 30 | f" __module__: megatron.energon\n" 31 | f" __class__: {self.__sample_type__.__name__}\n" 32 | f"# Keep the remaining content" 33 | ) 34 | super().__init__(path, **kwargs) 35 | -------------------------------------------------------------------------------- /src/megatron/energon/cli/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | """Main entry point for energon CLI tools.""" 5 | 6 | import click 7 | 8 | from megatron.energon.tools.analyze_debug import command as analyze_debug_command 9 | from megatron.energon.tools.info import command as info_command 10 | from megatron.energon.tools.lint import command as lint_command 11 | from megatron.energon.tools.prepare import command as prepare_command 12 | from megatron.energon.tools.preview import command as preview_command 13 | 14 | 15 | @click.group(context_settings=dict(help_option_names=["-h", "--help"]), invoke_without_command=True) 16 | @click.pass_context 17 | def main(ctx): 18 | """A set of tools that energon provides. 19 | 20 | Among other things, you can use it to lint or preprocess your dataset. 21 | 22 | See help of commands to learn more.""" 23 | 24 | # This is needed to show help if no subcommand is provided 25 | if ctx.invoked_subcommand is None: 26 | click.echo(main.get_help(ctx)) 27 | 28 | 29 | main.add_command(analyze_debug_command) 30 | main.add_command(lint_command) 31 | main.add_command(info_command) 32 | main.add_command(prepare_command) 33 | main.add_command(preview_command) 34 | 35 | if __name__ == "__main__": 36 | main() 37 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/webdataset/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from megatron.energon.flavors.webdataset.base_webdataset import BaseWebdatasetFactory 5 | from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME 6 | from megatron.energon.flavors.webdataset.default_decoder_webdataset import ( 7 | DefaultDecoderWebdatasetFactory, 8 | ImageDecoder, 9 | ) 10 | from megatron.energon.flavors.webdataset.default_generic_webdataset import ( 11 | DefaultGenericWebdatasetFactory, 12 | ) 13 | from megatron.energon.flavors.webdataset.empty_dataset_error import EmptyDatasetError 14 | from megatron.energon.flavors.webdataset.joined_webdataset import JoinedWebdatasetFactory 15 | from megatron.energon.flavors.webdataset.metadata import WebdatasetMeta 16 | from megatron.energon.flavors.webdataset.standard_webdataset import StandardWebdatasetFactory 17 | from megatron.energon.flavors.webdataset.video_data import VideoData 18 | 19 | __all__ = [ 20 | "BaseWebdatasetFactory", 21 | "DefaultDecoderWebdatasetFactory", 22 | "DefaultGenericWebdatasetFactory", 23 | "EmptyDatasetError", 24 | "ImageDecoder", 25 | "JoinedWebdatasetFactory", 26 | "StandardWebdatasetFactory", 27 | "VideoData", 28 | "WebdatasetMeta", 29 | "MAIN_FOLDER_NAME", 30 | ] 31 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/image.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import warnings 5 | from dataclasses import dataclass 6 | 7 | import torch 8 | 9 | from megatron.energon.epathlib.epath import EPath 10 | from megatron.energon.errors import warn_deprecated 11 | from megatron.energon.flavors.base_dataset import Sample 12 | from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory 13 | 14 | 15 | @dataclass 16 | class ImageSample(Sample): 17 | """Sample type for an image, e.g. for image reconstruction.""" 18 | 19 | #: The input image tensor in the shape (C, H, W) 20 | image: torch.Tensor 21 | 22 | 23 | class ImageWebdataset(DefaultDecoderWebdatasetFactory[ImageSample]): 24 | __sample_type__ = ImageSample 25 | 26 | def __init__(self, path: EPath, **kwargs): 27 | warn_deprecated( 28 | f"{type(self)} is deprecated, use the default instead and set the sample_type:\n" 29 | f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n" 30 | f"# remove top-level __module__ and __class__\n" 31 | f"sample_type:\n" 32 | f" __module__: megatron.energon\n" 33 | f" __class__: {self.__sample_type__.__name__}\n" 34 | f"# Keep the remaining content" 35 | ) 36 | super().__init__(path, **kwargs) 37 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/captioning.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import warnings 5 | from dataclasses import dataclass 6 | 7 | import torch 8 | 9 | from megatron.energon.epathlib.epath import EPath 10 | from megatron.energon.errors import warn_deprecated 11 | from megatron.energon.flavors.base_dataset import Sample 12 | from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory 13 | 14 | 15 | @dataclass 16 | class CaptioningSample(Sample): 17 | """Sample type for image captioning.""" 18 | 19 | #: The input image tensor in the shape (C, H, W) 20 | image: torch.Tensor 21 | 22 | #: The caption string 23 | caption: str 24 | 25 | 26 | class CaptioningWebdataset(DefaultDecoderWebdatasetFactory[CaptioningSample]): 27 | __sample_type__ = CaptioningSample 28 | 29 | def __init__(self, path: EPath, **kwargs): 30 | warn_deprecated( 31 | f"{type(self)} is deprecated, use the default instead and set the sample_type:\n" 32 | f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n" 33 | f"# remove top-level __module__ and __class__\n" 34 | f"sample_type:\n" 35 | f" __module__: megatron.energon\n" 36 | f" __class__: {self.__sample_type__.__name__}\n" 37 | f"# Keep the remaining content" 38 | ) 39 | super().__init__(path, **kwargs) 40 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/interleaved.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import warnings 5 | from dataclasses import dataclass 6 | from typing import List, Union 7 | 8 | import torch 9 | 10 | from megatron.energon.epathlib.epath import EPath 11 | from megatron.energon.errors import warn_deprecated 12 | from megatron.energon.flavors.base_dataset import Sample 13 | from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory 14 | 15 | 16 | @dataclass 17 | class InterleavedSample(Sample): 18 | """Sample type for interleaved media such as text with images.""" 19 | 20 | #: The interleaved media (either torch.tensor for an image, or str for text) 21 | sequence: List[Union[torch.Tensor, str]] 22 | 23 | 24 | class InterleavedWebdataset(DefaultDecoderWebdatasetFactory[InterleavedSample]): 25 | __sample_type__ = InterleavedSample 26 | 27 | def __init__(self, path: EPath, **kwargs): 28 | warn_deprecated( 29 | f"{type(self)} is deprecated, use the default instead and set the sample_type:\n" 30 | f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n" 31 | f"# remove top-level __module__ and __class__\n" 32 | f"sample_type:\n" 33 | f" __module__: megatron.energon\n" 34 | f" __class__: {self.__sample_type__.__name__}\n" 35 | f"# Keep the remaining content" 36 | ) 37 | super().__init__(path, **kwargs) 38 | -------------------------------------------------------------------------------- /.github/workflows/documentation.yml: -------------------------------------------------------------------------------- 1 | name: documentation 2 | 3 | # Runs on any pushes 4 | on: [push, workflow_dispatch] 5 | 6 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages 7 | permissions: 8 | contents: read 9 | pages: write 10 | id-token: write 11 | 12 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. 13 | # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. 14 | concurrency: 15 | group: "pages" 16 | cancel-in-progress: false 17 | 18 | jobs: 19 | # Build job 20 | build: 21 | runs-on: ubuntu-latest 22 | steps: 23 | - name: Checkout 24 | uses: actions/checkout@v4 25 | - uses: actions/setup-python@v5 26 | with: 27 | python-version: '3.10' 28 | - name: Setup Pages 29 | uses: actions/configure-pages@v5 30 | - name: Install dependencies 31 | run: | 32 | pip install -U sphinx-rtd-theme sphinx sphinxcontrib-napoleon myst-parser sphinx-click 33 | - name: Sphinx build 34 | run: | 35 | sphinx-build -b html docs/source _site 36 | - name: Upload artifact 37 | uses: actions/upload-pages-artifact@v3 38 | 39 | # Deployment job 40 | deploy: 41 | environment: 42 | name: github-pages 43 | url: ${{ steps.deployment.outputs.page_url }} 44 | runs-on: ubuntu-latest 45 | needs: build 46 | if: github.ref_name == 'main' 47 | steps: 48 | - name: Deploy to GitHub Pages 49 | id: deployment 50 | uses: actions/deploy-pages@v4 51 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/image_classification.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import warnings 5 | from dataclasses import dataclass 6 | from typing import Optional 7 | 8 | import torch 9 | 10 | from megatron.energon.epathlib.epath import EPath 11 | from megatron.energon.errors import warn_deprecated 12 | from megatron.energon.flavors.base_dataset import Sample 13 | from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory 14 | 15 | 16 | @dataclass 17 | class ImageClassificationSample(Sample): 18 | """Sample type for classifying an image.""" 19 | 20 | #: The input image tensor in the shape (C, H, W) 21 | image: torch.Tensor 22 | #: The class label of the image 23 | label: Optional[int] = None 24 | #: The class label of the image 25 | label_name: Optional[str] = None 26 | 27 | 28 | class ImageClassificationWebdataset(DefaultDecoderWebdatasetFactory[ImageClassificationSample]): 29 | __sample_type__ = ImageClassificationSample 30 | 31 | def __init__(self, path: EPath, **kwargs): 32 | warn_deprecated( 33 | f"{type(self)} is deprecated, use the default instead and set the sample_type:\n" 34 | f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n" 35 | f"# remove top-level __module__ and __class__\n" 36 | f"sample_type:\n" 37 | f" __module__: megatron.energon\n" 38 | f" __class__: {self.__sample_type__.__name__}\n" 39 | f"# Keep the remaining content" 40 | ) 41 | super().__init__(path, **kwargs) 42 | -------------------------------------------------------------------------------- /docs/source/internals/contrib_guidelines.md: -------------------------------------------------------------------------------- 1 | 3 | 4 | # Contribution Guidelines 5 | 6 | If you want to contribute to this repository please adhere to the following guidelines 7 | 8 | - Always use [black](https://pypi.org/project/black/) and [isort](https://pycqa.github.io/isort/) to format your code before committing 9 | - Check that all license headers are present using `python3 scripts/license_headers.py --fix .` 10 | - Python `@dataclass` and `NamedTuple` are preferred over dictionaries, which don't allow for IDE 11 | auto-completion and type checking 12 | - User-exposed classes and methods should be documented in Google-style docstrings that are parsed by sphinx 13 | and end up in this documentation 14 | - Breaking changes should be marked in the message of pull requests: 15 | - `CHECKPOINT BREAKING CHANGE`: When the save/restore structure changed incompatibly (check test `test_metadataset:TestDataset.test_save_restore_state_train`) 16 | - `ITERATION ORDER BREAKING CHANGE`: When the order of iterating samples changed, i.e. experiments would not be exactly reproducible (check tests `test_dataset:TestDataset.test_current_batch_index_generator`, `test_dataset:TestDataset.test_current_batch_index`, maybe more) 17 | - `API BREAKING CHANGE`: When the external programming api changed incompatibly 18 | - `DATASET CONFIG BREAKING CHANGE`: When the dataset config (`.nv-meta` folder) changed incompatibly 19 | - `METADATASET CONFIG BREAKING CHANGE`: When the metadataset config changed 20 | - In a release, all breaking changes except checkpoint lead to a new major version. 21 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/multichoice_vqa.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import warnings 5 | from dataclasses import dataclass 6 | from typing import List, Optional 7 | 8 | import torch 9 | 10 | from megatron.energon.epathlib.epath import EPath 11 | from megatron.energon.errors import warn_deprecated 12 | from megatron.energon.flavors.base_dataset import Sample 13 | from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory 14 | 15 | 16 | @dataclass 17 | class MultiChoiceVQASample(Sample): 18 | """Sample type for visual question answering.""" 19 | 20 | #: The input image tensor in the shape (C, H, W) 21 | image: torch.Tensor 22 | #: The context/question for the image 23 | context: str 24 | 25 | #: The candidate answers. 26 | choices: Optional[List[str]] = None 27 | #: The index of the correct answer. 28 | correct_choice_idx: int = 0 29 | 30 | 31 | class MultiChoiceVQAWebdataset(DefaultDecoderWebdatasetFactory[MultiChoiceVQASample]): 32 | __sample_type__ = MultiChoiceVQASample 33 | 34 | def __init__(self, path: EPath, **kwargs): 35 | warn_deprecated( 36 | f"{type(self)} is deprecated, use the default instead and set the sample_type:\n" 37 | f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n" 38 | f"# remove top-level __module__ and __class__\n" 39 | f"sample_type:\n" 40 | f" __module__: megatron.energon\n" 41 | f" __class__: {self.__sample_type__.__name__}\n" 42 | f"# Keep the remaining content" 43 | ) 44 | super().__init__(path, **kwargs) 45 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/vqa.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import warnings 5 | from dataclasses import dataclass 6 | from typing import List, Optional 7 | 8 | import torch 9 | 10 | from megatron.energon.epathlib.epath import EPath 11 | from megatron.energon.errors import warn_deprecated 12 | from megatron.energon.flavors.base_dataset import Sample 13 | from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory 14 | 15 | 16 | @dataclass 17 | class VQASample(Sample): 18 | """Sample type for visual question answering.""" 19 | 20 | #: The input image tensor in the shape (C, H, W) 21 | image: torch.Tensor 22 | #: The context/question for the image 23 | context: str 24 | 25 | #: The possible answers. Not set for testing. 26 | answers: Optional[List[str]] = None 27 | #: The weights of the possible answers. Optionally available. 28 | answer_weights: Optional[torch.Tensor] = None 29 | 30 | 31 | class VQAWebdataset(DefaultDecoderWebdatasetFactory[VQASample]): 32 | __sample_type__ = VQASample 33 | 34 | def __init__(self, path: EPath, **kwargs): 35 | warn_deprecated( 36 | f"{type(self)} is deprecated, use the default instead and set the sample_type:\n" 37 | f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n" 38 | f"# remove top-level __module__ and __class__\n" 39 | f"sample_type:\n" 40 | f" __module__: megatron.energon\n" 41 | f" __class__: {self.__sample_type__.__name__}\n" 42 | f"# Keep the remaining content" 43 | ) 44 | super().__init__(path, **kwargs) 45 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/vid_qa.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import warnings 5 | from dataclasses import dataclass 6 | from typing import List, Optional 7 | 8 | import torch 9 | 10 | from megatron.energon.epathlib.epath import EPath 11 | from megatron.energon.errors import warn_deprecated 12 | from megatron.energon.flavors.base_dataset import Sample 13 | from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory, VideoData 14 | 15 | 16 | @dataclass 17 | class VidQASample(Sample): 18 | """Sample type for video question answering.""" 19 | 20 | #: The video data containing the image and audio info. 21 | video: VideoData 22 | #: The context/question for the image. 23 | context: str 24 | #: The possible answers. Not set for testing. 25 | answers: Optional[List[str]] = None 26 | #: The weights of the possible answers. Optionally available. 27 | answer_weights: Optional[torch.Tensor] = None 28 | 29 | 30 | class VidQAWebdataset(DefaultDecoderWebdatasetFactory[VidQASample]): 31 | __sample_type__ = VidQASample 32 | 33 | def __init__(self, path: EPath, **kwargs): 34 | warn_deprecated( 35 | f"{type(self)} is deprecated, use the default instead and set the sample_type:\n" 36 | f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n" 37 | f"# remove top-level __module__ and __class__\n" 38 | f"sample_type:\n" 39 | f" __module__: megatron.energon\n" 40 | f" __class__: {self.__sample_type__.__name__}\n" 41 | f"# Keep the remaining content" 42 | ) 43 | super().__init__(path, **kwargs) 44 | -------------------------------------------------------------------------------- /docs/source/advanced/repro_scaling.md: -------------------------------------------------------------------------------- 1 | 3 | 4 | # Reproducible Scaling 5 | 6 | A special use case is to re-run or continue a training run with the exact same data order, but using a different number of nodes or ranks. 7 | 8 | Since version 2.0.0, Megatron Energon supports this behavior if a few constraints are met: 9 | 10 | * The energon major version must be the same across runs 11 | * The global batch size must stay the same across runs 12 | * The global batch size must be a multiple of `micro-batch size * world_size * num_workers` 13 | * The multiple of that is the number of gradient accumulation steps in your training 14 | * The product `world_size * num_workers` must stay the same across runs, such that the global number of workers stays the same 15 | * When using random seed offsets in your {py:class}`WorkerConfig `, those need to be the same 16 | 17 | By obeying these rules, you will be able to reproduce the same global batches. Let's look at an example. 18 | 19 | | Name | Global batch size | Micro batch size | World size | Number of Workers | Gradient accumulation steps | 20 | | ----- | ----------------- | ---------------- | ---------- | ----------------- | --------------------------- | 21 | | Run 1 | 8 | 2 | 4 | 1 | 1 | 22 | | Run 2 | 8 | 2 | 1 | 4 | 4 | 23 | 24 | Iterating the dataset will yield the same global batches for both of these runs, if the seed is set correctly. 25 | 26 | In practice, you will need to adapt your worker config accordingly. 27 | 28 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/webdataset/error_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from typing import Callable, Generic, Optional, Tuple, TypeVar, Union 5 | 6 | from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError 7 | from megatron.energon.flavors.base_dataset import Sample 8 | from megatron.energon.flavors.webdataset.structs import FilteredSample, reraise_exception 9 | 10 | T_sample = TypeVar("T_sample", covariant=True) 11 | 12 | 13 | class ErrorHandler(Generic[T_sample]): 14 | handler: Callable[[Exception, Optional[str]], None] = reraise_exception 15 | 16 | def sample_error_handler(self, e: Exception, sample_key: Optional[str]): 17 | if isinstance(e, SYSTEM_EXCEPTIONS): 18 | raise FatalSampleError(f"Error in sample {sample_key!r}: {e}") from e 19 | 20 | self.handler(e, sample_key) 21 | 22 | def error_handler( 23 | self, 24 | e: Exception, 25 | sample: Union[ 26 | T_sample, 27 | dict, 28 | FilteredSample, 29 | None, 30 | Tuple[Union[T_sample, dict, FilteredSample, None], ...], 31 | ], 32 | ): 33 | if isinstance(sample, dict): 34 | key = sample.get("__key__") 35 | elif isinstance(sample, list): 36 | if isinstance(sample[0], dict): 37 | key = ",".join("None" if s is None else s.get("__key__") for s in sample) 38 | elif isinstance(sample[0], Sample): 39 | key = ",".join("None" if s is None else s.__key__ for s in sample) 40 | else: 41 | key = None 42 | elif isinstance(sample, Sample): 43 | key = sample.__key__ 44 | else: 45 | key = None 46 | self.sample_error_handler(e, key) 47 | -------------------------------------------------------------------------------- /docs/source/api/cli.md: -------------------------------------------------------------------------------- 1 | 3 | 4 | # Command-Line Interface 5 | 6 | After you [installed](../intro/installation) energon, a script called `energon` will be added to your PATH. 7 | It provides commands to prepare, preview, or lint datasets on disk. 8 | 9 | Here's a simple example: 10 | 11 | ```shell 12 | energon prepare /mnt/data/my_captioning_webdataset 13 | ``` 14 | 15 | The above command will scan your existing off-the-shelf [web dataset](https://webdataset.github.io/webdataset/) 16 | and add the [needed metadata](data-on-disk) to make it compatible with Energon. 17 | 18 | Below, you can see the available sub-commands under `energon`. 19 | 20 | 21 | ```{eval-rst} 22 | .. click:: megatron.energon.cli.main:main 23 | :prog: energon 24 | :nested: short 25 | ``` 26 | 27 | (energon_data_prepare)= 28 | ## energon prepare 29 | 30 | An interactive tool to generate metadata for your existing webdataset. 31 | This will help make the dataset compliant with our [format](data-on-disk). 32 | 33 | The tool will ask you for a train/val/test split and how to assign the webdataset fields to the 34 | fields of the corresponding sample type in Energon. 35 | 36 | See [Data Preparation](../basic/data_prep) for more details on how to use this command. 37 | 38 | 39 | ## energon info 40 | 41 | Prints information about the dataset such as overall number of samples and size. 42 | 43 | 44 | ## energon lint 45 | 46 | You can execute this tool on the prepared dataset to check if the data is valid and loadable. 47 | It will report any problems such as non-readable images. 48 | 49 | 50 | ## energon preview 51 | 52 | This command will load a dataset and display samples one-by-one on the console. 53 | Note that this will not work for datasets with non-standard flavors or crude datasets. 54 | -------------------------------------------------------------------------------- /docs/source/index.md: -------------------------------------------------------------------------------- 1 | 3 | 4 | # Megatron-Energon Documentation 5 | 6 | This is the documentation of Megatron's multi-modal data loader "Energon". 7 | 8 | We recommend getting started in the [Introduction](intro/introduction) section, which explains what Energon is and how to install it. 9 | 10 | Once installed, check out the **Basic Usage** section starting with [Quickstart](basic/quickstart) for some basic examples and tutorials. 11 | Some underlying concepts, will be explained in the rest of that section. 12 | 13 | For specific use cases and advanced usage, please read **Advanced Usage**. 14 | 15 | In the end you will also find some documentation on how to interface with energon programmatically and how to contribute to the code base. 16 | 17 | ```{toctree} 18 | --- 19 | caption: Introduction 20 | maxdepth: 2 21 | --- 22 | 23 | intro/introduction 24 | intro/installation 25 | ``` 26 | 27 | 28 | ```{toctree} 29 | --- 30 | caption: Basic Usage 31 | maxdepth: 2 32 | --- 33 | basic/quickstart 34 | basic/data_prep 35 | basic/basics_flow 36 | basic/task_encoder 37 | basic/metadataset 38 | basic/save_restore 39 | ``` 40 | 41 | 42 | ```{toctree} 43 | --- 44 | caption: Advanced Usage 45 | maxdepth: 2 46 | --- 47 | advanced/remote_dataset 48 | advanced/advanced_dataformat 49 | advanced/repro_scaling 50 | advanced/packing 51 | advanced/grouping 52 | advanced/joining_datasets 53 | advanced/epochized_blending 54 | advanced/custom_blending 55 | ``` 56 | 57 | 58 | ```{toctree} 59 | --- 60 | caption: API 61 | maxdepth: 2 62 | --- 63 | api/modules 64 | api/cli 65 | ``` 66 | 67 | 68 | ```{toctree} 69 | --- 70 | caption: Internals 71 | maxdepth: 2 72 | --- 73 | internals/contrib_guidelines 74 | internals/code_structure 75 | ``` 76 | 77 | # Indices and tables 78 | 79 | - [](genindex) 80 | - [](modindex) 81 | -------------------------------------------------------------------------------- /src/megatron/energon/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from megatron.energon.wrappers.batch_dataset import BatchDataset 5 | from megatron.energon.wrappers.blend_dataset import BlendDataset 6 | from megatron.energon.wrappers.concat_dataset import ConcatDataset 7 | from megatron.energon.wrappers.epochize_dataset import EpochizeDataset 8 | from megatron.energon.wrappers.filter_dataset import FilterDataset 9 | from megatron.energon.wrappers.gc_dataset import GcDataset 10 | from megatron.energon.wrappers.group_batch_dataset import GroupBatchDataset 11 | from megatron.energon.wrappers.iter_map_dataset import IterMapDataset 12 | from megatron.energon.wrappers.limit_dataset import LimitDataset 13 | from megatron.energon.wrappers.log_sample_dataset import LogSampleDataset 14 | from megatron.energon.wrappers.map_dataset import MapDataset 15 | from megatron.energon.wrappers.mix_batch_dataset import ( 16 | MixBatchDataset, 17 | concat_pad, 18 | generic_concat, 19 | homogeneous_concat_mix, 20 | ) 21 | from megatron.energon.wrappers.packing_dataset import PackingDataset 22 | from megatron.energon.wrappers.repeat_dataset import RepeatDataset 23 | from megatron.energon.wrappers.shuffle_buffer_dataset import ShuffleBufferDataset 24 | from megatron.energon.wrappers.skip import SkipSample 25 | 26 | __all__ = [ 27 | "BatchDataset", 28 | "BlendDataset", 29 | "ConcatDataset", 30 | "EpochizeDataset", 31 | "FilterDataset", 32 | "GcDataset", 33 | "GroupBatchDataset", 34 | "IterMapDataset", 35 | "LimitDataset", 36 | "LogSampleDataset", 37 | "MapDataset", 38 | "MixBatchDataset", 39 | "RepeatDataset", 40 | "ShuffleBufferDataset", 41 | "SkipSample", 42 | "PackingDataset", 43 | "concat_pad", 44 | "generic_concat", 45 | "homogeneous_concat_mix", 46 | ] 47 | -------------------------------------------------------------------------------- /docs/source/intro/installation.md: -------------------------------------------------------------------------------- 1 | 3 | 4 | # Installation 5 | 6 | If you simply want to use this package without modifying it, the best option is to install it 7 | as a dependency of your project like you would with any other pip package. 8 | 9 | ## Normal Installation 10 | 11 | To install the most recent version, run 12 | 13 | ```shell 14 | pip install megatron-energon 15 | ``` 16 | 17 | in your project's Python environment, which could be a virtualenv, or a conda environment. 18 | You can even install it inside a `Dockerfile` to include it in your custom docker container. 19 | 20 | ## Installation from Source Code 21 | 22 | If you want to manage, debug or modify the code of the energon package, we recommend that you clone this repository 23 | on your disk or even as a submodule of your project. 24 | You can then install the package in "development" mode in-place. This way, the package will not be hidden inside 25 | your pip package management, but will stay in the location where you cloned it and you can even modify it in-place. 26 | 27 | To check out locally and install in development mode: 28 | ```shell 29 | git clone https://github.com/NVIDIA/Megatron-Energon.git megatron-energon 30 | pip install -e ./megatron-energon 31 | ``` 32 | 33 | Or to add as a submodule to your project and install in development mode: 34 | ```shell 35 | git submodule add https://github.com/NVIDIA/Megatron-Energon.git megatron-energon 36 | pip install -e ./megatron-energon 37 | ``` 38 | 39 | ```{warning} 40 | **We discourage importing the cloned repo without pip install** 41 | - You will not be able to use the command line tool 42 | - You would have to use hacks to get the package into your `PYTHONPATH` 43 | - You would need to take care of the dependencies yourself. 44 | 45 | Instead, simply install in development mode. 46 | ``` 47 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/similarity_interleaved.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import warnings 5 | from dataclasses import dataclass 6 | from typing import List, Optional 7 | 8 | import torch 9 | 10 | from megatron.energon.epathlib.epath import EPath 11 | from megatron.energon.errors import warn_deprecated 12 | from megatron.energon.flavors.base_dataset import Sample 13 | from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory 14 | 15 | 16 | @dataclass 17 | class SimilarityInterleavedSample(Sample): 18 | """Sample type for interleaved media such as text with images, but without image-text alignment. 19 | That alignment has to be assigned from the similarity matrix.""" 20 | 21 | #: The images of the sequence 22 | images: List[torch.Tensor] 23 | #: The texts of the sequence 24 | texts: List[str] 25 | 26 | #: Similarity matrix between image and text entries in the sequence 27 | similarity_matrix: Optional[torch.Tensor] = None 28 | 29 | #: The index within texts representing the sentence that this image is matched to 30 | matched_text_indices: Optional[List[int]] = None 31 | 32 | 33 | class SimilarityInterleavedWebdataset(DefaultDecoderWebdatasetFactory[SimilarityInterleavedSample]): 34 | __sample_type__ = SimilarityInterleavedSample 35 | 36 | def __init__(self, path: EPath, **kwargs): 37 | warn_deprecated( 38 | f"{type(self)} is deprecated, use the default instead and set the sample_type:\n" 39 | f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n" 40 | f"# remove top-level __module__ and __class__\n" 41 | f"sample_type:\n" 42 | f" __module__: megatron.energon\n" 43 | f" __class__: {self.__sample_type__.__name__}\n" 44 | f"# Keep the remaining content" 45 | ) 46 | super().__init__(path, **kwargs) 47 | -------------------------------------------------------------------------------- /docs/source/advanced/epochized_blending.md: -------------------------------------------------------------------------------- 1 | 3 | 4 | # Epochized Blending 5 | 6 | As an alternative to blending with a weight for each dataset, blending can be made accurate and 7 | iterating the dataset can follow epochs (i.e. interrupt iteration after an epoch) with this concept. 8 | 9 | Here is an example `metadataset.yaml` config file that changes to epochized blending: 10 | 11 | ```yaml 12 | __module__: megatron.energon 13 | __class__: MetadatasetV2 14 | splits: 15 | train: 16 | # Blend the following datasets, repeating coco 5 times, coyo-train 2 times and coyo-val 1 times 17 | blend_epochized: 18 | - repetitions: 5 19 | path: ./coco 20 | # ... Other parameters 21 | - repetitions: 2 22 | path: ./coyo 23 | - repetitions: 1 24 | path: ./coyo 25 | split_part: val 26 | ``` 27 | 28 | Now, the call to `get_train_dataset` requires the additional parameter `repeat=False` to interrupt iterating after one epoch: 29 | 30 | ```py 31 | from megatron.energon import get_train_dataset, get_loader, WorkerConfig 32 | 33 | 34 | loader = get_loader(get_train_dataset( 35 | 'metadataset.yaml', 36 | batch_size=2, 37 | shuffle_buffer_size=100, 38 | max_samples_per_sequence=100, 39 | worker_config=WorkerConfig.default_worker_config(), 40 | repeat=False, 41 | )) 42 | 43 | # This will now stop iterating after the datasets have been iterated (coco 5 times, coyo-train 2 44 | # times and coyo-val 1 times). Of course, the data is still being shuffled between all those 45 | # datasets. 46 | for batch in loader: 47 | print(batch) 48 | 49 | # This will iterate the second epoch 50 | for batch in loader: 51 | print(batch) 52 | 53 | ``` 54 | 55 | If used as dataset for `get_val_dataset`, the `repetitions` are ignored. 56 | The metadataset would also work without setting `repeat=False`, but then the shuffle buffer will shuffle samples across bounderies of epochs. 57 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/vqa_and_ocr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import warnings 5 | from dataclasses import dataclass 6 | from typing import List, Optional 7 | 8 | import torch 9 | 10 | from megatron.energon.epathlib.epath import EPath 11 | from megatron.energon.errors import warn_deprecated 12 | from megatron.energon.flavors.base_dataset import Sample 13 | from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory 14 | 15 | 16 | @dataclass 17 | class VQAOCRSample(Sample): 18 | """Sample type for visual question answering.""" 19 | 20 | #: The input image tensor in the shape (C, H, W) 21 | image: torch.Tensor 22 | 23 | #: The context/question for the image (VQA) 24 | context: str 25 | #: The text contained in the image (OCR) 26 | text: str 27 | 28 | #: The possible answers. Not set for testing. (VQA) 29 | answers: Optional[List[str]] = None 30 | #: The weights of the possible answers. Optionally available. (VQA) 31 | answer_weights: Optional[torch.Tensor] = None 32 | #: The bounding boxes of the words in the image (N, 4|5) (OCR) 33 | words_boxes: Optional[torch.Tensor] = None 34 | #: The text contained in each word (N,) (OCR) 35 | words_text: Optional[List[str]] = None 36 | 37 | 38 | class VQAOCRWebdataset(DefaultDecoderWebdatasetFactory[VQAOCRSample]): 39 | __sample_type__ = VQAOCRSample 40 | 41 | def __init__(self, path: EPath, **kwargs): 42 | warn_deprecated( 43 | f"{type(self)} is deprecated, use the default instead and set the sample_type:\n" 44 | f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n" 45 | f"# remove top-level __module__ and __class__\n" 46 | f"sample_type:\n" 47 | f" __module__: megatron.energon\n" 48 | f" __class__: {self.__sample_type__.__name__}\n" 49 | f"# Keep the remaining content" 50 | ) 51 | super().__init__(path, **kwargs) 52 | -------------------------------------------------------------------------------- /docs/source/advanced/custom_blending.md: -------------------------------------------------------------------------------- 1 | 3 | 4 | # Customized Blending 5 | 6 | In your Task Encoder you could customize the blend of datasets by overriding the `build_train_datasets` method as shown below. 7 | 8 | 9 | ```{warning} 10 | This interface is not stable and may be subject of changes quite often for new features we add. So if you change 11 | how the datasets are plugged together, consider that this may have to be adapted to future changes. 12 | ``` 13 | 14 | ```py 15 | 16 | class CaptioningTaskEncoder( 17 | DefaultTaskEncoder[CaptioningSample, CaptioningSample, CaptioningRawBatch, CaptioningBatch] 18 | ): 19 | ... 20 | 21 | def build_train_datasets( 22 | self, 23 | *, 24 | datasets: List[Tuple[BaseCoreDatasetFactory[T_sample], float]], 25 | worker_config: WorkerConfig, 26 | batch_size: Optional[int], 27 | batch_drop_last: bool = False, 28 | packing_buffer_size: Optional[int] = None, 29 | virtual_epoch_length: int = 0, 30 | shuffle_buffer_size: Optional[int] = None, 31 | ) -> SavableDataset[T_batch]: 32 | # The default implementation uses MixDataset, which mixes the datasets according to their weights 33 | # This could be customized, e.g. to batch the datasets first (i.e. each batch only contains data from a single datset) 34 | # and then blend, which would yield the same distribution. 35 | dataset = BlendDataset( 36 | *datasets, 37 | worker_config=worker_config, 38 | ) 39 | # Build batches from blended samples 40 | dataset = self.build_batch( 41 | dataset, 42 | batch_size=batch_size, 43 | batch_drop_last=batch_drop_last, 44 | worker_config=worker_config, 45 | ) 46 | # Optionally epochize 47 | if virtual_epoch_length > 0: 48 | dataset = EpochizeDataset( 49 | dataset, 50 | length=virtual_epoch_length, 51 | worker_config=worker_config, 52 | ) 53 | return dataset 54 | 55 | ``` -------------------------------------------------------------------------------- /docs/source/advanced/remote_dataset.md: -------------------------------------------------------------------------------- 1 | 3 | 4 | # Remote Dataset 5 | 6 | Since version 2.0.0, Megatron Energon supports the use of remote datasets that are stored in an object store bucket with an S3-compatible interface. 7 | This means you can train or validate with your data right from that object store by simply swapping the dataset path for a so-called _rclone URL_. 8 | 9 | ```{admonition} Note 10 | :class: important 11 | We are planning to migrate to a more advanced backend for remote datasets. Please contact us if you are planning to use this feature. 12 | ``` 13 | 14 | ## Prerequisites 15 | 16 | To use this feature, you need to set up an [Rclone](https://rclone.org/) configuration. Rclone is an open source software to manage files on cloud storage. While Energon does not depend on the Rclone software itself, we rely on the same configuration mechanism. 17 | 18 | So if you don't like to install or use Rclone, that's fine, but you will need to set up a config file that is compatible. We still recommend using Rclone, since it's a great tool. 19 | 20 | Once you set up your config at `~/.config/rclone/rclone.conf`, it may look like this: 21 | 22 | ``` 23 | [coolstore] 24 | type = s3 25 | provider = Other 26 | access_key_id = MY_ACCESS_KEY_ID 27 | secret_access_key = MY_SECRET_ACCESS_KEY 28 | region = us-east-1 29 | endpoint = pdx.s8k.io 30 | ``` 31 | 32 | ## The URL syntax 33 | 34 | The syntax is a simple as 35 | 36 | ``` 37 | rclone://RCLONE_NAME/BUCKET/PATH 38 | ``` 39 | 40 | For example: 41 | 42 | ``` 43 | rclone://coolstore/mainbucket/datasets/somedata 44 | ``` 45 | 46 | You can use this URL instead of paths to datasets in 47 | 48 | * Functions like `get_train_dataset`, `get_val_dataset` 49 | * Inside [metadataset](../basic/metadataset) specifications 50 | * As arguments to `energon prepare` or `energon lint`. Note that those may be slow for remote locations. 51 | 52 | Example usage: 53 | 54 | ```python 55 | ds = get_train_dataset( 56 | 'rclone://coolstore/mainbucket/datasets/somedata', 57 | batch_size=1, 58 | shuffle_buffer_size=100, 59 | max_samples_per_sequence=100, 60 | ) 61 | ``` 62 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling", "hatch-vcs"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "megatron-energon" 7 | dynamic = ["version"] 8 | authors = [ 9 | { name="Lukas Vögtle", email="lvoegtle@nvidia.com" }, 10 | { name="Philipp Fischer", email="pfischer@nvidia.com" }, 11 | ] 12 | description = "Megatron's multi-modal data loader" 13 | readme = "README.md" 14 | license = "BSD-3-Clause" 15 | requires-python = ">=3.8" 16 | classifiers = [ 17 | "Programming Language :: Python :: 3", 18 | "Programming Language :: Python :: 3.9", 19 | "Programming Language :: Python :: 3.10", 20 | "Programming Language :: Python :: 3.11", 21 | "Programming Language :: Python :: 3.12", 22 | "Operating System :: OS Independent", 23 | ] 24 | dependencies = [ 25 | "braceexpand", 26 | "click", 27 | "numpy", 28 | "pillow>=10.0.1", # WEBP vulnerability fixed starting from 10.0.1 29 | "pyyaml", 30 | "torch", 31 | "tqdm", 32 | "webdataset", 33 | "s3fs", 34 | ] 35 | 36 | [project.optional-dependencies] 37 | dev = [ 38 | "black", 39 | "isort", 40 | "sphinxcontrib-napoleon", 41 | "sphinx", 42 | "myst-parser", 43 | "sphinx-rtd-theme", 44 | "sphinx-click", 45 | ] 46 | transforms = [ 47 | "torchvision", # Needed for megatron.energon.transforms 48 | ] 49 | 50 | [project.urls] 51 | Homepage = "https://github.com/NVIDIA/Megatron-Energon" 52 | 53 | [tool.hatch.build.targets.wheel] 54 | packages = ["src/megatron"] 55 | 56 | [tool.hatch.build.targets.sdist] 57 | packages = ["src/megatron"] 58 | 59 | [project.scripts] 60 | energon = "megatron.energon.cli.main:main" 61 | 62 | [tool.hatch.version] 63 | source = "vcs" 64 | 65 | [tool.black] 66 | line-length = 100 67 | target-version = ['py38'] 68 | include = '\.pyi?$' 69 | exclude = ''' 70 | ( 71 | /( 72 | \.eggs # exclude a few common directories in the 73 | | \.git # root of the project 74 | | \.hg 75 | | \.mypy_cache 76 | | \.tox 77 | | \.venv 78 | | \.vscode 79 | | \.idea 80 | | _build 81 | | build 82 | | dist 83 | )/ 84 | ) 85 | ''' 86 | 87 | [tool.isort] 88 | profile = "black" 89 | multi_line_output = 3 90 | include_trailing_comma = true 91 | force_grid_wrap = 0 92 | use_parentheses = true 93 | ensure_newline_before_comments = true 94 | line_length = 100 95 | -------------------------------------------------------------------------------- /src/megatron/energon/task_encoder/cooking.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import dataclasses 5 | from dataclasses import dataclass 6 | from typing import Callable, Optional 7 | 8 | from megatron.energon.flavors.base_dataset import Sample 9 | from megatron.energon.flavors.crude import CrudeSample 10 | 11 | 12 | @dataclass 13 | class Cooker: 14 | """A cooker transforms a crude sample (simple dict) into a 15 | specific sample type inheriting from `Sample`. 16 | 17 | The `cook` method performs the transformation, 18 | the other fields are used to select the samples which this cooker 19 | can transform. If no filters are provided, the cooker will 20 | transform any sample. 21 | """ 22 | 23 | cook: Callable[[dict], Sample] 24 | 25 | # If multiple of the following conditions are provided 26 | # then the sample must satisfy all of them. 27 | is_subflavor: Optional[str] = None 28 | has_subflavors: Optional[dict] = None 29 | condition: Optional[Callable[[dict], bool]] = None 30 | 31 | def is_match(self, crude_sample: CrudeSample) -> bool: 32 | if self.is_subflavor is not None: 33 | if crude_sample["__subflavor__"] != self.is_subflavor: 34 | return False 35 | 36 | if self.has_subflavors is not None: 37 | # Checks if the dict entries provided as a filter all match 38 | # the ones in the sample. The sample may have additional entries. 39 | for k, v in self.has_subflavors.items(): 40 | if ( 41 | k not in crude_sample["__subflavors__"] 42 | or crude_sample["__subflavors__"][k] != v 43 | ): 44 | return False 45 | 46 | if self.condition is not None: 47 | if not self.condition(crude_sample): 48 | return False 49 | 50 | return True 51 | 52 | 53 | def basic_sample_keys(crude_sample: dict) -> dict: 54 | """A convenience helper to extract the basic keys from a crude sample, 55 | which you will always need to forward to the cooked sample.""" 56 | 57 | return { 58 | field.name: crude_sample[field.name] 59 | for field in dataclasses.fields(Sample) 60 | if field.name in crude_sample 61 | } 62 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/crude.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from typing import Any, Callable, Dict, List, Optional, Union 5 | 6 | from megatron.energon.epathlib import EPath 7 | from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory, ImageDecoder 8 | 9 | 10 | class CrudeSample(dict): 11 | """Generic sample type to be processed later.""" 12 | 13 | 14 | class CrudeWebdataset(DefaultDecoderWebdatasetFactory[CrudeSample]): 15 | """The CrudeWebdataset is used to load crude / raw samples and 16 | decode them in the user code using so-called cookers. 17 | 18 | See the documentation under "Crude Data" for more information. 19 | """ 20 | 21 | __sample_type__ = CrudeSample 22 | 23 | def __init__( 24 | self, 25 | path: EPath, 26 | *, 27 | subflavor: Optional[str] = None, 28 | subflavors: Optional[Dict[str, Any]] = None, 29 | part_filter: Union[str, List[str], Callable[[str], bool]] = lambda _: True, 30 | auto_decode: bool = True, 31 | image_decode: ImageDecoder = "torchrgb", 32 | ignore_decoder_errors: bool = False, 33 | **kwargs, 34 | ): 35 | """ 36 | Constructs a crude webdataset. 37 | 38 | Args: 39 | path: Root path to the joined datasets. 40 | subflavor: Deprecated. Subflavor to set for all loaded samples. 41 | subflavors: Subflavors dictionary to set for all loaded samples. 42 | part_filter: Function for filtering tar files to load by dict keys. 43 | auto_decode: Whether to decode the samples using webdataset decode or not. 44 | image_decode: Image decoding method to use. Only applies if `decode=True`. 45 | ignore_decoder_errors: Whether to ignore decoding errors or not. 46 | **kwargs: Additional arguments to the BaseWebdataset constructor. 47 | """ 48 | # We skip the parent class __init__ and call the BaseWebdataset.__init__ directly 49 | super().__init__( 50 | path, 51 | auto_decode=auto_decode, 52 | image_decode=image_decode, 53 | ignore_decoder_errors=ignore_decoder_errors, 54 | subflavor=subflavor, 55 | subflavors=subflavors, 56 | sample_loader=lambda sample: sample, 57 | part_filter=part_filter, 58 | **kwargs, 59 | ) 60 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/ocr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import warnings 5 | from dataclasses import dataclass 6 | from typing import List, Optional, Union 7 | 8 | import torch 9 | 10 | from megatron.energon.epathlib.epath import EPath 11 | from megatron.energon.errors import warn_deprecated 12 | from megatron.energon.flavors.base_dataset import Sample 13 | from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory 14 | 15 | 16 | @dataclass 17 | class OCRSample(Sample): 18 | """Sample type for optical character recognition.""" 19 | 20 | #: The input image tensor in the shape (C, H, W) 21 | image: torch.Tensor 22 | #: The text contained in the image 23 | text: str 24 | #: The bounding boxes of the blocks in the image float(N, 4|5) 25 | block_boxes: Optional[torch.Tensor] = None 26 | #: The classes of the blocks in the image int(N, 1) 27 | block_classes: Optional[Union[torch.Tensor, List[str]]] = None 28 | #: The text contained in each block (N,) 29 | block_text: Optional[List[str]] = None 30 | #: The bounding boxes of the lines in the image float(N, 4|5) 31 | lines_boxes: Optional[torch.Tensor] = None 32 | #: The text contained in each line (N,) 33 | lines_text: Optional[List[str]] = None 34 | #: The bounding boxes of the words in the image float(N, 4|5) 35 | words_boxes: Optional[torch.Tensor] = None 36 | #: The text contained in each word (N,) 37 | words_text: Optional[List[str]] = None 38 | #: The bounding boxes of the chars in the image float(N, 4|5) 39 | chars_boxes: Optional[torch.Tensor] = None 40 | #: The character contained in each char (N,) 41 | chars_text: Optional[List[str]] = None 42 | 43 | 44 | class OCRWebdataset(DefaultDecoderWebdatasetFactory[OCRSample]): 45 | __sample_type__ = OCRSample 46 | 47 | def __init__(self, path: EPath, **kwargs): 48 | warn_deprecated( 49 | f"{type(self)} is deprecated, use the default instead and set the sample_type:\n" 50 | f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n" 51 | f"# remove top-level __module__ and __class__\n" 52 | f"sample_type:\n" 53 | f" __module__: megatron.energon\n" 54 | f" __class__: {self.__sample_type__.__name__}\n" 55 | f"# Keep the remaining content" 56 | ) 57 | super().__init__(path, **kwargs) 58 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/webdataset/structs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from dataclasses import dataclass, field 5 | from typing import Dict, List, Optional, Tuple, TypedDict 6 | 7 | from megatron.energon.epathlib import EPath 8 | 9 | 10 | @dataclass 11 | class WebdatasetInfo: 12 | # Maps shard name to number of samples in that shard 13 | shard_counts: Dict[str, int] 14 | 15 | 16 | @dataclass 17 | class WebdatasetSplits: 18 | # Maps split part to list of shard names 19 | split_parts: Dict[str, List[str]] 20 | # Set of "" or "/" to exclude 21 | exclude: List[str] = field(default_factory=list) 22 | 23 | 24 | @dataclass 25 | class ShardInfo: 26 | #: Name of the shard file (relative path from the nvinfo dir) 27 | name: str 28 | #: The path to the shard file 29 | path: EPath 30 | #: The offset of the first sample (counted in samples) where this sub shard starts 31 | offset: int 32 | #: The number of samples in this sub shard 33 | count: int 34 | #: The absolute file offset in bytes where this sub shard starts, None if not computed yet 35 | byte_offset: Optional[int] = field(default=None, compare=False) 36 | #: The length of the sub shard in bytes within the shard file, None if not computed yet 37 | byte_size: Optional[int] = field(default=None, compare=False) 38 | 39 | 40 | class FilteredSample(TypedDict): 41 | #: The key of the sample within the tar file. 42 | #: If the tar file contains files 12.jpg and 12.txt, 43 | #: those two files make one sample with the key "12" 44 | __key__: str 45 | #: The base name of the shard file e.g. "shard_000" 46 | __shard__: str 47 | #: Globally unique key to restore a sample from disk. 48 | #: For example `("Webdataset", "shard_000", 123)` would restore 49 | #: the sample from "shard_000.tar" at byte offset 123. 50 | __restore_key__: Tuple[str, str, int] 51 | 52 | 53 | def reraise_exception(exc: Exception, key: Optional[str]) -> None: 54 | if key: 55 | raise Exception(f"For sample {key!r}") from exc 56 | else: 57 | raise 58 | 59 | 60 | @dataclass 61 | class ShardState: 62 | """ 63 | Represents the iteration state of a single shard. I.e. contains a reference to the shard and the 64 | reader's offset. 65 | """ 66 | 67 | #: The shard of this state 68 | shard: ShardInfo 69 | #: Sample offset relative to the sub-shard file offset 70 | offset: int 71 | #: Byte offset relative to the sub-shard file byte offset 72 | byte_offset: int 73 | -------------------------------------------------------------------------------- /src/megatron/energon/module_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from typing import Any, Callable, ClassVar, Dict, Optional, Union 5 | 6 | from megatron.energon.epathlib import EPath 7 | 8 | 9 | class ModuleLoader: 10 | """Class providing access to python files to be loaded as modules.""" 11 | 12 | #: Loaded modules. Global singleton, no need to load a module multiple times. 13 | _modules: ClassVar[Dict[EPath, Dict[str, Any]]] = {} 14 | 15 | def _load_module(self, path: Union[str, EPath]): 16 | """Internally loads a module""" 17 | path = EPath(path) 18 | with path.open("r") as rf: 19 | src_code = rf.read() 20 | code = compile(src_code, path.url, "exec") 21 | mod = {} 22 | eval(code, mod) 23 | self._modules[path] = mod 24 | 25 | def __getitem__(self, path: Union[str, EPath]) -> Any: 26 | """ 27 | Gets a module by path. 28 | 29 | Args: 30 | path: The path to load 31 | 32 | Returns: 33 | Loaded module dict. 34 | """ 35 | path = EPath(path) 36 | if path not in self._modules: 37 | self._load_module(path) 38 | return self._modules[path] 39 | 40 | def get_function( 41 | self, 42 | module_fn: str, 43 | default_fn_name: Optional[str] = None, 44 | relative_path: Optional[EPath] = None, 45 | ) -> Callable: 46 | """ 47 | Loads a function in the form `"module/submodule.py:function_name"`. 48 | 49 | Args: 50 | module_fn: The function string to load. Must contain the path of the python module 51 | and a function name separated by a colon (might be omitted if `default_fn_name` is 52 | specified). 53 | default_fn_name: Default name of the function if not given in `module_fn` string. 54 | relative_path: The relative parent path to the module. If not specified, the current 55 | working directory / absolute path is used. 56 | 57 | Returns: 58 | The function from the module 59 | """ 60 | if ":" in module_fn: 61 | module, fn_name = module_fn.rsplit(":", 1) 62 | else: 63 | if default_fn_name is None: 64 | raise ValueError("Function name must be specified") 65 | fn_name = default_fn_name 66 | module = module_fn 67 | if relative_path is not None: 68 | module = str(relative_path / module) 69 | fn = self[module][fn_name] 70 | assert callable(fn) 71 | return fn 72 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from megatron.energon.flavors.base_dataset import BaseCoreDatasetFactory, Sample, SavableDataset 5 | from megatron.energon.flavors.captioning import CaptioningSample, CaptioningWebdataset 6 | from megatron.energon.flavors.crude import CrudeSample, CrudeWebdataset 7 | from megatron.energon.flavors.image import ImageSample, ImageWebdataset 8 | from megatron.energon.flavors.image_classification import ( 9 | ImageClassificationSample, 10 | ImageClassificationWebdataset, 11 | ) 12 | from megatron.energon.flavors.interleaved import InterleavedSample, InterleavedWebdataset 13 | from megatron.energon.flavors.multichoice_vqa import MultiChoiceVQASample, MultiChoiceVQAWebdataset 14 | from megatron.energon.flavors.ocr import OCRSample, OCRWebdataset 15 | from megatron.energon.flavors.similarity_interleaved import ( 16 | SimilarityInterleavedSample, 17 | SimilarityInterleavedWebdataset, 18 | ) 19 | from megatron.energon.flavors.text import TextSample, TextWebdataset 20 | from megatron.energon.flavors.vid_qa import VidQASample, VidQAWebdataset 21 | from megatron.energon.flavors.vqa import VQASample, VQAWebdataset 22 | from megatron.energon.flavors.vqa_and_ocr import VQAOCRWebdataset 23 | from megatron.energon.flavors.webdataset import ( 24 | BaseWebdatasetFactory, 25 | DefaultDecoderWebdatasetFactory, 26 | DefaultGenericWebdatasetFactory, 27 | EmptyDatasetError, 28 | ImageDecoder, 29 | JoinedWebdatasetFactory, 30 | StandardWebdatasetFactory, 31 | VideoData, 32 | WebdatasetMeta, 33 | ) 34 | 35 | __all__ = [ 36 | "BaseCoreDatasetFactory", 37 | "BaseWebdatasetFactory", 38 | "CaptioningSample", 39 | "CaptioningWebdataset", 40 | "CrudeSample", 41 | "CrudeWebdataset", 42 | "DefaultGenericWebdatasetFactory", 43 | "DefaultDecoderWebdatasetFactory", 44 | "EmptyDatasetError", 45 | "ImageClassificationSample", 46 | "ImageClassificationWebdataset", 47 | "ImageDecoder", 48 | "ImageSample", 49 | "ImageWebdataset", 50 | "InterleavedSample", 51 | "InterleavedWebdataset", 52 | "JoinedWebdatasetFactory", 53 | "MultiChoiceVQASample", 54 | "MultiChoiceVQAWebdataset", 55 | "OCRSample", 56 | "OCRWebdataset", 57 | "Sample", 58 | "SavableDataset", 59 | "SimilarityInterleavedSample", 60 | "SimilarityInterleavedWebdataset", 61 | "StandardWebdatasetFactory", 62 | "TextSample", 63 | "TextWebdataset", 64 | "VQASample", 65 | "VQAWebdataset", 66 | "VQAOCRWebdataset", 67 | "VideoData", 68 | "VidQASample", 69 | "VidQAWebdataset", 70 | "WebdatasetMeta", 71 | ] 72 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/webdataset/metadata.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from dataclasses import dataclass 5 | from typing import List, Set 6 | 7 | import braceexpand 8 | import yaml 9 | 10 | from megatron.energon.epathlib import EPath 11 | from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME 12 | from megatron.energon.flavors.webdataset.empty_dataset_error import EmptyDatasetError 13 | from megatron.energon.flavors.webdataset.structs import ShardInfo, WebdatasetInfo, WebdatasetSplits 14 | from megatron.energon.typed_converter import raw_to_typed 15 | 16 | 17 | @dataclass 18 | class WebdatasetMeta: 19 | """Class for getting metadata from a webdataset.""" 20 | 21 | sample_excludes: Set[str] 22 | shards: List[ShardInfo] 23 | 24 | @staticmethod 25 | def from_config( 26 | path: EPath, 27 | *, 28 | split_part: str, 29 | info_config: str = ".info.yaml", 30 | split_config: str = "split.yaml", 31 | ) -> "WebdatasetMeta": 32 | """ 33 | Loads the metadata for a webdataset, i.e. the shards and sample excludes. 34 | 35 | Args: 36 | split_part: Which part to load (e.g. 'train', 'val', 'test'). 37 | info_config: Config file to use for sample metadata. 38 | split_config: Config file to use for shard split definitions. 39 | """ 40 | info = raw_to_typed( 41 | yaml.safe_load((path / MAIN_FOLDER_NAME / info_config).read_text()), 42 | WebdatasetInfo, 43 | ) 44 | splits = raw_to_typed( 45 | yaml.safe_load((path / MAIN_FOLDER_NAME / split_config).read_text()), 46 | WebdatasetSplits, 47 | ) 48 | assert split_part in splits.split_parts, f"Invalid split part: {split_part!r}" 49 | split_excludes = { 50 | excluded 51 | for excluded in splits.exclude 52 | for excluded in braceexpand.braceexpand(excluded) 53 | } 54 | split_part_files = [ 55 | name 56 | for name in splits.split_parts[split_part] 57 | for name in braceexpand.braceexpand(name) 58 | if name not in split_excludes 59 | ] 60 | if len(split_part_files) == 0: 61 | raise EmptyDatasetError(f"No shards found in split part {split_part!r}") 62 | return WebdatasetMeta( 63 | sample_excludes={excluded for excluded in split_excludes if "/" in excluded}, 64 | shards=[ 65 | ShardInfo( 66 | name=name, 67 | path=path / name, 68 | offset=0, 69 | count=info.shard_counts[name], 70 | ) 71 | for name in split_part_files 72 | ], 73 | ) 74 | -------------------------------------------------------------------------------- /docs/source/advanced/grouping.md: -------------------------------------------------------------------------------- 1 | 3 | 4 | # Grouping 5 | 6 | Grouping allows for rule-based batching of samples into one batch on the fly. 7 | 8 | Note how this is different from [packing](packing) which joins multiple samples into one (and is done before batching). 9 | On the other hand, grouping is an alternative to standard batching. 10 | 11 | ## Example use cases 12 | 13 | * Select samples to batch based on image resolution, so that only samples of the same size are in one batch 14 | * Select blended samples based on their dataset origin, so that one batch does not mix different tasks or data types 15 | 16 | ## How to group 17 | 18 | To use grouping, you need to define the method {py:meth}`batch_group_criterion ` in your custom task encoder. 19 | This method gets a sample and returns a hashable value that will be used to cluster/group the samples 20 | and it also returns the batch size for that group. 21 | 22 | Samples with the same batch group criterion will be batched together. Once enough samples for one group 23 | have been collected (reached the batch size for that group), they will be batched and pushed down the pipeline 24 | to the next processing step. 25 | 26 | Here's an example task encoder that batches samples based on their image aspect ratios: 27 | 28 | ```python 29 | class GroupingTaskEncoder(DefaultTaskEncoder): 30 | def batch_group_criterion(self, sample: CaptioningSample) -> Tuple[Hashable, Optional[int]]: 31 | aspect_ratio = sample.image.shape[2] / sample.image.shape[1] 32 | 33 | # Bin aspect ratios into 3 groups 34 | if aspect_ratio < 0.8: 35 | return "portrait", 8 36 | elif aspect_ratio < 1.2: 37 | return "square", 8 38 | else: 39 | return "landscape", 8 40 | ``` 41 | 42 | In the example, the aspect ratio is sorted into one of three bins and a string is used as the grouping key. 43 | The batch size used here is always 8. 44 | 45 | Here is another example where each batch contains only images with the exact same size. 46 | Note how the image shape itself is used as the grouping key. 47 | 48 | ```python 49 | class GroupingTaskEncoder(DefaultTaskEncoder): 50 | def batch_group_criterion(self, sample: CaptioningSample) -> Tuple[Hashable, Optional[int]]: 51 | batch_size = 4 if sample.image.shape[1] < 512 else 2 52 | return sample.image.shape, batch_size 53 | ``` 54 | 55 | For images with a height of less than 512 pixels, the batch size will be 4, for larger images it's reduced to 2. 56 | 57 | 58 | ## Fixed global batch size 59 | 60 | Instead of specifying the batch size for each group individually, you can also specify the batch size as usually when calling 61 | `get_train_dataset`. The `batch_group_criterion` method should then return `None` for the batch_size. 62 | -------------------------------------------------------------------------------- /src/megatron/energon/retry_stream.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from types import TracebackType 5 | from typing import BinaryIO, Optional, Type 6 | 7 | from megatron.energon.epathlib import EPath 8 | 9 | 10 | class RetryReadStream(BinaryIO): 11 | """A stream that retries reading from a file. Only supports reading bytes.""" 12 | 13 | _path: EPath 14 | _file: Optional[BinaryIO] 15 | _pos: int 16 | _size: int 17 | 18 | def __init__(self, path: EPath): 19 | """Construct a RetryReadStream. It reads only bytes from a file.""" 20 | self._path = path 21 | self._file = None 22 | self._pos = 0 23 | self._size = path.size() 24 | 25 | def __enter__(self) -> "RetryReadStream": 26 | return self 27 | 28 | def __exit__( 29 | self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType 30 | ) -> None: 31 | self.close() 32 | 33 | def close(self) -> None: 34 | if self._file is not None: 35 | self._file.close() 36 | 37 | def read(self, n: int = -1) -> bytes: 38 | buf = b"" 39 | for retry in range(10): 40 | try: 41 | if self._file is None: 42 | self._file = self._path.open("rb") 43 | self._file.seek(self._pos) 44 | res = self._file.read(n) 45 | self._pos += len(res) 46 | buf += res 47 | if ( 48 | (n == -1 and self._pos >= self._size) 49 | or len(buf) == n 50 | or self._pos >= self._size 51 | ): 52 | return res 53 | except IOError: 54 | try: 55 | self._file.close() 56 | except IOError: 57 | pass 58 | self._file = None 59 | if retry == 9: 60 | raise 61 | continue 62 | 63 | def seek(self, offset: int, whence: int = 0) -> int: 64 | if whence == 0: 65 | pass 66 | elif whence == 1: 67 | offset += self._pos 68 | elif whence == 2: 69 | offset += self._size 70 | else: 71 | raise ValueError(f"Invalid whence value: {whence}") 72 | offset = min(max(offset, 0), self._size) 73 | self._pos = offset 74 | try: 75 | if self._file is not None: 76 | self._file.seek(offset) 77 | except IOError: 78 | pass 79 | return self._pos 80 | 81 | def tell(self) -> int: 82 | return self._pos 83 | 84 | def isatty(self) -> bool: 85 | return False 86 | 87 | def readable(self) -> bool: 88 | return True 89 | 90 | def seekable(self) -> bool: 91 | return True 92 | 93 | def writable(self) -> bool: 94 | return False 95 | -------------------------------------------------------------------------------- /src/megatron/energon/metadataset/loader_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from abc import ABC, abstractmethod 5 | from enum import Enum 6 | from typing import Any, Dict, List, Literal, Optional, Tuple, Union 7 | 8 | from megatron.energon.flavors.base_dataset import BaseCoreDatasetFactory 9 | from megatron.energon.worker import WorkerConfig 10 | 11 | 12 | class DatasetBlendMode(Enum): 13 | """Determines how the the datasets are to be blended. Either by using the associated number as 14 | the weight for sampling from that dataset, or alternatively by using the number as the number 15 | of repetitions for samples in that dataset in one epoch (effectively, that corresponds to the 16 | weight for samples).""" 17 | 18 | NONE = "none" 19 | DATASET_WEIGHT = "dataset_weight" 20 | SAMPLE_REPETITIONS = "sample_repetitions" 21 | 22 | 23 | class DatasetLoaderInterface(ABC): 24 | """General interface for a dataset loader.""" 25 | 26 | @abstractmethod 27 | def get_datasets( 28 | self, 29 | *, 30 | training: bool, 31 | split_part: Union[Literal["train", "val", "test"], str], 32 | worker_config: WorkerConfig, 33 | subflavor: Optional[str] = None, 34 | subflavors: Optional[Dict[str, Any]] = None, 35 | shuffle_over_epochs_multiplier: int = 1, 36 | **kwargs, 37 | ) -> Tuple[DatasetBlendMode, List[Tuple[BaseCoreDatasetFactory, Union[float, int, None]]]]: 38 | """ 39 | Calls :func:`megatron.energon.dataset_config.get_dataset_from_config` (loads the raw dataset) 40 | for all innermost datasets and resolves their relative weights to absolute weights. 41 | 42 | Args: 43 | training: If true, apply training randomization. 44 | split_part: Default split part to use. 45 | worker_config: Worker configuration to use. 46 | subflavor: Set the default subflavor for all datasets. 47 | subflavors: Set the default subflavors for all datasets. 48 | shuffle_over_epochs_multiplier: Multiply the inner datasets 49 | `shuffle_over_epochs(_multiplier)` by this factor. E.g. if the inner dataset 50 | has `shuffle_over_epochs_multiplier=2` and this function has 51 | `shuffle_over_epochs_multiplier=3`, the inner dataset will be shuffled 52 | over 6 epochs. Shuffling over `n` epochs guarantees that each sample is seen 53 | exactly `n` times in `n` epochs of the inner dataset. Use -1 for shuffling over 54 | an infinite number of epochs (effectively, this will draw shard slices with 55 | replacement). 56 | **kwargs: Additional arguments to the dataset constructor. 57 | 58 | Returns: 59 | The dataset blending mode and the instantiated core datasets with their weights/repetitions. 60 | """ 61 | ... 62 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/webdataset/field_access.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import re 5 | from typing import Any, List, Union 6 | 7 | from megatron.energon.errors import SampleException, compact_str 8 | 9 | _field_split_re = re.compile(r"^(?P[^[\]]+)(?P(?:\[([^[\]]+)])*)$") 10 | 11 | 12 | def split_field_access(field_desc: str) -> List[List[str]]: 13 | """ 14 | Splits a field_map for access:: 15 | 16 | 'abcdef,ghi' -> [['abcdef'], ['ghi']] 17 | 'abcdef[ghi]' -> [['abcdef', 'ghi']] 18 | 'abcdef[ghi][jkl]' -> [['abcdef', 'ghi', 'jkl']] 19 | """ 20 | options = field_desc.split(",") 21 | option_fields = [] 22 | for option in options: 23 | match = _field_split_re.match(option) 24 | if match: 25 | option_fields.append( 26 | [match.group("field_name")] 27 | + [ 28 | access.lstrip("[").rstrip("]") 29 | for access in match.group("access").split("][") 30 | if access 31 | ] 32 | ) 33 | else: 34 | option_fields.append([field_desc]) 35 | return option_fields 36 | 37 | 38 | class FieldAccessError(SampleException): 39 | pass 40 | 41 | 42 | def _field_access(value: Union[dict, list, str, int, bool, None], field: List[str]) -> Any: 43 | """ 44 | Accesses a (nested) field in the value. 45 | 46 | Args: 47 | value: The value to access 48 | field: The access instruction (e.g. `['field1', 'field2']` for 49 | `value['field1']['field2']`) 50 | 51 | Returns: 52 | The accessed value 53 | """ 54 | try: 55 | if len(field) == 0: 56 | return value 57 | elif isinstance(value, dict): 58 | return _field_access(value[field[0]], field[1:]) 59 | elif isinstance(value, list): 60 | return _field_access(value[int(field[0])], field[1:]) 61 | else: 62 | raise FieldAccessError( 63 | f"Cannot access literal value {compact_str(value)} with {field!r}" 64 | ) 65 | except FieldAccessError: 66 | raise 67 | except KeyError: 68 | raise FieldAccessError(f"Cannot access {'.'.join(field)!r} in {compact_str(value)}") 69 | 70 | 71 | def field_access(value: Union[dict, list, str, int, bool, None], field: List[List[str]]) -> Any: 72 | """ 73 | Accesses a (nested) field in the value. 74 | 75 | Args: 76 | value: The value to access 77 | field: The access instruction (e.g. `[['field1', 'field2']]` for 78 | `value['field1']['field2']`, or `[['field1'], ['field2']]` for value.get('field1', value['field2'])`) 79 | 80 | Returns: 81 | The accessed value 82 | """ 83 | for f in field[:-1]: 84 | try: 85 | return _field_access(value, f) 86 | except (KeyError, ValueError, IndexError): 87 | pass 88 | return _field_access(value, field[-1]) 89 | -------------------------------------------------------------------------------- /src/megatron/energon/errors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import itertools 5 | import warnings 6 | from functools import wraps 7 | from typing import Any, Type, TypeVar, Union 8 | 9 | 10 | def compact_str( 11 | value: Union[dict, list, str, int, bool, None], 12 | depth: int = 3, 13 | max_items: int = 10, 14 | max_str_len: int = 50, 15 | ) -> str: 16 | """ 17 | Compact representation of a value as a string. 18 | 19 | Args: 20 | value: The value to compact 21 | depth: The maximum depth to compact 22 | max_items: The maximum number of items to show in a list or dict 23 | max_str_len: The maximum string length to show 24 | 25 | Returns: The printable string 26 | """ 27 | if isinstance(value, dict): 28 | if depth <= 0: 29 | return "{...}" 30 | return ( 31 | "{" 32 | + ", ".join( 33 | ( 34 | f"{k}: {v!r}" 35 | if isinstance(k, str) and k.startswith("__") 36 | else f"{k}: {compact_str(v, depth - 1, max_items, max_str_len)}" 37 | ) 38 | for k, v in itertools.islice(value.items(), max_items) 39 | ) 40 | + "}" 41 | ) 42 | elif isinstance(value, list): 43 | if depth <= 0: 44 | return "[...]" 45 | return ( 46 | "[" 47 | + ", ".join( 48 | compact_str(v, depth - 1, max_items, max_str_len) for v in value[:max_items] 49 | ) 50 | + "]" 51 | ) 52 | elif isinstance(value, str): 53 | if len(value) > max_str_len: 54 | return repr(value[:max_str_len] + "...") 55 | return repr(value) 56 | else: 57 | return repr(value) 58 | 59 | 60 | T = TypeVar("T") 61 | 62 | 63 | class SampleException(ValueError): 64 | @classmethod 65 | def from_sample_key(cls: Type[T], sample_key: str) -> T: 66 | return cls(f"Sample {sample_key} failed") 67 | 68 | @classmethod 69 | def from_sample(cls: Type[T], sample: Any) -> T: 70 | return cls(f"Sample {compact_str(sample)} failed") 71 | 72 | 73 | class FatalSampleError(SampleException): 74 | # This will not be handled by the error handler 75 | pass 76 | 77 | 78 | def warn_deprecated(reason, stacklevel=2): 79 | warnings.warn(reason, FutureWarning, stacklevel=stacklevel) 80 | 81 | 82 | def deprecated(reason): 83 | def decorator(func): 84 | @wraps(func) 85 | def wrapper(*args, **kwargs): 86 | warn_deprecated(f"{func.__name__} is deprecated: {reason}", stacklevel=3) 87 | return func(*args, **kwargs) 88 | 89 | return wrapper 90 | 91 | return decorator 92 | 93 | 94 | SYSTEM_EXCEPTIONS = ( 95 | SystemError, 96 | SyntaxError, 97 | ImportError, 98 | StopIteration, 99 | StopAsyncIteration, 100 | MemoryError, 101 | RecursionError, 102 | ReferenceError, 103 | NameError, 104 | UnboundLocalError, 105 | FatalSampleError, 106 | ) 107 | -------------------------------------------------------------------------------- /src/megatron/energon/epathlib/rclone_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import configparser 5 | import os 6 | import shutil 7 | from dataclasses import dataclass 8 | from pathlib import Path 9 | from typing import Dict, Optional 10 | 11 | 12 | @dataclass 13 | class ConfigEntry: 14 | name: str 15 | type: str 16 | provider: Optional[str] 17 | access_key_id: Optional[str] 18 | secret_access_key: Optional[str] 19 | region: Optional[str] 20 | endpoint: Optional[str] 21 | 22 | 23 | def find_executable_path(executable_name): 24 | """Find the path of an executable in the PATH environment variable. Returns None if not found.""" 25 | 26 | executable_path = shutil.which(executable_name) 27 | if executable_path: 28 | return Path(executable_path) 29 | return None 30 | 31 | 32 | def get_rclone_config_path() -> Optional[Path]: 33 | 34 | # First check if rclone executable is in PATH, if yes, check if rclone.conf is in the same directory 35 | rclone_exe_path = find_executable_path("rclone") 36 | if rclone_exe_path is not None and rclone_exe_path.is_file(): 37 | rclone_config_path = rclone_exe_path.with_name("rclone.conf") 38 | if rclone_config_path.is_file(): 39 | return rclone_config_path 40 | 41 | # As a second option check the XDG_CONFIG_HOME environment variable, if it is set, check for rclone/rclone.conf in that directory 42 | xdg_config_home = os.environ.get("XDG_CONFIG_HOME") 43 | if xdg_config_home and Path(xdg_config_home).is_dir(): 44 | rclone_config_path = Path(xdg_config_home) / "rclone" / "rclone.conf" 45 | if rclone_config_path.is_file(): 46 | return rclone_config_path 47 | 48 | # As a third option check the default location ~/.config/rclone/rclone.conf 49 | rclone_config_path = Path.home() / ".config" / "rclone" / "rclone.conf" 50 | if rclone_config_path.is_file(): 51 | return rclone_config_path 52 | 53 | # Last option is to check the legacy location ~/.rclone.conf 54 | legacy_config_path = Path.home() / ".rclone.conf" 55 | if legacy_config_path.is_file(): 56 | return legacy_config_path 57 | 58 | return None 59 | 60 | 61 | def read_rclone_config_at_path(config_path: Path) -> Dict[str, ConfigEntry]: 62 | """Reads the config file and returns a dictionary with the config entries.""" 63 | 64 | config = configparser.ConfigParser() 65 | config.read(config_path) 66 | 67 | config_entries = {} 68 | for section in config.sections(): 69 | entry = ConfigEntry( 70 | name=section, 71 | type=config[section].get("type"), 72 | provider=config[section].get("provider"), 73 | access_key_id=config[section].get("access_key_id"), 74 | secret_access_key=config[section].get("secret_access_key"), 75 | region=config[section].get("region"), 76 | endpoint=config[section].get("endpoint"), 77 | ) 78 | config_entries[section] = entry 79 | 80 | return config_entries 81 | 82 | 83 | def read_rclone_config() -> Dict[str, ConfigEntry]: 84 | config_path = get_rclone_config_path() 85 | if config_path is None: 86 | raise FileNotFoundError("Could not find rclone configuration file.") 87 | return read_rclone_config_at_path(config_path) 88 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/webdataset/standard_webdataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from typing import Generic, Type, TypeVar 5 | 6 | from megatron.energon.epathlib import EPath 7 | from megatron.energon.flavors.base_dataset import Sample 8 | from megatron.energon.flavors.webdataset.default_decoder_webdataset import ( 9 | DefaultDecoderWebdatasetFactory, 10 | ) 11 | 12 | T_sample = TypeVar("T_sample", covariant=True) 13 | 14 | 15 | class StandardWebdatasetFactory(DefaultDecoderWebdatasetFactory[T_sample], Generic[T_sample]): 16 | """ 17 | This dataset sample loader factory uses the sample type e.g. given from a dataset.yaml, and applies the default 18 | loading logic, which includes decoding images, videos and containers. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | path: EPath, 24 | *, 25 | sample_type: Type[T_sample], 26 | **kwargs, 27 | ): 28 | """ 29 | Factory for the standard webdataset sample loader. 30 | 31 | Args: 32 | path: Path to the dataset (passed to parent) 33 | sample_type: Type of the sample to be loaded 34 | auto_decode: If true, use the default webdataset sample decoder. 35 | image_decode: This defines the decoding results. 36 | ignore_decoder_errors: If true, ignore errors when decoding. 37 | subflavors: Subflavors dictionary to set for all loaded samples. 38 | field_map: Mapping from the webdataset fields to the sample fields. 39 | sample_loader: Function to load the sample from the webdataset fields. May be a string 40 | in order to load a function from a module, or a callable directly. 41 | part_filter: Filter for the parts to load. May be a string in order to load a function 42 | from a module, or a callable directly. 43 | split_part: Which part to load (e.g. 'train', 'val', 'test'). 44 | training: If true, apply shuffling and loop the dataset. 45 | worker_config: Configuration for the workers. 46 | shuffle_over_epochs: Only effective if training=True. 47 | How many epochs to shuffle over if training. 48 | If = 1, every sample is seen exactly once per epoch. 49 | If > 1, samples (or rather shard slices) are shuffled within this number of epochs 50 | (i.e. randomly selected without replacement). 51 | If -1, the shards are effectively shuffle over infinite epochs (i.e. shard slices 52 | are drawn with replacement). 53 | parallel_shard_iters: Number of parallel opened shards per worker, shuffling between. 54 | max_samples_per_sequence: Maximum number of samples per sequence (=how many samples 55 | will be sequentially iterated). 56 | info_config: Config file to use for sample metadata. 57 | split_config: Config file to use for shard split definitions. 58 | handler: Exception handler. Args: (exception, key). 59 | """ 60 | self.__sample_type__ = sample_type 61 | assert issubclass( 62 | sample_type, Sample 63 | ), f"sample_type must be subclass of megatron.energon.Sample, got {sample_type.__qualname__}" 64 | super().__init__(path, **kwargs) 65 | -------------------------------------------------------------------------------- /src/megatron/energon/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from typing import Optional, TypeVar 5 | 6 | from megatron.energon.errors import warn_deprecated 7 | from megatron.energon.flavors.base_dataset import SavableDataset 8 | from megatron.energon.savable_loader import BasicDataLoader, SavableDataLoader 9 | from megatron.energon.worker import WorkerConfig 10 | 11 | T = TypeVar("T") 12 | 13 | 14 | def get_savable_loader( 15 | dataset: SavableDataset[T], 16 | *, 17 | worker_config: Optional[WorkerConfig] = None, 18 | checkpoint_every_sec: float = 60, 19 | checkpoint_every_min_n_samples: Optional[int] = None, 20 | n_checkpoints: int = 2, 21 | ) -> SavableDataLoader[T]: 22 | """ 23 | 24 | Get a dataloader for the given dataset. 25 | 26 | Args: 27 | dataset: The dataset to create a loader for. 28 | worker_config: Deprecated. Please pass this to the dataset instead. 29 | checkpoint_every_sec: This is the time in seconds after which an internal checkpoint is 30 | saved. It may take the same duration to restore a checkpoint, but introduces additional 31 | overhead during reading data from the dataset, so this should be chosen accordingly. 32 | Only applies if using workers. 33 | checkpoint_every_min_n_samples: Overwrites the minimum number of samples between 34 | checkpoints. Defaults to `number of workers * 2`. Only applies if using workers. 35 | n_checkpoints: The number of internal checkpoints to keep. Only applies if using workers. 36 | 37 | Returns: 38 | The instantiated :class:`megatron.energon.SavableDataLoader`, yielding batches from the dataset, 39 | allowing to save the state of the dataset. 40 | """ 41 | if worker_config is not None: 42 | if worker_config != dataset.worker_config: 43 | raise AssertionError( 44 | "The worker_config passed to get_savable_loader() does not match the one of the dataset. " 45 | "Also note, it is deprecated to pass one to get_savable_loader() and it will have no effect." 46 | ) 47 | else: 48 | warn_deprecated( 49 | "Passing a worker_config to get_savable_loader() is deprecated and will have no effect." 50 | ) 51 | 52 | return SavableDataLoader( 53 | dataset, 54 | checkpoint_every_sec=checkpoint_every_sec, 55 | checkpoint_every_min_n_samples=checkpoint_every_min_n_samples, 56 | n_checkpoints=n_checkpoints, 57 | ) 58 | 59 | 60 | def get_loader( 61 | dataset: SavableDataset[T], 62 | *, 63 | worker_config: Optional[WorkerConfig] = None, 64 | ) -> BasicDataLoader[T]: 65 | """ 66 | Get a dataloader for the given dataset. 67 | 68 | Args: 69 | dataset: The dataset to create a loader for. 70 | worker_config: Deprecated. Please pass this to the dataset instead. 71 | 72 | Returns: 73 | The instantiated :class:`torch.data.DataLoader`, yielding batches from the dataset. 74 | """ 75 | if worker_config is not None: 76 | if worker_config != dataset.worker_config: 77 | raise AssertionError( 78 | "The worker_config passed to get_loader() does not match the one of the dataset. " 79 | "Also note, it is deprecated to pass one to get_loader() and it will have no effect." 80 | ) 81 | else: 82 | warn_deprecated( 83 | "Passing a worker_config to get_loader() is deprecated and will have no effect." 84 | ) 85 | 86 | return BasicDataLoader(dataset) 87 | -------------------------------------------------------------------------------- /src/megatron/energon/tools/info.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from importlib import import_module 5 | from io import StringIO 6 | from typing import Type 7 | 8 | import click 9 | import yaml 10 | 11 | import megatron.energon 12 | from megatron.energon.epathlib import EPath 13 | from megatron.energon.flavors.base_dataset import BaseCoreDatasetFactory 14 | from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME 15 | 16 | fmt = """ 17 | * {dataset_name}: `{path}` ({samples_count:,} samples, {samples_size} in {shards_count} shards) 18 | * Sample Type: {{py:class}}`{sample_name} <{sample_fullname}>` 19 | * Default Splits: 20 | {splits_str} 21 | """ 22 | 23 | split_fmt = """ * `{split_name}`: {split_ratio:.0f}%, {split_samples_count:,} samples in {split_shards_count} shards 24 | """ 25 | 26 | 27 | def fmt_size(size: int) -> str: 28 | keys = ["B", "KiB", "MiB", "GiB", "TiB"] 29 | for key in keys: 30 | if size < 1024: 31 | return f"{size:.2f} {key}" 32 | size /= 1024 33 | return f"{size:.2f} PiB" 34 | 35 | 36 | @click.command(name="info") 37 | @click.argument( 38 | "path", 39 | type=click.Path(file_okay=False, dir_okay=True, path_type=EPath), 40 | ) 41 | @click.option( 42 | "--split-config", default="split.yaml", help="Split config file name", show_default=True 43 | ) 44 | @click.option( 45 | "--dataset-config", default="dataset.yaml", help="Dataset config file name", show_default=True 46 | ) 47 | def command( 48 | path: EPath, 49 | split_config: str, 50 | dataset_config: str, 51 | ): 52 | """ 53 | Get summarizing information about a dataset. 54 | """ 55 | 56 | ds_config = yaml.safe_load(StringIO((path / MAIN_FOLDER_NAME / dataset_config).read_text())) 57 | info_config = yaml.safe_load(StringIO((path / MAIN_FOLDER_NAME / ".info.yaml").read_text())) 58 | split_config = yaml.safe_load(StringIO((path / MAIN_FOLDER_NAME / split_config).read_text())) 59 | samples_count = sum(info_config["shard_counts"].values()) 60 | dict_sample_type = ds_config["sample_type"] 61 | sample_module = import_module(dict_sample_type["__module__"]) 62 | 63 | sample_cls: Type[BaseCoreDatasetFactory] = getattr(sample_module, dict_sample_type["__class__"]) 64 | sample_module = sample_cls.__module__ 65 | if ( 66 | sample_module.startswith("megatron.energon") 67 | and getattr(megatron.energon, dict_sample_type["__class__"], None) == sample_cls 68 | ): 69 | sample_module = "megatron.energon" 70 | sample_name = sample_cls.__name__ 71 | sample_fullname = sample_module + "." + sample_name 72 | 73 | def srt_key(pair): 74 | try: 75 | return ("train", "val", "test").index(pair[0]) 76 | except ValueError: 77 | return 3 78 | 79 | splits_str = "".join( 80 | split_fmt.format( 81 | split_name=split_name, 82 | split_ratio=round( 83 | 100 84 | * sum(info_config["shard_counts"][shard] for shard in split_parts) 85 | / samples_count, 86 | 2, 87 | ), 88 | split_samples_count=sum(info_config["shard_counts"][shard] for shard in split_parts), 89 | split_shards_count=len(split_parts), 90 | ) 91 | for split_name, split_parts in sorted(split_config["split_parts"].items(), key=srt_key) 92 | ) 93 | print( 94 | fmt.format( 95 | dataset_name=path.name, 96 | path=str(path), 97 | samples_count=samples_count, 98 | samples_size=fmt_size( 99 | sum((path / split_name).size() for split_name in info_config["shard_counts"].keys()) 100 | ), 101 | shards_count=len(info_config["shard_counts"]), 102 | sample_name=sample_name, 103 | sample_fullname=sample_fullname, 104 | splits_str=splits_str, 105 | ) 106 | ) 107 | -------------------------------------------------------------------------------- /tests/test_epathlib.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | """This module tests the EPath class, our custom version of pathlib.Path""" 5 | 6 | import logging 7 | import struct 8 | import sys 9 | import unittest 10 | 11 | from megatron.energon.epathlib import EPath 12 | from megatron.energon.epathlib.rclone_config import ConfigEntry 13 | 14 | 15 | class TestEPath(unittest.TestCase): 16 | def setUp(self): 17 | logging.basicConfig(stream=sys.stderr, level=logging.INFO) 18 | 19 | def tearDown(self): 20 | pass 21 | 22 | def test_basic(self): 23 | """Some basic functionality tests""" 24 | 25 | p_rel = EPath("./subdir") 26 | p_abs = EPath("/tmp") 27 | 28 | p_comb = p_abs / p_rel 29 | # logging.info(f"p_comb: {p_comb}") 30 | # logging.info(f"p_comb: {p_comb.internal_path}") 31 | 32 | # We don't want to work on relative paths 33 | self.assertRaises(AssertionError, lambda: p_rel.is_file()) 34 | 35 | # Those should not raise: 36 | assert p_comb.is_absolute() 37 | _ = p_comb.is_file() 38 | _ = p_abs.is_file() 39 | 40 | def test_contextman(self): 41 | """Test the context manager""" 42 | 43 | tmp_file_path = "/tmp/testfile.bin" 44 | # First create a file 45 | with open(tmp_file_path, "wb") as f: 46 | f.write(struct.pack("H10s", 1337, b"1234567890")) 47 | 48 | # Test context manager reading 49 | p = EPath(tmp_file_path).open("rb") 50 | with p: 51 | b = p.read() 52 | assert isinstance(b, bytes) 53 | 54 | num, data = struct.unpack("H10s", b) 55 | logging.info(f"num: {num}") 56 | assert num == 1337 57 | assert data == b"1234567890" 58 | 59 | assert not p.closed 60 | 61 | assert p.closed 62 | 63 | # Test context manager writing 64 | tmp_file_path2 = "/tmp/testfile2.bin" 65 | with EPath(tmp_file_path2).open("wb") as p: 66 | p.write(struct.pack("H10s", 1337, b"1234567890")) 67 | 68 | def test_glob(self): 69 | """Test the glob functionality""" 70 | 71 | # First create some files 72 | for i in range(10): 73 | with open(f"/tmp/epathtestfile_{i}.bin", "wb") as f: 74 | f.write(b"dummycontent") 75 | 76 | # Test globbing 77 | p = EPath("/tmp").glob("epathtestfile_*.bin") 78 | 79 | logging.info(f"p: {p}, type of p: {type(p)}") 80 | elems = list(p) 81 | assert len(elems) == 10 82 | for i, e in enumerate(elems): 83 | logging.info(f"glob_result[{i}]: {e}") 84 | assert isinstance(e, EPath) 85 | assert e.is_file() 86 | 87 | # Test globbing with a pattern 88 | p = EPath("/tmp").glob("epathtestfile_[0-3].bin") 89 | assert len(list(p)) == 4 90 | 91 | def test_s3_path_resolution(self): 92 | """Test s3 path resolution""" 93 | config_override = { 94 | "s3": ConfigEntry( 95 | name="s3", 96 | type="s3", 97 | provider="s3", 98 | access_key_id="dummy", 99 | secret_access_key="dummy", 100 | region="dummy", 101 | endpoint="https://localhost", 102 | ) 103 | } 104 | 105 | # Test globbing 106 | p = EPath("rclone://s3/tmp/path/subpath.txt", config_override=config_override) 107 | assert str(p) == "rclone://s3/tmp/path/subpath.txt", str(p) 108 | 109 | p2 = p / ".." / "subpath2.txt" 110 | assert str(p2) == "rclone://s3/tmp/path/subpath2.txt", str(p2) 111 | 112 | p3 = EPath("rclone://s3/tmp/path/.././subpath.txt", config_override=config_override) 113 | assert str(p3) == "rclone://s3/tmp/subpath.txt", str(p3) 114 | 115 | p4 = p3.parent / "../bla/bla/bla/../../../no/../subpath2.txt" 116 | assert str(p4) == "rclone://s3/subpath2.txt", str(p4) 117 | 118 | 119 | if __name__ == "__main__": 120 | unittest.main() 121 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 3 | 4 | 5 |
6 |

Megatron's multi-modal data loader

7 |

Megatron Energon

8 |

9 | Tests Documentation 10 |
11 | Report Bug 12 | · 13 | Request Feature 14 |

15 |
16 | 17 |
18 | 19 | _**DISCLAIMER**: This package contains research code. APIs may change._ 20 | 21 | # What is this? 22 | 23 | **Megatron Energon** is the multi-modal data loader of [Megatron](https://github.com/NVIDIA/Megatron-LM) (you can also use it independently). 24 | 25 | It's best at 26 | 27 | - loading large training data to train large multi-modal models 28 | - blending many different datasets together 29 | - distributing the work across many nodes and processes of a cluster 30 | - ensuring reproducibility and resumability 31 | - adapting easily to various types of data samples and processing 32 | 33 | Try using it together with [Megatron](https://github.com/NVIDIA/Megatron-LM) Core. 34 | 35 | # Quickstart 36 | **Megatron Energon** is a pip-installable python package that offers 37 | - dataset-related classes that you can import in your project 38 | - a command line utility for data preprocessing and conversion 39 | 40 | This document is just a quick start. Please also check out the [documentation](https://nvidia.github.io/Megatron-Energon/). 41 | 42 | ## Installation 43 | 44 | To install the latest stable version: 45 | ```shell 46 | pip install megatron-energon 47 | ``` 48 | 49 | Or to install the current development version: 50 | ```shell 51 | pip install git+https://github.com/NVIDIA/Megatron-Energon.git 52 | ``` 53 | 54 | **NOTE**: We encourage you to install the package (and not just import a local copy). This will ensure you have all the needed dependencies and that you can use the command line tool. 55 | 56 | For more details on installing this package, see [here](https://nvidia.github.io/Megatron-Energon/installation.html). 57 | 58 | ## Usage of command line tool 59 | 60 | After installation, the command `energon` will be available. 61 | 62 | Here are some examples for things you can do: 63 | 64 | | Command | Description | 65 | |---|---| 66 | | `energon prepare DATASET_ROOT` | Take an existing WebDataset and add the required yaml files to turn it into an energon-compatible dataset | 67 | | `energon lint DATASET_ROOT` | Verify that the dataset complies with the energon dataset format and that all samples are loadable | 68 | 69 | 70 | ## Usage of the library 71 | 72 | To get started, pick a [WebDataset](https://github.com/webdataset/webdataset)-compliant dataset and run `energon prepare DATASET_ROOT` on it, to run the interactive assistant and create the `.nv-meta` folder. 73 | 74 | Once done, try to load it from your Python program: 75 | 76 | ```python 77 | from megatron.energon import get_train_dataset, get_loader, WorkerConfig 78 | 79 | 80 | simple_worker_config = WorkerConfig(rank=0, world_size=1, num_workers=2) 81 | 82 | 83 | train_ds = get_train_dataset( 84 | '/my/dataset/path', 85 | batch_size=32, 86 | shuffle_buffer_size=None, 87 | max_samples_per_sequence=None, 88 | worker_config=simple_worker_config, 89 | ) 90 | 91 | train_loader = get_loader(train_ds) 92 | 93 | for batch in train_loader: 94 | # Do something with batch 95 | # Infer, gradient step, ... 96 | pass 97 | ``` 98 | 99 | For more details, read the [documentation](https://nvidia.github.io/Megatron-Energon/). 100 | 101 | Most likely, you'll need your own [task encoder](https://nvidia.github.io/Megatron-Energon/task_encoders.html). 102 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/webdataset/default_decoder_webdataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from typing import Any, Callable, Dict, Generic, Literal, Optional, TypeVar 5 | 6 | import webdataset 7 | 8 | from megatron.energon.epathlib import EPath 9 | from megatron.energon.flavors.webdataset.default_generic_webdataset import ( 10 | DefaultGenericWebdatasetFactory, 11 | ) 12 | from megatron.energon.flavors.webdataset.structs import FilteredSample 13 | from megatron.energon.flavors.webdataset.video_data import VideoData 14 | 15 | T_sample = TypeVar("T_sample", covariant=True) 16 | 17 | ImageDecoder = Literal[ 18 | "l8", 19 | "rgb8", 20 | "rgba8", 21 | "l", 22 | "rgb", 23 | "rgba", 24 | "torchl8", 25 | "torchrgb8", 26 | "torchrgba8", 27 | "torchl", 28 | "torchrgb", 29 | "torch", 30 | "torchrgba", 31 | "pill", 32 | "pil", 33 | "pilrgb", 34 | "pilrgba", 35 | ] 36 | 37 | 38 | class DefaultDecoderWebdatasetFactory(DefaultGenericWebdatasetFactory[T_sample], Generic[T_sample]): 39 | """ 40 | Extends the default webdataset loading with decoding of contained files, such as images, videos or nested 41 | containers. 42 | """ 43 | 44 | #: Image decoding result type 45 | image_decode: ImageDecoder 46 | #: If true, ignore errors when decoding. 47 | ignore_decoder_errors: bool 48 | 49 | # The webdataset decoder function, if to be applied 50 | _decoder: Optional[Callable[[FilteredSample], FilteredSample]] 51 | 52 | def __init__( 53 | self, 54 | path: EPath, 55 | *, 56 | auto_decode: bool = True, 57 | image_decode: ImageDecoder = "torchrgb", 58 | ignore_decoder_errors: bool = False, 59 | **kwargs, 60 | ): 61 | """ 62 | Factory for the webdataset sample loader including the decoder. 63 | 64 | Args: 65 | path: Path to the dataset (passed to parent) 66 | auto_decode: If true, use the default webdataset sample decoder. 67 | image_decode: This defines the decoding results. 68 | ignore_decoder_errors: If true, ignore errors when decoding. 69 | **kwargs: Args passed to parent constructor 70 | """ 71 | self.image_decode = image_decode 72 | self.ignore_decoder_errors = ignore_decoder_errors 73 | super().__init__(path, **kwargs) 74 | 75 | if auto_decode: 76 | self._decoder = webdataset.autodecode.Decoder( 77 | [ 78 | webdataset.autodecode.imagehandler(self.image_decode), 79 | self._video_decoder, 80 | ] 81 | ) 82 | else: 83 | self._decoder = None 84 | 85 | def _decode_error_handler(self, exc: Exception) -> bool: 86 | if self.ignore_decoder_errors: 87 | return True 88 | raise exc 89 | 90 | def _video_decoder(self, key, data): 91 | """Extract the video data from default video extensions.""" 92 | # TODO: This function could be more efficient. It will write the data to `/tmp`, 93 | # then load it using `torchvision.io.video.read_video` which uses `av.open` from pyav. 94 | # pyav allows providing a file-like object, but torchvision does not expose that interface. 95 | # (https://github.com/pytorch/vision/issues/8438) 96 | video = webdataset.torch_video(key, data) 97 | if video is not None: 98 | return VideoData( 99 | frames=video[0].permute((0, 3, 1, 2)), 100 | aframes=video[1], 101 | info=video[2], 102 | ) 103 | return None 104 | 105 | def load_sample(self, sample: FilteredSample) -> T_sample: 106 | if self._decoder is not None: 107 | sample = self._decoder(sample) 108 | return super().load_sample(sample) 109 | 110 | def config(self) -> Dict[str, Any]: 111 | return dict( 112 | **super().config(), 113 | image_decode=self.image_decode, 114 | ignore_decoder_errors=self.ignore_decoder_errors, 115 | ) 116 | -------------------------------------------------------------------------------- /docs/source/advanced/packing.md: -------------------------------------------------------------------------------- 1 | 3 | 4 | # Packing 5 | 6 | Packing (sometimes also called sequence packing), enables you to selectively compress multiple 7 | input samples into a single sample, for example depending on their length. 8 | 9 | This technique is commonly used with large language models, if the input samples have very different 10 | lengths leading to lots of padding and hence wasted compute. 11 | 12 | This section explains how you can pack samples together and utilize the full context length. 13 | 14 | ## How to pack samples on the fly 15 | 16 | To use packing, you need to implement the TaskEncoder methods {py:meth}`select_samples_to_pack ` 17 | and {py:meth}`pack_selected_samples `. 18 | Furthermore, you need to initialize the loader with the `packing_buffer_size` argument set to a non-zero number. 19 | 20 | The `select_samples_to_pack` method will receive a list of samples (size according to the selected `packing_buffer_size`), 21 | and should partition those samples into groups that shall be packed together. Hence the function returns 22 | a list of lists of samples. 23 | 24 | For each group, the second method `pack_selected_samples` will be called. You need to implement how a group of 25 | samples will be mapped to a single sample. In terms of LLMs for example, this method might concatenate the input tokens. 26 | 27 | 28 | ```{admonition} Note 29 | :class: important 30 | You can set the `__restore_key__` of the packed sample to an empty tuple, since energon will set the correct 31 | restore key afterwards, based on the samples that went in. 32 | ``` 33 | 34 | ```{warning} 35 | To handle attention masks and tokenized inputs, you will want to operate on a different sample type. 36 | The `pack_selected_samples` method may return a different sample type that is expected as the input for the `batch` method. 37 | ``` 38 | 39 | It is important, to mark custom functions like `encode_sample` and `pack_selected_samples` as `@stateless` to allow saving 40 | samples for packing. If augmentations happen, it should be marked with 41 | `@stateless(restore_seeds=True)`, to deterministically set the seeds based on the `TaskEncoder.current_sample_index`. 42 | You have to make sure the methods are actually stateless, meaning that they will produce the same output when invoked 43 | with the same input and random states. 44 | 45 | Example packing for a large language model extending the example from the [](../basic/task_encoder) section: 46 | 47 | ```python 48 | class PackingCaptioningTaskEncoder(CaptioningTaskEncoder): 49 | """This class extends the CaptioningTaskEncoder and adds select_samples_to_pack and pack_selected_samples for packing samples 50 | efficiently on-the-fly. 51 | Set the `packing_buffer_size` of the get_(train|val)_dataset to an accordingly large number to get a 52 | properly sized input sample buffer with good diversity. 53 | """ 54 | 55 | @stateless(restore_seeds=True) 56 | def encode_sample(self, ...): 57 | # Added `stateless` decorator to allow saving samples for packing. Will set the seed 58 | # deterministically based on the self.current_sample_index. 59 | ... 60 | 61 | def select_samples_to_pack(self, samples: List[CaptioningSample]) -> List[List[CaptioningSample]]: 62 | # Do something intelligent here, e.g. sort by caption length and concat where possible. 63 | # This could be better, but it's just an example. 64 | samples.sort(key=lambda x: len(x.caption)) 65 | groups = [] 66 | while len(samples) > 0: 67 | batch = [] 68 | caption_len = 0 69 | while len(samples) > 0 and caption_len + len(samples[0].caption) < self.max_length: 70 | sample = samples.pop(0) 71 | batch.append(sample) 72 | caption_len += len(sample.caption) 73 | groups.append(batch) 74 | return groups 75 | 76 | @stateless 77 | def pack_selected_samples(self, samples: List[CaptioningSample]) -> CaptioningSample: 78 | # Construct a new CaptioningSample by concatenating the captions 79 | ... 80 | 81 | ``` 82 | -------------------------------------------------------------------------------- /src/megatron/energon/metadataset/dataset_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from dataclasses import dataclass 5 | from typing import Any, Dict, List, Literal, Optional, Tuple, Union 6 | 7 | from megatron.energon.dataset_config import get_dataset_from_config 8 | from megatron.energon.epathlib import EPath 9 | from megatron.energon.flavors import BaseCoreDatasetFactory 10 | from megatron.energon.metadataset.loader_interface import DatasetBlendMode, DatasetLoaderInterface 11 | from megatron.energon.worker import WorkerConfig 12 | 13 | 14 | @dataclass 15 | class DatasetLoader(DatasetLoaderInterface): 16 | """Loads a dataset from a path.""" 17 | 18 | path: Union[str, EPath] 19 | split_part: Optional[str] = None 20 | subflavor: Optional[str] = None 21 | subflavors: Optional[Dict[str, Any]] = None 22 | shuffle_over_epochs_multiplier: int = 1 23 | dataset_config: str = "dataset.yaml" 24 | split_config: str = "split.yaml" 25 | 26 | def get_dataset( 27 | self, 28 | *, 29 | training: bool, 30 | split_part: Optional[str] = None, 31 | worker_config: WorkerConfig, 32 | subflavor: Optional[str] = None, 33 | subflavors: Optional[Dict[str, Any]] = None, 34 | shuffle_over_epochs: int = 1, 35 | split_config: Optional[str] = None, 36 | dataset_config: Optional[str] = None, 37 | **kwargs, 38 | ) -> BaseCoreDatasetFactory: 39 | """ 40 | Args: 41 | training: If true, apply training randomization. 42 | split_part: Default split part to use. 43 | worker_config: Worker configuration. 44 | shuffle_buffer_size: Size of the sample shuffle buffer (before task encoding). 45 | subflavor: Subflavor to use, might be overridden by inner datasets. 46 | subflavors: Subflavors to use, might be overridden by inner datasets. 47 | shuffle_over_epochs: Shuffle the dataset over this many epochs. 48 | **kwargs: Additional arguments to the dataset constructor. 49 | 50 | Returns: 51 | The loaded dataset 52 | """ 53 | if self.split_part is not None: 54 | split_part = self.split_part 55 | if split_part is None: 56 | raise ValueError("Missing split part") 57 | if subflavor is None: 58 | subflavor = self.subflavor 59 | if self.subflavors is not None: 60 | subflavors = {**self.subflavors, **(subflavors or {})} 61 | if split_config is None: 62 | split_config = self.split_config 63 | if dataset_config is None: 64 | dataset_config = self.dataset_config 65 | return get_dataset_from_config( 66 | self.path, 67 | training=training, 68 | split_part=split_part, 69 | worker_config=worker_config, 70 | subflavor=subflavor, 71 | subflavors=subflavors, 72 | dataset_config=dataset_config, 73 | split_config=split_config, 74 | shuffle_over_epochs=shuffle_over_epochs, 75 | **kwargs, 76 | ) 77 | 78 | def get_datasets( 79 | self, 80 | *, 81 | training: bool, 82 | split_part: Union[Literal["train", "val", "test"], str], 83 | worker_config: WorkerConfig, 84 | subflavor: Optional[str] = None, 85 | subflavors: Optional[Dict[str, Any]] = None, 86 | shuffle_over_epochs_multiplier: int = 1, 87 | **kwargs, 88 | ) -> Tuple[DatasetBlendMode, List[Tuple[BaseCoreDatasetFactory, Union[float, int, None]]]]: 89 | return DatasetBlendMode.NONE, [ 90 | ( 91 | self.get_dataset( 92 | training=training, 93 | split_part=split_part, 94 | worker_config=worker_config, 95 | subflavor=subflavor, 96 | subflavors=subflavors, 97 | shuffle_over_epochs=shuffle_over_epochs_multiplier, 98 | **kwargs, 99 | ), 100 | None, 101 | ) 102 | ] 103 | -------------------------------------------------------------------------------- /docs/source/advanced/joining_datasets.md: -------------------------------------------------------------------------------- 1 | 3 | 4 | # Joining Datasets 5 | 6 | Joining is needed when you have two or more corresponding datasets, each dataset having some parts to compose each sample. 7 | 8 | ```{image} ../images/joining.png 9 | :alt: Joining Datasets 10 | :class: bg-primary 11 | :width: 550px 12 | :align: center 13 | ``` 14 | 15 | Dataset joining is analogous to webdataset's "add column", i.e., merging two datasets which require the same files within the tar files. 16 | 17 | ## Example File Structure 18 | 19 | Here is an example structure of the datasets within the tar files: 20 | 21 | ``` 22 | ds1 23 | ├── .nv-meta 24 | │ ├── .info.yaml 25 | │ ├── split.yaml 26 | │ └── dataset.yaml 27 | ├── shard1.tar 28 | │ ├── 0001.jpg 29 | │ ├── 0002.jpg 30 | │ └── 0003.jpg 31 | ├── shard1.idx 32 | 33 | ds2 34 | ├── .nv-meta 35 | │ ├── .info.yaml 36 | │ ├── split.yaml 37 | │ └── dataset.yaml 38 | ├── shard1.tar 39 | │ ├── 0001.txt 40 | │ ├── 0002.txt 41 | │ └── 0003.txt 42 | ├── shard1.idx 43 | ``` 44 | 45 | In this example, `ds1/shard1.tar` and `ds2/shard1.tar` contain files with the same base names / keys. 46 | When joining, the files from both datasets are joined based on their keys, which must be in the same order. Each dataset must be [prepared](../basic/data_prep) (i.e. .nv-meta created). 47 | 48 | 49 | ## Minimal Example 50 | 51 | Below is a minimal meta dataset example for joining. 52 | 53 | `metadataset_minimal.yaml`: 54 | ```yaml 55 | __module__: megatron.energon 56 | __class__: MetadatasetV2 57 | splits: 58 | train: 59 | join: 60 | - path: ds1 61 | - path: ds2 62 | joiner: 63 | __module__: megatron.energon 64 | __class__: CaptioningSample 65 | ``` 66 | Note that the blending is omitted here, as it is optional. The joining is of course also optional. 67 | 68 | ## Extensive Example 69 | 70 | Here is a more extensive example that shows multiple things at once: 71 | 72 | * Joining can be used inside blending 73 | * The datasets to be joined can have custom subflavors or dataset yamls specified 74 | * A custom "joiner" can be specified to define how samples are joined and what the resulting type is 75 | 76 | `metadataset_extended.yaml`: 77 | ```yaml 78 | __module__: megatron.energon 79 | __class__: MetadatasetV2 80 | splits: 81 | train: 82 | blend: 83 | - weight: 1 84 | join: 85 | - path: ds1 86 | dataset_config: dataset.yaml # If override is needed 87 | - path: ds2 88 | dataset_config: dataset.yaml 89 | subflavors: # If needed, will be merged(overriding) with parent subflavor 90 | ds2_extra: 2 91 | split_config: split.yaml 92 | joiner: 93 | __module__: my_module 94 | __class__: JoinedSample # Type should implement from_joined(ds1, ds2) 95 | split_config: split.yaml # Sets this for all joined datasets 96 | split_part: train # Sets this for all joined datasets 97 | subflavor: ds1 # Sets this for all joined datasets 98 | subflavors: # Sets this for all joined datasets (it will be merged with their individual subflavors) 99 | source: metadataset.yaml 100 | ``` 101 | 102 | ## Custom Join Type 103 | 104 | To define a custom join type, you can create a Python class as shown below in `my_module.py`: 105 | 106 | ```python 107 | from dataclasses import dataclass 108 | import torch 109 | from megatron.energon import Sample, TextSample 110 | 111 | @dataclass 112 | class JoinedSample(Sample): 113 | text1: torch.Tensor 114 | text2: torch.Tensor 115 | 116 | @staticmethod 117 | def from_joined(ds1: TextSample, ds2: TextSample) -> "JoinedSample": 118 | return JoinedSample.derive_from( 119 | ds1, 120 | text1=ds1.text, 121 | text2=ds2.text, 122 | ) 123 | ``` 124 | 125 | This class should implement the `from_joined` method to combine samples from `ds1` and `ds2`. 126 | Note: It is important to use `derive_from` with the first argument being the first sample, as this will guarantee that the state can be saved and restored. It ensures that all the internal keys of the sample are retained. 127 | 128 | 129 | -------------------------------------------------------------------------------- /src/megatron/energon/wrappers/concat_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from dataclasses import dataclass 5 | from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union 6 | 7 | from megatron.energon.flavors.base_dataset import ( 8 | MergedState, 9 | SavableDataset, 10 | State, 11 | add_sample_restore_key, 12 | ) 13 | from megatron.energon.worker import WorkerConfig 14 | from megatron.energon.wrappers.base import BaseWrapperDataset 15 | 16 | T_sample = TypeVar("T_sample") 17 | 18 | 19 | @dataclass 20 | class ConcatState(State): 21 | #: State of the inner datasets 22 | dataset_states: List[State] 23 | 24 | 25 | @dataclass 26 | class ConcatMergedState(MergedState): 27 | #: State of the inner datasets 28 | dataset_states: List[MergedState] 29 | 30 | 31 | class ConcatDataset(BaseWrapperDataset[T_sample], Generic[T_sample]): 32 | """ 33 | This dataset wrapper concatenates multiple iterable datasets together. The datasets must be 34 | finite, otherwise not all datasets can be sampled. This is only useful for validation / test 35 | datasets. 36 | """ 37 | 38 | datasets: Tuple[SavableDataset[T_sample], ...] 39 | 40 | def __init__( 41 | self, 42 | *datasets: SavableDataset[T_sample], 43 | worker_config: WorkerConfig, 44 | ): 45 | """Construct a concatenated dataset.""" 46 | super().__init__(datasets, worker_config=worker_config) 47 | self.datasets = datasets 48 | assert len(self) >= 0, "Datasets must be finite." 49 | 50 | def __len__(self): 51 | return sum(len(dataset) for dataset in self.datasets) 52 | 53 | def __iter__(self) -> Iterator[T_sample]: 54 | for ds_idx, dataset in enumerate(self.datasets): 55 | for sample in dataset: 56 | yield add_sample_restore_key( 57 | sample, 58 | ds_idx, 59 | src=self, 60 | ) 61 | 62 | def worker_has_samples(self) -> bool: 63 | return any(dataset.worker_has_samples() for dataset in self.datasets) 64 | 65 | def save_state(self) -> ConcatState: 66 | return ConcatState( 67 | dataset_states=[dataset.save_state() for dataset in self.datasets], 68 | ) 69 | 70 | def merge_states(self, states: List[ConcatState]) -> ConcatMergedState: 71 | assert all(s is None or isinstance(s, ConcatState) for s in states) 72 | assert all(s is None or len(s.dataset_states) == len(self.datasets) for s in states) 73 | return ConcatMergedState( 74 | dataset_states=[ 75 | dataset.merge_states( 76 | [None if s is None else s.dataset_states[ds_idx] for s in states] 77 | ) 78 | for ds_idx, dataset in enumerate(self.datasets) 79 | ], 80 | ) 81 | 82 | def restore_state(self, state: Optional[ConcatMergedState]) -> None: 83 | if state is None: 84 | for dataset in self.datasets: 85 | dataset.restore_state(None) 86 | else: 87 | assert isinstance(state, ConcatMergedState) 88 | assert len(self.datasets) == len(state.dataset_states) 89 | for dataset, dstate in zip(self.datasets, state.dataset_states): 90 | dataset.restore_state(dstate) 91 | 92 | def can_restore_sample(self) -> bool: 93 | return all(dataset.can_restore_sample() for dataset in self.datasets) 94 | 95 | def assert_can_restore(self) -> None: 96 | for dataset in self.datasets: 97 | dataset.assert_can_restore() 98 | 99 | def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_sample: 100 | id, ds_idx = index[:2] 101 | assert id == type(self).__name__ 102 | index = index[2:] 103 | assert isinstance(ds_idx, int) 104 | return add_sample_restore_key( 105 | self.datasets[ds_idx].restore_sample(index), 106 | ds_idx, 107 | src=self, 108 | ) 109 | 110 | def config(self) -> Dict[str, Any]: 111 | return { 112 | "type": type(self).__qualname__, 113 | "datasets": [dataset.config() for dataset in self.datasets], 114 | } 115 | 116 | def __str__(self): 117 | return f"ConcatDataset(datasets={self.datasets})" 118 | -------------------------------------------------------------------------------- /src/megatron/energon/tools/lint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import logging 5 | import traceback 6 | from typing import Optional 7 | 8 | import click 9 | import tqdm 10 | import yaml 11 | 12 | from megatron.energon import WorkerConfig, get_loader, get_val_dataset 13 | from megatron.energon.epathlib import EPath 14 | from megatron.energon.flavors import BaseWebdatasetFactory 15 | from megatron.energon.flavors.webdataset.itar import TarIndexReader 16 | 17 | 18 | def check_index(dataset_path: EPath): 19 | if not dataset_path.is_dir(): 20 | click.echo(f"{dataset_path} is not a directory, therefore the index will not be checked") 21 | return 22 | 23 | ok = True 24 | 25 | # Get info file 26 | info_file = dataset_path / ".nv-meta/.info.yaml" 27 | info = yaml.safe_load(info_file.read_text()) 28 | 29 | click.echo("Checking the index files...") 30 | shards = info["shard_counts"] 31 | for shard_file, length in shards.items(): 32 | with TarIndexReader(dataset_path / shard_file) as itar: 33 | l = len(itar) 34 | if l - 1 != length: 35 | ok = False 36 | print( 37 | f"Error in shard {shard_file}: Shard length in Info file {length} != {l - 1} (length in index)" 38 | ) 39 | 40 | return ok 41 | 42 | 43 | @click.command(name="lint") 44 | @click.argument( 45 | "path", 46 | type=click.Path(path_type=EPath), 47 | ) 48 | @click.option( 49 | "--split-parts", default="train,val,test", help="The splits to verify", show_default=True 50 | ) 51 | @click.option( 52 | "--dataset-config", default="dataset.yaml", help="Dataset config file name", show_default=True 53 | ) 54 | @click.option( 55 | "--split-config", default="split.yaml", help="Split config file name", show_default=True 56 | ) 57 | @click.option( 58 | "--parallel", default=1, help="Number of parallel workers", show_default=True, type=int 59 | ) 60 | def command(path: EPath, split_parts: str, dataset_config: str, split_config: str, parallel: int): 61 | """Check energon dataset for errors. 62 | 63 | The PATH should point to the folder with the dataset. 64 | The dataset must comply with the energon dataset format. See README.md for more details.""" 65 | 66 | path = path.absolute() 67 | 68 | # Check the tar file index 69 | if not check_index(path): 70 | raise click.ClickException("Validation failed with errors, see logs for details.") 71 | 72 | # Check the dataset 73 | failed = False 74 | 75 | ignore_list = [] 76 | 77 | def handler(exc: Exception, key: Optional[str] = None) -> None: 78 | nonlocal failed 79 | failed = True 80 | logging.exception(str(exc)) 81 | if key is not None: 82 | ignore_list.append(key) 83 | 84 | kwargs = {} 85 | if dataset_config != "dataset.yaml": 86 | kwargs["dataset_config"] = dataset_config 87 | if split_config != "split.yaml": 88 | kwargs["split_config"] = split_config 89 | 90 | worker_config = WorkerConfig(rank=0, world_size=1, num_workers=parallel) 91 | 92 | for split_part in split_parts.split(","): 93 | try: 94 | dataset = get_val_dataset( 95 | EPath(path), 96 | split_part=split_part, 97 | worker_config=worker_config, 98 | batch_size=1, 99 | handler=handler, 100 | **kwargs, 101 | ) 102 | except BaseWebdatasetFactory.EmptyDatasetError: 103 | click.echo(f"Skipping empty split part {split_part}") 104 | continue 105 | 106 | try: 107 | for _ in tqdm.tqdm(get_loader(dataset)): 108 | pass 109 | except InterruptedError: 110 | raise 111 | except BaseException: 112 | traceback.print_exc() 113 | raise click.ClickException("Validation failed with errors, see logs for details.") 114 | 115 | if failed: 116 | click.echo( 117 | "The following shards/samples failed (maybe set as dataset.yaml:ignore_list):", err=True 118 | ) 119 | for item in ignore_list: 120 | click.echo(f"- {item}", err=True) 121 | raise click.ClickException("Validation failed with errors, see logs for details.") 122 | 123 | 124 | if __name__ == "__main__": 125 | command() 126 | -------------------------------------------------------------------------------- /src/megatron/energon/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from megatron.energon.flavors import ( 5 | BaseCoreDatasetFactory, 6 | BaseWebdatasetFactory, 7 | CaptioningSample, 8 | CaptioningWebdataset, 9 | CrudeSample, 10 | CrudeWebdataset, 11 | DefaultDecoderWebdatasetFactory, 12 | DefaultGenericWebdatasetFactory, 13 | ImageClassificationSample, 14 | ImageClassificationWebdataset, 15 | ImageSample, 16 | ImageWebdataset, 17 | InterleavedSample, 18 | InterleavedWebdataset, 19 | JoinedWebdatasetFactory, 20 | MultiChoiceVQASample, 21 | MultiChoiceVQAWebdataset, 22 | OCRSample, 23 | OCRWebdataset, 24 | Sample, 25 | SavableDataset, 26 | SimilarityInterleavedSample, 27 | SimilarityInterleavedWebdataset, 28 | StandardWebdatasetFactory, 29 | TextSample, 30 | TextWebdataset, 31 | VidQASample, 32 | VidQAWebdataset, 33 | VQAOCRWebdataset, 34 | VQASample, 35 | VQAWebdataset, 36 | ) 37 | from megatron.energon.loader import get_loader, get_savable_loader 38 | from megatron.energon.metadataset import ( 39 | DatasetLoader, 40 | DatasetLoaderInterface, 41 | Metadataset, 42 | MetadatasetV2, 43 | load_dataset, 44 | ) 45 | from megatron.energon.savable_loader import SavableDataLoader 46 | from megatron.energon.task_encoder import ( 47 | AugmentTaskEncoder, 48 | Batch, 49 | DefaultTaskEncoder, 50 | TaskEncoder, 51 | batch_list, 52 | batch_pad_stack, 53 | batch_stack, 54 | generic_batch, 55 | get_train_dataset, 56 | get_val_dataset, 57 | get_val_datasets, 58 | stateless, 59 | ) 60 | from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys 61 | from megatron.energon.worker import WorkerConfig 62 | from megatron.energon.wrappers import ( 63 | BatchDataset, 64 | BlendDataset, 65 | ConcatDataset, 66 | EpochizeDataset, 67 | FilterDataset, 68 | GcDataset, 69 | GroupBatchDataset, 70 | IterMapDataset, 71 | LimitDataset, 72 | LogSampleDataset, 73 | MapDataset, 74 | MixBatchDataset, 75 | PackingDataset, 76 | RepeatDataset, 77 | ShuffleBufferDataset, 78 | SkipSample, 79 | concat_pad, 80 | generic_concat, 81 | homogeneous_concat_mix, 82 | ) 83 | 84 | __all__ = [ 85 | "AugmentTaskEncoder", 86 | "BaseCoreDatasetFactory", 87 | "BaseWebdatasetFactory", 88 | "Batch", 89 | "BatchDataset", 90 | "BlendDataset", 91 | "CaptioningSample", 92 | "CaptioningWebdataset", 93 | "ConcatDataset", 94 | "Cooker", 95 | "CrudeSample", 96 | "CrudeWebdataset", 97 | "DatasetLoader", 98 | "DatasetLoaderInterface", 99 | "DefaultDecoderWebdatasetFactory", 100 | "DefaultGenericWebdatasetFactory", 101 | "DefaultTaskEncoder", 102 | "EpochizeDataset", 103 | "FilterDataset", 104 | "GcDataset", 105 | "GroupBatchDataset", 106 | "ImageClassificationSample", 107 | "ImageClassificationWebdataset", 108 | "ImageSample", 109 | "ImageWebdataset", 110 | "InterleavedSample", 111 | "InterleavedWebdataset", 112 | "IterMapDataset", 113 | "LimitDataset", 114 | "LogSampleDataset", 115 | "MapDataset", 116 | "JoinedWebdatasetFactory", 117 | "Metadataset", 118 | "MetadatasetV2", 119 | "MixBatchDataset", 120 | "MultiChoiceVQASample", 121 | "MultiChoiceVQAWebdataset", 122 | "OCRSample", 123 | "OCRWebdataset", 124 | "RepeatDataset", 125 | "Sample", 126 | "SavableDataLoader", 127 | "SavableDataset", 128 | "SimilarityInterleavedSample", 129 | "SimilarityInterleavedWebdataset", 130 | "ShuffleBufferDataset", 131 | "SkipSample", 132 | "StandardWebdatasetFactory", 133 | "PackingDataset", 134 | "TaskEncoder", 135 | "TextSample", 136 | "TextWebdataset", 137 | "VidQASample", 138 | "VidQAWebdataset", 139 | "VQASample", 140 | "VQAWebdataset", 141 | "VQAOCRWebdataset", 142 | "WorkerConfig", 143 | "basic_sample_keys", 144 | "batch_list", 145 | "batch_pad_stack", 146 | "batch_stack", 147 | "concat_pad", 148 | "generic_batch", 149 | "generic_concat", 150 | "get_loader", 151 | "get_savable_loader", 152 | "get_train_dataset", 153 | "get_val_dataset", 154 | "get_val_datasets", 155 | "homogeneous_concat_mix", 156 | "load_dataset", 157 | "stateless", 158 | ] 159 | -------------------------------------------------------------------------------- /scripts/license_headers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from typing import Optional, Tuple 6 | 7 | import click 8 | 9 | 10 | @dataclass 11 | class HeaderUpdater: 12 | file_ext: str 13 | line_comment: Optional[str] = None 14 | comment_start: Optional[str] = None 15 | comment_end: Optional[str] = None 16 | 17 | UPDATE_IDENTIFIER = "Copyright" 18 | 19 | HEADER_LINES: Tuple[str, ...] = ( 20 | "Copyright (c) 2025, NVIDIA CORPORATION.", 21 | "SPDX-License-Identifier: BSD-3-Clause", 22 | ) 23 | 24 | _expected_lines: Tuple[str, ...] = () 25 | 26 | def __post_init__(self): 27 | if self.line_comment is not None: 28 | self._expected_lines = tuple(self.line_comment + line for line in self.HEADER_LINES) 29 | else: 30 | assert self.comment_start is not None and self.comment_end is not None 31 | if len(self.HEADER_LINES) >= 2: 32 | self._expected_lines = ( 33 | self.comment_start + self.HEADER_LINES[0], 34 | *self.HEADER_LINES[1:-1], 35 | self.HEADER_LINES[-1] + self.comment_end, 36 | ) 37 | else: 38 | assert len(self.HEADER_LINES) == 1 39 | self._expected_lines = ( 40 | self.comment_start + self.HEADER_LINES[0] + self.comment_end, 41 | ) 42 | 43 | def has_header(self, file: Path) -> bool: 44 | with file.open() as rf: 45 | num_lines = 0 46 | for line, expected in zip(rf, self._expected_lines): 47 | num_lines += 1 48 | if line.rstrip("\n") != expected: 49 | return False 50 | return num_lines == len(self._expected_lines) 51 | 52 | def fix_header(self, file: Path): 53 | contents = file.read_text() 54 | first_comment = self.line_comment if self.line_comment is not None else self.comment_start 55 | if contents.startswith(first_comment) and contents[len(first_comment) :].startswith( 56 | self.UPDATE_IDENTIFIER 57 | ): 58 | # Already has header, but want to update 59 | *header_lines, remainder = contents.split("\n", len(self._expected_lines)) 60 | new_contents = "\n".join(self._expected_lines) + "\n" + remainder 61 | else: 62 | # No header, add it 63 | new_contents = "\n".join(self._expected_lines) + "\n" + contents 64 | file.write_text(new_contents) 65 | 66 | 67 | headers = ( 68 | HeaderUpdater( 69 | file_ext=".py", 70 | line_comment="# ", 71 | ), 72 | HeaderUpdater( 73 | file_ext=".sh", 74 | line_comment="# ", 75 | ), 76 | # Do not add to yamls 77 | # HeaderUpdater( 78 | # file_ext=".yml", 79 | # line_comment="# ", 80 | # ), 81 | # HeaderUpdater( 82 | # file_ext=".yaml", 83 | # line_comment="# ", 84 | # ), 85 | HeaderUpdater( 86 | file_ext=".md", 87 | comment_start="", 89 | ), 90 | ) 91 | 92 | 93 | @click.command() 94 | @click.argument( 95 | "path", 96 | type=click.Path(exists=True, file_okay=False, dir_okay=True, readable=True, path_type=Path), 97 | ) 98 | @click.option( 99 | "--fix", is_flag=True, help="Automatically add missing/adapt existing license headers." 100 | ) 101 | def main(path: Path, fix: bool) -> None: 102 | # Get all files to check 103 | success = True 104 | for header_update in headers: 105 | for filepath in path.rglob(f"*{header_update.file_ext}"): 106 | if not header_update.has_header(filepath): 107 | success = False 108 | print(f"Missing header in {filepath}") 109 | if fix: 110 | header_update.fix_header(filepath) 111 | print(f"Added header to {filepath}") 112 | 113 | if success: 114 | print("All files have the correct license header.") 115 | elif fix: 116 | print("Some license headers were fixed.") 117 | exit(1) 118 | else: 119 | print( 120 | "Some license headers are missing or incorrect. Run `python scripts/license_headers.py . --fix` to fix." 121 | ) 122 | exit(1) 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /src/megatron/energon/wrappers/filter_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from dataclasses import dataclass 5 | from typing import Any, Callable, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union 6 | 7 | from megatron.energon.flavors.base_dataset import SavableDataset 8 | from megatron.energon.worker import WorkerConfig 9 | from megatron.energon.wrappers.base import ( 10 | BaseSingleWrapperDataset, 11 | BaseSingleWrapperMergedState, 12 | BaseSingleWrapperState, 13 | SampleIndex, 14 | ) 15 | 16 | T_sample = TypeVar("T_sample") 17 | 18 | 19 | @dataclass 20 | class FilterState(BaseSingleWrapperState): 21 | sample_index: int 22 | 23 | 24 | @dataclass 25 | class FilterMergedState(BaseSingleWrapperMergedState): 26 | sample_indexes: List[int] 27 | 28 | 29 | class FilterDataset(BaseSingleWrapperDataset[T_sample, T_sample], Generic[T_sample]): 30 | """This dataset wrapper applies a custom filter function to each sample and does not yield 31 | filtered samples.""" 32 | 33 | filter_fn: Callable[[T_sample], bool] 34 | filter_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] 35 | _sample_index: SampleIndex 36 | 37 | def __init__( 38 | self, 39 | dataset: SavableDataset[T_sample], 40 | *, 41 | filter_fn: Callable[[T_sample], bool], 42 | filter_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None, 43 | worker_config: WorkerConfig, 44 | ): 45 | """Construct a MapDataset. 46 | 47 | Args: 48 | dataset: The input dataset to wrap 49 | filter_fn: The function to apply to each sample. If it returns `True`, the sample is 50 | accepted. 51 | filter_fn_config: Configuration for the filter function. If callable, it should return the 52 | configuration. Defaults to None. 53 | worker_config: Configuration for the workers. 54 | """ 55 | super().__init__(dataset, worker_config=worker_config) 56 | self.filter_fn = filter_fn 57 | self.filter_fn_config = filter_fn_config 58 | self._sample_index = SampleIndex(worker_config, src=self) 59 | 60 | def __len__(self): 61 | return len(self.dataset) 62 | 63 | def __iter__(self) -> Iterator[T_sample]: 64 | for sample in self.dataset: 65 | with self._sample_index.ctx(): 66 | filter_res = self.filter_fn(sample) 67 | if filter_res: 68 | yield sample 69 | 70 | def save_state(self) -> FilterState: 71 | return FilterState.extend( 72 | super().save_state(), 73 | sample_index=self._sample_index.save_state(), 74 | ) 75 | 76 | def merge_states(self, states: List[FilterState]) -> FilterMergedState: 77 | assert all(s is None or isinstance(s, FilterState) for s in states) 78 | return FilterMergedState.extend( 79 | super().merge_states(states), 80 | sample_indexes=self._sample_index.merge_states( 81 | [0 if state is None else state.sample_index for state in states] 82 | ), 83 | ) 84 | 85 | def restore_state(self, state: Optional[FilterMergedState]) -> None: 86 | super().restore_state(state) 87 | if state is None: 88 | self._sample_index.restore_state(None) 89 | else: 90 | assert isinstance(state, FilterMergedState) 91 | self._sample_index.restore_state(state.sample_indexes) 92 | 93 | def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_sample: 94 | return self.dataset.restore_sample(index) 95 | 96 | def config(self) -> Dict[str, Any]: 97 | return { 98 | "type": type(self).__qualname__, 99 | "dataset": self.dataset.config(), 100 | "filter_fn": self._function_config(self.filter_fn), 101 | **( 102 | { 103 | "filter_fn_config": ( 104 | self.filter_fn_config() 105 | if callable(self.filter_fn_config) 106 | else self.filter_fn_config 107 | ) 108 | } 109 | if self.filter_fn_config 110 | else {} 111 | ), 112 | } 113 | 114 | def __str__(self): 115 | return f"FilterDataset(filter_fn={self.filter_fn}, dataset={self.dataset})" 116 | -------------------------------------------------------------------------------- /src/megatron/energon/wrappers/shuffle_buffer_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from dataclasses import dataclass 5 | from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union 6 | 7 | from megatron.energon.flavors.base_dataset import SavableDataset 8 | from megatron.energon.rng import WorkerRng, WorkerRngMergedState, WorkerRngState 9 | from megatron.energon.worker import WorkerConfig 10 | from megatron.energon.wrappers.base import ( 11 | BaseSingleWrapperDataset, 12 | BaseSingleWrapperMergedState, 13 | BaseSingleWrapperState, 14 | ) 15 | from megatron.energon.wrappers.buffer import ( 16 | SampleBufferMergedState, 17 | SampleBufferState, 18 | SavableSampleBuffer, 19 | ) 20 | 21 | T_sample = TypeVar("T_sample") 22 | 23 | 24 | @dataclass 25 | class ShuffleBufferState(BaseSingleWrapperState): 26 | buffer: SampleBufferState 27 | rng: WorkerRngState 28 | 29 | 30 | @dataclass 31 | class ShuffleBufferMergedState(BaseSingleWrapperMergedState): 32 | buffer: SampleBufferMergedState 33 | rng: WorkerRngMergedState 34 | 35 | 36 | class ShuffleBufferDataset(BaseSingleWrapperDataset[T_sample, T_sample], Generic[T_sample]): 37 | """Shuffle buffer for the dataset.""" 38 | 39 | size: int 40 | _worker_rng: WorkerRng 41 | 42 | _active_buffer: SavableSampleBuffer[T_sample] 43 | 44 | def __init__( 45 | self, 46 | dataset: SavableDataset[T_sample], 47 | size: int, 48 | *, 49 | worker_config: WorkerConfig, 50 | ): 51 | """Create a shuffle buffer for the dataset.""" 52 | super().__init__(dataset, worker_config=worker_config) 53 | self.size = size 54 | self._worker_rng = WorkerRng(self.worker_config) 55 | self._active_buffer = SavableSampleBuffer(dataset, worker_config=worker_config) 56 | 57 | def __len__(self) -> int: 58 | return len(self.dataset) 59 | 60 | def __iter__(self) -> Iterator[T_sample]: 61 | self._active_buffer.worker_start() 62 | it = iter(self._active_buffer.append_iter()) 63 | while True: 64 | if len(self._active_buffer) >= self.size: 65 | pop_idx = self._worker_rng.randbelow(len(self._active_buffer)) 66 | yield self._active_buffer.pop(pop_idx) 67 | else: 68 | try: 69 | next(it) 70 | except StopIteration: 71 | break 72 | while len(self._active_buffer) > 0: 73 | pop_idx = self._worker_rng.randbelow(len(self._active_buffer)) 74 | yield self._active_buffer.pop(pop_idx) 75 | 76 | def save_state(self) -> ShuffleBufferState: 77 | return ShuffleBufferState.extend( 78 | super().save_state(), 79 | rng=self._worker_rng.save_state(), 80 | buffer=self._active_buffer.save_state(), 81 | ) 82 | 83 | def merge_states(self, states: List[Optional[ShuffleBufferState]]) -> ShuffleBufferMergedState: 84 | assert all(s is None or isinstance(s, ShuffleBufferState) for s in states) 85 | return ShuffleBufferMergedState.extend( 86 | super().merge_states(states), 87 | rng=self._worker_rng.merge_states([None if s is None else s.rng for s in states]), 88 | buffer=self._active_buffer.merge_states( 89 | [None if s is None else s.buffer for s in states] 90 | ), 91 | ) 92 | 93 | def restore_state(self, state: Optional[ShuffleBufferMergedState]) -> None: 94 | super().restore_state(state) 95 | if state is None: 96 | self._active_buffer.restore_state(None) 97 | self._worker_rng.restore_state(None) 98 | else: 99 | assert isinstance(state, ShuffleBufferMergedState) 100 | self._active_buffer.restore_state(state.buffer) 101 | self._worker_rng.restore_state(state.rng) 102 | 103 | def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_sample: 104 | return self._active_buffer.restore_sample(index) 105 | 106 | def config(self) -> Dict[str, Any]: 107 | return { 108 | "type": type(self).__qualname__, 109 | "dataset": self.dataset.config(), 110 | "size": self.size, 111 | "worker_config": self.worker_config.config(), 112 | } 113 | 114 | def __str__(self): 115 | return f"ShuffleBufferDataset(size={self.size}, dataset={self.dataset})" 116 | -------------------------------------------------------------------------------- /src/megatron/energon/dataset_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from pathlib import Path 5 | from typing import Any, Dict, Optional, Type, TypeVar, Union 6 | 7 | import yaml 8 | 9 | from megatron.energon.epathlib import EPath 10 | from megatron.energon.flavors import BaseCoreDatasetFactory, StandardWebdatasetFactory 11 | from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME 12 | from megatron.energon.typed_converter import raw_to_instance 13 | from megatron.energon.worker import WorkerConfig 14 | 15 | T = TypeVar("T") 16 | 17 | 18 | def load_config( 19 | path: Union[EPath, Dict[str, Any]], 20 | *, 21 | default_type: Type[T], 22 | strict: bool = True, 23 | default_kwargs: Optional[Dict[str, Any]] = None, 24 | ) -> T: 25 | """ 26 | Loads a config from a file or directly from a dictionary. 27 | 28 | Args: 29 | path: Path to the config to load or a dictionary containing the config. 30 | default_type: If set, this is the type to use if no type is specified in the config. 31 | strict: If true, don't allow additional attributes in the config. 32 | default_kwargs: Default kwargs to use, will be overridden by the config. 33 | 34 | Returns: 35 | The instantiated type. 36 | """ 37 | if isinstance(path, dict): 38 | data = path 39 | else: 40 | # Read the config from a file 41 | path = path.absolute() 42 | with path.open() as f: 43 | data: dict = yaml.safe_load(f) 44 | 45 | if default_kwargs is not None: 46 | new_data = default_kwargs.copy() 47 | new_data.update(data) 48 | data = new_data 49 | 50 | return raw_to_instance(data, default_type, strict=strict) 51 | 52 | 53 | T_sample = TypeVar("T_sample", covariant=True) 54 | 55 | 56 | def get_dataset_from_config( 57 | path: Union[EPath, Path, str], 58 | *, 59 | dataset_config: str = "dataset.yaml", 60 | split_config: str = "split.yaml", 61 | split_part: str = "train", 62 | training: bool = True, 63 | subflavor: Optional[str] = None, 64 | subflavors: Optional[Dict[str, Any]] = None, 65 | worker_config: WorkerConfig, 66 | sample_type: Optional[Type[T_sample]] = None, 67 | **kwargs, 68 | ) -> BaseCoreDatasetFactory[T_sample]: 69 | """ 70 | Gets a dataset from a config path. 71 | 72 | Args: 73 | path: Path to the folder where the `.nv-meta` folder is contained. 74 | dataset_config: Filename of the dataset config file (`path / '.nv-meta' / config`) 75 | split_config: Filename of the split config file (`path / '.nv-meta' / split_config`) 76 | split_part: Name of the split to load. 77 | training: If true, apply training randomization and loop the dataset. 78 | subflavor: Override the __subflavor__ property of each sample. 79 | subflavors: Merge-Override the __subflavors__ property of each sample. 80 | worker_config: If set, use this worker config instead of the default one. 81 | sample_type: Type of the samples to load, only used to ensure typing. 82 | **kwargs: Additional arguments to be passed to the dataset constructor. 83 | 84 | Returns: 85 | The instantiated dataset 86 | """ 87 | path = EPath(path).absolute() 88 | if not (path / MAIN_FOLDER_NAME / ".info.yaml").is_file(): 89 | raise ValueError( 90 | f"Path {path} does not contain a {MAIN_FOLDER_NAME}/.info.yaml file. Did you forget to " 91 | f"prepare the dataset? Please check the documentation for an introduction to dataset " 92 | f"preparation." 93 | ) 94 | dataset: BaseCoreDatasetFactory[T_sample] = load_config( 95 | path / MAIN_FOLDER_NAME / dataset_config, 96 | default_kwargs=dict( 97 | path=path, 98 | split_config=split_config, 99 | split_part=split_part, 100 | training=training, 101 | subflavor=subflavor, 102 | worker_config=worker_config, 103 | **kwargs, 104 | ), 105 | default_type=StandardWebdatasetFactory, 106 | ) 107 | if dataset.subflavors is None: 108 | dataset.subflavors = subflavors 109 | elif subflavors is not None: 110 | dataset.subflavors.update(subflavors) 111 | if sample_type is not None: 112 | assert issubclass( 113 | dataset.__sample_type__, sample_type 114 | ), f"Sample of type {dataset.__sample_type__} is not a subclass of {sample_type}." 115 | return dataset 116 | -------------------------------------------------------------------------------- /src/megatron/energon/wrappers/repeat_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from dataclasses import dataclass 5 | from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar 6 | 7 | from megatron.energon.flavors.base_dataset import SavableDataset 8 | from megatron.energon.worker import WorkerConfig 9 | from megatron.energon.wrappers.base import ( 10 | BaseSingleWrapperDataset, 11 | BaseSingleWrapperMergedState, 12 | BaseSingleWrapperState, 13 | ) 14 | 15 | T_sample = TypeVar("T_sample") 16 | 17 | 18 | @dataclass 19 | class RepeatState(BaseSingleWrapperState): 20 | repetition: int 21 | 22 | 23 | @dataclass 24 | class RepeatMergedState(BaseSingleWrapperMergedState): 25 | repetition: List[int] 26 | 27 | 28 | class RepeatDataset(BaseSingleWrapperDataset[T_sample, T_sample], Generic[T_sample]): 29 | """This dataset repeats the inner dataset indefinitely or a specific number of repeats.""" 30 | 31 | repeats: Optional[int] 32 | _repetition: List[int] 33 | 34 | def __init__( 35 | self, 36 | dataset: SavableDataset[T_sample], 37 | *, 38 | repeats: Optional[int] = None, 39 | restart: bool = True, 40 | worker_config: WorkerConfig, 41 | ): 42 | """Construct a RepeatDataset. 43 | 44 | Args: 45 | dataset: The input dataset to repeat. 46 | repeats: Number of repeats, `None` for indefinitely repeating. 47 | restart: If true, restart the underlying dataset after iterating once through the 48 | repeats if repeats is set to an integer, but still stop iterating. 49 | worker_config: Configuration for the workers. 50 | """ 51 | super().__init__(dataset, worker_config=worker_config) 52 | self.repeats = repeats 53 | self.restart = restart 54 | self._repetition = [0] * max(self.worker_config.num_workers, 1) 55 | 56 | def __len__(self): 57 | if self.repeats is None: 58 | return len(self.dataset) 59 | return len(self.dataset) * self.repeats 60 | 61 | def __iter__(self) -> Iterator[T_sample]: 62 | worker_idx = self.worker_config.rank_worker_id() 63 | assert ( 64 | self.repeats is not None or self.dataset.worker_has_samples() 65 | ), "Cannot repeat empty dataset indefinitely" 66 | while self.repeats is None or self._repetition[worker_idx] < self.repeats: 67 | for sample in self.dataset: 68 | yield sample 69 | if self.worker_config.should_log(level=2): 70 | self.worker_config.worker_log( 71 | { 72 | "t": "RepeatDataset.repeat", 73 | "r": self.worker_config.rank, 74 | "w": self.worker_config.rank_worker_id(), 75 | "offset": self._repetition[worker_idx], 76 | "repeats": self.repeats, 77 | } 78 | ) 79 | self._repetition[worker_idx] += 1 80 | if self.restart: 81 | self._repetition[worker_idx] = 0 82 | else: 83 | # No more repeats 84 | self._repetition[worker_idx] = self.repeats 85 | 86 | def save_state(self) -> RepeatState: 87 | return RepeatState.extend( 88 | super().save_state(), 89 | repetition=self._repetition[self.worker_config.rank_worker_id()], 90 | ) 91 | 92 | def merge_states(self, states: List[RepeatState]) -> RepeatMergedState: 93 | assert all(s is None or isinstance(s, RepeatState) for s in states) 94 | return RepeatMergedState.extend( 95 | super().merge_states(states), 96 | repetition=[0 if state is None else state.repetition for state in states], 97 | ) 98 | 99 | def restore_state(self, state: Optional[RepeatMergedState]) -> None: 100 | super().restore_state(state) 101 | if state is None: 102 | self._repetition = [0] * max(self.worker_config.num_workers, 1) 103 | else: 104 | assert isinstance(state, RepeatMergedState) 105 | self._repetition = state.repetition 106 | 107 | def config(self) -> Dict[str, Any]: 108 | return { 109 | "type": type(self).__qualname__, 110 | "dataset": self.dataset.config(), 111 | "repeats": self.repeats, 112 | "worker_config": self.worker_config.config(), 113 | } 114 | 115 | def __str__(self): 116 | return f"RepeatDataset(repeats={self.repeats}, dataset={self.dataset})" 117 | -------------------------------------------------------------------------------- /src/megatron/energon/wrappers/log_sample_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from dataclasses import dataclass 5 | from typing import Any, Callable, Dict, Generic, Iterator, List, Literal, Optional, TypeVar 6 | 7 | from megatron.energon.flavors.base_dataset import SavableDataset 8 | from megatron.energon.worker import WorkerConfig 9 | from megatron.energon.wrappers.base import ( 10 | BaseSingleWrapperDataset, 11 | BaseSingleWrapperMergedState, 12 | BaseSingleWrapperState, 13 | ) 14 | 15 | T_sample = TypeVar("T_sample") 16 | 17 | 18 | def default_get_keys(batch: Any) -> Optional[List[str]]: 19 | """Default get_keys, which has some heuristics to find the sample keys.""" 20 | if isinstance(batch, list): 21 | batch = batch[0] 22 | if ( 23 | hasattr(batch, "__key__") 24 | and isinstance(batch.__key__, list) 25 | and all(isinstance(k, str) for k in batch.__key__) 26 | ): 27 | return batch.__key__ 28 | elif ( 29 | hasattr(batch, "__keys__") 30 | and isinstance(batch.__keys__, list) 31 | and all(isinstance(k, str) for k in batch.__keys__) 32 | ): 33 | return batch.__keys__ 34 | elif ( 35 | isinstance(batch, dict) 36 | and "__key__" in batch 37 | and all(isinstance(k, str) for k in batch["__key__"]) 38 | ): 39 | return batch["__key__"] 40 | elif ( 41 | isinstance(batch, dict) 42 | and "__keys__" in batch 43 | and all(isinstance(k, str) for k in batch["__keys__"]) 44 | ): 45 | return batch["__keys__"] 46 | elif ( 47 | isinstance(batch, dict) 48 | and "keys" in batch 49 | and all(isinstance(k, str) for k in batch["keys"]) 50 | ): 51 | return batch["keys"] 52 | return None 53 | 54 | 55 | @dataclass 56 | class LogSampleState(BaseSingleWrapperState): 57 | step: int 58 | 59 | 60 | @dataclass 61 | class LogSampleMergedState(BaseSingleWrapperMergedState): 62 | step: List[int] 63 | 64 | 65 | class LogSampleDataset(BaseSingleWrapperDataset[T_sample, T_sample], Generic[T_sample]): 66 | """This dataset logs every yielded sample to the debug logs.""" 67 | 68 | get_keys_fn: Callable[[T_sample], Optional[List[str]]] 69 | mode: Literal["train", "val"] 70 | _step: List[int] 71 | 72 | def __init__( 73 | self, 74 | dataset: SavableDataset[T_sample], 75 | mode: Literal["train", "val"], 76 | worker_config: WorkerConfig, 77 | get_keys_fn: Callable[[T_sample], Optional[List[str]]] = default_get_keys, 78 | ): 79 | """Construct the log sample dataset, which logs every yielded sample to the debug logs. 80 | 81 | Args: 82 | dataset: The input dataset to wrap 83 | """ 84 | super().__init__(dataset, worker_config=worker_config) 85 | self.get_keys_fn = get_keys_fn 86 | self.mode = mode 87 | self._step = [0] * max(self.worker_config.num_workers, 1) 88 | 89 | def __len__(self): 90 | return len(self.dataset) 91 | 92 | def _log(self, sample: T_sample) -> None: 93 | if self.worker_config.should_log(level=1): 94 | log_entry = { 95 | "t": "yield_batch", 96 | "r": self.worker_config.rank, 97 | "w": self.worker_config.global_worker_id(), 98 | "m": self.mode, 99 | "idx": self._step[self.worker_config.rank_worker_id()], 100 | } 101 | keys = self.get_keys_fn(sample) 102 | if keys is not None: 103 | log_entry["keys"] = keys 104 | 105 | self.worker_config.worker_log(log_entry) 106 | 107 | def __iter__(self) -> Iterator[T_sample]: 108 | worker_id = self.worker_config.rank_worker_id() 109 | for sample in self.dataset: 110 | self._log(sample) 111 | self._step[worker_id] += 1 112 | yield sample 113 | 114 | def save_state(self) -> LogSampleState: 115 | return LogSampleState.extend( 116 | super().save_state(), 117 | step=self._step[self.worker_config.rank_worker_id()], 118 | ) 119 | 120 | def merge_states(self, states: List[Optional[LogSampleState]]) -> LogSampleMergedState: 121 | assert all(s is None or isinstance(s, LogSampleState) for s in states) 122 | return LogSampleMergedState.extend( 123 | super().merge_states(states), 124 | step=[0 if state is None else state.step for state in states], 125 | ) 126 | 127 | def restore_state(self, state: Optional[LogSampleMergedState]) -> None: 128 | super().restore_state(state) 129 | if state is None: 130 | self._step = [0] * max(self.worker_config.num_workers, 1) 131 | else: 132 | assert isinstance(state, LogSampleMergedState) 133 | self._step = state.step 134 | 135 | def config(self) -> Dict[str, Any]: 136 | # Transparent logger, it won't change the samples 137 | return self.dataset.config() 138 | 139 | def __str__(self): 140 | return f"LogSampleDataset(mode={self.mode}, get_keys_fn={self.get_keys_fn}, dataset={self.dataset})" 141 | -------------------------------------------------------------------------------- /src/megatron/energon/wrappers/gc_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import gc 5 | from typing import Any, Dict, Generic, Iterator, TypeVar 6 | 7 | import torch 8 | import torch.utils.data 9 | import torch.utils.data.dataloader 10 | from torch.distributed._shard.sharded_tensor import ShardedTensorBase 11 | from torch.distributed.distributed_c10d import reduce_op 12 | 13 | from megatron.energon.flavors.base_dataset import SavableDataset 14 | from megatron.energon.worker import WorkerConfig 15 | from megatron.energon.wrappers.base import BaseSingleWrapperDataset 16 | 17 | T_sample = TypeVar("T_sample") 18 | 19 | _frozen_cuda_tensors = set() 20 | _frozen_cuda_tensors_initialized = False 21 | 22 | 23 | class GcFreezeError(RuntimeError): 24 | pass 25 | 26 | 27 | def gc_init_worker(worker_id: int): 28 | """This function should be called by any forked worker process that uses CUDA. 29 | It should be called as early as possible in the worker process, ideally in 30 | the worker_init_fn of the DataLoader. 31 | 32 | By keeping a reference to all CUDA tensors in the worker process, we can 33 | prevent the forked tensors from being garbage collected.""" 34 | 35 | global _frozen_cuda_tensors_initialized, _frozen_cuda_tensors 36 | 37 | num_tensors = 0 38 | for o in gc.get_objects(): 39 | try: 40 | if o is not reduce_op: 41 | if isinstance(o, torch.Tensor): 42 | if isinstance(o, ShardedTensorBase) or o.is_cuda: 43 | # Calling .is_cuda or any hasattr on ShardedTensor will raise an error 44 | # Hence, o.is_cuda is only called if o is not a ShardedTensor (in the if above) 45 | 46 | _frozen_cuda_tensors.add(o) 47 | num_tensors += 1 48 | elif isinstance(o, torch.utils.data.dataloader._MultiProcessingDataLoaderIter): 49 | o._shutdown = True 50 | except ReferenceError: 51 | # Can happen if the object is a weakref proxy, don't care 52 | pass 53 | 54 | _frozen_cuda_tensors_initialized = True 55 | 56 | 57 | class GcDataset(BaseSingleWrapperDataset[T_sample, T_sample], Generic[T_sample]): 58 | """Applies a garbage collection step. This is needed, because python garbage collection 59 | does not work well with very large objects, such as tensors. This case happens, if there are 60 | a few hundred objects created and released every epoch (some of them being (large) tensors), 61 | where a lot of them are alive at the same time, but released later. In that case, those objects 62 | may end up in gc generation 2, where they may live until a lot of objects have been created, 63 | until automatic garbage collection of gen2 is actually triggered. To avoid this memory leak, 64 | `gc.collect()` is best to be called regularly. In addition, if `gc.freeze()` is used before the 65 | loop, it will remove the objects currently alive from garbage collection checks, thus making the 66 | gc faster. 67 | """ 68 | 69 | every_n_iter: int 70 | freeze: bool 71 | 72 | def __init__( 73 | self, 74 | dataset: SavableDataset[T_sample], 75 | *, 76 | worker_config: WorkerConfig, 77 | every_n_iter: int = 1, 78 | freeze: bool = True, 79 | ): 80 | """Construct a GcDataset, which applies garbage collection every `every_n_iter` iterations. 81 | 82 | Args: 83 | dataset: The input dataset to wrap 84 | every_n_iter: How often to perform garbage collection 85 | freeze: If true, run `gc.freeze()` before the loop, and `gc.unfreeze()` after the loop. 86 | This will speed up garbage collection, but will keep all initially alive objects 87 | alive until the end of the loop (i.e. if the dataset state was restored, that state 88 | will be saved as well). 89 | """ 90 | super().__init__(dataset, worker_config=worker_config) 91 | self.every_n_iter = every_n_iter 92 | self.freeze = freeze 93 | 94 | def __len__(self): 95 | return len(self.dataset) 96 | 97 | def __iter__(self) -> Iterator[T_sample]: 98 | in_worker = torch.utils.data.get_worker_info() is not None 99 | if in_worker and not _frozen_cuda_tensors_initialized: 100 | raise GcFreezeError( 101 | "You are using GcDataset with multiple workers, but forgot to call gc_init_worker() in at least one forked worker process." 102 | ) 103 | 104 | if self.freeze: 105 | gc.collect() 106 | gc.freeze() 107 | try: 108 | iter = 0 109 | for sample in self.dataset: 110 | yield sample 111 | iter += 1 112 | if iter >= self.every_n_iter: 113 | gc.collect() 114 | iter = 0 115 | finally: 116 | if self.freeze: 117 | gc.unfreeze() 118 | 119 | def config(self) -> Dict[str, Any]: 120 | # This is transparent, no config to be saved (it does not affect the dataset) 121 | return self.dataset.config() 122 | 123 | def __str__(self): 124 | return f"GcDataset(every_n_iter={self.every_n_iter}, dataset={self.dataset}, freeze={self.freeze})" 125 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/webdataset/itar.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import contextlib 5 | import struct 6 | import tarfile 7 | from types import TracebackType 8 | from typing import BinaryIO, Optional, Type, Union 9 | 10 | from megatron.energon.epathlib import EPath 11 | from megatron.energon.retry_stream import RetryReadStream 12 | 13 | ITAR_SUFFIX = ".tar.idx" 14 | 15 | 16 | class TarIndexReader: 17 | def __init__(self, tar_path: Union[EPath, str]): 18 | tar_path = EPath(tar_path) 19 | self.itar = (tar_path.with_suffix(ITAR_SUFFIX)).open("rb") 20 | self._length = len(self) 21 | 22 | def __getitem__(self, index: int) -> int: 23 | if index >= self._length or index < 0: 24 | raise IndexError(f"Index {index} out of range") 25 | 26 | self.itar.seek(8 * index) 27 | return struct.unpack("Q", self.itar.read(8))[0] 28 | 29 | def __len__(self) -> int: 30 | self.itar.seek(0, 2) 31 | return self.itar.tell() // 8 32 | 33 | def close(self): 34 | self.itar.close() 35 | 36 | def __enter__(self): 37 | return self 38 | 39 | def __exit__(self, exc_type, exc_val, exc_tb): 40 | self.close() 41 | 42 | 43 | class TarIndexWriter: 44 | def __init__(self, tar_path: EPath): 45 | self.final_name = tar_path.with_suffix(ITAR_SUFFIX) 46 | self.tmp_name = tar_path.with_suffix(ITAR_SUFFIX + ".tmp") 47 | self.itar = self.tmp_name.open("wb") 48 | 49 | def append(self, offset: int): 50 | self.itar.write(struct.pack("Q", offset)) 51 | 52 | def close(self, finalize: bool = True): 53 | self.itar.close() 54 | if finalize: 55 | self.tmp_name.move(self.final_name) 56 | else: 57 | self.tmp_name.unlink() 58 | 59 | def __enter__(self): 60 | return self 61 | 62 | def __exit__(self, exc_type, exc_val, exc_tb): 63 | self.close(finalize=exc_val is None) 64 | 65 | 66 | class SubFileReader(BinaryIO): 67 | """A file-like object that reads a subfile (i.e. offset, size defined portion) of a larger 68 | file.""" 69 | 70 | def __init__(self, stream: BinaryIO, offset: int, size: int): 71 | self.offset = offset 72 | self._pos = 0 73 | self.size = size 74 | self.stream = stream 75 | self.stream.seek(self.offset) 76 | 77 | def read(self, n: int = -1) -> bytes: 78 | if n == -1: 79 | n = self.size - self._pos 80 | else: 81 | n = min(n, self.size - self._pos) 82 | if n == 0: 83 | return b"" 84 | read = self.stream.read(n) 85 | self._pos += len(read) 86 | return read 87 | 88 | def seek(self, offset: int, whence: int = 0) -> int: 89 | if whence == 0: 90 | self._pos = offset 91 | elif whence == 1: 92 | self._pos += offset 93 | elif whence == 2: 94 | self._pos = self.size + offset 95 | else: 96 | raise ValueError("Invalid whence value") 97 | self._pos = max(0, min(self._pos, self.size)) 98 | self.stream.seek(self.offset + self._pos) 99 | return self._pos 100 | 101 | def tell(self) -> int: 102 | return self._pos 103 | 104 | def __enter__(self) -> BinaryIO: 105 | return self 106 | 107 | def __exit__( 108 | self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType 109 | ) -> None: 110 | self.close() 111 | 112 | def close(self) -> None: 113 | self.stream.close() 114 | 115 | def isatty(self) -> bool: 116 | return False 117 | 118 | def seekable(self) -> bool: 119 | return True 120 | 121 | def writable(self) -> bool: 122 | return False 123 | 124 | 125 | def get_itar_byte_offset( 126 | path: Union[str, EPath], 127 | sample_offset: int = 0, 128 | ) -> int: 129 | """Gets the byte offset from sample offsets.""" 130 | if sample_offset == 0: 131 | return 0 132 | with TarIndexReader(path) as itar: 133 | return itar[sample_offset] 134 | 135 | 136 | @contextlib.contextmanager 137 | def open_itar(path: Union[str, EPath], byte_offset: int = 0, byte_size: Optional[int] = None): 138 | """ 139 | Open an indexed tarfile with offset and size. 140 | Args: 141 | path: Path to the tarfile to open 142 | byte_offset: Byte offset within the file 143 | byte_size: Size of the file to read 144 | 145 | Returns: 146 | The opened tarfile 147 | """ 148 | path = EPath(path) 149 | 150 | # TODO: if tar file startswith(b"\x1f\x8b\x08") -> Seekable gzip file 151 | with path.open("rb") as f: 152 | if f.read(3) == b"\x1f\x8b\x08": 153 | # Open as seekable tgz 154 | raise ValueError("Seekable tgz not supported yet") 155 | 156 | if byte_offset != 0 or byte_size is not None: 157 | if byte_size is None: 158 | byte_size = path.size() - byte_offset 159 | with RetryReadStream(path) as stream: 160 | with SubFileReader( 161 | stream, 162 | offset=byte_offset, 163 | size=byte_size, 164 | ) as fileobj: 165 | with tarfile.open(fileobj=fileobj, mode="r|") as f: 166 | yield f 167 | else: 168 | with RetryReadStream(path) as fileobj: 169 | with tarfile.open(fileobj=fileobj, mode="r|") as f: 170 | yield f 171 | -------------------------------------------------------------------------------- /docs/source/basic/save_restore.md: -------------------------------------------------------------------------------- 1 | 3 | 4 | # Save and Restore 5 | 6 | For long-running training jobs, you will usually need to stop and resume the training including the data loader. 7 | One of energon's unqiue features is the deterministic save and restore capability. 8 | 9 | At any iteration, you'll be able to store the overall state of the data loader across all ranks and accurately resume it later on, to continue where it left off. 10 | Below, we list a few different ways to achieve that. 11 | 12 | ## Scenarios 13 | 14 | All these scenarios work in distributed and non-distributed settings. 15 | 16 | For simplicity, we are assuming the same loader and data for all scenarios from the [](../basic/quickstart): 17 | ```python 18 | from megatron.energon import get_train_dataset, get_savable_loader, WorkerConfig 19 | 20 | worker_config = WorkerConfig.default_worker_config() 21 | 22 | def get_my_loader(): 23 | return get_savable_loader(get_train_dataset( 24 | 'coyo-coco-dataset.yaml', 25 | batch_size=4, 26 | shuffle_buffer_size=100, 27 | max_samples_per_sequence=100, 28 | worker_config=worker_config, 29 | )) 30 | 31 | ``` 32 | 33 | ### 1. Save/Restore the State per Rank Separately 34 | 35 | In this scenario, each rank saves and restores its own state in an independent file. 36 | This is our recommended way, since it avoids transferring the data across ranks. 37 | 38 | ```python 39 | # Saving the state 40 | loader = get_my_loader() 41 | 42 | # Iterate for some steps 43 | for i, batch in zip(range(10), loader): 44 | print(batch) 45 | break 46 | 47 | # Save the state 48 | state = loader.save_state_rank() 49 | # Save the state on each rank 50 | # In this example, save the state using `torch.save`, this can of course be custom 51 | torch.save(dataloader_state, f'dataloader_state_rank{worker_config.rank}.pth') 52 | ``` 53 | 54 | ```python 55 | # Restoring the state 56 | loader = get_my_loader() 57 | 58 | # Now, when restoring the state: 59 | state = torch.load(f'dataloader_state_rank{worker_config.rank}.pth') 60 | 61 | # Restore the state for the loader on each rank separately 62 | loader.restore_state_rank(state) 63 | ``` 64 | 65 | 66 | ### 2. Save/Restore the State on the Primary Rank Only 67 | 68 | In this scenario, the primary rank (usually rank 0) is responsible for saving the state. 69 | All ranks' states are collected (gathered) by one rank and can be stored in one file. 70 | When restoring, the state is scatterd from the primary rank to all other ranks. 71 | This approach centralizes the state management, which can simplify the process and reduces the number of files stored. 72 | 73 | ```python 74 | # Saving the state 75 | loader = get_my_loader() 76 | 77 | # Iterate for some steps 78 | for i, batch in zip(range(10), loader): 79 | print(batch) 80 | break 81 | 82 | # Save the state to primary rank 0 83 | state = loader.save_state_global(dst_rank=0) 84 | if worker_config.rank == 0: 85 | # Only rank 0 has the state now, for the others, the state is None 86 | # In this example, save the state using `torch.save`, this can of course be custom 87 | torch.save(dataloader_state, 'dataloader_state.pth') 88 | ``` 89 | 90 | ```python 91 | # Restoring the state 92 | loader = get_my_loader() 93 | 94 | # Load the state only on the primary rank 95 | if worker_config.rank == 0: 96 | state = torch.load('dataloader_state.pth') 97 | else: 98 | state = None 99 | 100 | # Restore the state for the loader, broadcasting from rank 0 101 | loader.restore_state_global(state, src_rank=0) 102 | ``` 103 | 104 | 105 | ```{admonition} Note 106 | :class: important 107 | Even though only one rank collects the states, all ranks need to execute the `loader.save_state_global()` and `loader.restore_state_global()` lines of code 108 | ``` 109 | 110 | ### 3. Save the State on the Primary Rank, Restore on Ranks Separately 111 | 112 | In this scenario, the primary rank saves the state, but each rank restores the state separately. Each rank loads all saved states and selects the correct one. This approach combines centralized saving with distributed restoring and is rather uncommon. 113 | 114 | Depending on the framework used for training, that framework may already handle the scattering/gathering of the states. In that case, refer to the first scenario using `save_state_rank`/`restore_state_rank`. 115 | 116 | ```python 117 | # Saving the state 118 | loader = get_my_loader() 119 | 120 | # Iterate for some steps 121 | for i, batch in zip(range(10), loader): 122 | print(batch) 123 | break 124 | 125 | # Save the state 126 | state = loader.save_state_global(dst_rank=0) 127 | if worker_config.rank == 0: 128 | # In this example, save the state using `torch.save`, this can of course be custom 129 | torch.save(dataloader_state, 'dataloader_state.pth') 130 | ``` 131 | 132 | ```python 133 | # Restoring the state 134 | loader = get_my_loader() 135 | 136 | # Load on all ranks 137 | state = torch.load('dataloader_state.pth') 138 | 139 | # Restore the state for the loader on current rank, using all ranks checkpoint 140 | loader.restore_state_global(state, src_rank=None) 141 | ``` 142 | 143 | ## Summary 144 | 145 | In each of these scenarios, ensure that the logic for saving and restoring the state is appropriately synchronized across ranks to maintain consistency. 146 | If you encounter torch distributed errors, likely torch distributed calls are out of sync, or not all ranks are called correctly. If unsure, debug using the first scenario, saving each rank separately. 147 | -------------------------------------------------------------------------------- /src/megatron/energon/wrappers/mix_batch_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import dataclasses 5 | from typing import Any, Callable, Dict, Generator, Generic, Iterator, List, Tuple, TypeVar, Union 6 | 7 | import torch 8 | 9 | from megatron.energon.flavors.base_dataset import SavableDataset 10 | from megatron.energon.worker import WorkerConfig 11 | from megatron.energon.wrappers.base import BaseSingleWrapperDataset 12 | from megatron.energon.wrappers.batch_dataset import BatchDataset 13 | from megatron.energon.wrappers.blend_dataset import BlendDataset 14 | 15 | T_batch_in = TypeVar("T_batch_in") 16 | T_batch = TypeVar("T_batch") 17 | 18 | 19 | def generic_concat(batch: List[Any]) -> Any: 20 | """Based on the types/shapes of the batch: Will either pad and stack, or return as list. 21 | Recurses structures (dict, dataclass, namedtuple) and applies the same logic to each field.""" 22 | if isinstance(batch[0], torch.Tensor): 23 | return concat_pad(batch) 24 | elif isinstance(batch[0], dict): 25 | return {key: generic_concat([sample[key] for sample in batch]) for key in batch[0].keys()} 26 | elif dataclasses.is_dataclass(batch[0]): 27 | return type(batch[0])( 28 | **{ 29 | field.name: generic_concat([getattr(sample, field.name) for sample in batch]) 30 | for field in dataclasses.fields(batch[0]) 31 | } 32 | ) 33 | elif isinstance(batch[0], tuple) and hasattr(batch[0], "_fields"): 34 | # NamedTuple 35 | return type(batch[0])( 36 | **{ 37 | field: generic_concat([getattr(sample, field) for sample in batch]) 38 | for field in batch[0]._fields 39 | } 40 | ) 41 | else: 42 | return batch 43 | 44 | 45 | def concat_pad(batch: List[Any]) -> Any: 46 | """Concat a batch of arbitrary-sized tensors padded with 0s.""" 47 | total_bs = sum(b.shape[0] for b in batch) 48 | max_size = [max(b.shape[dim] for b in batch) for dim in range(1, batch[0].ndim)] 49 | concat_tensor = batch[0].new_zeros((total_bs, *max_size)) 50 | b_idx = 0 51 | for b in batch: 52 | concat_tensor[(slice(b_idx, b_idx + b.shape[0]), *(slice(0, s) for s in b.shape[1:]))] = b 53 | b_idx += b.shape[0] 54 | # Pad all tensors to max_size 55 | return concat_tensor 56 | 57 | 58 | def homogeneous_concat_mix(samples: List[T_batch_in]) -> T_batch: 59 | """ 60 | Mixes a list of batches into a single batch. The default implementation is to concat the 61 | batches if they are all of the same type, otherwise return a list of batches. 62 | 63 | Args: 64 | samples: THe samples to mix. 65 | 66 | Returns: 67 | The mixed batch. 68 | """ 69 | first_type = type(samples[0]) 70 | assert all(first_type == type(sample) for sample in samples) 71 | # All the same type -> concat batches 72 | return generic_concat(samples) 73 | 74 | 75 | class MixBatchDataset(BaseSingleWrapperDataset[T_batch_in, T_batch], Generic[T_batch_in, T_batch]): 76 | """ 77 | This dataset wrapper blends multiple iterable datasets together give a weight. 78 | The datasets may be infinite. This dataset is always infinite. 79 | Effectively combines :class:`megatron.energon.BlendDataset` and :class:`megatron.energon.BatchDataset`. 80 | """ 81 | 82 | def __init__( 83 | self, 84 | *dataset_weights: Tuple[SavableDataset[T_batch_in], float], 85 | batch_size: int, 86 | batch_mix_fn: Callable[ 87 | [List[T_batch_in]], Union[T_batch, Generator[T_batch, None, None]] 88 | ] = lambda x: x, 89 | worker_config: WorkerConfig, 90 | ): 91 | """Construct a BlendDataset. 92 | 93 | Args: 94 | dataset_weights: Each argument should be a tuple of (dataset, weight) with a weight 95 | between 0 and 1. The output samples are sampled from the input datasets with the 96 | given probabilities. The datasets should have a batch size of 1, otherwise the 97 | whole batches will be sampled. 98 | batch_size: The batch size to output. 99 | batch_mix_fn: A function that takes a list of samples from the input datasets and 100 | returns a batch sample. The default implementation returns a list of batches. 101 | For homogeneous datasets, it is recommended to use the 102 | :func:`megatron.energon.homogeneous_concat_mix` which concatenates the batches. May raise 103 | :exc:`megatron.energon.SkipSample` to skip a sample. May also return a generator, which 104 | will be iterated over to produce batches. 105 | worker_config: Configuration for the workers. 106 | """ 107 | super().__init__( 108 | BatchDataset( 109 | BlendDataset(*dataset_weights, worker_config=worker_config), 110 | batch_size=batch_size, 111 | batcher=batch_mix_fn, 112 | worker_config=worker_config, 113 | ), 114 | worker_config=worker_config, 115 | ) 116 | 117 | def __len__(self) -> int: 118 | return len(self.dataset) 119 | 120 | def __iter__(self) -> Iterator[T_batch]: 121 | yield from self.dataset 122 | 123 | def config(self) -> Dict[str, Any]: 124 | return { 125 | "type": type(self).__qualname__, 126 | "dataset": self.dataset.config(), 127 | } 128 | 129 | def __str__(self): 130 | return f"MixBatchDataset(dataset={self.dataset})" 131 | -------------------------------------------------------------------------------- /src/megatron/energon/wrappers/epochize_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from dataclasses import dataclass 5 | from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar 6 | 7 | from megatron.energon.flavors.base_dataset import SavableDataset 8 | from megatron.energon.worker import WorkerConfig 9 | from megatron.energon.wrappers.base import ( 10 | BaseSingleWrapperDataset, 11 | BaseSingleWrapperMergedState, 12 | BaseSingleWrapperState, 13 | ) 14 | 15 | T_sample = TypeVar("T_sample") 16 | 17 | 18 | @dataclass 19 | class EpochizeState(BaseSingleWrapperState): 20 | offset: int 21 | 22 | 23 | @dataclass 24 | class EpochizeMergedState(BaseSingleWrapperMergedState): 25 | offset: List[int] 26 | 27 | 28 | class EpochizeDataset(BaseSingleWrapperDataset[T_sample, T_sample], Generic[T_sample]): 29 | """ 30 | Uses the base dataset, and creates one epoch, which has length samples. Keeps the underlying 31 | dataset iterator alive over epochs (i.e. if it is an infinite dataset, it will keep the state). 32 | Repeats the underlying dataset if the iterator is exhausted. 33 | """ 34 | 35 | length: int 36 | _active_iter: Optional[Iterator[T_sample]] 37 | 38 | def __init__( 39 | self, 40 | dataset: SavableDataset[T_sample], 41 | length: int, 42 | worker_config: WorkerConfig, 43 | ): 44 | """ 45 | Create the epochized dataset. 46 | 47 | Args: 48 | dataset: The source dataset (possibly infinite) 49 | length: Number of samples to iterate before iteration stops (i.e. one epoch). When 50 | iteration continues, the original dataset iterator is resumed and does only restart 51 | if exhausted. 52 | worker_config: Configuration for the workers. 53 | """ 54 | super().__init__(dataset, worker_config=worker_config) 55 | self.length = length 56 | self._offset = [0] * max(self.worker_config.num_workers, 1) 57 | self._active_iter = None 58 | 59 | def __iter__(self) -> Iterator[T_sample]: 60 | # Compute the local length for this worker, i.e. all worker's lengths sum up to the total 61 | worker_idx = self.worker_config.rank_worker_id() 62 | 63 | if self.worker_config.num_workers <= 1: 64 | local_length = self.length 65 | else: 66 | local_length = self.length // self.worker_config.num_workers 67 | if self.worker_config.rank_worker_id() < self.length % self.worker_config.num_workers: 68 | local_length += 1 69 | 70 | if self.worker_config.should_log(level=2): 71 | self.worker_config.worker_log( 72 | { 73 | "t": "EpochizeDataset.epoch_start", 74 | "r": self.worker_config.rank, 75 | "w": self.worker_config.rank_worker_id(), 76 | "offset": self._offset[worker_idx], 77 | "local_length": local_length, 78 | "length": self.length, 79 | } 80 | ) 81 | 82 | offset_range = list(range(self._offset[worker_idx], local_length)) 83 | 84 | # Only iterate if there are samples to iterate 85 | if len(offset_range) > 0: 86 | if self._active_iter is None: 87 | self._active_iter = iter(self.dataset) 88 | 89 | for idx in offset_range: 90 | self._offset[worker_idx] = (idx + 1) % local_length 91 | try: 92 | sample = next(self._active_iter) 93 | except StopIteration: 94 | break 95 | yield sample 96 | 97 | if self.worker_config.should_log(level=2): 98 | self.worker_config.worker_log( 99 | { 100 | "t": "EpochizeDataset.epoch_end", 101 | "r": self.worker_config.rank, 102 | "w": self.worker_config.rank_worker_id(), 103 | "offset": self._offset[worker_idx], 104 | "local_length": local_length, 105 | "length": self.length, 106 | } 107 | ) 108 | 109 | def __len__(self) -> int: 110 | return self.length 111 | 112 | def save_state(self) -> EpochizeState: 113 | return EpochizeState.extend( 114 | super().save_state(), offset=self._offset[self.worker_config.rank_worker_id()] 115 | ) 116 | 117 | def merge_states(self, states: List[EpochizeState]) -> EpochizeMergedState: 118 | assert all(s is None or isinstance(s, EpochizeState) for s in states) 119 | return EpochizeMergedState.extend( 120 | super().merge_states(states), 121 | offset=[0 if state is None else state.offset for state in states], 122 | ) 123 | 124 | def restore_state(self, state: Optional[EpochizeMergedState]) -> None: 125 | super().restore_state(state) 126 | if state is None: 127 | self._offset = [0] * max(self.worker_config.num_workers, 1) 128 | else: 129 | assert isinstance(state, EpochizeMergedState) 130 | self._offset = state.offset 131 | 132 | def config(self) -> Dict[str, Any]: 133 | return { 134 | "type": type(self).__qualname__, 135 | "dataset": self.dataset.config(), 136 | "length": self.length, 137 | "worker_config": self.worker_config.config(), 138 | } 139 | 140 | def __str__(self): 141 | return f"EpochizeDataset(length={self.length}, dataset={self.dataset})" 142 | -------------------------------------------------------------------------------- /src/megatron/energon/wrappers/limit_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from dataclasses import dataclass 5 | from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar 6 | 7 | from megatron.energon.flavors.base_dataset import SavableDataset 8 | from megatron.energon.worker import WorkerConfig 9 | from megatron.energon.wrappers.base import ( 10 | BaseSingleWrapperDataset, 11 | BaseSingleWrapperMergedState, 12 | BaseSingleWrapperState, 13 | ) 14 | 15 | T_sample = TypeVar("T_sample") 16 | 17 | 18 | @dataclass 19 | class LimitState(BaseSingleWrapperState): 20 | offset: int 21 | 22 | 23 | @dataclass 24 | class LimitMergedState(BaseSingleWrapperMergedState): 25 | offset: List[int] 26 | 27 | 28 | class LimitDataset(BaseSingleWrapperDataset[T_sample, T_sample], Generic[T_sample]): 29 | """Limits the length of the dataset.""" 30 | 31 | dataset: SavableDataset[T_sample] 32 | length: int 33 | 34 | _current_offset: List[int] 35 | 36 | def __init__( 37 | self, 38 | dataset: SavableDataset[T_sample], 39 | length: int, 40 | *, 41 | reset_after_epoch: bool = False, 42 | worker_config: WorkerConfig, 43 | ): 44 | """ 45 | Limits the length of the dataset. 46 | 47 | Args: 48 | dataset: The dataset to limit 49 | length: The length to limit to 50 | reset_after_epoch: If true, reset the underlying dataset after one epoch. 51 | worker_config: Configuration for the workers. 52 | """ 53 | super().__init__(dataset, worker_config=worker_config) 54 | self.length = length 55 | self.reset_after_epoch = reset_after_epoch 56 | self._current_offset = [0] * max(self.worker_config.num_workers, 1) 57 | 58 | def __len__(self) -> int: 59 | return min(self.length, len(self.dataset)) 60 | 61 | def __iter__(self) -> Iterator[T_sample]: 62 | worker_id = self.worker_config.rank_worker_id() 63 | 64 | # Compute the local limit for this worker, i.e. all worker's limits sum up to the total 65 | if self.worker_config.num_workers <= 1: 66 | local_limit = self.length 67 | else: 68 | local_limit = self.length // self.worker_config.num_workers 69 | if worker_id < self.length % self.worker_config.num_workers: 70 | local_limit += 1 71 | 72 | if self.worker_config.should_log(level=2): 73 | self.worker_config.worker_log( 74 | { 75 | "t": "LimitDataset.start", 76 | "r": self.worker_config.rank, 77 | "w": worker_id, 78 | "offset": self._current_offset[worker_id], 79 | "local_limit": local_limit, 80 | "limit": self.length, 81 | } 82 | ) 83 | 84 | offset_range = list( 85 | range(self._current_offset[self.worker_config.rank_worker_id()], local_limit) 86 | ) 87 | # Only iterate self.dataset if there are samples to iterate 88 | if len(offset_range) > 0: 89 | for sample, offset in zip( 90 | self.dataset, 91 | offset_range, 92 | ): 93 | self._current_offset[worker_id] = offset + 1 94 | yield sample 95 | 96 | if self.worker_config.should_log(level=2): 97 | self.worker_config.worker_log( 98 | { 99 | "t": "LimitDataset.done", 100 | "r": self.worker_config.rank, 101 | "w": worker_id, 102 | "offset": self._current_offset[worker_id], 103 | "local_limit": local_limit, 104 | "limit": self.length, 105 | } 106 | ) 107 | 108 | # Reset the inner dataset 109 | self.dataset.restore_state(None) 110 | self._current_offset = [0] * max(self.worker_config.num_workers, 1) 111 | if self.reset_after_epoch: 112 | self.dataset.restore_state(None) 113 | 114 | def worker_has_samples(self) -> bool: 115 | return super().worker_has_samples() and self.length > 0 116 | 117 | def save_state(self) -> LimitState: 118 | return LimitState.extend( 119 | super().save_state(), 120 | offset=self._current_offset[self.worker_config.rank_worker_id()], 121 | ) 122 | 123 | def merge_states(self, states: List[LimitState]) -> LimitMergedState: 124 | assert all(s is None or isinstance(s, LimitState) for s in states) 125 | return LimitMergedState.extend( 126 | super().merge_states(states), 127 | offset=[0 if s is None else s.offset for s in states], 128 | ) 129 | 130 | def restore_state(self, state: Optional[LimitMergedState]) -> None: 131 | super().restore_state(state) 132 | if state is None: 133 | self._current_offset = [0] * max(self.worker_config.num_workers, 1) 134 | else: 135 | assert isinstance(state, LimitMergedState) 136 | self._current_offset = state.offset 137 | 138 | def config(self) -> Dict[str, Any]: 139 | return { 140 | "type": type(self).__qualname__, 141 | "dataset": self.dataset.config(), 142 | "length": self.length, 143 | "reset_after_epoch": self.reset_after_epoch, 144 | "worker_config": self.worker_config.config(), 145 | } 146 | 147 | def __str__(self): 148 | return f"LimitDataset(length={self.length}, dataset={self.dataset})" 149 | -------------------------------------------------------------------------------- /src/megatron/energon/flavors/webdataset/default_generic_webdataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | import dataclasses 5 | from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union 6 | 7 | from megatron.energon.epathlib import EPath 8 | from megatron.energon.flavors.base_dataset import SavableDataset 9 | from megatron.energon.flavors.webdataset.base_webdataset import BaseWebdatasetFactory 10 | from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME 11 | from megatron.energon.flavors.webdataset.field_access import field_access, split_field_access 12 | from megatron.energon.flavors.webdataset.structs import FilteredSample 13 | from megatron.energon.module_loader import ModuleLoader 14 | 15 | T_sample = TypeVar("T_sample", covariant=True) 16 | 17 | 18 | class DefaultGenericWebdatasetFactory(BaseWebdatasetFactory[T_sample], Generic[T_sample]): 19 | """ 20 | Default implementation of webdataset for generic samples and the generic config interface for use with dataset.yaml. 21 | """ 22 | 23 | _sample_loader: Callable[[Dict[str, Any]], Dict[str, Any]] 24 | 25 | def __init__( 26 | self, 27 | path: EPath, 28 | *, 29 | subflavor: Optional[str] = None, 30 | subflavors: Optional[Dict[str, Any]] = None, 31 | field_map: Optional[Dict[str, str]] = None, 32 | sample_loader: Optional[Union[str, Callable[[dict], dict]]] = None, 33 | part_filter: Optional[Union[str, List[str], Callable[[str], bool]]] = None, 34 | **kwargs, 35 | ): 36 | """ 37 | Factory for the webdataset sample loader and basic configuration options. 38 | 39 | Args: 40 | subflavor: Deprecated. Subflavor to set for all loaded samples. 41 | subflavors: Subflavors dictionary to set for all loaded samples. 42 | field_map: Mapping from the webdataset fields to the sample fields. 43 | sample_loader: Function to load the sample from the webdataset fields. May be a string 44 | in order to load a function from a module, or a callable directly. 45 | part_filter: Filter for the parts to load. May be a string in order to load a function 46 | from a module, or a callable directly. 47 | **kwargs: Args passed to parent constructor. 48 | """ 49 | assert (field_map is None) != ( 50 | sample_loader is None 51 | ), "Either field_map or sample_loader must be provided." 52 | if sample_loader is not None: 53 | assert ( 54 | part_filter is not None 55 | ), "part_filter must be provided if sample_loader is provided." 56 | module_loader = ModuleLoader() 57 | if isinstance(sample_loader, str): 58 | sample_loader = module_loader.get_function( 59 | sample_loader, "sample_loader", relative_path=path / MAIN_FOLDER_NAME 60 | ) 61 | else: 62 | assert callable(sample_loader) 63 | sample_loader = sample_loader 64 | if isinstance(part_filter, list): 65 | parts = set(part_filter) 66 | part_filter = lambda part: part in parts 67 | elif isinstance(part_filter, str): 68 | part_filter = module_loader.get_function( 69 | part_filter, "part_filter", relative_path=path / MAIN_FOLDER_NAME 70 | ) 71 | else: 72 | assert callable(part_filter) 73 | self._sample_loader = sample_loader 74 | else: 75 | assert field_map is not None 76 | assert part_filter is None 77 | # Split field map fields by json[field][field] 78 | fields = {key: split_field_access(field) for key, field in field_map.items()} 79 | assert set(field.name for field in dataclasses.fields(self.__sample_type__)).issuperset( 80 | fields.keys() 81 | ) and set( 82 | field.name 83 | for field in dataclasses.fields(self.__sample_type__) 84 | if field.default is not dataclasses.MISSING 85 | and field.default_factory is not dataclasses.MISSING 86 | ).issubset( 87 | field_map.keys() 88 | ), f"field_map does not map to type {self.__sample_type__.__name__} fields" 89 | self._sample_loader = lambda sample: { 90 | k: field_access(sample, v) for k, v in fields.items() 91 | } 92 | parts = set(access[0] for options in fields.values() for access in options) 93 | part_filter = lambda part: part in parts 94 | inner_sample_loader = self._sample_loader 95 | self._sample_loader = lambda sample: { 96 | "__key__": sample["__key__"], 97 | **inner_sample_loader(sample), 98 | "__restore_key__": sample["__restore_key__"], 99 | "__subflavor__": self.subflavor, 100 | "__subflavors__": self.subflavors, 101 | } 102 | super().__init__(path, **kwargs, part_filter=part_filter) 103 | self.subflavor = subflavor 104 | self.subflavors = subflavors or {} 105 | 106 | def load_sample(self, sample: FilteredSample) -> T_sample: 107 | return self.__sample_type__(**self._sample_loader(sample)) 108 | 109 | def config(self) -> Dict[str, Any]: 110 | return dict( 111 | **super().config(), 112 | subflavor=self.subflavor, 113 | subflavors=self.subflavors, 114 | sample_loader=SavableDataset._function_config(self._sample_loader), 115 | ) 116 | -------------------------------------------------------------------------------- /src/megatron/energon/metadataset/join_dataset_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. 2 | # SPDX-License-Identifier: BSD-3-Clause 3 | 4 | from dataclasses import dataclass 5 | from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union 6 | 7 | from megatron.energon.flavors import ( 8 | BaseCoreDatasetFactory, 9 | BaseWebdatasetFactory, 10 | JoinedWebdatasetFactory, 11 | Sample, 12 | ) 13 | from megatron.energon.metadataset.dataset_loader import DatasetLoader 14 | from megatron.energon.metadataset.loader_interface import DatasetBlendMode, DatasetLoaderInterface 15 | from megatron.energon.worker import WorkerConfig 16 | 17 | 18 | @dataclass 19 | class JoinDatasetLoader(DatasetLoaderInterface): 20 | """Loads a joined dataset from a path.""" 21 | 22 | datasets: Union[List[DatasetLoader], Dict[str, DatasetLoader]] 23 | joiner: Union[Type[Sample], Callable[..., Sample]] 24 | join_method: Literal["inner_match", "inner", "left"] = "inner_match" 25 | 26 | split_part: Optional[str] = None 27 | split_config: Optional[str] = None 28 | subflavor: Optional[str] = None 29 | subflavors: Optional[Dict[str, Any]] = None 30 | shuffle_over_epochs_multiplier: int = 1 31 | 32 | def get_dataset( 33 | self, 34 | *, 35 | training: bool, 36 | split_part: Optional[str] = None, 37 | worker_config: WorkerConfig, 38 | subflavor: Optional[str] = None, 39 | subflavors: Optional[Dict[str, Any]] = None, 40 | shuffle_over_epochs: int = 1, 41 | split_config: Optional[str] = None, 42 | **kwargs, 43 | ) -> BaseCoreDatasetFactory: 44 | """ 45 | Args: 46 | training: If true, apply training randomization. 47 | split_part: Default split part to use. 48 | worker_config: Worker configuration. 49 | shuffle_buffer_size: Size of the sample shuffle buffer (before task encoding). 50 | subflavor: Subflavor to use, might be overridden by inner datasets. 51 | subflavors: Subflavors to use, might be overridden by inner datasets. 52 | shuffle_over_epochs: Shuffle the dataset over this many epochs. 53 | **kwargs: Additional arguments to the dataset constructor. 54 | 55 | Returns: 56 | The loaded dataset 57 | """ 58 | if self.split_config is not None: 59 | split_config = self.split_config 60 | if self.split_part is not None: 61 | split_part = self.split_part 62 | if split_part is None: 63 | raise ValueError("Missing split part") 64 | if subflavor is None: 65 | subflavor = self.subflavor 66 | if self.subflavors is not None: 67 | subflavors = {**self.subflavors, **(subflavors or {})} 68 | if isinstance(self.datasets, list): 69 | inner_datasets = [ 70 | dataset.get_dataset( 71 | training=training, 72 | split_part=split_part, 73 | worker_config=worker_config, 74 | subflavor=subflavor, 75 | subflavors=subflavors, 76 | shuffle_over_epochs=shuffle_over_epochs, 77 | split_config=split_config, 78 | **kwargs, 79 | ) 80 | for dataset in self.datasets 81 | ] 82 | assert all( 83 | isinstance(d, BaseWebdatasetFactory) for d in inner_datasets 84 | ), "Can only merge webdatasets efficiently" 85 | elif isinstance(self.datasets, dict): 86 | inner_datasets = { 87 | key: dataset.get_dataset( 88 | training=training, 89 | split_part=split_part, 90 | worker_config=worker_config, 91 | subflavor=subflavor, 92 | subflavors=subflavors, 93 | shuffle_over_epochs=shuffle_over_epochs, 94 | split_config=split_config, 95 | **kwargs, 96 | ) 97 | for key, dataset in self.datasets.items() 98 | } 99 | assert all( 100 | isinstance(d, BaseWebdatasetFactory) for d in inner_datasets.values() 101 | ), "Can only merge webdatasets efficiently" 102 | else: 103 | raise ValueError("Invalid join type") 104 | return JoinedWebdatasetFactory( 105 | inner_datasets=inner_datasets, 106 | training=training, 107 | worker_config=worker_config, 108 | shuffle_over_epochs=shuffle_over_epochs, 109 | join_method=self.join_method, 110 | joiner=self.joiner, 111 | **kwargs, 112 | ) 113 | 114 | def get_datasets( 115 | self, 116 | *, 117 | training: bool, 118 | split_part: Union[Literal["train", "val", "test"], str], 119 | worker_config: WorkerConfig, 120 | subflavor: Optional[str] = None, 121 | subflavors: Optional[Dict[str, Any]] = None, 122 | shuffle_over_epochs_multiplier: int = 1, 123 | **kwargs, 124 | ) -> Tuple[DatasetBlendMode, List[Tuple[BaseCoreDatasetFactory, Union[float, int, None]]]]: 125 | return DatasetBlendMode.NONE, [ 126 | ( 127 | self.get_dataset( 128 | training=training, 129 | split_part=split_part, 130 | worker_config=worker_config, 131 | subflavor=subflavor, 132 | subflavors=subflavors, 133 | shuffle_over_epochs=shuffle_over_epochs_multiplier, 134 | **kwargs, 135 | ), 136 | None, 137 | ) 138 | ] 139 | --------------------------------------------------------------------------------