├── src ├── olmo_core │ ├── py.typed │ ├── __init__.py │ ├── kernels │ │ └── __init__.py │ ├── internal │ │ └── __init__.py │ ├── nn │ │ ├── bolmo │ │ │ ├── __init__.py │ │ │ ├── hf │ │ │ │ └── __init__.py │ │ │ └── embed.py │ │ ├── __init__.py │ │ ├── attention │ │ │ └── te_attn_api.py │ │ ├── conversion │ │ │ └── __init__.py │ │ ├── functional │ │ │ └── __init__.py │ │ ├── moe │ │ │ └── __init__.py │ │ ├── hf │ │ │ └── __init__.py │ │ ├── transformer │ │ │ └── __init__.py │ │ ├── buffer_cache.py │ │ ├── fla.py │ │ └── utils.py │ ├── launch │ │ ├── __init__.py │ │ └── utils.py │ ├── distributed │ │ ├── __init__.py │ │ └── parallel │ │ │ ├── context_parallel.py │ │ │ ├── expert_parallel.py │ │ │ ├── tensor_parallel.py │ │ │ └── data_parallel.py │ ├── aliases.py │ ├── eval │ │ ├── __init__.py │ │ ├── evaluator.py │ │ └── metrics.py │ ├── generate │ │ ├── generation_module │ │ │ ├── transformer │ │ │ │ └── __init__.py │ │ │ ├── __init__.py │ │ │ ├── generation_module.py │ │ │ └── config.py │ │ ├── __init__.py │ │ └── utils.py │ ├── version.py │ ├── doc_utils.py │ ├── optim │ │ ├── adam.py │ │ ├── __init__.py │ │ └── noop.py │ ├── train │ │ ├── callbacks │ │ │ ├── monkey_patcher.py │ │ │ ├── garbage_collector.py │ │ │ ├── __init__.py │ │ │ ├── list_checkpointer.py │ │ │ ├── config_saver.py │ │ │ ├── gpu_memory_monitor.py │ │ │ └── console_logger.py │ │ └── train_module │ │ │ ├── transformer │ │ │ └── __init__.py │ │ │ └── __init__.py │ ├── exceptions.py │ ├── ops │ │ └── __init__.py │ ├── data │ │ ├── mixes │ │ │ └── v3-small-ppl-validation.txt │ │ ├── types.py │ │ └── __init__.py │ ├── testing │ │ └── __init__.py │ └── float8 │ │ └── utils.py ├── test │ ├── __init__.py │ ├── nn │ │ ├── __init__.py │ │ ├── hf │ │ │ ├── __init__.py │ │ │ ├── config_test.py │ │ │ └── checkpoint_test.py │ │ ├── moe │ │ │ ├── __init__.py │ │ │ ├── router_test.py │ │ │ └── mlp_test.py │ │ ├── conversion │ │ │ └── __init__.py │ │ ├── functional │ │ │ ├── __init__.py │ │ │ └── cross_entropy_loss_test.py │ │ ├── transformer │ │ │ ├── __init__.py │ │ │ └── block_test.py │ │ ├── buffer_cache_test.py │ │ ├── layer_norm_test.py │ │ └── feed_forward_test.py │ ├── data │ │ ├── __init__.py │ │ ├── tokenizer_test.py │ │ ├── custom_data_loader_test.py │ │ ├── utils.py │ │ ├── mixes_test.py │ │ ├── fixtures.py │ │ └── custom_data_loader.py │ ├── float8 │ │ ├── __init__.py │ │ ├── utils_test.py │ │ └── ao_test.py │ ├── launch │ │ ├── __init__.py │ │ ├── beaker_test.py │ │ └── utils_test.py │ ├── ops │ │ └── __init__.py │ ├── optim │ │ ├── __init__.py │ │ ├── lion_test.py │ │ ├── skip_step_optimizer_test.py │ │ └── noop_test.py │ ├── train │ │ ├── __init__.py │ │ ├── train_module │ │ │ ├── __init__.py │ │ │ └── transformer │ │ │ │ ├── __init__.py │ │ │ │ └── config_test.py │ │ └── utils_test.py │ ├── distributed │ │ ├── __init__.py │ │ ├── checkpoint │ │ │ ├── __init__.py │ │ │ └── filesystem_test.py │ │ └── utils_test.py │ ├── examples │ │ ├── __init__.py │ │ └── huggingface │ │ │ └── __init__.py │ ├── generate │ │ ├── __init__.py │ │ └── generation_module │ │ │ ├── __init__.py │ │ │ └── transformer │ │ │ └── __init__.py │ ├── model_ladder_test.py │ ├── utils_test.py │ ├── conftest.py │ └── config_test.py ├── examples │ ├── __init__.py │ ├── llm │ │ └── __init__.py │ ├── moe │ │ └── __init__.py │ ├── bolmo │ │ └── __init__.py │ ├── ngpt │ │ └── __init__.py │ └── huggingface │ │ ├── __init__.py │ │ └── upload_checkpoint_to_hf.py └── scripts │ ├── beaker │ ├── get_full_image_name.sh │ ├── create_beaker_image.sh │ └── launch_test.py │ ├── train │ ├── OLMo3 │ │ └── README.md │ ├── README.md │ ├── nGPT-1B.py │ ├── Llama3-8B.py │ ├── OLMo2 │ │ └── OLMo2-13B.py │ ├── OLMoE-1B-7B.py │ └── small-moe.py │ ├── release │ ├── slack_notification.py │ ├── add_pr_comments_on_release.sh │ ├── release.sh │ ├── prepare_changelog.py │ └── release_notes.py │ ├── official │ └── README.md │ └── compare_wandb_configs.py ├── docs ├── .gitignore ├── source │ ├── _static │ │ ├── css │ │ │ └── custom.css │ │ ├── favicon.ico │ │ └── olmo-full-color.png │ ├── optim.rst │ ├── config.rst │ ├── float8.rst │ ├── data │ │ ├── mixes.rst │ │ ├── types.rst │ │ ├── utils.rst │ │ ├── tokenizer.rst │ │ ├── source_mixture.rst │ │ ├── collator.rst │ │ ├── numpy_dataset.rst │ │ ├── data_loader.rst │ │ └── index.rst │ ├── io.rst │ ├── eval │ │ ├── metrics.rst │ │ ├── evaluator.rst │ │ ├── lm_evaluator.rst │ │ └── index.rst │ ├── model_ladder.rst │ ├── utils.rst │ ├── nn │ │ ├── functional.rst │ │ ├── moe.rst │ │ ├── rope.rst │ │ ├── lm_head.rst │ │ ├── attention.rst │ │ ├── conversion.rst │ │ ├── layer_norm.rst │ │ ├── feed_forward.rst │ │ ├── transformer.rst │ │ ├── index.rst │ │ └── hf.rst │ ├── testing.rst │ ├── train │ │ ├── callbacks.rst │ │ ├── train_module.rst │ │ └── index.rst │ ├── exceptions.rst │ ├── distributed │ │ ├── parallel.rst │ │ ├── utils.rst │ │ ├── checkpoint.rst │ │ └── index.rst │ ├── launch.rst │ ├── examples │ │ ├── llm.rst │ │ └── huggingface.rst │ ├── overview │ │ ├── installation.rst │ │ └── introduction.rst │ └── index.rst ├── requirements.txt ├── Makefile └── make.bat ├── MANIFEST.in ├── .readthedocs.yaml ├── .github ├── dependabot.yml ├── RELEASE_PROCESS.md ├── workflows │ ├── pr_checks.yml │ └── docker.yml └── actions │ └── setup-venv │ └── action.yml ├── .gitignore ├── bolmo_scripts ├── launch_stage2_7b.sh ├── launch_stage2_1b.sh ├── launch_stage1_7b.sh └── launch_stage1_1b.sh └── Makefile /src/olmo_core/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/test/nn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | -------------------------------------------------------------------------------- /src/examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/examples/llm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/examples/moe/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/olmo_core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/test/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/test/float8/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/test/launch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/test/nn/hf/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/test/nn/moe/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/test/ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/test/optim/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/test/train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/examples/bolmo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/examples/ngpt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/olmo_core/kernels/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/test/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/test/examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/test/generate/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/examples/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/olmo_core/internal/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/olmo_core/nn/bolmo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/test/nn/conversion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/test/nn/functional/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/test/nn/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/test/distributed/checkpoint/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/test/examples/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/test/train/train_module/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/test/generate/generation_module/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/test/train/train_module/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include src/olmo_core/data/mixes/*.txt 2 | -------------------------------------------------------------------------------- /src/test/generate/generation_module/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | --index-url https://download.pytorch.org/whl/cpu 2 | torch 3 | -------------------------------------------------------------------------------- /src/olmo_core/nn/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common :class:`torch.nn.Module` implementations. 3 | """ 4 | -------------------------------------------------------------------------------- /docs/source/optim.rst: -------------------------------------------------------------------------------- 1 | ``optim`` 2 | ========= 3 | 4 | .. automodule:: olmo_core.optim 5 | :members: 6 | -------------------------------------------------------------------------------- /src/olmo_core/launch/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | An API for launching experiments on various platforms. 3 | """ 4 | -------------------------------------------------------------------------------- /docs/source/config.rst: -------------------------------------------------------------------------------- 1 | ``config`` 2 | ========== 3 | 4 | .. automodule:: olmo_core.config 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/float8.rst: -------------------------------------------------------------------------------- 1 | ``float8`` 2 | ========== 3 | 4 | .. automodule:: olmo_core.float8 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/_static/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/bolmo-core/HEAD/docs/source/_static/favicon.ico -------------------------------------------------------------------------------- /docs/source/data/mixes.rst: -------------------------------------------------------------------------------- 1 | ``data.mixes`` 2 | ============== 3 | 4 | .. automodule:: olmo_core.data.mixes 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/data/types.rst: -------------------------------------------------------------------------------- 1 | ``data.types`` 2 | ============== 3 | 4 | .. automodule:: olmo_core.data.types 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/data/utils.rst: -------------------------------------------------------------------------------- 1 | ``data.utils`` 2 | ============== 3 | 4 | .. automodule:: olmo_core.data.utils 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/io.rst: -------------------------------------------------------------------------------- 1 | ``io`` 2 | ====== 3 | 4 | .. automodule:: olmo_core.io 5 | :members: 6 | :member-order: bysource 7 | -------------------------------------------------------------------------------- /src/olmo_core/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | APIs for distributed communication, bookkeeping, and checkpointing. 3 | """ 4 | -------------------------------------------------------------------------------- /docs/source/_static/olmo-full-color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/bolmo-core/HEAD/docs/source/_static/olmo-full-color.png -------------------------------------------------------------------------------- /docs/source/eval/metrics.rst: -------------------------------------------------------------------------------- 1 | ``eval.metrics`` 2 | ================ 3 | 4 | .. automodule:: olmo_core.eval.metrics 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/model_ladder.rst: -------------------------------------------------------------------------------- 1 | ``model_ladder`` 2 | ================ 3 | 4 | .. automodule:: olmo_core.model_ladder 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/utils.rst: -------------------------------------------------------------------------------- 1 | ``utils`` 2 | ========= 3 | 4 | .. automodule:: olmo_core.utils 5 | :members: 6 | :member-order: bysource 7 | -------------------------------------------------------------------------------- /docs/source/data/tokenizer.rst: -------------------------------------------------------------------------------- 1 | ``data.tokenizer`` 2 | ================== 3 | 4 | .. automodule:: olmo_core.data.tokenizer 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/eval/evaluator.rst: -------------------------------------------------------------------------------- 1 | ``eval.evaluator`` 2 | ================== 3 | 4 | .. automodule:: olmo_core.eval.evaluator 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/nn/functional.rst: -------------------------------------------------------------------------------- 1 | ``nn.functional`` 2 | ================= 3 | 4 | .. automodule:: olmo_core.nn.functional 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/nn/moe.rst: -------------------------------------------------------------------------------- 1 | ``nn.moe`` 2 | ========== 3 | 4 | .. automodule:: olmo_core.nn.moe 5 | :members: 6 | :member-order: bysource 7 | -------------------------------------------------------------------------------- /docs/source/nn/rope.rst: -------------------------------------------------------------------------------- 1 | ``nn.rope`` 2 | =========== 3 | 4 | .. automodule:: olmo_core.nn.rope 5 | :members: 6 | :member-order: bysource 7 | -------------------------------------------------------------------------------- /docs/source/testing.rst: -------------------------------------------------------------------------------- 1 | ``testing`` 2 | =========== 3 | 4 | .. automodule:: olmo_core.testing 5 | :members: 6 | :member-order: bysource 7 | -------------------------------------------------------------------------------- /docs/source/train/callbacks.rst: -------------------------------------------------------------------------------- 1 | ``train.callbacks`` 2 | =================== 3 | 4 | .. automodule:: olmo_core.train.callbacks 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/eval/lm_evaluator.rst: -------------------------------------------------------------------------------- 1 | ``eval.lm_evaluator`` 2 | ===================== 3 | 4 | .. automodule:: olmo_core.eval.lm_evaluator 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/exceptions.rst: -------------------------------------------------------------------------------- 1 | ``exceptions`` 2 | ============== 3 | 4 | .. automodule:: olmo_core.exceptions 5 | :members: 6 | :member-order: bysource 7 | -------------------------------------------------------------------------------- /docs/source/nn/lm_head.rst: -------------------------------------------------------------------------------- 1 | ``nn.lm_head`` 2 | ============== 3 | 4 | .. automodule:: olmo_core.nn.lm_head 5 | :members: 6 | :member-order: bysource 7 | -------------------------------------------------------------------------------- /src/olmo_core/aliases.py: -------------------------------------------------------------------------------- 1 | from os import PathLike 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | PathOrStr = Union[Path, PathLike, str] 6 | -------------------------------------------------------------------------------- /docs/source/data/source_mixture.rst: -------------------------------------------------------------------------------- 1 | ``data.source_mixture`` 2 | ======================= 3 | 4 | .. automodule:: olmo_core.data.source_mixture 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/train/train_module.rst: -------------------------------------------------------------------------------- 1 | ``train.train_module`` 2 | ====================== 3 | 4 | .. automodule:: olmo_core.train.train_module 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/distributed/parallel.rst: -------------------------------------------------------------------------------- 1 | ``distributed.parallel`` 2 | ======================== 3 | 4 | .. automodule:: olmo_core.distributed.parallel 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/nn/attention.rst: -------------------------------------------------------------------------------- 1 | ``nn.attention`` 2 | ================ 3 | 4 | .. automodule:: olmo_core.nn.attention 5 | :members: 6 | :member-order: bysource 7 | -------------------------------------------------------------------------------- /docs/source/data/collator.rst: -------------------------------------------------------------------------------- 1 | ``data.collator`` 2 | ================= 3 | 4 | .. automodule:: olmo_core.data.collator 5 | :members: 6 | :special-members: __call__ 7 | -------------------------------------------------------------------------------- /docs/source/nn/conversion.rst: -------------------------------------------------------------------------------- 1 | ``nn.conversion`` 2 | ================= 3 | 4 | .. automodule:: olmo_core.nn.conversion 5 | :members: 6 | :member-order: bysource 7 | -------------------------------------------------------------------------------- /docs/source/nn/layer_norm.rst: -------------------------------------------------------------------------------- 1 | ``nn.layer_norm`` 2 | ================= 3 | 4 | .. automodule:: olmo_core.nn.layer_norm 5 | :members: 6 | :member-order: bysource 7 | -------------------------------------------------------------------------------- /docs/source/nn/feed_forward.rst: -------------------------------------------------------------------------------- 1 | ``nn.feed_forward`` 2 | =================== 3 | 4 | .. automodule:: olmo_core.nn.feed_forward 5 | :members: 6 | :member-order: bysource 7 | -------------------------------------------------------------------------------- /docs/source/distributed/utils.rst: -------------------------------------------------------------------------------- 1 | ``distributed.utils`` 2 | ========================== 3 | 4 | .. automodule:: olmo_core.distributed.utils 5 | :members: 6 | :member-order: bysource 7 | -------------------------------------------------------------------------------- /docs/source/data/numpy_dataset.rst: -------------------------------------------------------------------------------- 1 | ``data.numpy_dataset`` 2 | ====================== 3 | 4 | .. automodule:: olmo_core.data.numpy_dataset 5 | :members: 6 | :special-members: __getitem__,__len__ 7 | -------------------------------------------------------------------------------- /docs/source/distributed/checkpoint.rst: -------------------------------------------------------------------------------- 1 | ``distributed.checkpoint`` 2 | ========================== 3 | 4 | .. automodule:: olmo_core.distributed.checkpoint 5 | :members: 6 | :member-order: bysource 7 | -------------------------------------------------------------------------------- /docs/source/launch.rst: -------------------------------------------------------------------------------- 1 | ``launch`` 2 | ========== 3 | 4 | .. automodule:: olmo_core.launch 5 | :members: 6 | 7 | Beaker 8 | ------ 9 | 10 | .. automodule:: olmo_core.launch.beaker 11 | :members: 12 | -------------------------------------------------------------------------------- /docs/source/data/data_loader.rst: -------------------------------------------------------------------------------- 1 | ``data.data_loader`` 2 | ==================== 3 | 4 | .. automodule:: olmo_core.data.data_loader 5 | :members: 6 | :special-members: __iter__,__len__ 7 | :private-members: _iter_batches 8 | -------------------------------------------------------------------------------- /docs/source/eval/index.rst: -------------------------------------------------------------------------------- 1 | ``eval`` 2 | ======== 3 | 4 | .. automodule:: olmo_core.eval 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | :caption: Submodules 9 | 10 | metrics 11 | evaluator 12 | lm_evaluator 13 | -------------------------------------------------------------------------------- /docs/source/train/index.rst: -------------------------------------------------------------------------------- 1 | ``train`` 2 | ========= 3 | 4 | .. automodule:: olmo_core.train 5 | :members: 6 | 7 | .. toctree:: 8 | :maxdepth: 2 9 | :caption: Submodules 10 | 11 | callbacks 12 | train_module 13 | -------------------------------------------------------------------------------- /docs/source/nn/transformer.rst: -------------------------------------------------------------------------------- 1 | ``nn.transformer`` 2 | ================== 3 | 4 | .. automodule:: olmo_core.nn.transformer 5 | :members: 6 | :exclude-members: TransformerDataParallelWrappingStrategy,TransformerActivationCheckpointingMode 7 | -------------------------------------------------------------------------------- /src/test/model_ladder_test.py: -------------------------------------------------------------------------------- 1 | from olmo_core.model_ladder import ModelSize 2 | 3 | 4 | def test_model_size_num_params(): 5 | assert ModelSize.size_190M.num_params == 190_000_000 6 | assert ModelSize.size_7B.num_params == 7_000_000_000 7 | -------------------------------------------------------------------------------- /docs/source/distributed/index.rst: -------------------------------------------------------------------------------- 1 | ``distributed`` 2 | =============== 3 | 4 | .. automodule:: olmo_core.distributed 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | :caption: Submodules 9 | 10 | checkpoint 11 | parallel 12 | utils 13 | -------------------------------------------------------------------------------- /src/scripts/beaker/get_full_image_name.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -euo pipefail 4 | 5 | name=$1 6 | workspace=$2 7 | full_name=$(beaker workspace images "${workspace}" --format=json | jq -r ".[] | select(.name==\"${name}\") | .fullName") 8 | echo "${full_name}" 9 | -------------------------------------------------------------------------------- /src/scripts/train/OLMo3/README.md: -------------------------------------------------------------------------------- 1 | # Olmo3 Pretraining and Midtraining Configs 2 | 3 | Execution Order: 4 | 5 | 1. OLMo3-7B.py 6 | 2. OLMo3-7B-midtraining.py 7 | 3. OLMo3-7B-long-context.py 8 | 9 | OLMo3-7B-anneal.py was used for ad-hoc progress checks during pretraining. 10 | -------------------------------------------------------------------------------- /src/olmo_core/eval/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Metrics and evaluator classes. 3 | """ 4 | 5 | from .evaluator import Evaluator 6 | from .lm_evaluator import LMEvaluator 7 | from .metrics import MeanMetric, Metric 8 | 9 | __all__ = ["Evaluator", "LMEvaluator", "Metric", "MeanMetric"] 10 | -------------------------------------------------------------------------------- /src/olmo_core/generate/generation_module/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import TransformerGenerationModuleConfig 2 | from .generation_module import TransformerGenerationModule 3 | 4 | __all__ = [ 5 | "TransformerGenerationModule", 6 | "TransformerGenerationModuleConfig", 7 | ] 8 | -------------------------------------------------------------------------------- /docs/source/data/index.rst: -------------------------------------------------------------------------------- 1 | ``data`` 2 | ======== 3 | 4 | .. automodule:: olmo_core.data 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | :caption: Submodules 9 | 10 | numpy_dataset 11 | source_mixture 12 | collator 13 | mixes 14 | tokenizer 15 | data_loader 16 | types 17 | utils 18 | -------------------------------------------------------------------------------- /docs/source/nn/index.rst: -------------------------------------------------------------------------------- 1 | ``nn`` 2 | ====== 3 | 4 | .. automodule:: olmo_core.nn 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | :caption: Submodules 9 | 10 | attention 11 | conversion 12 | feed_forward 13 | functional 14 | hf 15 | layer_norm 16 | lm_head 17 | moe 18 | rope 19 | transformer 20 | -------------------------------------------------------------------------------- /src/olmo_core/distributed/parallel/context_parallel.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from olmo_core.config import Config 4 | 5 | 6 | @dataclass 7 | class ContextParallelConfig(Config): 8 | """ 9 | Configuration class for context parallelism (CP). 10 | """ 11 | 12 | degree: int 13 | """ 14 | The CP degree. 15 | """ 16 | -------------------------------------------------------------------------------- /src/olmo_core/distributed/parallel/expert_parallel.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from olmo_core.config import Config 4 | 5 | 6 | @dataclass 7 | class ExpertParallelConfig(Config): 8 | """ 9 | Configuration class for expert parallelism (EP). 10 | """ 11 | 12 | degree: int 13 | """ 14 | The EP degree. 15 | """ 16 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | sphinx: 4 | configuration: docs/source/conf.py 5 | fail_on_warning: true 6 | 7 | build: 8 | os: ubuntu-22.04 9 | tools: 10 | python: "3.10" 11 | 12 | python: 13 | install: 14 | - requirements: docs/requirements.txt 15 | - method: pip 16 | path: . 17 | extra_requirements: 18 | - all 19 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | labels: 8 | - dependencies/python 9 | open-pull-requests-limit: 5 10 | - package-ecosystem: "github-actions" 11 | directory: "/" 12 | schedule: 13 | interval: "weekly" 14 | labels: 15 | - dependencies/actions 16 | -------------------------------------------------------------------------------- /src/olmo_core/nn/attention/te_attn_api.py: -------------------------------------------------------------------------------- 1 | try: 2 | import transformer_engine.pytorch as te # type: ignore 3 | except ImportError: 4 | te = None 5 | 6 | 7 | def has_te_attn() -> bool: 8 | """Check if Transformer Engine attention is available.""" 9 | return te is not None 10 | 11 | 12 | TEDotProductAttention = te.DotProductAttention if te is not None else None 13 | -------------------------------------------------------------------------------- /src/test/data/tokenizer_test.py: -------------------------------------------------------------------------------- 1 | from olmo_core.data import TokenizerConfig 2 | 3 | 4 | def test_padded_vocab_size(): 5 | assert TokenizerConfig.dolma2().padded_vocab_size() == 100352 6 | assert TokenizerConfig.gpt_neox_olmo_dolma_v1_5().padded_vocab_size() == 50304 7 | 8 | 9 | def test_from_hf(): 10 | assert TokenizerConfig.from_hf("gpt2") == TokenizerConfig.gpt2() 11 | -------------------------------------------------------------------------------- /src/olmo_core/version.py: -------------------------------------------------------------------------------- 1 | _MAJOR = "2" 2 | _MINOR = "3" 3 | _PATCH = "0" 4 | _SUFFIX = "" 5 | 6 | VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) 7 | VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX) 8 | 9 | 10 | if __name__ == "__main__": 11 | import sys 12 | 13 | if sys.argv[-1] == "short": 14 | print(VERSION_SHORT) 15 | else: 16 | print(VERSION) 17 | -------------------------------------------------------------------------------- /src/olmo_core/nn/conversion/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common logic for converting :mod:`olmo_core.nn` features to/from other formats (like Hugging Face). 3 | """ 4 | 5 | from .state_converter import StateConverter 6 | from .state_mapping import StateMapping, StateMappingTemplate, TemplatePlaceholder 7 | 8 | __all__ = [ 9 | "StateConverter", 10 | "StateMapping", 11 | "StateMappingTemplate", 12 | "TemplatePlaceholder", 13 | ] 14 | -------------------------------------------------------------------------------- /src/olmo_core/generate/__init__.py: -------------------------------------------------------------------------------- 1 | from .generation_module import GenerationModule 2 | from .generation_module.config import GenerationConfig 3 | from .generation_module.transformer import ( 4 | TransformerGenerationModule, 5 | TransformerGenerationModuleConfig, 6 | ) 7 | 8 | __all__ = [ 9 | "GenerationConfig", 10 | "GenerationModule", 11 | "TransformerGenerationModule", 12 | "TransformerGenerationModuleConfig", 13 | ] 14 | -------------------------------------------------------------------------------- /docs/source/examples/llm.rst: -------------------------------------------------------------------------------- 1 | Train an LLM 2 | ============ 3 | 4 | The following snippets can be found in `src/examples/llm/ `_. 5 | The ``train.py`` script is meant to be launched via ``torchrun``. 6 | You can also use the ``python -m olmo_core.launch.beaker`` CLI to quickly launch this script on Beaker. 7 | 8 | .. tab:: ``train.py`` 9 | 10 | .. literalinclude:: ../../../src/examples/llm/train.py 11 | :language: py 12 | -------------------------------------------------------------------------------- /src/test/launch/beaker_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from beaker import Beaker 5 | 6 | from olmo_core.launch.beaker import OLMoCoreBeakerImage 7 | 8 | 9 | @pytest.mark.skipif( 10 | os.environ.get("BEAKER_TOKEN", "") == "", reason="Missing 'BEAKER_TOKEN' env var" 11 | ) 12 | @pytest.mark.parametrize("image", list(OLMoCoreBeakerImage)) 13 | def test_official_images_exist(image): 14 | beaker = Beaker.from_env(default_workspace="ai2/OLMo-core") 15 | beaker.image.get(image) 16 | -------------------------------------------------------------------------------- /src/test/float8/utils_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from olmo_core.float8.utils import cast_to_fp8, per_block_cast_to_fp8 4 | 5 | 6 | def test_cast_to_fp8(): 7 | x = torch.randn(2, 3, 3 * 128) 8 | x_fp8, s = cast_to_fp8(x) 9 | assert x_fp8.shape == x.shape 10 | assert s.shape == (2, 3, 3) 11 | 12 | 13 | def test_per_block_cast_to_fp8(): 14 | x = torch.randn(8, 3 * 128, 2 * 128) 15 | x_fp8, s = per_block_cast_to_fp8(x) 16 | assert x_fp8.shape == x.shape 17 | assert s.shape == (8, 3, 2) 18 | -------------------------------------------------------------------------------- /src/olmo_core/generate/generation_module/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import GenerationConfig 2 | from .generation_module import GenerationModule 3 | from .transformer.config import TransformerGenerationModuleConfig 4 | from .transformer.generation_module import BolmoTransformerGenerationModule, TransformerGenerationModule 5 | 6 | __all__ = [ 7 | "BolmoTransformerGenerationModule", 8 | "GenerationConfig", 9 | "GenerationModule", 10 | "TransformerGenerationModule", 11 | "TransformerGenerationModuleConfig", 12 | ] 13 | -------------------------------------------------------------------------------- /src/scripts/beaker/create_beaker_image.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -exuo pipefail 4 | 5 | source_image=$1 6 | beaker_image=$2 7 | beaker_workspace=$3 8 | timestamp=$(date "+%Y%m%d%H%M%S") 9 | 10 | beaker_user=$(beaker account whoami --format=json | jq -r '.[0].name') 11 | beaker image create "${source_image}" --name "${beaker_image}-tmp" --workspace "${beaker_workspace}" 12 | beaker image rename "${beaker_user}/${beaker_image}" "${beaker_image}-${timestamp}" || true 13 | beaker image rename "${beaker_user}/${beaker_image}-tmp" "${beaker_image}" 14 | -------------------------------------------------------------------------------- /src/olmo_core/nn/functional/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common ``nn`` function implementations. 3 | """ 4 | 5 | import torch 6 | 7 | from .cross_entropy_loss import * 8 | 9 | __all__ = [ 10 | "cross_entropy_loss", 11 | "fused_linear_cross_entropy_loss", 12 | "l2_normalize", 13 | ] 14 | 15 | 16 | def l2_normalize(x: torch.Tensor, dim: int = -1) -> torch.Tensor: 17 | # NOTE: could also use F.normalize(), but that doesn't work with DTensor at the moment. 18 | return x / torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32).type_as(x) 19 | -------------------------------------------------------------------------------- /src/olmo_core/doc_utils.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar 2 | 3 | T = TypeVar("T") 4 | 5 | 6 | def beta_feature(f: T) -> T: 7 | """ 8 | Mark a class or function as a beta feature. 9 | """ 10 | if f.__doc__ is None: 11 | f.__doc__ = "" 12 | 13 | f.__doc__ += """ 14 | 15 | .. warning:: 16 | This is a beta feature! The API is subject to change even with minor and patch releases. 17 | If you choose to use this feature please read the `CHANGELOG `_ 18 | before upgrading your version of this library. 19 | 20 | """ 21 | 22 | return f 23 | -------------------------------------------------------------------------------- /src/test/launch/utils_test.py: -------------------------------------------------------------------------------- 1 | from olmo_core.launch.utils import parse_git_remote_url 2 | 3 | 4 | def test_parse_git_remote_url(): 5 | # HTTPS format. 6 | assert parse_git_remote_url("https://github.com/allenai/OLMo-core.git") == ( 7 | "allenai", 8 | "OLMo-core", 9 | ) 10 | # SSH format. 11 | assert parse_git_remote_url("git@github.com:allenai/OLMo-core.git") == ( 12 | "allenai", 13 | "OLMo-core", 14 | ) 15 | # Username+password format. 16 | assert parse_git_remote_url("https://USERNAME:PASSWORD@github.com/allenai/OLMo-core.git") == ( 17 | "allenai", 18 | "OLMo-core", 19 | ) 20 | -------------------------------------------------------------------------------- /.github/RELEASE_PROCESS.md: -------------------------------------------------------------------------------- 1 | # GitHub Release Process 2 | 3 | ## Steps 4 | 5 | 1. Update the version in `src/olmo_core/version.py`. 6 | 2. Run the release script: 7 | 8 | ```bash 9 | ./src/scripts/release/release.sh 10 | ``` 11 | 12 | This will commit the changes to the CHANGELOG and `version.py` files and then create a new tag in git 13 | which will trigger a workflow on GitHub Actions that handles the rest. 14 | 15 | ## Fixing a failed release 16 | 17 | If for some reason the GitHub Actions release workflow failed with an error that needs to be fixed, you'll have to delete the tag on GitHub. Once you've pushed a fix you can simply repeat the steps above. 18 | -------------------------------------------------------------------------------- /src/scripts/train/README.md: -------------------------------------------------------------------------------- 1 | ## Ai2 internal training scripts 2 | 3 | ## Notice❗ 4 | 5 | The scripts in this folder use an internal module (`olmo_core.internal`), which is subject to breaking changes without notice, 6 | and reference data paths that are only accessible to Ai2 employees, so they won't work out-of-the-box for external users. 7 | Instead, see [`src/scripts/official/`](https://github.com/allenai/OLMo-core/tree/main/src/scripts/official) for public versions. 8 | 9 | ## Usage 10 | 11 | Most of the scripts here have a consistent command-line API, but that's not guaranteed. 12 | Generally if you run the script without any arguments it will print out some usage information. 13 | -------------------------------------------------------------------------------- /src/olmo_core/optim/adam.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Tuple, Type 3 | 4 | import torch 5 | 6 | from .config import OptimConfig 7 | 8 | 9 | @dataclass 10 | class AdamConfig(OptimConfig): # NOTE: omagaconf doesn't like "OptimConfig[torch.optim.AdamW]" 11 | """ 12 | Configuration class for building an :class:`torch.optim.Adam` optimizer. 13 | """ 14 | 15 | lr: float = 1e-3 16 | betas: Tuple[float, float] = (0.9, 0.999) 17 | eps: float = 1e-8 18 | foreach: Optional[bool] = None 19 | fused: Optional[bool] = None 20 | 21 | @classmethod 22 | def optimizer(cls) -> Type[torch.optim.Adam]: 23 | return torch.optim.Adam 24 | -------------------------------------------------------------------------------- /src/olmo_core/nn/moe/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | MoE layers. 3 | """ 4 | 5 | from .loss import MoELoadBalancingLossGranularity 6 | from .mlp import DroplessMoEMLP, MoEMLP 7 | from .moe import DroplessMoE, MoEBase, MoEConfig, MoEType 8 | from .router import ( 9 | MoELinearRouter, 10 | MoERouter, 11 | MoERouterConfig, 12 | MoERouterGatingFunction, 13 | MoERouterType, 14 | ) 15 | 16 | __all__ = [ 17 | "MoEBase", 18 | "DroplessMoE", 19 | "MoEConfig", 20 | "MoEType", 21 | "MoEMLP", 22 | "DroplessMoEMLP", 23 | "MoERouter", 24 | "MoELinearRouter", 25 | "MoERouterConfig", 26 | "MoERouterType", 27 | "MoERouterGatingFunction", 28 | "MoELoadBalancingLossGranularity", 29 | ] 30 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= -W 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /src/olmo_core/nn/bolmo/hf/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoConfig 2 | 3 | from olmo_core.nn.bolmo.hf.configuration_bolmo import BolmoConfig 4 | from olmo_core.nn.bolmo.hf.modeling_bolmo import BolmoForCausalLM, BolmoModel 5 | from olmo_core.nn.bolmo.hf.tokenization_bolmo import BolmoTokenizer 6 | 7 | AutoTokenizer.register("bolmo", BolmoTokenizer) 8 | AutoConfig.register("bolmo", BolmoConfig) 9 | AutoModelForCausalLM.register(BolmoConfig, BolmoForCausalLM) 10 | AutoModel.register(BolmoConfig, BolmoModel) 11 | BolmoConfig.register_for_auto_class("AutoConfig") 12 | BolmoForCausalLM.register_for_auto_class("AutoModelForCausalLM") 13 | BolmoModel.register_for_auto_class("AutoModel") 14 | BolmoTokenizer.register_for_auto_class("AutoTokenizer") -------------------------------------------------------------------------------- /src/test/nn/buffer_cache_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from olmo_core.nn.buffer_cache import BufferCache 4 | 5 | 6 | def test_buffer_cache(): 7 | cache = BufferCache() 8 | cache["a"] = torch.tensor(1) 9 | cache["b"] = torch.tensor(2) 10 | assert set(cache) == {"a", "b"} 11 | assert len(cache) == 2 12 | assert cache["a"].item() == 1 13 | assert (x := cache.get("a")) is not None and x.item() == 1 14 | 15 | cache2 = cache.with_namespace("foo") 16 | assert len(cache2) == 0 17 | cache2["a"] = torch.tensor(3) 18 | assert len(cache2) == 1 19 | assert cache2["a"].item() == 3 20 | assert cache2["a"].item() == 3 21 | assert (x := cache2.get("a")) is not None and x.item() == 3 22 | assert len(cache) == 2 23 | assert cache["a"].item() == 1 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # build artifacts 2 | 3 | .eggs/ 4 | .mypy_cache 5 | *.egg-info/ 6 | build/ 7 | dist/ 8 | pip-wheel-metadata/ 9 | private/ 10 | tmp/ 11 | 12 | # dev tools 13 | 14 | .envrc 15 | .python-version 16 | .idea 17 | .venv/ 18 | .vscode/ 19 | /*.iml 20 | pyrightconfig.json 21 | .ruff.toml 22 | uv.lock 23 | 24 | 25 | # jupyter notebooks 26 | 27 | .ipynb_checkpoints 28 | 29 | 30 | # miscellaneous 31 | 32 | .cache/ 33 | doc/_build/ 34 | *.swp 35 | .DS_Store 36 | 37 | 38 | # python 39 | 40 | *.pyc 41 | *.pyo 42 | __pycache__ 43 | 44 | 45 | # testing and continuous integration 46 | 47 | .coverage 48 | .pytest_cache/ 49 | .benchmarks 50 | 51 | # documentation build artifacts 52 | 53 | docs/build 54 | site/ 55 | 56 | # runs 57 | /runs/ 58 | /wandb/ 59 | /scratch/ 60 | core 61 | /dataset-cache/ 62 | -------------------------------------------------------------------------------- /src/scripts/release/slack_notification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import requests 4 | 5 | from olmo_core.version import VERSION 6 | 7 | 8 | def parse_args() -> argparse.Namespace: 9 | parser = argparse.ArgumentParser( 10 | "slack-notifier", description="Send a release notifcation to a Slack channel." 11 | ) 12 | parser.add_argument("webhook_url", type=str, help="The webhook URL for a Slack channel.") 13 | return parser.parse_args() 14 | 15 | 16 | def main(): 17 | args = parse_args() 18 | text = ( 19 | f"OLMo-core *v{VERSION}* is now out. See " 20 | f"https://github.com/allenai/OLMo-core/releases/tag/v{VERSION} for release notes." 21 | ) 22 | requests.post(args.webhook_url, json={"text": text}) 23 | 24 | 25 | if __name__ == "__main__": 26 | main() 27 | -------------------------------------------------------------------------------- /src/olmo_core/nn/hf/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for converting models between OLMo Core and Hugging Face formats. To configure the 3 | mappings between OLMo Core and Hugging Face, you may change the variables in 4 | :mod:`olmo_core.nn.hf.convert` (e.g. :data:`olmo_core.nn.hf.convert.HF_TO_OLMO_CORE_WEIGHT_MAPPINGS`). 5 | """ 6 | 7 | from .checkpoint import load_hf_model, save_hf_model 8 | from .config import get_hf_config 9 | from .convert import ( 10 | convert_state_from_hf, 11 | convert_state_to_hf, 12 | get_converter_from_hf, 13 | get_converter_to_hf, 14 | ) 15 | 16 | __all__ = [ 17 | "convert_state_from_hf", 18 | "convert_state_to_hf", 19 | "get_converter_from_hf", 20 | "get_converter_to_hf", 21 | "get_hf_config", 22 | "load_hf_model", 23 | "save_hf_model", 24 | ] 25 | -------------------------------------------------------------------------------- /src/olmo_core/train/callbacks/monkey_patcher.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | from dataclasses import dataclass 4 | 5 | from torch.distributed import DeviceMesh 6 | 7 | from .callback import Callback 8 | 9 | log = logging.getLogger(__name__) 10 | 11 | 12 | @dataclass 13 | class MonkeyPatcherCallback(Callback): 14 | """ 15 | While looking into performance issues with OLMo3 training, we discovered that 16 | `DeviceMesh.__getitem__()` can become a bottleneck because it gets called very often by FSDP and 17 | creates a new sub-mesh object each time. So this callback patches that method to cache 18 | the sub-meshes. 19 | """ 20 | 21 | def pre_train(self): 22 | # Cache DeviceMesh.__get_item__ 23 | DeviceMesh.__getitem__ = functools.lru_cache(maxsize=None)(DeviceMesh.__getitem__) 24 | -------------------------------------------------------------------------------- /src/test/float8/ao_test.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | 3 | import pytest 4 | 5 | from olmo_core.float8.ao import AOCastConfig, AOFloat8LinearConfig, AOScalingType 6 | 7 | 8 | def has_torchao() -> bool: 9 | return importlib.util.find_spec("torchao") is not None 10 | 11 | 12 | @pytest.mark.skipif(not has_torchao(), reason="Requires torchao") 13 | def test_ao_float8_linear_config(): 14 | from torchao.float8.config import Float8LinearConfig, ScalingType 15 | 16 | assert isinstance(AOFloat8LinearConfig().to_ao_type(), Float8LinearConfig) 17 | assert AOFloat8LinearConfig(emulate=True).to_ao_type().emulate 18 | assert ( 19 | AOFloat8LinearConfig(cast_config_input=AOCastConfig(scaling_type=AOScalingType.disabled)) 20 | .to_ao_type() 21 | .cast_config_input.scaling_type 22 | == ScalingType.DISABLED 23 | ) 24 | -------------------------------------------------------------------------------- /.github/workflows/pr_checks.yml: -------------------------------------------------------------------------------- 1 | name: PR Checks 2 | 3 | concurrency: 4 | group: ${{ github.workflow }}-${{ github.ref }} 5 | cancel-in-progress: true 6 | 7 | on: 8 | pull_request: 9 | branches: 10 | - main 11 | - v2 12 | paths: 13 | - 'src/**' 14 | 15 | jobs: 16 | changelog: 17 | name: CHANGELOG 18 | runs-on: ubuntu-latest 19 | if: github.event_name == 'pull_request' 20 | 21 | steps: 22 | - uses: actions/checkout@v4 23 | with: 24 | fetch-depth: 0 25 | 26 | - name: Check that CHANGELOG has been updated 27 | run: | 28 | # If this step fails, this means you haven't updated the CHANGELOG.md 29 | # file with notes on your contribution. 30 | git diff --name-only $(git merge-base origin/main HEAD) | grep '^CHANGELOG.md$' && echo "Thanks for helping keep our CHANGELOG up-to-date!" 31 | -------------------------------------------------------------------------------- /src/olmo_core/exceptions.py: -------------------------------------------------------------------------------- 1 | class OLMoError(Exception): 2 | """ 3 | Base exception for OLMo custom error types. 4 | """ 5 | 6 | 7 | class OLMoInvalidRangeRequestError(OLMoError): 8 | pass 9 | 10 | 11 | class OLMoNetworkError(OLMoError): 12 | pass 13 | 14 | 15 | class OLMoEnvironmentError(OLMoError): 16 | pass 17 | 18 | 19 | class OLMoUserError(OLMoError): 20 | pass 21 | 22 | 23 | class OLMoCheckpointError(OLMoError): 24 | pass 25 | 26 | 27 | class OLMoConfigurationError(OLMoError): 28 | pass 29 | 30 | 31 | class OLMoCLIError(OLMoError): 32 | pass 33 | 34 | 35 | class OLMoThreadError(OLMoError): 36 | pass 37 | 38 | 39 | class BeakerExperimentFailedError(OLMoError): 40 | pass 41 | 42 | 43 | class BeakerInsufficientResourcesError(OLMoError): 44 | pass 45 | 46 | 47 | class OLMoUploadError(OLMoError): 48 | pass 49 | -------------------------------------------------------------------------------- /src/scripts/official/README.md: -------------------------------------------------------------------------------- 1 | # Official public training scripts 2 | 3 | Please check the config carefully before attempting to run them. You may need to adjust hyperparameters based on your hardware. 4 | 5 | ## Usage 6 | 7 | Each Python training script in this directory has the same CLI, and they're intended to be launched directly with `torchrun` or, for Beaker users, through OLMo-core Beaker launch CLI: `python -m olmo_core.launch.beaker`. 8 | The scripts themselves take several required arguments as well as any number of config overrides in dot-notation. 9 | Run a script with the `--help` flag to see which arguments are required, and run with the `--dry-run` flag to see the full config that will be used. 10 | To override a field in the config such as the `data_loader`'s `prefetch_factor`, you could add the option `--data_loader.prefetch_factor=4` to your command-line options. 11 | -------------------------------------------------------------------------------- /src/olmo_core/generate/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | @torch.compile(dynamic=True) 5 | def selective_log_softmax(logits: torch.Tensor, index: torch.Tensor) -> torch.Tensor: 6 | """ 7 | Compute log softmax probabilities for selected tokens. 8 | 9 | .. note:: 10 | torch.compile() performs an optimization that avoids materializing the full log softmax 11 | tensor when combined with gather operations, which can save significant memory compared 12 | to computing the full log softmax and then indexing. 13 | 14 | :param logits: The logits tensor of shape ``(..., vocab_size)``. 15 | :param index: The index tensor of shape ``(...)``. 16 | 17 | :returns: The log probabilities of shape ``(...)``. 18 | """ 19 | logprobs = torch.log_softmax(logits.float(), dim=-1) 20 | return torch.gather(logprobs, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) 21 | -------------------------------------------------------------------------------- /src/test/optim/lion_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | 5 | from olmo_core.optim import LionConfig 6 | from olmo_core.testing import DEVICES 7 | 8 | 9 | class Model(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | self.w1 = nn.Linear(8, 16) 13 | self.w2 = nn.Linear(16, 8) 14 | 15 | def forward(self, x: torch.Tensor) -> torch.Tensor: 16 | return self.w2(self.w1(x)) 17 | 18 | 19 | @pytest.mark.parametrize("device", DEVICES) 20 | def test_lion(device: torch.device): 21 | config = LionConfig() 22 | model = Model().train().to(device) 23 | optim = config.build(model) 24 | 25 | for group in optim.param_groups: 26 | assert "initial_lr" in group 27 | 28 | # Take a step. 29 | optim.zero_grad(set_to_none=True) 30 | model(torch.randn(2, 8, device=device)).sum().backward() 31 | optim.step() 32 | -------------------------------------------------------------------------------- /src/test/train/train_module/transformer/config_test.py: -------------------------------------------------------------------------------- 1 | from olmo_core.distributed.parallel import PipelineScheduleType, PipelineSplitStyle 2 | from olmo_core.train.train_module.transformer import TransformerPipelineParallelConfig 3 | 4 | 5 | def test_generate_pipeline_split_points(): 6 | pp_config = TransformerPipelineParallelConfig( 7 | degree=2, schedule=PipelineScheduleType.single_1F1B, style=PipelineSplitStyle.loop 8 | ) 9 | assert pp_config.get_split_points(4) == [2] 10 | 11 | pp_config = TransformerPipelineParallelConfig( 12 | degree=4, schedule=PipelineScheduleType.single_1F1B, style=PipelineSplitStyle.loop 13 | ) 14 | assert pp_config.get_split_points(4) == [1, 2, 3] 15 | 16 | pp_config = TransformerPipelineParallelConfig( 17 | degree=2, schedule=PipelineScheduleType.interleaved_1F1B, style=PipelineSplitStyle.loop 18 | ) 19 | assert pp_config.get_split_points(4) == [1, 2, 3] 20 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /src/test/data/custom_data_loader_test.py: -------------------------------------------------------------------------------- 1 | from .custom_data_loader import CustomDataLoader 2 | 3 | 4 | def test_custom_data_loader(tmp_path): 5 | data_loader = CustomDataLoader( 6 | sequence_length=128, 7 | vocab_size=1024, 8 | work_dir=tmp_path, 9 | global_batch_size=512, 10 | total_batches=100, 11 | ) 12 | data_loader.reshuffle() 13 | 14 | batches_processed = 0 15 | for batch in data_loader: 16 | batches_processed += 1 17 | assert batch["input_ids"].numel() == 512 18 | if batches_processed > 10: 19 | break 20 | 21 | state_dict = data_loader.state_dict() 22 | data_loader.reset() 23 | data_loader.load_state_dict(state_dict) 24 | assert data_loader.batches_processed == batches_processed 25 | 26 | for batch in data_loader: 27 | batches_processed += 1 28 | assert batch["input_ids"].numel() == 512 29 | 30 | assert batches_processed == 100 31 | -------------------------------------------------------------------------------- /docs/source/nn/hf.rst: -------------------------------------------------------------------------------- 1 | ``nn.hf`` 2 | ============== 3 | 4 | .. automodule:: olmo_core.nn.hf 5 | :members: 6 | :member-order: bysource 7 | 8 | .. autodata:: olmo_core.nn.hf.convert.HF_TO_OLMO_CORE_WEIGHT_MAPPINGS 9 | :no-value: 10 | .. autodata:: olmo_core.nn.hf.convert.HF_TO_OLMO_CORE_MODULE_MAPPINGS 11 | :no-value: 12 | .. autodata:: olmo_core.nn.hf.convert.MODEL_TYPE_SPECIFIC_HF_TO_OLMO_CORE_WEIGHT_MAPPINGS 13 | :no-value: 14 | .. autodata:: olmo_core.nn.hf.convert.MODEL_TYPE_SPECIFIC_HF_TO_OLMO_CORE_MODULE_MAPPINGS 15 | :no-value: 16 | .. autodata:: olmo_core.nn.hf.convert.HF_TO_OLMO_CORE_TEMPLATE_MAPPINGS 17 | :no-value: 18 | .. autodata:: olmo_core.nn.hf.convert.OLMO_CORE_TO_HF_WEIGHT_MAPPINGS 19 | :no-value: 20 | .. autodata:: olmo_core.nn.hf.convert.OLMO_CORE_TO_HF_MODULE_MAPPINGS 21 | :no-value: 22 | .. autodata:: olmo_core.nn.hf.convert.OLMO_CORE_TO_HF_TEMPLATE_MAPPINGS 23 | :no-value: 24 | .. autodata:: olmo_core.nn.hf.convert.MODEL_TYPE_SPECIFIC_OLMO_CORE_TO_HF_TEMPLATE_MAPPINGS 25 | :no-value: -------------------------------------------------------------------------------- /src/olmo_core/ops/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AutoAuxiliaryLoss(torch.autograd.Function): 5 | """ 6 | An autograd function that triggers the backward pass for an auxiliary loss. 7 | """ 8 | 9 | @staticmethod 10 | def forward(ctx, activation: torch.Tensor, aux_loss: torch.Tensor): 11 | ctx.save_for_backward(aux_loss) 12 | return activation 13 | 14 | @staticmethod 15 | def backward(ctx, grad_output: torch.Tensor): 16 | (aux_loss,) = ctx.saved_tensors 17 | aux_loss_grad = torch.ones_like(aux_loss) 18 | return grad_output, aux_loss_grad 19 | 20 | 21 | def attach_auxiliary_loss(activation: torch.Tensor, aux_loss: torch.Tensor) -> torch.Tensor: 22 | """ 23 | Attach an auxiliary loss to an activation with an autograd function in order to trigger 24 | gradients for the aux loss in the backwards pass. 25 | 26 | :returns: The input activation unchanged. 27 | """ 28 | return AutoAuxiliaryLoss.apply(activation, aux_loss) # type: ignore[return-value] 29 | -------------------------------------------------------------------------------- /src/scripts/release/add_pr_comments_on_release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | repo_url=https://github.com/allenai/OLMo-core 6 | 7 | tags=$(git tag -l --sort=-version:refname 'v*' | head -n 2) 8 | current_tag=$(echo "$tags" | head -n 1) 9 | last_tag=$(echo "$tags" | tail -n 1) 10 | 11 | echo "Current release: $current_tag" 12 | echo "Last release: $last_tag" 13 | 14 | if [ -z "$last_tag" ]; then 15 | echo "No previous release, nothing to do" 16 | exit 0; 17 | fi 18 | 19 | commits_since_last_release=$(git log "${last_tag}..${current_tag}" --format=format:%H) 20 | 21 | echo "Commits/PRs since last release:" 22 | for commit in $commits_since_last_release; do 23 | pr_number=$(gh pr list --search "$commit" --state merged --json number --jq '.[-1].number') 24 | if [ -z "$pr_number" ]; then 25 | echo "$commit" 26 | else 27 | echo "$commit (PR #$pr_number)" 28 | gh pr comment "$pr_number" --body "This PR has been released in [${current_tag}](${repo_url}/releases/tag/${current_tag})." 29 | fi 30 | done 31 | -------------------------------------------------------------------------------- /src/olmo_core/generate/generation_module/generation_module.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABCMeta, abstractmethod 3 | from typing import Any, Dict, Optional 4 | 5 | import torch.distributed as dist 6 | from torch.distributed.checkpoint.metadata import Metadata 7 | from torch.distributed.checkpoint.stateful import Stateful 8 | 9 | log = logging.getLogger(__name__) 10 | 11 | 12 | class GenerationModule(Stateful, metaclass=ABCMeta): 13 | @property 14 | def dp_process_group(self) -> Optional[dist.ProcessGroup]: 15 | """ 16 | Should return the data parallel process group if it's anything other than the default 17 | process group. 18 | """ 19 | return None 20 | 21 | def state_dict_to_load(self, metadata: Metadata) -> Dict[str, Any]: 22 | del metadata 23 | return self.state_dict() 24 | 25 | @abstractmethod 26 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 27 | """ 28 | Load a state dict. 29 | """ 30 | raise NotImplementedError 31 | -------------------------------------------------------------------------------- /src/test/data/utils.py: -------------------------------------------------------------------------------- 1 | from os import PathLike 2 | from pathlib import Path 3 | from typing import Any, List, Tuple, Type, Union 4 | 5 | import numpy as np 6 | 7 | Mmaps = List[Tuple[Union[Path, PathLike[Any], str], Any]] 8 | 9 | 10 | def mk_mmaps( 11 | tmp_path: Path, 12 | prefix: str, 13 | num_files: int, 14 | size: int, 15 | dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] = np.uint32, 16 | eos: int = 0, 17 | seq_length: int = 4, 18 | seed: int = 42, 19 | ) -> Mmaps: 20 | mmaps: Mmaps = [] 21 | for i in range(num_files): 22 | filepath = f"{tmp_path}/{prefix}_{i}.npy" 23 | np.random.seed(seed) 24 | data = np.random.randint(1, np.iinfo(dtype).max, size=size, dtype=dtype) 25 | data = np.append( 26 | np.insert(data, np.arange(seq_length + 1, len(data), seq_length), eos), eos 27 | ) 28 | mm = np.memmap(filepath, mode="w+", dtype=dtype, shape=(len(data),)) 29 | mm[:] = data 30 | mm.flush() 31 | mmaps.append((Path(filepath), data)) 32 | 33 | return mmaps 34 | -------------------------------------------------------------------------------- /docs/source/overview/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | Prior to installing OLMo-core you should install `PyTorch `_ according to the official instructions 5 | specific to your operating system and hardware. 6 | 7 | Then you can install OLMo-core from `PyPI `_ with:: 8 | 9 | pip install ai2-olmo-core 10 | 11 | There are a number of optional dependencies that must be installed to use certain functionality as well, including: 12 | 13 | - `flash-attn `_, `ring-flash-attn `_, and `TransformerEngine `_ for the corresponding attention backends. 14 | - `Liger-Kernel `_ for a low-memory "fused-linear" loss implementation. 15 | - `torchao `_ for float8 training (see :mod:`olmo_core.float8`). 16 | - `grouped_gemm `_ for dropless mixture-of-experts (MoE) models (see :mod:`olmo_core.nn.moe`). 17 | -------------------------------------------------------------------------------- /src/scripts/release/release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | # Make sure clone is up-to-date with remote. 6 | git pull > /dev/null 7 | git tag -l | xargs git tag -d > /dev/null 8 | git fetch -t > /dev/null 9 | 10 | TAG=$(python -c 'from olmo_core.version import VERSION; print("v" + VERSION)') 11 | 12 | # Make sure tag/release doesn't already exist. 13 | STATUS_CODE=$(curl -s -o /dev/null -w "%{http_code}" "https://github.com/allenai/OLMo-core/releases/tag/${TAG}") 14 | if [[ $STATUS_CODE == "200" ]]; then 15 | echo "Release tag ${TAG} already exists" 16 | exit 1 17 | fi 18 | 19 | python src/scripts/release/prepare_changelog.py 20 | 21 | read -rp "Creating new release for $TAG. Do you want to continue? [Y/n] " prompt 22 | 23 | if [[ $prompt == "y" || $prompt == "Y" || $prompt == "yes" || $prompt == "Yes" ]]; then 24 | git add -A 25 | git commit -m "(chore) prepare for release $TAG" || true && git push 26 | echo "Creating new git tag $TAG" 27 | git tag "$TAG" -m "$TAG" 28 | git push --tags 29 | else 30 | echo "Canceled" 31 | git checkout CHANGELOG.md 32 | exit 1 33 | fi 34 | -------------------------------------------------------------------------------- /src/olmo_core/data/mixes/v3-small-ppl-validation.txt: -------------------------------------------------------------------------------- 1 | c4_en-validation,eval-data/perplexity/v3_small_{TOKENIZER}/c4_en/val/part-0-00000.npy 2 | dolma_books-validation,eval-data/perplexity/v3_small_{TOKENIZER}/dolma_books/val/part-0-00000.npy 3 | dolma_common-crawl-validation,eval-data/perplexity/v3_small_{TOKENIZER}/dolma_common-crawl/val/part-0-00000.npy 4 | dolma_pes2o-validation,eval-data/perplexity/v3_small_{TOKENIZER}/dolma_pes2o/val/part-0-00000.npy 5 | dolma_reddit-validation,eval-data/perplexity/v3_small_{TOKENIZER}/dolma_reddit/val/part-0-00000.npy 6 | dolma_stack-validation,eval-data/perplexity/v3_small_{TOKENIZER}/dolma_stack/val/part-0-00000.npy 7 | dolma_wiki-validation,eval-data/perplexity/v3_small_{TOKENIZER}/dolma_wiki/val/part-0-00000.npy 8 | ice-validation,eval-data/perplexity/v3_small_{TOKENIZER}/ice/val/part-0-00000.npy 9 | m2d2_s2orc-validation,eval-data/perplexity/v3_small_{TOKENIZER}/m2d2_s2orc/val/part-0-00000.npy 10 | pile-validation,eval-data/perplexity/v3_small_{TOKENIZER}/pile/val/part-0-00000.npy 11 | wikitext_103-validation,eval-data/perplexity/v3_small_{TOKENIZER}/wikitext_103/val/part-0-00000.npy 12 | -------------------------------------------------------------------------------- /src/olmo_core/testing/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributed import run_distributed_test 2 | from .utils import ( 3 | BACKENDS, 4 | DEVICES, 5 | FLASH_2_MARKS, 6 | FLASH_3_MARKS, 7 | GPU_MARKS, 8 | GROUPED_GEMM_MARKS, 9 | INIT_DEVICES, 10 | LOW_PRECISION_DTYPES, 11 | MULTI_GPU_MARKS, 12 | TE_MARKS, 13 | has_cuda, 14 | has_flash_attn_2, 15 | has_grouped_gemm, 16 | has_multiple_gpus, 17 | has_torchao, 18 | requires_flash_attn_2, 19 | requires_gpu, 20 | requires_grouped_gemm, 21 | requires_multi_gpu, 22 | requires_te, 23 | ) 24 | 25 | __all__ = [ 26 | "BACKENDS", 27 | "DEVICES", 28 | "FLASH_MARKS", 29 | "FLASH_3_MARKS", 30 | "TE_MARKS", 31 | "GPU_MARKS", 32 | "GROUPED_GEMM_MARKS", 33 | "INIT_DEVICES", 34 | "LOW_PRECISION_DTYPES", 35 | "MULTI_GPU_MARKS", 36 | "has_cuda", 37 | "has_flash_attn_2", 38 | "has_grouped_gemm", 39 | "has_multiple_gpus", 40 | "has_torchao", 41 | "requires_flash_attn", 42 | "requires_te", 43 | "requires_gpu", 44 | "requires_grouped_gemm", 45 | "requires_multi_gpu", 46 | "run_distributed_test", 47 | ] 48 | -------------------------------------------------------------------------------- /src/scripts/release/prepare_changelog.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from pathlib import Path 3 | 4 | from olmo_core.version import VERSION 5 | 6 | 7 | def main() -> None: 8 | changelog = Path("CHANGELOG.md") 9 | 10 | with changelog.open() as f: 11 | lines = f.readlines() 12 | 13 | insert_index: int = -1 14 | for i in range(len(lines)): 15 | line = lines[i] 16 | if line.startswith("## Unreleased"): 17 | insert_index = i + 1 18 | elif line.startswith(f"## [v{VERSION}]"): 19 | print("CHANGELOG already up-to-date") 20 | return 21 | elif line.startswith("## [v"): 22 | break 23 | 24 | if insert_index < 0: 25 | raise RuntimeError("Couldn't find 'Unreleased' section") 26 | 27 | lines.insert(insert_index, "\n") 28 | lines.insert( 29 | insert_index + 1, 30 | f"## [v{VERSION}](https://github.com/allenai/OLMo-core/releases/tag/v{VERSION}) - " 31 | f"{datetime.now().strftime('%Y-%m-%d')}\n", 32 | ) 33 | 34 | with changelog.open("w") as f: 35 | f.writelines(lines) 36 | 37 | 38 | if __name__ == "__main__": 39 | main() 40 | -------------------------------------------------------------------------------- /src/test/distributed/utils_test.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import pytest 4 | import torch.distributed as dist 5 | 6 | import olmo_core.distributed.utils as dist_utils 7 | from olmo_core.testing import BACKENDS, run_distributed_test 8 | 9 | 10 | def scatter_object(): 11 | if dist.get_rank() == 0: 12 | x = ("abc", "def") 13 | else: 14 | x = ("abc", "abc") 15 | x = dist_utils.scatter_object(x) 16 | assert x == ("abc", "def") 17 | 18 | 19 | @pytest.mark.parametrize("backend", BACKENDS) 20 | def test_scatter_object(backend: str): 21 | run_distributed_test(scatter_object, backend=backend) 22 | 23 | 24 | @pytest.mark.parametrize("n, world_size", [(2, 1), (8, 64)]) 25 | def test_do_n_at_a_time(n: int, world_size: int): 26 | times_called = 0 27 | calling_ranks = set() 28 | 29 | def func(rank: int): 30 | nonlocal times_called 31 | times_called += 1 32 | calling_ranks.add(rank) 33 | 34 | for rank in range(world_size): 35 | dist_utils.do_n_at_a_time(partial(func, rank), n=n, world_size=world_size, local_rank=rank) 36 | 37 | assert times_called == world_size 38 | assert calling_ranks == set(range(world_size)) 39 | -------------------------------------------------------------------------------- /src/test/utils_test.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import pytest 4 | import torch 5 | 6 | from olmo_core.utils import apply_to_tensors, flatten_dict 7 | 8 | 9 | @dataclass 10 | class Foo: 11 | x: torch.Tensor 12 | 13 | 14 | @pytest.mark.parametrize( 15 | "container, tensor_count", 16 | [ 17 | (Foo(x=torch.rand(2, 2)), 1), 18 | ({"x": torch.rand(2, 2)}, 1), 19 | ((torch.rand(2, 2),), 1), 20 | ([torch.rand(2, 2)], 1), 21 | ({torch.rand(2, 2)}, 1), 22 | ({"x": {"x": torch.rand(2, 2), "y": torch.rand(1, 1)}}, 2), 23 | ((torch.rand(1) for _ in range(2)), 2), 24 | ], 25 | ) 26 | def test_apply_to_tensors(container, tensor_count): 27 | count = 0 28 | 29 | def count_tensors(x): 30 | nonlocal count 31 | if isinstance(x, torch.Tensor): 32 | count += 1 33 | 34 | apply_to_tensors(count_tensors, container) 35 | 36 | assert count == tensor_count 37 | 38 | 39 | def test_flatten_dict(): 40 | assert flatten_dict( 41 | { 42 | "a": {"foo": 1, "bar": {"baz": 2}}, 43 | "b": 2, 44 | } 45 | ) == { 46 | "a.foo": 1, 47 | "a.bar.baz": 2, 48 | "b": 2, 49 | } 50 | -------------------------------------------------------------------------------- /src/olmo_core/data/types.py: -------------------------------------------------------------------------------- 1 | from typing import Type, Union 2 | 3 | import numpy as np 4 | 5 | from olmo_core.config import StrEnum 6 | 7 | NumpyUIntTypes = Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] 8 | 9 | 10 | class LongDocStrategy(StrEnum): 11 | """ 12 | Specifies how to handle documents that are longer than the max sequence length when packing. 13 | """ 14 | 15 | truncate = "truncate" 16 | """ 17 | Long docs are truncated and the excess tokens are discarded. 18 | """ 19 | 20 | fragment = "fragment" 21 | """ 22 | Long docs are split into smaller docs so that no tokens are discarded, but you end up with 23 | fragmented docs. 24 | """ 25 | 26 | 27 | class NumpyDatasetDType(StrEnum): 28 | """ 29 | Supported numpy unsigned integer data types for datasets. 30 | """ 31 | 32 | uint8 = "uint8" 33 | uint16 = "uint16" 34 | uint32 = "uint32" 35 | uint64 = "uint64" 36 | 37 | def as_np_dtype(self) -> NumpyUIntTypes: 38 | """ 39 | Convert the enum value to its corresponding numpy dtype. 40 | 41 | Returns: 42 | The numpy unsigned integer dtype corresponding to this enum value. 43 | """ 44 | return getattr(np, str(self)) 45 | -------------------------------------------------------------------------------- /src/olmo_core/train/train_module/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import ( 2 | TransformerActivationCheckpointingConfig, 3 | TransformerActivationCheckpointingMode, 4 | TransformerContextParallelConfig, 5 | TransformerDataParallelConfig, 6 | TransformerDataParallelWrappingStrategy, 7 | TransformerExpertParallelConfig, 8 | TransformerPipelineParallelConfig, 9 | TransformerPipelineTrainModuleConfig, 10 | TransformerTensorParallelConfig, 11 | TransformerTrainModuleConfig, 12 | ) 13 | from .pipeline_train_module import TransformerPipelineTrainModule 14 | from .bolmo_train_module import TransformerBolmoTrainModule 15 | from .train_module import TransformerTrainModule 16 | 17 | __all__ = [ 18 | "TransformerBolmoTrainModule", 19 | "TransformerTrainModule", 20 | "TransformerTrainModuleConfig", 21 | "TransformerPipelineTrainModule", 22 | "TransformerPipelineTrainModuleConfig", 23 | "TransformerActivationCheckpointingConfig", 24 | "TransformerActivationCheckpointingMode", 25 | "TransformerDataParallelConfig", 26 | "TransformerDataParallelWrappingStrategy", 27 | "TransformerExpertParallelConfig", 28 | "TransformerTensorParallelConfig", 29 | "TransformerContextParallelConfig", 30 | "TransformerPipelineParallelConfig", 31 | ] 32 | -------------------------------------------------------------------------------- /src/test/nn/layer_norm_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from olmo_core.exceptions import OLMoConfigurationError 5 | from olmo_core.nn.layer_norm import ( 6 | FusedRMSNorm, 7 | L2Norm, 8 | LayerNormConfig, 9 | LayerNormType, 10 | RMSNorm, 11 | ) 12 | from olmo_core.testing import requires_flash_attn_2, requires_gpu 13 | 14 | 15 | @requires_gpu 16 | @requires_flash_attn_2 17 | @pytest.mark.parametrize("bias", [pytest.param(True, id="bias"), pytest.param(False, id="no-bias")]) 18 | @pytest.mark.parametrize( 19 | "dtype", [pytest.param(torch.float32, id="fp32"), pytest.param(torch.bfloat16, id="bf16")] 20 | ) 21 | def test_fused_rms_norm(bias, dtype): 22 | dim = 64 23 | norm = RMSNorm(size=dim, bias=bias, init_device="cuda") 24 | norm_fused = FusedRMSNorm(size=dim, bias=bias, init_device="cuda") 25 | 26 | x = torch.randn(4, dim, device="cuda", dtype=dtype) 27 | y1 = norm(x) 28 | y2 = norm_fused(x) 29 | torch.testing.assert_close(y1, y2) 30 | 31 | 32 | def test_layer_norm_builder_config(): 33 | norm = LayerNormConfig(name=LayerNormType.l2_norm).build(size=1024) 34 | assert isinstance(norm, L2Norm) 35 | 36 | with pytest.raises(OLMoConfigurationError): 37 | LayerNormConfig(name=LayerNormType.l2_norm, elementwise_affine=True).build(size=1024) 38 | -------------------------------------------------------------------------------- /src/test/nn/functional/cross_entropy_loss_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from olmo_core.nn.functional import cross_entropy_loss 5 | from olmo_core.testing import DEVICES 6 | 7 | 8 | @pytest.mark.parametrize("device", DEVICES) 9 | @pytest.mark.parametrize("reduction", ["sum", "mean"]) 10 | def test_cross_entropy_loss(device, reduction): 11 | vocab_size = 50257 12 | N = 32 13 | 14 | logits = torch.randn(N, vocab_size, device=device) 15 | labels = torch.randint(0, vocab_size, (N,), device=device) 16 | 17 | ce_loss, z_loss = cross_entropy_loss(logits, labels, reduction=reduction, compute_z_loss=True) 18 | assert ce_loss.shape == tuple() 19 | assert ce_loss.numel() == 1 20 | assert z_loss is not None 21 | assert z_loss.shape == tuple() 22 | assert z_loss.numel() == 1 23 | 24 | # Now add some masked values to logits and labels and make sure we get the same result. 25 | logits_padded = torch.cat([logits, torch.rand(3, vocab_size, device=device)], dim=0) 26 | labels_padded = torch.cat([labels, torch.tensor([-100] * 3, device=device)], dim=0) 27 | ce_loss1, z_loss1 = cross_entropy_loss( 28 | logits_padded, labels_padded, reduction=reduction, compute_z_loss=True 29 | ) 30 | torch.testing.assert_close(ce_loss, ce_loss1) 31 | torch.testing.assert_close(z_loss, z_loss1) 32 | -------------------------------------------------------------------------------- /src/olmo_core/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .adam import AdamConfig 2 | from .adamw import AdamWConfig, SkipStepAdamW, SkipStepAdamWConfig 3 | from .config import INITIAL_LR_FIELD, LR_FIELD, OptimConfig, OptimGroupOverride 4 | from .lion import Lion, LionConfig, SkipStepLion, SkipStepLionConfig 5 | from .noop import NoOpConfig, NoOpOptimizer 6 | from .scheduler import ( 7 | WSD, 8 | WSDS, 9 | ConstantScheduler, 10 | ConstantWithWarmup, 11 | CosWithWarmup, 12 | HalfCosWithWarmup, 13 | InvSqrtWithWarmup, 14 | LinearWithWarmup, 15 | Scheduler, 16 | SchedulerUnits, 17 | SequentialScheduler, 18 | ) 19 | from .skip_step_optimizer import SkipStepOptimizer 20 | 21 | __all__ = [ 22 | "OptimConfig", 23 | "OptimGroupOverride", 24 | "SkipStepOptimizer", 25 | "AdamWConfig", 26 | "SkipStepAdamWConfig", 27 | "SkipStepAdamW", 28 | "AdamConfig", 29 | "LionConfig", 30 | "Lion", 31 | "SkipStepLionConfig", 32 | "SkipStepLion", 33 | "NoOpConfig", 34 | "NoOpOptimizer", 35 | "Scheduler", 36 | "SchedulerUnits", 37 | "ConstantScheduler", 38 | "ConstantWithWarmup", 39 | "CosWithWarmup", 40 | "HalfCosWithWarmup", 41 | "InvSqrtWithWarmup", 42 | "LinearWithWarmup", 43 | "SequentialScheduler", 44 | "WSD", 45 | "WSDS", 46 | "LR_FIELD", 47 | "INITIAL_LR_FIELD", 48 | ] 49 | -------------------------------------------------------------------------------- /src/olmo_core/train/train_module/__init__.py: -------------------------------------------------------------------------------- 1 | from .train_module import ( 2 | BasicTrainModule, 3 | EvalBatchSizeUnit, 4 | EvalBatchSpec, 5 | TrainModule, 6 | ) 7 | from .transformer import ( 8 | TransformerActivationCheckpointingConfig, 9 | TransformerActivationCheckpointingMode, 10 | TransformerContextParallelConfig, 11 | TransformerDataParallelConfig, 12 | TransformerDataParallelWrappingStrategy, 13 | TransformerExpertParallelConfig, 14 | TransformerPipelineParallelConfig, 15 | TransformerPipelineTrainModule, 16 | TransformerPipelineTrainModuleConfig, 17 | TransformerTensorParallelConfig, 18 | TransformerTrainModule, 19 | TransformerTrainModuleConfig, 20 | TransformerBolmoTrainModule, 21 | ) 22 | 23 | __all__ = [ 24 | "TrainModule", 25 | "EvalBatchSpec", 26 | "EvalBatchSizeUnit", 27 | "BasicTrainModule", 28 | "TransformerTrainModule", 29 | "TransformerTrainModuleConfig", 30 | "TransformerPipelineTrainModule", 31 | "TransformerPipelineTrainModuleConfig", 32 | "TransformerActivationCheckpointingConfig", 33 | "TransformerActivationCheckpointingMode", 34 | "TransformerBolmoTrainModule", 35 | "TransformerDataParallelConfig", 36 | "TransformerDataParallelWrappingStrategy", 37 | "TransformerExpertParallelConfig", 38 | "TransformerTensorParallelConfig", 39 | "TransformerContextParallelConfig", 40 | "TransformerPipelineParallelConfig", 41 | ] 42 | -------------------------------------------------------------------------------- /src/olmo_core/nn/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .block import ( 2 | LayerNormScaledTransformerBlock, 3 | MoEHybridReorderedNormTransformerBlock, 4 | MoEHybridTransformerBlock, 5 | MoEHybridTransformerBlockBase, 6 | MoEReorderedNormTransformerBlock, 7 | MoETransformerBlock, 8 | NormalizedTransformerBlock, 9 | ReorderedNormTransformerBlock, 10 | TransformerBlock, 11 | TransformerBlockBase, 12 | ) 13 | from .config import ( 14 | TransformerActivationCheckpointingMode, 15 | TransformerBlockConfig, 16 | TransformerBlockType, 17 | TransformerConfig, 18 | TransformerDataParallelWrappingStrategy, 19 | TransformerType, 20 | ) 21 | from .init import InitMethod 22 | from .model import MoETransformer, NormalizedTransformer, Transformer 23 | 24 | __all__ = [ 25 | "TransformerType", 26 | "TransformerConfig", 27 | "Transformer", 28 | "NormalizedTransformer", 29 | "MoETransformer", 30 | "MoEHybridTransformerBlockBase", 31 | "MoEHybridTransformerBlock", 32 | "MoEHybridReorderedNormTransformerBlock", 33 | "TransformerBlockType", 34 | "TransformerBlockConfig", 35 | "TransformerBlockBase", 36 | "TransformerBlock", 37 | "ReorderedNormTransformerBlock", 38 | "LayerNormScaledTransformerBlock", 39 | "NormalizedTransformerBlock", 40 | "MoETransformerBlock", 41 | "MoEReorderedNormTransformerBlock", 42 | "TransformerDataParallelWrappingStrategy", 43 | "TransformerActivationCheckpointingMode", 44 | "InitMethod", 45 | ] 46 | -------------------------------------------------------------------------------- /src/scripts/beaker/launch_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Launch tests on Beaker. 3 | """ 4 | 5 | import sys 6 | from typing import List 7 | 8 | from rich import print 9 | 10 | from olmo_core.launch.beaker import BeakerLaunchConfig, OLMoCoreBeakerImage 11 | from olmo_core.utils import generate_uuid, prepare_cli_environment 12 | 13 | 14 | def build_config(command: List[str], overrides: List[str]) -> BeakerLaunchConfig: 15 | return BeakerLaunchConfig( 16 | name=f"olmo-core-test-{generate_uuid()[:8]}", 17 | budget="ai2/oe-base", 18 | cmd=command, 19 | task_name="test", 20 | workspace="ai2/OLMo-core", 21 | beaker_image=OLMoCoreBeakerImage.stable, 22 | clusters=[ 23 | "ai2/jupiter", 24 | "ai2/augusta", 25 | "ai2/ceres", 26 | ], 27 | num_nodes=1, 28 | num_gpus=2, 29 | shared_filesystem=True, 30 | # host_networking=False, 31 | ).merge(overrides) 32 | 33 | 34 | if __name__ == "__main__": 35 | if len(sys.argv) < 3 or "--" not in sys.argv: 36 | print(f"Usage: python {sys.argv[0]} [OVERRIDES...] -- [CMD...]") 37 | sys.exit(1) 38 | 39 | sep_index = sys.argv.index("--") 40 | overrides = sys.argv[1:sep_index] 41 | entrypoint = sys.argv[sep_index + 1] 42 | command = sys.argv[sep_index + 2 :] 43 | 44 | prepare_cli_environment() 45 | 46 | config = build_config(command, overrides) 47 | print(config) 48 | config.launch(follow=True, torchrun=False, entrypoint=entrypoint) 49 | -------------------------------------------------------------------------------- /src/olmo_core/float8/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | 5 | 6 | def cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 7 | """ 8 | Cast a tensor to FP8 with scaling factors calculated from contiguous blocks of 128 elements 9 | in last dimension of the tensor. The size of the last dimension must be divisible by 128. 10 | 11 | :returns: The FP8 tensor and its scaling factors. 12 | """ 13 | assert x.dim() >= 2 14 | assert x.size(-1) % 128 == 0 15 | in_shape = x.shape 16 | x = x.view(*in_shape[:-1], -1, 128) 17 | x_amax = x.abs().float().amax(dim=-1).view(*in_shape[:-1], -1).clamp(1e-4) 18 | x = (x * (448.0 / x_amax.unsqueeze(-1))).to(torch.float8_e4m3fn) 19 | return x.view(in_shape), x_amax / 448.0 20 | 21 | 22 | def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 23 | """ 24 | Cast a tensor with shape ``(*, m, n)`` to FP8 with scaling factors calculated from interior 25 | 128 x 128 blocks. 26 | 27 | :returns: The FP8 tensor and its scaling factors. 28 | """ 29 | assert x.dim() >= 2 30 | assert x.size(-1) % 128 == 0 and x.size(-2) % 128 == 0 31 | in_shape = x.shape 32 | m, n = in_shape[-2:] 33 | x = x.view(*in_shape[:-2], m // 128, 128, n // 128, 128) 34 | x_amax = x.abs().float().amax(dim=(-3, -1), keepdim=True).clamp(1e-4) 35 | x = (x * (448.0 / x_amax)).to(torch.float8_e4m3fn) 36 | return x.view(in_shape).contiguous(), (x_amax / 448.0).view(*in_shape[:-2], m // 128, n // 128) 37 | -------------------------------------------------------------------------------- /docs/source/examples/huggingface.rst: -------------------------------------------------------------------------------- 1 | HuggingFace models 2 | ================== 3 | 4 | The OLMo-core :class:`~olmo_core.train.Trainer` can be used to fine-tune language models from HuggingFace's ``transformers`` library. 5 | 6 | One way to do this would be to manually apply a data parallel wrapper (like DDP or FSDP) to your ``AutoModelForCausalLM`` and then pass that model directly to the trainer. The downside with this approach is that you won't be able to take advantage of all of the optimizations in this library. 7 | 8 | Instead we recommend converting your HuggingFace checkpoint into a format that can be loaded into an equivalent OLMo-core :class:`~olmo_core.nn.transformer.Transformer` model, when possible, using the functions provided by :mod:`olmo_core.nn.hf`. 9 | 10 | Below is an example that shows how to convert an OLMo2 or Llama-3 checkpoint on HuggingFace into the right format for OLMo-core, and an example for how to convert OLMo-core checkpoints into HuggingFace formats. The mapping of OLMo Core and HF states can be configured using the constants in :mod:`olmo_core.nn.hf.convert` (see :mod:`olmo_core.nn.hf`). 11 | 12 | .. seealso:: 13 | See the `train a Llama model `_ example to learn how to use OLMo-core's training API to pretrain or fine-tune any Llama-like language model. 14 | 15 | .. tab:: ``src/examples/huggingface/convert_checkpoint_from_hf.py`` 16 | 17 | .. literalinclude:: ../../../src/examples/huggingface/convert_checkpoint_from_hf.py 18 | :language: py 19 | 20 | .. tab:: ``src/examples/huggingface/convert_checkpoint_to_hf.py`` 21 | 22 | .. literalinclude:: ../../../src/examples/huggingface/convert_checkpoint_to_hf.py 23 | :language: py 24 | -------------------------------------------------------------------------------- /src/olmo_core/train/callbacks/garbage_collector.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | from ...aliases import PathOrStr 7 | from .callback import Callback 8 | 9 | log = logging.getLogger(__name__) 10 | 11 | 12 | @dataclass 13 | class GarbageCollectorCallback(Callback): 14 | """ 15 | Disables automatic garbage collection during training and runs gen1 collection 16 | on a set schedule instead. 17 | 18 | .. important:: 19 | This callback gets added automatically in a distributed training setting if you 20 | don't explicitly configure it. 21 | If you want to override this callback you should subclass it. 22 | """ 23 | 24 | gc_interval: int = 1000 25 | enabled: bool = True 26 | _start_state: Optional[bool] = None 27 | 28 | def pre_train(self): 29 | if not self.enabled: 30 | return 31 | self._start_state = gc.isenabled() 32 | gc.disable() 33 | log.info(f"Automatic GC disabled for training, will run GC every {self.gc_interval} steps") 34 | 35 | def post_step(self): 36 | if not self.enabled: 37 | return 38 | if self.step % self.gc_interval == 0: 39 | if self.gc_interval > 10: 40 | log.info("Running garbage collection") 41 | gc.collect(1) 42 | 43 | def close(self): 44 | if not self.enabled: 45 | return 46 | if self._start_state: 47 | gc.enable() 48 | 49 | def post_checkpoint_saved(self, path: PathOrStr): 50 | del path 51 | if not self.enabled: 52 | return 53 | gc.collect(1) 54 | -------------------------------------------------------------------------------- /src/test/nn/hf/config_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from transformers import Olmo2Config 3 | 4 | from olmo_core.nn.hf.config import get_hf_config 5 | from olmo_core.nn.transformer.config import TransformerConfig 6 | 7 | try: 8 | from transformers import FlexOlmoConfig # type: ignore 9 | except ImportError: 10 | FlexOlmoConfig = None 11 | 12 | 13 | def test_get_hf_config(): 14 | vocab_size = 4096 15 | model_config = TransformerConfig.olmo2_190M(vocab_size) 16 | model = model_config.build() 17 | 18 | hf_config = get_hf_config(model) 19 | assert isinstance(hf_config, Olmo2Config) 20 | assert hf_config.hidden_size == model_config.d_model 21 | assert hf_config.intermediate_size == 3072 22 | assert hf_config.num_hidden_layers == model_config.n_layers 23 | 24 | 25 | def test_get_hf_config_default_block(): 26 | vocab_size = 4096 27 | model_config = TransformerConfig.llama2_271M(vocab_size) 28 | model = model_config.build() 29 | 30 | with pytest.raises(NotImplementedError): 31 | get_hf_config(model) 32 | 33 | 34 | def test_get_hf_config_moe(): 35 | vocab_size = 4096 36 | model_config = TransformerConfig.smallmoe(vocab_size) 37 | model = model_config.build() 38 | 39 | if FlexOlmoConfig is None: 40 | pytest.skip("The installed transformers version does not support FlexOlmo") 41 | 42 | hf_config = get_hf_config(model) 43 | assert isinstance(hf_config, FlexOlmoConfig) 44 | assert hf_config.hidden_size == model_config.d_model 45 | assert model_config.block.feed_forward_moe is not None 46 | assert hf_config.intermediate_size == model_config.block.feed_forward_moe.hidden_size 47 | assert hf_config.num_hidden_layers == model_config.n_layers 48 | -------------------------------------------------------------------------------- /src/test/data/mixes_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from olmo_core.data import DataMix, TokenizerName 4 | from olmo_core.io import file_exists 5 | 6 | 7 | def test_olmoe_mix(): 8 | from botocore.exceptions import NoCredentialsError 9 | 10 | paths, labels = DataMix.OLMoE_mix_0824.build("s3://ai2-llm", TokenizerName.dolma2) 11 | assert len(paths) == len(labels) 12 | assert ( 13 | paths[-1] 14 | == "s3://ai2-llm/preprocessed/olmo-mix/danyh-compiled-v1_7/documents/wiki/allenai/dolma2-tokenizer/part-1-00000.npy" 15 | ) 16 | 17 | try: 18 | assert file_exists(paths[-1]) 19 | except NoCredentialsError: 20 | pytest.skip("Requires AWS credentials") 21 | 22 | 23 | def test_dolma17_mix(): 24 | from botocore.exceptions import NoCredentialsError 25 | 26 | paths, labels = DataMix.dolma17.build("s3://ai2-llm", TokenizerName.gpt_neox_olmo_dolma_v1_5) 27 | assert len(paths) == len(labels) 28 | assert ( 29 | paths[-1] 30 | == "s3://ai2-llm/preprocessed/olmo-mix/v1_7-dd_ngram_dp_030-qc_cc_en_bin_001/cc_en_tail/gpt-neox-olmo-dolma-v1_5/part-092-00000.npy" 31 | ) 32 | 33 | try: 34 | assert file_exists(paths[-1]) 35 | except NoCredentialsError: 36 | pytest.skip("Requires AWS credentials") 37 | 38 | 39 | def test_v3_small_ppl_validation_mix(): 40 | from botocore.exceptions import NoCredentialsError 41 | 42 | paths, labels = DataMix.v3_small_ppl_validation.build("s3://ai2-llm", TokenizerName.dolma2) 43 | assert len(paths) == len(labels) 44 | assert ( 45 | paths[0] 46 | == "s3://ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/c4_en/val/part-0-00000.npy" 47 | ) 48 | assert labels[0] == "c4_en-validation" 49 | 50 | try: 51 | assert file_exists(paths[-1]) 52 | except NoCredentialsError: 53 | pytest.skip("Requires AWS credentials") 54 | -------------------------------------------------------------------------------- /src/test/optim/skip_step_optimizer_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | 5 | from olmo_core.optim import SkipStepAdamWConfig 6 | from olmo_core.testing import DEVICES 7 | 8 | 9 | class MyModel(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | self.wte = nn.Embedding(1024, 16) 13 | self.fc1 = nn.Linear(16, 32) 14 | self.fc2 = nn.Linear(32, 16) 15 | 16 | def forward(self, x: torch.Tensor) -> torch.Tensor: 17 | x = self.wte(x) 18 | x = self.fc1(x) 19 | x = torch.relu(x) 20 | return self.fc2(x) 21 | 22 | 23 | @pytest.mark.parametrize("device", DEVICES) 24 | def test_skip_step_optimizer(device: torch.device): 25 | """Test that skip step optimizer skips steps with outlier losses.""" 26 | model = MyModel().to(device) 27 | optim = SkipStepAdamWConfig(rolling_interval_length=2, sigma_factor=1).build(model) 28 | 29 | # Normal step - should not skip 30 | optim.zero_grad(set_to_none=True) 31 | loss = model(torch.randint(0, 128, (4, 8), device=device)).sum() 32 | optim.latest_loss = loss.detach() 33 | loss.backward() 34 | optim.step() 35 | assert torch.equal(optim.step_skipped.cpu().detach(), torch.tensor(False)) 36 | 37 | # Outlier step - should skip 38 | optim.zero_grad(set_to_none=True) 39 | loss = model(torch.randint(0, 128, (4, 8), device=device)).sum() 40 | optim.latest_loss = torch.tensor(1e9, device=device) # Outlier loss 41 | loss.backward() 42 | optim.step() 43 | assert torch.equal(optim.step_skipped.cpu().detach(), torch.tensor(True)) 44 | 45 | # Another normal step 46 | optim.zero_grad(set_to_none=True) 47 | loss = model(torch.randint(0, 128, (4, 8), device=device)).sum() 48 | optim.latest_loss = loss.detach() 49 | loss.backward() 50 | optim.step() 51 | assert torch.equal(optim.step_skipped.cpu().detach(), torch.tensor(False)) 52 | -------------------------------------------------------------------------------- /bolmo_scripts/launch_stage2_7b.sh: -------------------------------------------------------------------------------- 1 | NAME=stage2_bolmo_7b 2 | NUM_WORKERS=24 \ 3 | OLMO_ARCH=olmo3_7B \ 4 | SEQUENCE_LENGTH=4096 \ 5 | DATA_SOURCE=data_sources.txt \ 6 | LOCAL_MODEL_STYLE="hnet:xlstm" \ 7 | ADD_HASH_EMBEDDINGS=false \ 8 | ADD_EXPANDED_EMBEDDINGS=true \ 9 | LR_SCHEDULE=linear_with_warmup \ 10 | STAGE1_CKPT_PATH=/path/to/stage1/ckpt \ 11 | GLOBAL_MODEL_LEARNING_RATE=1.83e-5 \ 12 | SAVE_FOLDER=/path/to/save/folder/$NAME \ 13 | python3 src/examples/bolmo/train_stage2.py $NAME \ 14 | train_module.optim.lr=3.66e-5 \ 15 | data_loader.seed=1234 \ 16 | data_loader.global_batch_size=1572864 \ 17 | train_module.rank_microbatch_size=49152 \ 18 | train_module.bolmo_config.losses=[ce,boundary] \ 19 | train_module.bolmo_config.loss_weights=[1,4] \ 20 | train_module.bolmo_config.teacher_force_boundaries=false \ 21 | train_module.bolmo_config.do_alm_debiasing=false \ 22 | train_module.bolmo_config.merge_boundary_loss=false \ 23 | train_module.optim.weight_decay=0.1 \ 24 | train_module.optim.betas=[0.9,0.95] \ 25 | train_module.max_grad_norm=0.5 \ 26 | model.block.attention.use_flash=true \ 27 | model.local_encoder.n_layers=1 \ 28 | model.local_decoder.n_layers=4 \ 29 | model.local_decoder.hnet_smooth=false \ 30 | model.local_decoder.hnet_modulate=false \ 31 | model.local_encoder.boundary_predictor_lookahead=1 \ 32 | model.local_decoder.add_in_projection=true \ 33 | model.local_decoder.add_norm_onto_residual=false \ 34 | model.local_decoder.add_projected_patch_residuals=false \ 35 | model.local_encoder.block_config.feed_forward.hidden_size=5504 \ 36 | model.local_decoder.block_config.feed_forward.hidden_size=5504 \ 37 | model.local_encoder.d_model=4096 \ 38 | model.local_decoder.d_model=4096 \ 39 | trainer.callbacks.checkpointer.ephemeral_save_interval=1000 \ 40 | trainer.callbacks.checkpointer.save_interval=30000 \ 41 | trainer.callbacks.downstream_evaluator.eval_interval=150000 \ 42 | trainer.max_duration.value=150000 -------------------------------------------------------------------------------- /bolmo_scripts/launch_stage2_1b.sh: -------------------------------------------------------------------------------- 1 | NAME=stage2_bolmo_1b 2 | NUM_WORKERS=24 \ 3 | OLMO_ARCH=olmo2_1B_v2 \ 4 | SEQUENCE_LENGTH=4096 \ 5 | DATA_SOURCE=data_sources.txt \ 6 | LOCAL_MODEL_STYLE="hnet:xlstm" \ 7 | ADD_HASH_EMBEDDINGS=false \ 8 | ADD_EXPANDED_EMBEDDINGS=true \ 9 | LR_SCHEDULE=linear_with_warmup \ 10 | STAGE1_CKPT_PATH=/path/to/stage1/ckpt \ 11 | GLOBAL_MODEL_LEARNING_RATE=2.6e-5 \ 12 | SAVE_FOLDER=/path/to/save/folder/$NAME \ 13 | python3 src/examples/bolmo/train_stage2.py $NAME \ 14 | train_module.optim.lr=5.2e-5 \ 15 | data_loader.seed=1234 \ 16 | data_loader.global_batch_size=1572864 \ 17 | train_module.rank_microbatch_size=98304 \ 18 | train_module.bolmo_config.losses=[ce,boundary] \ 19 | train_module.bolmo_config.loss_weights=[1,4] \ 20 | train_module.bolmo_config.teacher_force_boundaries=false \ 21 | train_module.bolmo_config.do_alm_debiasing=false \ 22 | train_module.bolmo_config.merge_boundary_loss=false \ 23 | train_module.optim.weight_decay=0.1 \ 24 | train_module.optim.betas=[0.9,0.95] \ 25 | train_module.max_grad_norm=0.5 \ 26 | model.block.attention.use_flash=true \ 27 | model.local_encoder.n_layers=1 \ 28 | model.local_decoder.n_layers=4 \ 29 | model.local_decoder.hnet_smooth=false \ 30 | model.local_decoder.hnet_modulate=false \ 31 | model.local_encoder.boundary_predictor_lookahead=1 \ 32 | model.local_decoder.add_in_projection=true \ 33 | model.local_decoder.add_norm_onto_residual=false \ 34 | model.local_decoder.add_projected_patch_residuals=false \ 35 | model.local_encoder.block_config.feed_forward.hidden_size=2816 \ 36 | model.local_decoder.block_config.feed_forward.hidden_size=2816 \ 37 | model.local_encoder.d_model=2048 \ 38 | model.local_decoder.d_model=2048 \ 39 | trainer.callbacks.checkpointer.ephemeral_save_interval=1000 \ 40 | trainer.callbacks.checkpointer.save_interval=30000 \ 41 | trainer.callbacks.downstream_evaluator.eval_interval=150000 \ 42 | trainer.max_duration.value=150000 -------------------------------------------------------------------------------- /src/test/conftest.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from functools import partial 3 | from typing import Generator 4 | 5 | import pytest 6 | import torch 7 | import torch.nn as nn 8 | 9 | from olmo_core.io import clear_directory 10 | 11 | 12 | @pytest.fixture 13 | def bucket_name() -> str: 14 | return "ai2-olmo-testing" 15 | 16 | 17 | @pytest.fixture 18 | def gcs_bucket_name() -> str: 19 | return "olmo-core-testing" 20 | 21 | 22 | @pytest.fixture 23 | def unique_name() -> str: 24 | return uuid.uuid4().hex 25 | 26 | 27 | @pytest.fixture 28 | def s3_checkpoint_dir(bucket_name, unique_name) -> Generator[str, None, None]: 29 | from botocore.exceptions import NoCredentialsError 30 | 31 | folder = f"s3://{bucket_name}/checkpoints/{unique_name}" 32 | yield folder 33 | 34 | try: 35 | clear_directory(folder, force=True) 36 | except NoCredentialsError: 37 | pass 38 | 39 | 40 | @pytest.fixture 41 | def gcs_checkpoint_dir(gcs_bucket_name, unique_name) -> Generator[str, None, None]: 42 | from google.auth.exceptions import DefaultCredentialsError 43 | 44 | folder = f"gs://{gcs_bucket_name}/checkpoints/{unique_name}" 45 | yield folder 46 | 47 | try: 48 | clear_directory(folder, force=True) 49 | except DefaultCredentialsError: 50 | pass 51 | 52 | 53 | class TinyModel(nn.Module): 54 | def __init__(self, dim: int = 8): 55 | super().__init__() 56 | self.fc = nn.Sequential( 57 | nn.Linear(dim, dim * 2), 58 | nn.ReLU(), 59 | nn.Linear(dim * 2, dim), 60 | nn.ReLU(), 61 | nn.Linear(dim, dim), 62 | ) 63 | 64 | def forward(self, x): 65 | return self.fc(x) 66 | 67 | 68 | @pytest.fixture 69 | def tiny_model_factory(): 70 | return TinyModel 71 | 72 | 73 | @pytest.fixture 74 | def tiny_model(tiny_model_factory) -> TinyModel: 75 | return tiny_model_factory() 76 | 77 | 78 | @pytest.fixture 79 | def tiny_model_data_factory(): 80 | return partial(torch.rand, 2, 8) 81 | 82 | 83 | @pytest.fixture 84 | def tiny_model_data(tiny_model_data_factory) -> torch.Tensor: 85 | return tiny_model_data_factory() 86 | -------------------------------------------------------------------------------- /src/olmo_core/train/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trainer :class:`Callback` implementations. 3 | """ 4 | 5 | from .batch_size_scheduler import BatchSizeSchedulerCallback 6 | from .beaker import BeakerCallback 7 | from .callback import Callback, CallbackConfig 8 | from .checkpointer import CheckpointerCallback, CheckpointRemovalStrategy 9 | from .comet import CometCallback, CometNotificationSetting 10 | from .config_saver import ConfigSaverCallback 11 | from .console_logger import ConsoleLoggerCallback 12 | from .evaluator_callback import ( 13 | DownstreamEvaluatorCallbackConfig, 14 | EvaluatorCallback, 15 | LMEvaluatorCallbackConfig, 16 | ) 17 | from .gap_monitor import GAPMonitorCallback 18 | from .garbage_collector import GarbageCollectorCallback 19 | from .gpu_memory_monitor import GPUMemoryMonitorCallback 20 | from .list_checkpointer import ListCheckpointerCallback 21 | from .monkey_patcher import MonkeyPatcherCallback 22 | from .profiler import ProfilerCallback 23 | from .sequence_length_scheduler import SequenceLengthSchedulerCallback 24 | from .slack_notifier import SlackNotificationSetting, SlackNotifierCallback 25 | from .speed_monitor import SpeedMonitorCallback 26 | from .wandb import WandBCallback 27 | 28 | __all__ = [ 29 | "Callback", 30 | "CallbackConfig", 31 | "CheckpointerCallback", 32 | "CheckpointRemovalStrategy", 33 | "CometCallback", 34 | "CometNotificationSetting", 35 | "ConfigSaverCallback", 36 | "ConsoleLoggerCallback", 37 | "EvaluatorCallback", 38 | "LMEvaluatorCallbackConfig", 39 | "DownstreamEvaluatorCallbackConfig", 40 | "GAPMonitorCallback", 41 | "GarbageCollectorCallback", 42 | "GPUMemoryMonitorCallback", 43 | "ProfilerCallback", 44 | "SlackNotifierCallback", 45 | "SlackNotificationSetting", 46 | "SequenceLengthSchedulerCallback", 47 | "SpeedMonitorCallback", 48 | "WandBCallback", 49 | "BeakerCallback", 50 | "BatchSizeSchedulerCallback", 51 | "MonkeyPatcherCallback", 52 | "ListCheckpointerCallback", 53 | ] 54 | 55 | __doc__ += "\n" 56 | for name in __all__[2:]: 57 | if name.endswith("Callback"): 58 | __doc__ += f"- :class:`{name}`\n" 59 | -------------------------------------------------------------------------------- /src/test/data/fixtures.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Type, Union 3 | 4 | import numpy as np 5 | 6 | from olmo_core.data import NumpyDatasetBase, NumpyFSLDatasetConfig, TokenizerConfig 7 | from olmo_core.data.source_mixture import ( 8 | SourceMixtureConfig, 9 | SourceMixtureDatasetConfig, 10 | SourceMixtureList, 11 | ) 12 | from olmo_core.data.types import NumpyDatasetDType 13 | 14 | from .utils import mk_mmaps 15 | 16 | 17 | def get_fsl_mixture( 18 | tmp_path: Path, 19 | dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] = np.uint32, 20 | seed: int = 42, 21 | sequence_length: int = 4, 22 | num_tokens: int = 20 * 1000, 23 | eos: int = 0, 24 | ) -> NumpyDatasetBase: 25 | seed = 42 26 | mmap1 = mk_mmaps( 27 | tmp_path, "mmap1", 1, num_tokens * 2, dtype, eos=eos, seed=seed, seq_length=sequence_length 28 | ) 29 | mmap2 = mk_mmaps( 30 | tmp_path, "mmap2", 1, num_tokens * 2, dtype, eos=eos, seed=seed, seq_length=sequence_length 31 | ) 32 | 33 | tokenizer = TokenizerConfig( 34 | vocab_size=32_000, 35 | eos_token_id=eos, 36 | pad_token_id=-1, 37 | ) 38 | 39 | mixture_config = SourceMixtureDatasetConfig( 40 | requested_tokens=num_tokens, 41 | source_list=SourceMixtureList( 42 | [ 43 | SourceMixtureConfig( 44 | source_name="mmap1", 45 | paths=[str(i[0]) for i in mmap1], 46 | target_ratio=0.8, 47 | ), 48 | SourceMixtureConfig( 49 | source_name="mmap2", 50 | paths=[str(i[0]) for i in mmap2], 51 | target_ratio=0.2, 52 | ), 53 | ] 54 | ), 55 | seed=seed, 56 | global_batch_size=sequence_length * 32, 57 | ) 58 | 59 | ds = NumpyFSLDatasetConfig.from_src_mix( 60 | src_mix=mixture_config, 61 | sequence_length=sequence_length, 62 | tokenizer=tokenizer, 63 | dtype=NumpyDatasetDType.uint16, 64 | include_instance_metadata=False, 65 | ).build() 66 | ds.prepare() 67 | return ds 68 | -------------------------------------------------------------------------------- /src/test/train/utils_test.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import pytest 4 | import torch 5 | import torch.distributed as dist 6 | 7 | from olmo_core.testing import BACKENDS, run_distributed_test 8 | from olmo_core.train import ReduceType 9 | from olmo_core.train.utils import reduce_metrics 10 | from olmo_core.utils import get_default_device 11 | 12 | 13 | def run_reduce_metrics(): 14 | device = get_default_device() 15 | raw_metrics = { 16 | 0: { 17 | "train/CrossEntropyLoss": torch.tensor(2.0, device=device), 18 | "train/masked_instances": torch.tensor(1.0, device=device), 19 | "optim/total_grad_norm": torch.tensor(1.0, device=device), 20 | }, 21 | 1: { 22 | "train/CrossEntropyLoss": torch.tensor( 23 | 1.5 if dist.get_rank() == 0 else 2.5, device=device 24 | ), 25 | "train/masked_instances": torch.tensor( 26 | 0.0 if dist.get_rank() == 0 else 1.0, device=device 27 | ), 28 | "train/rank": torch.tensor(float(dist.get_rank()), device=device), 29 | "optim/weight_norm": torch.tensor(2.0 if dist.get_rank() == 0 else 3.0, device=device), 30 | }, 31 | } 32 | metrics_reduce_type = { 33 | "train/CrossEntropyLoss": ReduceType.mean, 34 | "train/rank": ReduceType.max, 35 | "train/masked_instances": ReduceType.sum, 36 | "optim/total_grad_norm": None, 37 | "optim/weight_norm": ReduceType.l2_norm, 38 | } 39 | 40 | metrics = reduce_metrics(raw_metrics, metrics_reduce_type, device) 41 | if dist.get_rank() == 0: 42 | assert metrics == { 43 | 0: { 44 | "train/CrossEntropyLoss": 2.0, 45 | "optim/total_grad_norm": 1.0, 46 | "train/masked_instances": 2.0, 47 | }, 48 | 1: { 49 | "train/CrossEntropyLoss": 2.0, 50 | "train/rank": 1.0, 51 | "train/masked_instances": 1.0, 52 | "optim/weight_norm": math.sqrt(13), 53 | }, 54 | } 55 | 56 | 57 | @pytest.mark.parametrize("backend", BACKENDS) 58 | def test_reduce_metrics(backend): 59 | run_distributed_test(run_reduce_metrics, backend=backend) 60 | -------------------------------------------------------------------------------- /src/olmo_core/train/callbacks/list_checkpointer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass, field 3 | from typing import List, Optional, Set 4 | 5 | from .checkpointer import CheckpointerCallback 6 | 7 | log = logging.getLogger(__name__) 8 | 9 | 10 | @dataclass 11 | class ListCheckpointerCallback(CheckpointerCallback): 12 | """ 13 | Save checkpoints only at specific steps provided in a list. 14 | 15 | Pass 'save_steps' as a sorted list of step numbers (integers) at which to save. 16 | All other base behavior (async save, removal) is preserved. 17 | 18 | This is useful for saving at predetermined milestones, such as: 19 | - Period boundaries in WSD-S schedules (when LR = 0) 20 | - Specific token budgets 21 | - Other training milestones 22 | 23 | Example: 24 | save_steps = [100, 500, 1000, 2000] # save at these exact steps 25 | """ 26 | 27 | # Disable the interval behavior in the base class by setting a huge interval 28 | # (we fully override 'post_train_batch' so this is purely defensive). 29 | save_interval: int = 1_000_000_000 30 | 31 | # user-provided exact steps to save at 32 | save_steps: Optional[List[int]] = None 33 | 34 | _save_steps_set: Set[int] = field( 35 | default_factory=set, init=False, repr=False, metadata={"omegaconf_ignore": True} 36 | ) 37 | _last_saved_step: int = field(default=-1, init=False, repr=False) 38 | 39 | def __post_init__(self): 40 | super().__post_init__() 41 | if not self.save_steps: 42 | raise ValueError("'save_steps' must be provided (list of step indices to checkpoint).") 43 | self._save_steps_set = {int(s) for s in self.save_steps} 44 | 45 | def post_train_batch(self): 46 | if not self.enabled: 47 | return 48 | 49 | self._await_last_checkpoint(blocking=False) 50 | if not self.checkpoint_pending: 51 | self._remove_old_checkpoints() 52 | 53 | step = int(self.step) 54 | if step in self._save_steps_set and step != self._last_saved_step: 55 | # save checkpoint at this exact step! 56 | path = self._save_checkpoint() 57 | self._last_saved_step = step 58 | log.info(f"Saved WSD‑S boundary checkpoint at step={step} -> {path}") 59 | -------------------------------------------------------------------------------- /bolmo_scripts/launch_stage1_7b.sh: -------------------------------------------------------------------------------- 1 | NAME=stage1_bolmo_7b 2 | SEQUENCE_LENGTH=4096 \ 3 | DTYPE=float32 \ 4 | DATA_SOURCE=data_sources.txt \ 5 | OLMO_ARCH=olmo3_7B \ 6 | OLMO_CKPT_PATH=/path/to/olmo3/ckpt \ 7 | TRAIN_MODE=stage_1 \ 8 | LOCAL_MODEL_STYLE="hnet:xlstm" \ 9 | ADD_HASH_EMBEDDINGS=false \ 10 | ADD_EXPANDED_EMBEDDINGS=true \ 11 | EMBEDDING_INIT_PATH="" \ 12 | SAVE_FOLDER=/path/to/save/folder/$NAME \ 13 | python3 src/examples/bolmo/train_stage1.py $NAME \ 14 | train_module.bolmo_config.losses=[local_encoder,ce,local_decoder,boundary] \ 15 | train_module.bolmo_config.loss_weights=[1,1,1,4] \ 16 | train_module.bolmo_config.div_fn=kl \ 17 | train_module.bolmo_config.binarization_temp=5.0 \ 18 | train_module.bolmo_config.use_oracle_patch_reps=true \ 19 | train_module.bolmo_config.teacher_force_boundaries=true \ 20 | train_module.bolmo_config.encoder_loss_lookahead=4 \ 21 | train_module.bolmo_config.encoder_loss_no_lookahead_weight=0.0 \ 22 | train_module.bolmo_config.encoder_loss_lookahead_weights=[0.0,0.0,0.0,4.0] \ 23 | train_module.bolmo_config.do_alm_debiasing=true \ 24 | train_module.bolmo_config.merge_boundary_loss=false \ 25 | train_module.optim.weight_decay=0.1 \ 26 | train_module.max_grad_norm=0.5 \ 27 | train_module.optim.lr=5e-4 \ 28 | model.block.attention.use_flash=true \ 29 | model.local_encoder.n_layers=1 \ 30 | model.local_decoder.n_layers=4 \ 31 | model.local_decoder.hnet_smooth=false \ 32 | model.local_decoder.hnet_modulate=false \ 33 | model.local_encoder.boundary_predictor_lookahead=1 \ 34 | model.local_decoder.add_in_projection=true \ 35 | model.local_decoder.add_norm_onto_residual=false \ 36 | model.local_decoder.add_projected_patch_residuals=false \ 37 | model.local_encoder.block_config.feed_forward.hidden_size=5504 \ 38 | model.local_decoder.block_config.feed_forward.hidden_size=5504 \ 39 | model.local_encoder.d_model=4096 \ 40 | model.local_decoder.d_model=4096 \ 41 | data_loader.global_batch_size=786432 \ 42 | train_module.rank_microbatch_size=49152 \ 43 | trainer.callbacks.checkpointer.ephemeral_save_interval=1000 \ 44 | trainer.callbacks.checkpointer.save_interval=75000 \ 45 | trainer.callbacks.downstream_evaluator.eval_interval=75000 \ 46 | trainer.max_duration.value=75000 -------------------------------------------------------------------------------- /bolmo_scripts/launch_stage1_1b.sh: -------------------------------------------------------------------------------- 1 | NAME=stage1_bolmo_1b 2 | SEQUENCE_LENGTH=4096 \ 3 | DTYPE=float32 \ 4 | DATA_SOURCE=data_sources.txt \ 5 | OLMO_ARCH=olmo2_1B_v2 \ 6 | OLMO_CKPT_PATH=/path/to/olmo2/ckpt \ 7 | TRAIN_MODE=stage_1 \ 8 | LOCAL_MODEL_STYLE="hnet:xlstm" \ 9 | ADD_HASH_EMBEDDINGS=false \ 10 | ADD_EXPANDED_EMBEDDINGS=true \ 11 | EMBEDDING_INIT_PATH="" \ 12 | SAVE_FOLDER=/path/to/save/folder/$NAME \ 13 | python3 src/examples/bolmo/train_stage1.py $NAME \ 14 | train_module.bolmo_config.losses=[local_encoder,ce,local_decoder,boundary] \ 15 | train_module.bolmo_config.loss_weights=[1,1,1,4] \ 16 | train_module.bolmo_config.div_fn=kl \ 17 | train_module.bolmo_config.binarization_temp=5.0 \ 18 | train_module.bolmo_config.use_oracle_patch_reps=true \ 19 | train_module.bolmo_config.teacher_force_boundaries=true \ 20 | train_module.bolmo_config.encoder_loss_lookahead=4 \ 21 | train_module.bolmo_config.encoder_loss_no_lookahead_weight=0.0 \ 22 | train_module.bolmo_config.encoder_loss_lookahead_weights=[0.0,0.0,0.0,4.0] \ 23 | train_module.bolmo_config.do_alm_debiasing=true \ 24 | train_module.bolmo_config.merge_boundary_loss=false \ 25 | train_module.optim.weight_decay=0.1 \ 26 | train_module.max_grad_norm=0.5 \ 27 | train_module.optim.lr=7e-4 \ 28 | model.block.attention.use_flash=true \ 29 | model.local_encoder.n_layers=1 \ 30 | model.local_decoder.n_layers=4 \ 31 | model.local_decoder.hnet_smooth=false \ 32 | model.local_decoder.hnet_modulate=false \ 33 | model.local_encoder.boundary_predictor_lookahead=1 \ 34 | model.local_decoder.add_in_projection=true \ 35 | model.local_decoder.add_norm_onto_residual=false \ 36 | model.local_decoder.add_projected_patch_residuals=false \ 37 | model.local_encoder.block_config.feed_forward.hidden_size=2816 \ 38 | model.local_decoder.block_config.feed_forward.hidden_size=2816 \ 39 | model.local_encoder.d_model=2048 \ 40 | model.local_decoder.d_model=2048 \ 41 | data_loader.global_batch_size=786432 \ 42 | train_module.rank_microbatch_size=49152 \ 43 | trainer.callbacks.checkpointer.ephemeral_save_interval=1000 \ 44 | trainer.callbacks.checkpointer.save_interval=75000 \ 45 | trainer.callbacks.downstream_evaluator.eval_interval=75000 \ 46 | trainer.max_duration.value=75000 -------------------------------------------------------------------------------- /src/olmo_core/distributed/parallel/tensor_parallel.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from functools import partial 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.distributed import DeviceMesh 9 | from torch.distributed.tensor import Placement, Shard, distribute_module 10 | from torch.distributed.tensor.parallel import SequenceParallel as _SequenceParallel 11 | 12 | from olmo_core.config import Config 13 | 14 | log = logging.getLogger(__name__) 15 | 16 | 17 | @dataclass 18 | class TensorParallelConfig(Config): 19 | """ 20 | Configuration class for tensor parallelism (TP). 21 | """ 22 | 23 | degree: int 24 | """ 25 | The TP degree. 26 | """ 27 | 28 | enable_async: bool = False 29 | """ 30 | Enable experimental async tensor parallelism. 31 | """ 32 | 33 | def maybe_enable_async_tp(self, tp_mesh: DeviceMesh): 34 | if self.enable_async: 35 | log.info("Enabling async tensor parallel") 36 | 37 | from torch.distributed._symmetric_memory import enable_symm_mem_for_group 38 | 39 | torch._inductor.config._micro_pipeline_tp = True # type: ignore 40 | enable_symm_mem_for_group(tp_mesh.get_group().group_name) 41 | 42 | 43 | class SequenceParallel(_SequenceParallel): 44 | def __init__( 45 | self, 46 | *, 47 | sequence_dim: int = 1, 48 | use_local_output: bool = False, 49 | output_layouts: Optional[Placement] = None, 50 | ): 51 | super().__init__(sequence_dim=sequence_dim, use_local_output=use_local_output) 52 | self.output_layouts = (output_layouts or Shard(sequence_dim),) 53 | 54 | @staticmethod 55 | def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): 56 | del mod, device_mesh 57 | if outputs.placements != output_layouts: 58 | outputs = outputs.redistribute(placements=output_layouts, async_op=True) 59 | return outputs.to_local() if use_local_output else outputs 60 | 61 | def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: 62 | return distribute_module( 63 | module, 64 | device_mesh, 65 | self._replicate_module_fn, 66 | partial(self._prepare_input_fn, self.sequence_sharding), # type: ignore 67 | partial(self._prepare_output_fn, self.output_layouts, self.use_local_output), # type: ignore 68 | ) 69 | -------------------------------------------------------------------------------- /src/scripts/compare_wandb_configs.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | 4 | import click 5 | 6 | from olmo_core.utils import flatten_dict, prepare_cli_environment 7 | 8 | log = logging.getLogger(__name__) 9 | run_path_re = re.compile(r"^[^/]+/[^/]+/[^/]+$") 10 | run_path_url = re.compile(r"^https?://wandb.ai/([^/]+)/([^/]+)/runs/([^/]+)") 11 | 12 | 13 | def parse_run_path(run_path: str) -> str: 14 | """For convenience, we allow run paths as well as URLs.""" 15 | run_path = run_path.strip("/") 16 | if run_path_re.match(run_path): 17 | return run_path 18 | 19 | m = run_path_url.match(run_path) 20 | if m is not None: 21 | entity, project, run_id = m.groups() 22 | return f"{entity}/{project}/{run_id}" 23 | 24 | raise ValueError(f"Could not parse '{run_path}'") 25 | 26 | 27 | @click.command() 28 | @click.argument( 29 | "left_run_path", 30 | type=str, 31 | ) 32 | @click.argument( 33 | "right_run_path", 34 | type=str, 35 | ) 36 | def main( 37 | left_run_path: str, 38 | right_run_path: str, 39 | ): 40 | import wandb 41 | 42 | api = wandb.Api() 43 | left_run = api.run(parse_run_path(left_run_path)) 44 | right_run = api.run(parse_run_path(right_run_path)) 45 | 46 | left_config = flatten_dict(left_run._attrs["rawconfig"]) 47 | right_config = flatten_dict(right_run._attrs["rawconfig"]) 48 | 49 | left_only_keys = left_config.keys() - right_config.keys() 50 | if len(left_only_keys) > 0: 51 | print("Settings only in left:") 52 | print("\n".join(f"\t{k}: {left_config[k]}" for k in sorted(left_only_keys))) 53 | print() 54 | 55 | right_only_keys = right_config.keys() - left_config.keys() 56 | if len(right_only_keys) > 0: 57 | print("Settings only in right:") 58 | print("\n".join(f"\t{k}: {right_config[k]}" for k in sorted(right_only_keys))) 59 | print() 60 | 61 | keys_with_differences = { 62 | k for k in left_config.keys() & right_config.keys() if left_config[k] != right_config[k] 63 | } 64 | if len(keys_with_differences) > 0: 65 | if len(left_only_keys) > 0 or len(right_only_keys) > 0: 66 | print("Settings with differences:") 67 | print( 68 | "\n".join( 69 | f"{k}\n\t{left_config[k]}\n\t{right_config[k]}\n" 70 | for k in sorted(keys_with_differences) 71 | ) 72 | ) 73 | else: 74 | print("No differences in shared settings.") 75 | 76 | 77 | if __name__ == "__main__": 78 | prepare_cli_environment() 79 | main() 80 | -------------------------------------------------------------------------------- /src/olmo_core/nn/buffer_cache.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from collections.abc import MutableMapping 3 | from typing import Dict, Optional 4 | 5 | import torch 6 | 7 | from olmo_core.utils import move_to_device 8 | 9 | 10 | class BufferCache(MutableMapping[str, torch.Tensor]): 11 | """ 12 | Cache for buffers such as attention biases that would normally be registered as module buffers. 13 | 14 | We avoid using buffers because we've run into various issues doing so with FSDP. 15 | In general it appears the way FSDP handles buffers is not well-defined. 16 | It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid 17 | since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into 18 | NaNs when they're synchronized due to casting or some other issue. 19 | 20 | :param namespace: Optional namespace for the cache. This allows you to have a separate sub-cache 21 | in a shared :class:`BufferCache` to avoid key collisions. See how this is used in the 22 | :class:`olmo_core.nn.rope.RotaryEmbeddingBase` class for an example. 23 | """ 24 | 25 | def __init__(self, namespace: str = ""): 26 | self._data: Dict[str, Dict[str, torch.Tensor]] = defaultdict(dict) 27 | self._namespace = namespace 28 | 29 | def __getitem__(self, key: str) -> torch.Tensor: 30 | return self._data[self._namespace][key] 31 | 32 | def __setitem__(self, key: str, value: torch.Tensor) -> None: 33 | self._data[self._namespace][key] = value 34 | 35 | def __delitem__(self, key: str) -> None: 36 | del self._data[self._namespace][key] 37 | 38 | def __iter__(self): 39 | yield from self._data[self._namespace].keys() 40 | 41 | def __len__(self) -> int: 42 | return len(self._data[self._namespace]) 43 | 44 | def get_for_device(self, key: str, device: torch.device) -> Optional[torch.Tensor]: 45 | if (tensor := self.get(key)) is not None: 46 | if tensor.device != device: 47 | tensor = move_to_device(tensor, device) 48 | self[key] = tensor 49 | return tensor 50 | else: 51 | return None 52 | 53 | def with_namespace(self, namespace: str) -> "BufferCache": 54 | """ 55 | This creates a new :class:`BufferCache` object with a pointer to the same underlying data 56 | but with the given namespace. 57 | """ 58 | out = BufferCache(namespace=namespace) 59 | out._data = self._data 60 | return out 61 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. olmo-core documentation master file, created by 2 | sphinx-quickstart on Tue Sep 21 08:07:48 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | **OLMo-core** 7 | =============== 8 | 9 | **OLMo-core** is a Python library that provides building blocks for large-scale distributed training with PyTorch. 10 | 11 | To get started first install `PyTorch `_ according to the official instructions 12 | specific to your environment. Then you can install OLMo-core from PyPI with: 13 | 14 | .. code-block:: bash 15 | 16 | pip install ai2-olmo-core 17 | 18 | .. toctree:: 19 | :hidden: 20 | :maxdepth: 2 21 | :caption: Overview 22 | 23 | overview/introduction.rst 24 | overview/installation.rst 25 | 26 | .. toctree:: 27 | :hidden: 28 | :maxdepth: 2 29 | :caption: Guides 30 | 31 | guides/all_in_one_for_researchers.md 32 | guides/data_loading.rst 33 | guides/data_mixing.rst 34 | 35 | .. toctree:: 36 | :hidden: 37 | :maxdepth: 2 38 | :caption: Examples 39 | 40 | examples/huggingface.rst 41 | examples/llm.rst 42 | 43 | .. toctree:: 44 | :hidden: 45 | :maxdepth: 2 46 | :caption: API Reference 47 | 48 | config 49 | data/index 50 | distributed/index 51 | eval/index 52 | exceptions 53 | float8 54 | io 55 | launch 56 | model_ladder 57 | nn/index 58 | optim 59 | testing 60 | train/index 61 | utils 62 | 63 | .. toctree:: 64 | :hidden: 65 | :caption: Development 66 | 67 | License 68 | CHANGELOG 69 | GitHub Repository 70 | 71 | Team 72 | ---- 73 | 74 | **OLMo-core** is developed and maintained at 75 | `the Allen Institute for Artificial Intelligence (AI2) `_. 76 | AI2 is a non-profit institute with the mission to contribute to humanity through high-impact AI research and engineering. 77 | 78 | To learn more about who specifically contributed to this codebase, see 79 | `our contributors `_ page. 80 | 81 | License 82 | ------- 83 | 84 | **OLMo-core** is licensed under `Apache 2.0 `_. 85 | A full copy of the license can be found `on GitHub `_. 86 | 87 | Indices and tables 88 | ------------------ 89 | 90 | * :ref:`genindex` 91 | * :ref:`modindex` 92 | -------------------------------------------------------------------------------- /docs/source/overview/introduction.rst: -------------------------------------------------------------------------------- 1 | Introduction 2 | ============ 3 | 4 | OLMo-core represents a major rewrite of the original training and modeling code from `OLMo `_ 5 | with a focus on performance and API stability. 6 | It aims to provide a standard set of robust tools that can be used by LLM researchers at `AI2 `_ and other organizations 7 | to build their research projects on. 8 | 9 | The library is centered around a highly efficient, yet flexible, :class:`~olmo_core.train.Trainer` and a :mod:`~olmo_core.launch` 10 | module that handles all of the boilerplate of launching experiments on `Beaker `_ 11 | or other platforms. It also comes with a simple, yet optimized, :class:`~olmo_core.nn.transformer.Transformer` 12 | model and many other useful :class:`torch.nn.Module` implementations. 13 | 14 | Most users will likely follow a workflow that looks like this: 15 | 16 | 1. Define the various components of an experiment through configuration classes. 17 | For example:: 18 | 19 | model_config = TransformerConfig.llama2_7B(...) 20 | train_module_config = TransformerTrainModuleConfig(...) 21 | data_config = NumpyFSLDatasetConfig(...) 22 | data_loader_config = NumpyDataLoaderConfig(...) 23 | trainer_config = TrainerConfig(...) 24 | 25 | 2. Build the corresponding components within a ``main()`` function at runtime and then call :meth:`Trainer.fit() `. 26 | For example:: 27 | 28 | def main(): 29 | model = model_config.build() 30 | train_module = train_module_config.build(model) 31 | data_loader = data_loader_config.build(data_config.build(), dp_process_group=train_module.dp_process_groupo) 32 | trainer = trainer_config.build(train_module, data_loader) 33 | 34 | trainer.fit() 35 | 36 | if __name__ == "__main__": 37 | prepare_training_environment(seed=SEED) 38 | main() 39 | teardown_training_environment() 40 | 41 | 3. Launch their training script with a :mod:`~olmo_core.launch` config, like the :class:`~olmo_core.launch.beaker.BeakerLaunchConfig`. 42 | For example:: 43 | 44 | launch_config = BeakerLaunchConfig(...) 45 | launch_config.launch(follow=True) 46 | 47 | Or simply launch their training script manually with ``torchrun``:: 48 | 49 | torchrun --nproc-per-node=8 train_script.py ... 50 | 51 | You can find a complete example of this workflow in the `Train an LLM <../examples/llm.html>`_ example. 52 | And for a more comprehensive overview, see the `All-in-one for researchers <../guides/all_in_one_for_researchers.html>`_. 53 | -------------------------------------------------------------------------------- /src/test/nn/feed_forward_test.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import pytest 4 | import torch 5 | from torch.distributed.tensor import Shard, init_device_mesh 6 | 7 | from olmo_core.distributed.checkpoint import ( 8 | load_model_and_optim_state, 9 | save_model_and_optim_state, 10 | ) 11 | from olmo_core.distributed.utils import get_rank, get_world_size 12 | from olmo_core.nn.feed_forward import FeedForward 13 | from olmo_core.testing import BACKENDS, run_distributed_test 14 | from olmo_core.utils import get_default_device, seed_all 15 | 16 | 17 | def _run_tensor_parallel_feed_forward( 18 | checkpoint_dir: str, inputs_path: str, outputs_path: str, ff_kwargs: Dict[str, Any] 19 | ): 20 | device = get_default_device() 21 | mesh = init_device_mesh(device.type, (get_world_size(),), mesh_dim_names=("tp",)) 22 | 23 | ff = FeedForward(init_device=device.type, **ff_kwargs) 24 | 25 | ff.apply_tp(mesh["tp"], output_layout=Shard(1), use_local_output=False) 26 | load_model_and_optim_state(checkpoint_dir, ff) 27 | 28 | # Input x is replicated across ranks, output y is sharded on the sequence dimension. 29 | x = torch.load(inputs_path, map_location=device) 30 | y_local = ff(x).to_local() 31 | 32 | # Backward to exercise graph in TP mode. 33 | y_local.sum().backward() 34 | 35 | # Check the local shard of the output is the same as the corresponding shard of the reference output 36 | y_ref = torch.load(outputs_path, map_location=device) 37 | rank, world_size = get_rank(), get_world_size() 38 | chunk = x.size(1) // world_size 39 | y_ref_local = y_ref[:, rank * chunk : (rank + 1) * chunk, :] 40 | torch.testing.assert_close(y_ref_local, y_local) 41 | 42 | 43 | @pytest.mark.parametrize("backend", BACKENDS) 44 | def test_tensor_parallel_feed_forward(backend: str, tmp_path): 45 | device = torch.device("cuda") if "nccl" in backend else torch.device("cpu") 46 | 47 | seed_all(0) 48 | d_model = 128 49 | hidden = 4 * d_model 50 | ff_kwargs: Dict[str, Any] = {"d_model": d_model, "hidden_size": hidden} 51 | ff = FeedForward(init_device=device.type, **ff_kwargs) 52 | 53 | bs, seq_len = 2, 64 54 | x = torch.randn(bs, seq_len, d_model, device=device) 55 | y = ff(x) 56 | 57 | outputs_path = tmp_path / "ff_y.pt" 58 | torch.save(y, outputs_path) 59 | inputs_path = tmp_path / "ff_x.pt" 60 | torch.save(x, inputs_path) 61 | checkpoint_dir = tmp_path / "checkpoint" 62 | save_model_and_optim_state(checkpoint_dir, ff) 63 | 64 | run_distributed_test( 65 | _run_tensor_parallel_feed_forward, 66 | backend=backend, 67 | start_method="spawn", 68 | func_args=(checkpoint_dir, inputs_path, outputs_path, ff_kwargs), 69 | ) 70 | -------------------------------------------------------------------------------- /.github/actions/setup-venv/action.yml: -------------------------------------------------------------------------------- 1 | name: Python virtualenv 2 | description: Set up a Python virtual environment with caching 3 | inputs: 4 | python-version: 5 | description: The Python version to use 6 | required: true 7 | cache-prefix: 8 | description: Update this to invalidate the cache 9 | required: true 10 | default: v0 11 | torch-version: 12 | description: The PyTorch version to install 13 | required: false 14 | default: '2.7.0' 15 | torchao-version: 16 | description: The torchao version to install 17 | required: false 18 | default: '0.9.0' 19 | channel: 20 | description: The channel to install from 21 | required: false 22 | default: 'whl/cpu' 23 | runs: 24 | using: composite 25 | steps: 26 | - name: Setup Python 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: ${{ inputs.python-version }} 30 | 31 | - shell: bash 32 | run: | 33 | # Install prerequisites. 34 | pip install --upgrade pip setuptools build wheel virtualenv 35 | 36 | - shell: bash 37 | run: | 38 | # Get the exact Python version to use in the cache key. 39 | echo "PYTHON_VERSION=$(python --version)" >> $GITHUB_ENV 40 | 41 | - uses: actions/cache@v4 42 | id: virtualenv-cache 43 | with: 44 | path: .venv 45 | key: ${{ inputs.cache-prefix }}-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ inputs.torch-version }}-${{ inputs.torchao-version }}-${{ inputs.channel }}-${{ hashFiles('*requirements.txt', '*pyproject.toml') }} 46 | restore-keys: | 47 | ${{ inputs.cache-prefix }}-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ inputs.torch-version }}-${{ inputs.torchao-version }}-${{ inputs.channel }} 48 | 49 | - if: steps.virtualenv-cache.outputs.cache-hit != 'true' 50 | shell: bash 51 | run: | 52 | # Set up virtual environment without cache hit. 53 | test -d .venv || virtualenv -p $(which python) --copies --reset-app-data .venv 54 | . .venv/bin/activate 55 | pip install torch==${{ inputs.torch-version }} torchao==${{ inputs.torchao-version}} --index-url https://download.pytorch.org/${{ inputs.channel }} 56 | pip install -e .[all] 57 | 58 | - if: steps.virtualenv-cache.outputs.cache-hit == 'true' 59 | shell: bash 60 | run: | 61 | # Set up virtual environment from cache hit. 62 | . .venv/bin/activate 63 | pip install --no-deps -e .[all] 64 | 65 | - shell: bash 66 | run: | 67 | # Show environment info. 68 | . .venv/bin/activate 69 | echo "✓ Installed $(python --version) virtual environment to $(which python)" 70 | echo "========= Python packages ===========" 71 | pip freeze 72 | -------------------------------------------------------------------------------- /src/test/config_test.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Set, Tuple 3 | 4 | from olmo_core.config import Config 5 | 6 | 7 | def test_simple_config_as_dict(): 8 | @dataclass 9 | class MockConfig(Config): 10 | name: str = "default" 11 | x: Optional[int] = None 12 | 13 | c = MockConfig() 14 | assert c.as_dict() == dict(name="default", x=None) 15 | assert c.as_dict(exclude_none=True) == dict(name="default") 16 | 17 | 18 | def test_nested_configs(): 19 | @dataclass 20 | class Bar: 21 | x: int 22 | y: int 23 | _z: int = 0 24 | 25 | @dataclass 26 | class Foo(Config): 27 | bar: Bar 28 | z: str 29 | 30 | foo = Foo(bar=Bar(x=1, y=2), z="z") 31 | data = foo.as_dict() 32 | assert isinstance(data["bar"], dict) 33 | 34 | foo1 = Foo.from_dict(data) 35 | assert foo1 == foo 36 | assert isinstance(foo1.bar, Bar) 37 | 38 | foo2 = Foo.from_dict(data, overrides=["bar.x=0"]) 39 | assert foo2.bar.x == 0 40 | foo3 = foo2.merge(["bar.x=-1"]) 41 | assert foo3.bar.x == -1 42 | 43 | assert foo.as_dict(recurse=False) == {"z": "z", "bar": foo.bar} 44 | 45 | assert foo.as_config_dict() == { 46 | Config.CLASS_NAME_FIELD: "test.config_test.Foo", 47 | "z": "z", 48 | "bar": { 49 | Config.CLASS_NAME_FIELD: "test.config_test.Bar", 50 | "x": 1, 51 | "y": 2, 52 | }, 53 | } 54 | 55 | 56 | def test_json_safe_dump(): 57 | @dataclass 58 | class Foo(Config): 59 | x_list: List[int] 60 | x_tuple: Tuple[int, ...] 61 | x_set: Set[str] 62 | 63 | foo = Foo(x_list=[0, 1], x_tuple=(0, 1), x_set={"a"}) 64 | assert foo.as_config_dict() == { 65 | Config.CLASS_NAME_FIELD: "test.config_test.Foo", 66 | "x_list": [0, 1], 67 | "x_tuple": [0, 1], 68 | "x_set": ["a"], 69 | } 70 | 71 | 72 | def test_non_strict_merge(): 73 | @dataclass 74 | class Bar(Config): 75 | x: int 76 | y: int 77 | 78 | @dataclass 79 | class Foo(Config): 80 | bar: Bar 81 | z: str 82 | 83 | foo = Foo(bar=Bar(x=1, y=2), z="a").merge(["--z=b", "--bar.x=0", "--baz.booz=0"], strict=False) 84 | assert foo.z == "b" 85 | assert foo.bar.x == 0 86 | 87 | 88 | def test_merge_with_prefix(): 89 | @dataclass 90 | class Bar(Config): 91 | x: int 92 | y: int 93 | 94 | @dataclass 95 | class Foo(Config): 96 | bar: Bar 97 | z: str 98 | 99 | foo = Foo(bar=Bar(x=1, y=2), z="a").merge(["--foo.z=b", "--foo.bar.x=0"], prefix="foo") 100 | assert foo.z == "b" 101 | assert foo.bar.x == 0 102 | -------------------------------------------------------------------------------- /src/olmo_core/nn/fla.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass, field 3 | from typing import Optional 4 | 5 | import fla.layers 6 | import torch 7 | from torch import nn 8 | 9 | from olmo_core.config import Config, DType 10 | 11 | log = logging.getLogger(__name__) 12 | 13 | 14 | class FLA(nn.Module): 15 | def __init__(self, inner: fla.layers.ABCAttention): 16 | super().__init__() 17 | self.inner = inner 18 | 19 | self.kv_cache_manager = None 20 | 21 | def init_kv_cache_manager(self, batch_size: int): 22 | self.kv_cache_manager = FLACacheManager() 23 | 24 | def forward(self, x: torch.Tensor, **_kwargs) -> torch.Tensor: 25 | # FIXME: Right now we just ignore the kwargs. 26 | 27 | if self.kv_cache_manager is not None: 28 | # Use the cache manager with past_key_values API 29 | cache = self.kv_cache_manager.cache 30 | 31 | # Call the inner FLA layer with cache 32 | out, _, new_cache = self.inner(x, past_key_values=cache, use_cache=True) 33 | 34 | # Update the cache manager 35 | self.kv_cache_manager.cache = new_cache 36 | 37 | return out 38 | else: 39 | return self.inner(x)[0] # returns out, ?, cache 40 | 41 | 42 | class FLACacheManager(nn.Module): 43 | def __init__(self): 44 | super().__init__() 45 | 46 | self.zero_cache() 47 | 48 | def zero_cache(self): 49 | from fla.models.utils import Cache 50 | # FLA layers manage their own cache through the Cache object 51 | # We just store a reference to it 52 | self.cache = Cache() 53 | 54 | def reallocate(self, batch_size: int): 55 | from fla.models.utils import Cache 56 | 57 | self.cache = Cache() 58 | 59 | def is_reusable(self, batch_size: int) -> bool: 60 | # FLA library doesn't provide a simple way to check cache compatibility 61 | # So we just recreate it to be safe 62 | return False 63 | 64 | def reset(self, batch_size: int): 65 | if self.is_reusable(batch_size): 66 | self.zero_cache() 67 | else: 68 | log.debug("Unreusable FLA cache, reallocating") 69 | self.reallocate(batch_size) 70 | 71 | 72 | @dataclass 73 | class FLAConfig(Config): 74 | name: str 75 | fla_layer_kwargs: dict = field(default_factory=dict) 76 | dtype: DType = DType.float32 77 | 78 | def build(self, d_model: int, init_device) -> FLA: 79 | layer = getattr(fla.layers, self.name)( 80 | hidden_size=d_model, 81 | layer_idx=0, # for cache 82 | **self.fla_layer_kwargs, 83 | ).to(device=init_device, dtype=self.dtype.as_pt()) 84 | 85 | return FLA(layer) 86 | 87 | def num_params(self): 88 | raise NotImplementedError() -------------------------------------------------------------------------------- /src/olmo_core/nn/utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Dict, Tuple, Type 3 | 4 | import torch 5 | from torch.distributed.tensor.parallel import ( 6 | ColwiseParallel, 7 | PrepareModuleInput, 8 | RowwiseParallel, 9 | ) 10 | 11 | 12 | def _get_custom_checkpoint_policy(meta: Dict[str, int]): 13 | # Adapted from 14 | # https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py 15 | from torch.utils.checkpoint import CheckpointPolicy 16 | 17 | _save_list = { 18 | torch.ops.aten.mm.default, # type: ignore 19 | torch.ops.aten._scaled_dot_product_efficient_attention.default, # type: ignore 20 | torch.ops.aten._scaled_dot_product_flash_attention.default, # type: ignore 21 | torch.ops._c10d_functional.reduce_scatter_tensor.default, # type: ignore 22 | # for low precision training, it's useful to always save 23 | # the result of max(abs(tensor)) 24 | torch.ops.aten.abs.default, # type: ignore 25 | torch.ops.aten.max.default, # type: ignore 26 | } 27 | 28 | def _custom_policy(ctx, func, *args, **kwargs): 29 | del args, kwargs 30 | mode = "recompute" if ctx.is_recompute else "forward" 31 | mm_count_key = f"{mode}_mm_count" 32 | if func == torch.ops.aten.mm.default: # type: ignore 33 | meta[mm_count_key] += 1 34 | # Saves output of all compute ops, except every second mm 35 | to_save = func in _save_list and not ( 36 | func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 # type: ignore 37 | ) 38 | return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE 39 | 40 | return _custom_policy 41 | 42 | 43 | def selective_checkpointing_context_fn(): 44 | from torch.utils.checkpoint import create_selective_checkpoint_contexts 45 | 46 | meta: Dict[str, int] = defaultdict(int) 47 | return create_selective_checkpoint_contexts(_get_custom_checkpoint_policy(meta)) 48 | 49 | 50 | def get_tp_wrappers( 51 | float8_enabled: bool, 52 | ) -> Tuple[Type[RowwiseParallel], Type[ColwiseParallel], Type[PrepareModuleInput]]: 53 | if not float8_enabled: 54 | return ( 55 | RowwiseParallel, 56 | ColwiseParallel, 57 | PrepareModuleInput, 58 | ) 59 | else: 60 | # TODO (epwalsh): once float8 configuration supports delayed scaling, 61 | # add a check here to enforce supported float8 all-gather configurations. 62 | from torchao.float8.float8_tensor_parallel import ( # type: ignore 63 | Float8ColwiseParallel, 64 | Float8RowwiseParallel, 65 | PrepareFloat8ModuleInput, 66 | ) 67 | 68 | return ( 69 | Float8RowwiseParallel, 70 | Float8ColwiseParallel, 71 | PrepareFloat8ModuleInput, 72 | ) 73 | -------------------------------------------------------------------------------- /src/scripts/release/release_notes.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | """ 4 | Prepares markdown release notes for GitHub releases. 5 | """ 6 | 7 | import os 8 | from typing import List, Optional 9 | 10 | import packaging.version 11 | 12 | TAG = os.environ["TAG"] 13 | 14 | ADDED_HEADER = "### Added 🎉" 15 | CHANGED_HEADER = "### Changed ⚠️" 16 | FIXED_HEADER = "### Fixed ✅" 17 | REMOVED_HEADER = "### Removed 👋" 18 | 19 | 20 | def get_change_log_notes() -> str: 21 | in_current_section = False 22 | current_section_notes: List[str] = [] 23 | with open("CHANGELOG.md") as changelog: 24 | for line in changelog: 25 | if line.startswith("## "): 26 | if line.startswith("## Unreleased"): 27 | continue 28 | if line.startswith(f"## [{TAG}]"): 29 | in_current_section = True 30 | continue 31 | break 32 | if in_current_section: 33 | if line.startswith("### Added"): 34 | line = ADDED_HEADER + "\n" 35 | elif line.startswith("### Changed"): 36 | line = CHANGED_HEADER + "\n" 37 | elif line.startswith("### Fixed"): 38 | line = FIXED_HEADER + "\n" 39 | elif line.startswith("### Removed"): 40 | line = REMOVED_HEADER + "\n" 41 | current_section_notes.append(line) 42 | assert current_section_notes 43 | return "## What's new\n\n" + "".join(current_section_notes).strip() + "\n" 44 | 45 | 46 | def get_commit_history() -> str: 47 | new_version = packaging.version.parse(TAG) 48 | 49 | # Pull all tags. 50 | os.popen("git fetch --tags") 51 | 52 | # Get all tags sorted by version, latest first. 53 | all_tags = os.popen("git tag -l --sort=-version:refname 'v*'").read().split("\n") 54 | 55 | # Out of `all_tags`, find the latest previous version so that we can collect all 56 | # commits between that version and the new version we're about to publish. 57 | # Note that we ignore pre-releases unless the new version is also a pre-release. 58 | last_tag: Optional[str] = None 59 | for tag in all_tags: 60 | if not tag.strip(): # could be blank line 61 | continue 62 | version = packaging.version.parse(tag) 63 | if new_version.pre is None and version.pre is not None: 64 | continue 65 | if version < new_version: 66 | last_tag = tag 67 | break 68 | if last_tag is not None: 69 | commits = os.popen(f"git log {last_tag}..{TAG} --oneline --first-parent").read() 70 | else: 71 | commits = os.popen("git log --oneline --first-parent").read() 72 | return "## Commits\n\n" + commits 73 | 74 | 75 | def main(): 76 | print(get_change_log_notes()) 77 | print(get_commit_history()) 78 | 79 | 80 | if __name__ == "__main__": 81 | main() 82 | -------------------------------------------------------------------------------- /src/examples/huggingface/upload_checkpoint_to_hf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import sys 5 | 6 | from huggingface_hub import HfApi, login 7 | from tqdm import tqdm 8 | 9 | 10 | def upload_to_branch(local_checkpoint_dir: str, repo_id: str, step: int, token: str): 11 | login(token=token) 12 | api = HfApi() 13 | total_tokens = step * 2048 * 4096 14 | tokens_b = math.ceil(total_tokens / 1_000_000_000) 15 | branch = f"stage2-ingredient3-step{step}-tokens{tokens_b}B" 16 | print(f"Creating and uploading to branch: {branch}") 17 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" 18 | try: 19 | api.create_branch(repo_id=repo_id, branch=branch, token=token) 20 | print(f"Created new branch: {branch}") 21 | except Exception as e: 22 | print(f"Branch might already exist: {e}") 23 | files_to_upload = [] 24 | for root, _, files in os.walk(local_checkpoint_dir): 25 | for file in files: 26 | local_path = os.path.join(root, file) 27 | repo_path = os.path.relpath(local_path, local_checkpoint_dir) 28 | files_to_upload.append((local_path, repo_path)) 29 | 30 | print(f"\nStarting upload of {len(files_to_upload)} files...") 31 | 32 | for local_path, repo_path in tqdm(files_to_upload, desc="Uploading files"): 33 | try: 34 | print(f"\nUploading: {repo_path}") 35 | api.upload_file( 36 | path_or_fileobj=local_path, 37 | path_in_repo=repo_path, 38 | repo_id=repo_id, 39 | token=token, 40 | revision=branch, 41 | ) 42 | print(f"Successfully uploaded {repo_path}") 43 | except Exception as e: 44 | print(f"Error uploading {repo_path}: {e}") 45 | 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser(description="Upload checkpoint to Hugging Face Hub") 49 | parser.add_argument( 50 | "--local_checkpoint_dir", 51 | type=str, 52 | required=True, 53 | help="Local directory containing checkpoint files", 54 | ) 55 | parser.add_argument( 56 | "--repo_id", 57 | type=str, 58 | required=True, 59 | help='Hugging Face repo ID (e.g., "allenai/OLMo-2-0325-32B")', 60 | ) 61 | parser.add_argument("--step", type=int, required=True, help="Step number") 62 | parser.add_argument("--token", type=str, required=True, help="Hugging Face API token") 63 | args = parser.parse_args() 64 | 65 | print("Starting upload process...") 66 | if not os.path.exists(args.local_checkpoint_dir): 67 | print("Error: Directory not found!") 68 | sys.exit(1) 69 | else: 70 | print(f"Found directory. Contents: {os.listdir(args.local_checkpoint_dir)}") 71 | upload_to_branch( 72 | local_checkpoint_dir=args.local_checkpoint_dir, 73 | repo_id=args.repo_id, 74 | step=args.step, 75 | token=args.token, 76 | ) 77 | -------------------------------------------------------------------------------- /src/olmo_core/distributed/parallel/data_parallel.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from typing import Optional, Tuple 4 | 5 | from olmo_core.config import Config, DType, StrEnum 6 | from olmo_core.distributed.utils import get_num_nodes 7 | from olmo_core.exceptions import OLMoConfigurationError 8 | 9 | log = logging.getLogger(__name__) 10 | 11 | 12 | class DPMeshDimName(StrEnum): 13 | """ 14 | ``DeviceMesh`` dimension names for data parallelism. 15 | """ 16 | 17 | replicate = "dp_replicate" 18 | """ 19 | The device mesh dimension over which the model is replicated. 20 | """ 21 | shard = "dp_shard" 22 | """ 23 | The device mesh dimension over which the model is sharded. 24 | """ 25 | 26 | 27 | class DataParallelType(StrEnum): 28 | fsdp = "fsdp" 29 | hsdp = "hsdp" 30 | ddp = "ddp" 31 | 32 | 33 | @dataclass 34 | class DataParallelConfig(Config): 35 | name: DataParallelType 36 | param_dtype: Optional[DType] = None 37 | reduce_dtype: DType = DType.float32 38 | num_replicas: Optional[int] = None 39 | shard_degree: Optional[int] = None 40 | 41 | def get_replicate_and_shard_degree(self, dp_world_size: int) -> Tuple[int, int]: 42 | """ 43 | Defaults to one replica per node, with the shard degree set to the number of gpus per node. 44 | 45 | :param dp_world_size: The data parallel world size. 46 | :return: A tuple of (num_replicas, shard_degree) 47 | """ 48 | if self.num_replicas is None and self.shard_degree is None: 49 | return get_num_nodes(), dp_world_size // get_num_nodes() 50 | elif self.num_replicas is not None and self.shard_degree is not None: 51 | return _check_num_replicas(self.num_replicas, dp_world_size), _check_shard_degree( 52 | self.shard_degree, dp_world_size 53 | ) 54 | elif self.num_replicas is not None: 55 | return ( 56 | _check_num_replicas(self.num_replicas, dp_world_size), 57 | dp_world_size // self.num_replicas, 58 | ) 59 | else: 60 | assert self.shard_degree is not None 61 | return dp_world_size // self.shard_degree, _check_shard_degree( 62 | self.shard_degree, dp_world_size 63 | ) 64 | 65 | 66 | def _check_num_replicas(num_replicas: int, dp_world_size: int) -> int: 67 | if dp_world_size % num_replicas != 0: 68 | raise OLMoConfigurationError( 69 | f"data parallel world size ({dp_world_size}) must be " 70 | f"divisible by 'num_replicas' ({num_replicas})" 71 | ) 72 | return num_replicas 73 | 74 | 75 | def _check_shard_degree(shard_degree: int, dp_world_size: int) -> int: 76 | if dp_world_size % shard_degree != 0: 77 | raise OLMoConfigurationError( 78 | f"data parallel world size ({dp_world_size}) must be " 79 | f"divisible by 'shard_degree' ({shard_degree})" 80 | ) 81 | return shard_degree 82 | -------------------------------------------------------------------------------- /src/olmo_core/eval/evaluator.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import Any, Dict, Iterable, Iterator, Optional 3 | 4 | import torch 5 | 6 | from ..data import DataLoaderBase 7 | 8 | 9 | class Evaluator(metaclass=ABCMeta): 10 | """ 11 | Base class for in-loop evaluators. 12 | 13 | .. seealso:: 14 | This can be used with an :class:`~olmo_core.train.callbacks.EvaluatorCallback` to run an 15 | evaluator within the training loop. 16 | 17 | :param name: A name to assign to the evaluator. 18 | :param batches: Generates batches for the evaluator. These should at least include the 19 | "input_ids" field, but can contain any other arbitrary fields as well. 20 | :param device: The device to compute/reduce metrics on. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | *, 26 | name: str, 27 | batches: Iterable[Dict[str, Any]], 28 | device: Optional[torch.device] = None, 29 | ): 30 | self.name = name 31 | self.batches = batches 32 | self.device = device 33 | 34 | def __iter__(self) -> Iterator[Dict[str, Any]]: 35 | """ 36 | Iterator over the evaluator's batches. 37 | """ 38 | if isinstance(self.batches, DataLoaderBase): 39 | self.batches.reshuffle(in_memory=True) 40 | for batch in self.batches: 41 | yield batch 42 | if isinstance(self.batches, DataLoaderBase): 43 | self.batches.reset() 44 | 45 | @property 46 | def total_batches(self) -> Optional[int]: 47 | """ 48 | Get the total number of batches in an eval loop if it's known ahead of time. 49 | """ 50 | try: 51 | return len(self.batches) # type: ignore 52 | except TypeError: 53 | return None 54 | 55 | @abstractmethod 56 | def update_metrics( 57 | self, batch: Dict[str, Any], ce_loss: Optional[torch.Tensor], logits: Optional[torch.Tensor] 58 | ) -> None: 59 | """ 60 | Update metrics with from the ``batch`` just processed and the corresponding ``logits``. 61 | 62 | :param batch: A batch generated from :data:`batches`. 63 | :param ce_loss: The cross-entropy loss per token (un-reduced) of the batch. This will 64 | have shape ``(batch_size, (seq_len - 1))``. 65 | :param logits: The logits generated from the forward pass of the model. 66 | """ 67 | raise NotImplementedError 68 | 69 | @abstractmethod 70 | def compute_metrics(self) -> Dict[str, torch.Tensor]: 71 | """ 72 | Compute the final value of the metrics for the current evaluation loop. 73 | The metrics returned should already be reduced, if needed. 74 | """ 75 | raise NotImplementedError 76 | 77 | @abstractmethod 78 | def reset_metrics(self) -> None: 79 | """ 80 | Reset metrics. Should be called after :meth:`compute_metrics()`. 81 | """ 82 | raise NotImplementedError 83 | -------------------------------------------------------------------------------- /src/olmo_core/launch/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from typing import Optional, Tuple 4 | 5 | import requests 6 | 7 | from olmo_core.config import Config 8 | 9 | GIT_REPO_URL_ENV_VAR = "REPO_URL" 10 | GIT_REF_ENV_VAR = "GIT_REF" 11 | GIT_BRANCH_ENV_VAR = "GIT_BRANCH" 12 | 13 | 14 | def parse_git_remote_url(url: str) -> Tuple[str, str]: 15 | """ 16 | Parse a git remote URL into a GitHub (account, repo) pair. 17 | 18 | :raises InvalidRemoteError: If the URL can't be parsed correctly. 19 | """ 20 | if "github.com" not in url: 21 | raise ValueError(f"Remote ('{url}') must point to a GitHub repo") 22 | try: 23 | account, repo = url.split("github.com", 1)[-1].strip("/:").split(".git")[0].split("/") 24 | except ValueError: 25 | raise ValueError(f"Failed to parse GitHub repo path from remote '{url}'") 26 | return account, repo 27 | 28 | 29 | @dataclass 30 | class GitConfig(Config): 31 | repo_url: str 32 | ref: str 33 | branch: Optional[str] = None 34 | 35 | @property 36 | def is_dirty(self) -> bool: 37 | from git.exc import InvalidGitRepositoryError 38 | from git.repo import Repo 39 | 40 | try: 41 | repo = Repo(".") 42 | return repo.is_dirty() 43 | except InvalidGitRepositoryError: 44 | return False 45 | 46 | @property 47 | def is_public(self) -> bool: 48 | response = requests.get(self.repo_url) 49 | if response.status_code not in {200, 404}: 50 | response.raise_for_status() 51 | return response.status_code == 200 52 | 53 | @classmethod 54 | def from_env(cls) -> Optional["GitConfig"]: 55 | from git.exc import InvalidGitRepositoryError 56 | from git.repo import Repo 57 | 58 | try: 59 | repo = Repo(".") 60 | except InvalidGitRepositoryError: 61 | return None 62 | 63 | git_ref = os.environ.get(GIT_REF_ENV_VAR, str(repo.commit())) 64 | remote = repo.remote() 65 | 66 | # Try to find a remote based on the current tracking branch. 67 | try: 68 | branch = repo.active_branch 69 | except TypeError: 70 | branch = None 71 | 72 | if branch is not None: 73 | branch = branch.tracking_branch() 74 | 75 | branch_name = os.environ.get(GIT_BRANCH_ENV_VAR) 76 | if branch is not None: 77 | remote = repo.remote(branch.remote_name) 78 | if branch_name is None: 79 | assert branch.name.startswith(branch.remote_name + "/") 80 | branch_name = branch.name.replace(branch.remote_name + "/", "", 1) 81 | 82 | if (repo_url := os.environ.get(GIT_REPO_URL_ENV_VAR)) is None: 83 | account, repo_name = parse_git_remote_url(remote.url) 84 | repo_url = f"https://github.com/{account}/{repo_name}" 85 | 86 | return cls( 87 | repo_url=repo_url, 88 | ref=git_ref, 89 | branch=branch_name, 90 | ) 91 | -------------------------------------------------------------------------------- /src/test/nn/hf/checkpoint_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | import torch.distributed.checkpoint.state_dict as dist_cp_sd 5 | from transformers import AutoModelForCausalLM, Olmo2Config 6 | 7 | from olmo_core.nn.hf.checkpoint import load_hf_model, save_hf_model 8 | from olmo_core.nn.transformer.config import TransformerConfig 9 | 10 | 11 | def test_load_hf_model(tmp_path: Path): 12 | vocab_size = 200 13 | padded_vocab_size = 256 14 | model_config = TransformerConfig.olmo2_190M(padded_vocab_size) 15 | 16 | hf_config = Olmo2Config( 17 | vocab_size=vocab_size, 18 | hidden_size=model_config.d_model, 19 | intermediate_size=3072, 20 | num_hidden_layers=model_config.n_layers, 21 | num_attention_heads=12, 22 | rope_theta=500_000, 23 | rms_norm_eps=1e-6, 24 | ) 25 | hf_model = AutoModelForCausalLM.from_config(hf_config) 26 | hf_model.save_pretrained(tmp_path / "hf") 27 | 28 | model = model_config.build() 29 | 30 | state_dict_options = dist_cp_sd.StateDictOptions( 31 | flatten_optimizer_state_dict=True, cpu_offload=True 32 | ) 33 | model_state_dict = dist_cp_sd.get_model_state_dict(model, options=state_dict_options) 34 | load_hf_model( 35 | tmp_path / "hf", 36 | model_state_dict, 37 | num_embeddings=padded_vocab_size, 38 | ) 39 | model.load_state_dict(model_state_dict) 40 | 41 | rand_input = torch.randint(0, vocab_size, (2, 3)) 42 | with torch.no_grad(): 43 | hf_logits, *_ = hf_model(input_ids=rand_input, return_dict=False) 44 | 45 | model.eval() 46 | with torch.no_grad(): 47 | logits = model(input_ids=rand_input) 48 | 49 | assert hf_logits.shape[-1] == vocab_size 50 | assert logits.shape[-1] == padded_vocab_size 51 | torch.testing.assert_close(hf_logits, logits[..., :vocab_size]) 52 | 53 | 54 | def test_save_hf_model(tmp_path: Path): 55 | vocab_size = 200 56 | padded_vocab_size = 256 57 | model_config = TransformerConfig.olmo2_190M(padded_vocab_size) 58 | model = model_config.build() 59 | 60 | state_dict_options = dist_cp_sd.StateDictOptions( 61 | flatten_optimizer_state_dict=True, cpu_offload=True 62 | ) 63 | model_state_dict = dist_cp_sd.get_model_state_dict(model, options=state_dict_options) 64 | save_hf_model( 65 | tmp_path / "hf", 66 | model_state_dict, 67 | model, 68 | vocab_size=vocab_size, 69 | ) 70 | model.load_state_dict(model_state_dict) 71 | 72 | hf_model = AutoModelForCausalLM.from_pretrained(tmp_path / "hf") 73 | 74 | rand_input = torch.randint(0, vocab_size, (2, 3)) 75 | with torch.no_grad(): 76 | hf_logits, *_ = hf_model(input_ids=rand_input, return_dict=False) 77 | 78 | model.eval() 79 | with torch.no_grad(): 80 | logits = model(input_ids=rand_input) 81 | 82 | assert hf_logits.shape[-1] == vocab_size 83 | assert logits.shape[-1] == padded_vocab_size 84 | torch.testing.assert_close(hf_logits, logits[..., :vocab_size]) 85 | -------------------------------------------------------------------------------- /.github/workflows/docker.yml: -------------------------------------------------------------------------------- 1 | name: Docker 2 | 3 | concurrency: 4 | group: ${{ github.workflow }}-${{ github.ref }} 5 | cancel-in-progress: true 6 | 7 | on: 8 | workflow_dispatch: 9 | # TODO: disabled for now because it takes too long in CI 10 | # pull_request: 11 | # branches: 12 | # - main 13 | # paths: 14 | # - 'Makefile' 15 | # - 'pyproject.toml' 16 | # - 'src/olmo_core/version.py' 17 | # - 'src/Dockerfile' 18 | # - '.github/workflows/docker.yml' 19 | push: 20 | # branches: 21 | # - main 22 | tags: 23 | - 'v*.*.*' 24 | 25 | jobs: 26 | docker: 27 | name: CUDA ${{ matrix.cuda }} ${{ matrix.target }} 28 | runs-on: ubuntu-latest-m 29 | timeout-minutes: 60 30 | env: 31 | BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }} 32 | DOCKER_HUB_TOKEN: ${{ secrets.DOCKER_HUB_TOKEN }} 33 | DOCKER_HUB_USER: ${{ secrets.DOCKER_HUB_USER }} 34 | strategy: 35 | fail-fast: false 36 | matrix: 37 | cuda: ["12.6.3", "12.8.1"] # NOTE: check tags for options: https://hub.docker.com/r/nvidia/cuda/tags 38 | torch: 39 | - version: 2.7.0 40 | channel: whl/test 41 | steps: 42 | - uses: actions/checkout@v4 43 | 44 | - name: Setup Python 45 | uses: actions/setup-python@v5 46 | with: 47 | python-version: '3.11' 48 | 49 | - name: Set env vars 50 | run: | 51 | echo "BEAKER_WORKSPACE=$(make get-beaker-workspace)" >> $GITHUB_ENV 52 | 53 | - name: Authenticate with Beaker 54 | uses: allenai/setup-beaker@v2 55 | if: env.BEAKER_TOKEN != '' 56 | with: 57 | token: ${{ env.BEAKER_TOKEN }} 58 | workspace: ${{ env.BEAKER_WORKSPACE }} 59 | 60 | - name: Authenticate with Docker Hub 61 | if: env.DOCKER_HUB_TOKEN != '' 62 | run: | 63 | echo ${{ env.DOCKER_HUB_TOKEN }} | docker login -u ${{ env.DOCKER_HUB_USER }} --password-stdin 64 | 65 | - name: Build image 66 | run: | 67 | #rm -rf /opt/hostedtoolcache # clear up some disk space 68 | make docker-image \ 69 | CUDA_VERSION=${{ matrix.cuda }} \ 70 | TORCH_VERSION=${{ matrix.torch.version }} \ 71 | INSTALL_CHANNEL=${{ matrix.torch.channel }} 72 | 73 | - name: Push to GHCR 74 | if: startsWith(github.ref, 'refs/tags/') 75 | run: | 76 | echo ${{ secrets.GITHUB_TOKEN }} | docker login ghcr.io -u ${{ github.actor }} --password-stdin 77 | make ghcr-image \ 78 | CUDA_VERSION=${{ matrix.cuda }} \ 79 | TORCH_VERSION=${{ matrix.torch.version }} \ 80 | INSTALL_CHANNEL=${{ matrix.torch.channel }} 81 | 82 | - name: Push to Beaker 83 | if: env.BEAKER_TOKEN != '' && startsWith(github.ref, 'refs/tags/') 84 | run: | 85 | make beaker-image \ 86 | CUDA_VERSION=${{ matrix.cuda }} \ 87 | TORCH_VERSION=${{ matrix.torch.version }} \ 88 | INSTALL_CHANNEL=${{ matrix.torch.channel }} 89 | -------------------------------------------------------------------------------- /src/olmo_core/train/callbacks/config_saver.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from dataclasses import dataclass 4 | from typing import Any, Dict, Optional 5 | 6 | from olmo_core.aliases import PathOrStr 7 | from olmo_core.data import NumpyDataLoaderBase 8 | from olmo_core.distributed.utils import get_rank 9 | 10 | from .beaker import BeakerCallback 11 | from .callback import Callback 12 | from .comet import CometCallback 13 | from .wandb import WandBCallback 14 | 15 | log = logging.getLogger(__name__) 16 | 17 | DEFAULT_DATA_PATHS_FNAME = "data_paths.txt" 18 | 19 | 20 | @dataclass 21 | class ConfigSaverCallback(Callback): 22 | """ 23 | A callback that writes an arbitrary JSON-serializable config dictionary (:data:`config`) to every checkpoint 24 | directory written during training. It will also set the config to save for other callbacks, including 25 | the :class:`WandBCallback`, :class:`CometCallback`, and others, if not already set. 26 | 27 | .. important:: The :data:`config` should be set *after* initializing the trainer and attaching all 28 | other callbacks. 29 | """ 30 | 31 | fname: str = "config.json" 32 | save_data_paths: Optional[bool] = None 33 | data_paths_fname: Optional[str] = None 34 | 35 | _config: Optional[Dict[str, Any]] = None 36 | 37 | @property 38 | def config(self) -> Optional[Dict[str, Any]]: 39 | """ 40 | The JSON config dictionary to record. 41 | """ 42 | return self._config 43 | 44 | @config.setter 45 | def config(self, config: Dict[str, Any]): 46 | self._config = config 47 | for callback_name, callback in self.trainer.callbacks.items(): 48 | if ( 49 | isinstance(callback, (WandBCallback, CometCallback, BeakerCallback)) 50 | and callback.config is None 51 | ): 52 | log.info( 53 | f"Setting config for '{callback_name}' callback of type '{callback.__class__.__name__}'" 54 | ) 55 | callback.config = config 56 | 57 | def post_checkpoint_saved(self, path: PathOrStr): 58 | if get_rank() != 0: 59 | return 60 | 61 | if self.config is None: 62 | log.warning(f"Config not set on {self.__class__.__name__}, doing nothing") 63 | else: 64 | self.trainer.write_file(self.fname, json.dumps(self.config), dir=path) 65 | 66 | if self.save_data_paths is not False: 67 | if isinstance(self.trainer.data_loader, NumpyDataLoaderBase): 68 | ds = self.trainer.data_loader.dataset 69 | all_paths = "\n".join(str(p) for p in ds.paths) 70 | self.trainer.write_file( 71 | self.data_paths_fname or DEFAULT_DATA_PATHS_FNAME, all_paths, dir=path 72 | ) 73 | elif self.save_data_paths: 74 | log.warning( 75 | f"Unable to save paths for data loader of type '{self.trainer.data_loader.__class__.__name__}' (not implemented)" 76 | ) 77 | -------------------------------------------------------------------------------- /src/test/nn/moe/router_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from olmo_core.nn.moe.router import MoELinearRouter, MoERouterGatingFunction 5 | from olmo_core.testing import DEVICES 6 | 7 | 8 | @pytest.mark.parametrize("device", DEVICES) 9 | @pytest.mark.parametrize( 10 | "uniform_expert_assignment", 11 | [ 12 | pytest.param(True, id="uniform"), 13 | pytest.param(False, id="computed"), 14 | ], 15 | ) 16 | @pytest.mark.parametrize( 17 | "gating_function", 18 | [ 19 | pytest.param(MoERouterGatingFunction.softmax, id="softmax"), 20 | pytest.param(MoERouterGatingFunction.sigmoid, id="sigmoid"), 21 | ], 22 | ) 23 | def test_router( 24 | device: torch.device, uniform_expert_assignment: bool, gating_function: MoERouterGatingFunction 25 | ): 26 | router = MoELinearRouter( 27 | d_model=128, 28 | num_experts=4, 29 | jitter_eps=0.1, 30 | top_k=2, 31 | normalize_expert_weights=True, 32 | uniform_expert_assignment=uniform_expert_assignment, 33 | gating_function=gating_function, 34 | ).to(device) 35 | 36 | x = torch.randn((2, 4, 128), device=device) 37 | weights, indices, bz_per_expert, _ = router(x) 38 | 39 | assert weights.shape == (2, 4, 2) 40 | assert indices.shape == (2, 4, 2) 41 | assert bz_per_expert.shape == (4,) 42 | 43 | 44 | @pytest.mark.parametrize("device", DEVICES) 45 | def test_router_with_bias_gamma(device: torch.device): 46 | router1 = MoELinearRouter( 47 | d_model=128, 48 | num_experts=4, 49 | top_k=2, 50 | bias_gamma=0.001, 51 | ).to(device) 52 | router1.reset_parameters() 53 | 54 | assert router1.score_bias is not None 55 | assert router1.score_bias.nonzero().sum().item() == 0 # type: ignore 56 | assert router1.score_bias_batch_size_per_expert is not None 57 | assert router1.score_bias_batch_size_per_expert.nonzero().sum().item() == 0 58 | 59 | router2 = MoELinearRouter( 60 | d_model=128, 61 | num_experts=4, 62 | top_k=2, 63 | ).to(device) 64 | router2.reset_parameters() 65 | state_dict = router1.state_dict() 66 | del state_dict["score_bias"] 67 | router2.load_state_dict(state_dict) 68 | 69 | x = torch.randn((2, 4, 128), device=device) 70 | 71 | # At this point, the output should be exactly the same as it would be without a bias gamma. 72 | weights1, indices1, bz_per_expert1, _ = router1(x) 73 | weights2, indices2, bz_per_expert2, _ = router2(x) 74 | torch.testing.assert_close(weights1, weights2) 75 | torch.testing.assert_close(indices1, indices2) 76 | torch.testing.assert_close(bz_per_expert1, bz_per_expert2) 77 | 78 | assert router1.batch_size_per_expert.sum().item() == 8 * 2 79 | 80 | # Update the biases and check. 81 | router1.post_batch() 82 | assert router1.score_bias.nonzero().sum().item() > 0 # type: ignore 83 | assert router1.score_bias_batch_size_per_expert is not None 84 | assert router1.score_bias_batch_size_per_expert.nonzero().sum().item() == 0 85 | -------------------------------------------------------------------------------- /src/olmo_core/train/callbacks/gpu_memory_monitor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from typing import ClassVar, Optional 4 | 5 | import torch 6 | 7 | from .callback import Callback 8 | 9 | log = logging.getLogger(__name__) 10 | 11 | 12 | @dataclass 13 | class GPUMemoryMonitorCallback(Callback): 14 | """ 15 | Adds metrics for GPU memory statistics. 16 | """ 17 | 18 | priority: ClassVar[int] = -1 19 | device_id: Optional[int] = None 20 | _num_alloc_retries: int = 0 21 | 22 | @property 23 | def device(self) -> torch.device: 24 | return ( 25 | torch.device("cuda") 26 | if self.device_id is None 27 | else torch.device(f"cuda:{self.device_id}") 28 | ) 29 | 30 | @property 31 | def device_name(self) -> str: 32 | return torch.cuda.get_device_name(self.device) 33 | 34 | @property 35 | def device_capacity(self) -> int: 36 | return torch.cuda.get_device_properties(self.device).total_memory 37 | 38 | def pre_train(self): 39 | torch.cuda.reset_peak_memory_stats() 40 | torch.cuda.empty_cache() 41 | log.info( 42 | f"GPU capacity: {self.device_name} with {self._to_gib(self.device_capacity):.2f}GiB memory " 43 | f"of which {self._to_gib(torch.cuda.memory_allocated()):.2f}GiB is currently allocated and " 44 | f"{self._to_gib(torch.cuda.memory_reserved()):.2f}GiB is currently reserved." 45 | ) 46 | 47 | def post_step(self): 48 | cuda_info = torch.cuda.memory_stats(self.device) 49 | 50 | max_active = cuda_info["active_bytes.all.peak"] 51 | max_active_gib = self._to_gib(max_active) 52 | max_active_pct = self._to_pct(max_active) 53 | self.trainer.record_metric("system/GPU active mem (GiB)", max_active_gib) 54 | self.trainer.record_metric("system/GPU active mem (%)", max_active_pct) 55 | 56 | max_reserved = cuda_info["reserved_bytes.all.peak"] 57 | max_reserved_gib = self._to_gib(max_reserved) 58 | max_reserved_pct = self._to_pct(max_reserved) 59 | self.trainer.record_metric("system/GPU reserved mem (GiB)", max_reserved_gib) 60 | self.trainer.record_metric("system/GPU reserved mem (%)", max_reserved_pct) 61 | 62 | num_retries = cuda_info["num_alloc_retries"] 63 | if num_retries > self._num_alloc_retries: 64 | log.warning(f"{num_retries} CUDA memory allocation retries.") 65 | self._num_alloc_retries = num_retries 66 | 67 | num_ooms = cuda_info["num_ooms"] 68 | if num_ooms > 0: 69 | log.warning(f"{num_ooms} CUDA OOM errors thrown.") 70 | 71 | torch.cuda.reset_peak_memory_stats() 72 | 73 | def _to_pct(self, memory: float) -> float: 74 | return 100 * memory / self.device_capacity 75 | 76 | def _to_gib(self, memory_in_bytes: int) -> float: 77 | # NOTE: GiB (gibibyte) is 1024, vs GB is 1000 78 | _gib_in_bytes = 1024 * 1024 * 1024 79 | memory_in_gib = memory_in_bytes / _gib_in_bytes 80 | return memory_in_gib 81 | -------------------------------------------------------------------------------- /src/olmo_core/generate/generation_module/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import TYPE_CHECKING, List, Optional 3 | 4 | from olmo_core.config import Config 5 | 6 | if TYPE_CHECKING: 7 | pass 8 | 9 | 10 | @dataclass 11 | class GenerationConfig(Config): 12 | """Configuration for text generation.""" 13 | 14 | pad_token_id: int 15 | """Padding token ID.""" 16 | 17 | eos_token_id: int 18 | """End of sequence token ID.""" 19 | 20 | max_length: Optional[int] = None 21 | """Maximum length of input + newly generated tokens.""" 22 | 23 | max_new_tokens: Optional[int] = None 24 | """Maximum number of new tokens to generate. If provided, this takes precedence over max_length.""" 25 | 26 | do_sample: bool = True 27 | """Whether to use sampling for generation. If False, greedy decoding is used. This overrides temperature, top_k, and top_p.""" 28 | 29 | temperature: float = 0.0 30 | """Temperature for sampling. If 0, this is equivalent to greedy selection.""" 31 | 32 | top_k: int = -1 33 | """Top-k sampling. Only consider the top k tokens with the highest probabilities. -1 means no filtering.""" 34 | 35 | top_p: float = 1.0 36 | """Top-p (nucleus) sampling. Only consider the smallest set of tokens whose cumulative probability exceeds this threshold. 1.0 means no filtering.""" 37 | 38 | use_cache: bool = True 39 | """Whether to use an inference cache (e.g. a kv-cache) for generation.""" 40 | 41 | stop_token_ids: Optional[List[int]] = None 42 | """Tokens to stop generation at. If provided, the generation will stop when any of these tokens are generated.""" 43 | 44 | until: Optional[List[str]] = None 45 | """Strings to stop generation at. If provided, the generation will stop when any of these strings are generated.""" 46 | 47 | def __post_init__(self): 48 | self.validate() 49 | 50 | def validate(self): 51 | """Validate the generation configuration.""" 52 | if self.pad_token_id < 0: 53 | raise ValueError(f"pad_token_id must be non-negative, got {self.pad_token_id}") 54 | if self.eos_token_id < 0: 55 | raise ValueError(f"eos_token_id must be non-negative, got {self.eos_token_id}") 56 | if self.pad_token_id == self.eos_token_id: 57 | raise ValueError( 58 | f"pad_token_id and eos_token_id must be different, got {self.pad_token_id} and {self.eos_token_id}" 59 | ) 60 | if self.max_length is not None and self.max_length <= 0: 61 | raise ValueError(f"max_length must be positive, got {self.max_length}") 62 | if self.max_new_tokens is not None and self.max_new_tokens <= 0: 63 | raise ValueError(f"max_new_tokens must be positive, got {self.max_new_tokens}") 64 | if self.temperature < 0: 65 | raise ValueError(f"temperature must be non-negative, got {self.temperature}") 66 | if self.top_k <= 0 and self.top_k != -1: 67 | raise ValueError(f"top_k must be positive or -1, got {self.top_k}") 68 | if self.top_p <= 0.0 or self.top_p > 1.0: 69 | raise ValueError(f"top_p must be in (0, 1], got {self.top_p}") 70 | -------------------------------------------------------------------------------- /src/test/nn/moe/mlp_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | from olmo_core.distributed.parallel import ( 5 | ExpertParallelConfig, 6 | build_expert_parallel_mesh, 7 | get_ep_mesh, 8 | ) 9 | from olmo_core.distributed.utils import get_local_tensor 10 | from olmo_core.nn.moe.mlp import DroplessMoEMLP, MoEMLP 11 | from olmo_core.testing import ( 12 | requires_gpu, 13 | requires_grouped_gemm, 14 | requires_multi_gpu, 15 | run_distributed_test, 16 | ) 17 | from olmo_core.utils import get_default_device 18 | 19 | 20 | @requires_gpu 21 | def test_mlp(): 22 | mlp = MoEMLP( 23 | d_model=128, hidden_size=256, num_experts=2, init_device="cuda", dtype=torch.bfloat16 24 | ) 25 | x = torch.randn(2, 3, 128, device="cuda", dtype=torch.bfloat16) 26 | out = mlp(x) 27 | assert out.shape == (2, 3, 128) 28 | 29 | 30 | @requires_gpu 31 | @requires_grouped_gemm 32 | def test_dropless_mlp(): 33 | mlp = DroplessMoEMLP( 34 | d_model=128, hidden_size=256, num_experts=2, init_device="cuda", dtype=torch.bfloat16 35 | ) 36 | x = torch.randn(5, 128, device="cuda", dtype=torch.bfloat16) 37 | tokens_per_expert = torch.tensor([3, 2], device="cuda") 38 | out = mlp(x, tokens_per_expert) 39 | assert out.shape == (5, 128) 40 | 41 | 42 | def run_mlp_with_expert_parallelism(): 43 | world_mesh = build_expert_parallel_mesh(ExpertParallelConfig(degree=dist.get_world_size())) 44 | ep_mesh = get_ep_mesh(world_mesh) 45 | 46 | mlp = MoEMLP( 47 | d_model=128, 48 | hidden_size=256, 49 | num_experts=dist.get_world_size() * 2, 50 | init_device="meta", 51 | dtype=torch.bfloat16, 52 | ) 53 | mlp.apply_ep(ep_mesh) 54 | mlp.to_empty(device=get_default_device()) 55 | assert get_local_tensor(mlp.w1).shape == (2 * 128, 256) 56 | 57 | x = torch.randn(2, 3, 128, device="cuda", dtype=torch.bfloat16) 58 | out = mlp(x) 59 | 60 | assert out.shape == (2, 3, 128) 61 | 62 | 63 | @requires_multi_gpu 64 | def test_mlp_with_expert_parallelism(): 65 | run_distributed_test(run_mlp_with_expert_parallelism, backend="nccl", start_method="spawn") 66 | 67 | 68 | def run_dropless_mlp_with_expert_parallelism(): 69 | world_mesh = build_expert_parallel_mesh(ExpertParallelConfig(degree=dist.get_world_size())) 70 | ep_mesh = get_ep_mesh(world_mesh) 71 | 72 | mlp = DroplessMoEMLP( 73 | d_model=128, 74 | hidden_size=256, 75 | num_experts=dist.get_world_size() * 2, 76 | init_device="meta", 77 | dtype=torch.bfloat16, 78 | ) 79 | mlp.apply_ep(ep_mesh) 80 | mlp.to_empty(device=get_default_device()) 81 | assert get_local_tensor(mlp.w1).shape == (2 * 256, 128) 82 | 83 | x = torch.randn(5, 128, device="cuda", dtype=torch.bfloat16) 84 | tokens_per_expert = torch.tensor([2, 3], device="cuda") 85 | out = mlp(x, tokens_per_expert) 86 | 87 | assert out.shape == (5, 128) 88 | 89 | 90 | @requires_multi_gpu 91 | @requires_grouped_gemm 92 | def test_dropless_mlp_with_expert_parallelism(): 93 | run_distributed_test( 94 | run_dropless_mlp_with_expert_parallelism, backend="nccl", start_method="spawn" 95 | ) 96 | -------------------------------------------------------------------------------- /src/olmo_core/eval/metrics.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import Optional, Union 3 | 4 | import torch 5 | import torch.distributed as dist 6 | 7 | from ..distributed.utils import all_reduce_value 8 | from ..utils import get_default_device 9 | 10 | __all__ = ["Metric", "MeanMetric"] 11 | 12 | 13 | class Metric(metaclass=ABCMeta): 14 | """ 15 | Base class for evaluation metrics. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | device: Optional[torch.device] = None, 21 | process_group: Optional[dist.ProcessGroup] = None, 22 | ): 23 | self.device = device if device is not None else get_default_device() 24 | self.process_group = process_group 25 | 26 | @abstractmethod 27 | def update(self, *args, **kwargs) -> None: 28 | """ 29 | Update the metric. 30 | """ 31 | raise NotImplementedError 32 | 33 | @abstractmethod 34 | def compute(self) -> torch.Tensor: 35 | """ 36 | Compute the metric. 37 | """ 38 | raise NotImplementedError 39 | 40 | @abstractmethod 41 | def reset(self) -> None: 42 | """ 43 | Reset the metric. 44 | """ 45 | raise NotImplementedError 46 | 47 | def as_tensor(self, value: Union[float, torch.Tensor]) -> torch.Tensor: 48 | if not isinstance(value, torch.Tensor): 49 | value = torch.tensor(value, dtype=torch.float32) 50 | return value.to(device=self.device, non_blocking=self.device.type != "cpu") 51 | 52 | 53 | class MeanMetric(Metric): 54 | """ 55 | Computes the mean over a stream of values. 56 | """ 57 | 58 | def __init__( 59 | self, 60 | device: Optional[torch.device] = None, 61 | process_group: Optional[dist.ProcessGroup] = None, 62 | ): 63 | super().__init__(device=device, process_group=process_group) 64 | self.weighted_sum = torch.tensor(0.0, device=self.device) 65 | self.weight = torch.tensor(0.0, device=self.device) 66 | 67 | def update( 68 | self, value: Union[float, torch.Tensor], weight: Union[float, torch.Tensor] = 1.0 69 | ) -> None: 70 | """ 71 | :param value: The latest value to update the metric with. Could be a tensor of values. 72 | :param weight: The corresponding weight(s) for the value. Should be the same shape as ``value``. 73 | """ 74 | value = self.as_tensor(value) 75 | weight = torch.broadcast_to(self.as_tensor(weight), value.shape) 76 | if value.numel() == 0: 77 | return 78 | self.weighted_sum += (value * weight).sum() 79 | self.weight += weight.sum() 80 | 81 | def compute(self) -> torch.Tensor: 82 | """ 83 | Computes the mean over the values and weights given. 84 | """ 85 | weighted_sum = all_reduce_value( 86 | self.weighted_sum, device=self.device, group=self.process_group 87 | ) 88 | weight = all_reduce_value(self.weight, device=self.device, group=self.process_group) 89 | return weighted_sum / weight 90 | 91 | def reset(self) -> None: 92 | self.weighted_sum.zero_() 93 | self.weight.zero_() 94 | -------------------------------------------------------------------------------- /src/olmo_core/optim/noop.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Type 3 | 4 | import torch 5 | 6 | from .config import OptimConfig 7 | from .skip_step_optimizer import SkipStepOptimizer 8 | 9 | 10 | class NoOpOptimizer(SkipStepOptimizer): 11 | """ 12 | A no-op optimizer that performs no parameter updates but maintains all step skipping logic. 13 | 14 | This optimizer is useful for gathering statistics from training without actually modifying 15 | the model parameters. It tracks losses and gradient norms, computes step factors based on 16 | rolling statistics, but does not apply any updates to the model. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | params, 22 | lr: float = 1e-3, 23 | rolling_interval_length: int = 128, 24 | sigma_factor: int = 6, 25 | ) -> None: 26 | defaults = dict(lr=lr) 27 | super().__init__( 28 | params, 29 | defaults, 30 | rolling_interval_length=rolling_interval_length, 31 | sigma_factor=sigma_factor, 32 | ) 33 | self._step_skipped: Optional[torch.Tensor] = None 34 | 35 | @property 36 | def step_skipped(self) -> torch.Tensor: 37 | if self._step_skipped is not None: 38 | return self._step_skipped 39 | else: 40 | return torch.tensor(0.0) 41 | 42 | @torch.no_grad() 43 | def step(self, closure=None) -> None: 44 | if closure is not None: 45 | with torch.enable_grad(): 46 | closure() 47 | 48 | # Compute step factor to maintain step skipping logic 49 | step_factor = self.get_step_factor() 50 | self._step_skipped = 1 - step_factor 51 | 52 | # Iterate through parameters to maintain optimizer structure 53 | # but perform no updates 54 | for group in self.param_groups: 55 | for p in group["params"]: 56 | if p.grad is None: 57 | continue 58 | 59 | # Initialize state if needed (for consistency) 60 | state = self.state[p] 61 | if len(state) == 0: 62 | state["step"] = torch.zeros((), dtype=torch.float32, device=p.device) 63 | 64 | # Increment step counter 65 | state["step"] += step_factor 66 | 67 | 68 | @dataclass 69 | class NoOpConfig(OptimConfig): 70 | """ 71 | Configuration class for building a :class:`NoOpOptimizer`. 72 | 73 | This optimizer performs no parameter updates but maintains step skipping logic 74 | for gathering statistics during training. 75 | """ 76 | 77 | lr: float = 1e-3 78 | """Learning rate (not used for updates, but maintained for compatibility).""" 79 | 80 | rolling_interval_length: int = 128 81 | """ 82 | The length of the rolling interval to use for computing the mean and standard deviation 83 | of the loss and gradient norm. 84 | """ 85 | 86 | sigma_factor: int = 6 87 | """ 88 | The number of standard deviations above the mean loss/grad norm to skip a step. 89 | """ 90 | 91 | @classmethod 92 | def optimizer(cls) -> Type[NoOpOptimizer]: 93 | return NoOpOptimizer 94 | -------------------------------------------------------------------------------- /src/olmo_core/data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset, data loaders, and config builders for use with the :class:`~olmo_core.train.Trainer`. 3 | 4 | Overview 5 | -------- 6 | 7 | For text-based data you should prepare your data by writing token IDs to numpy arrays on disk, using the 8 | `Dolma toolkit `_ for example. 9 | Then configure and build your dataset using one of the 10 | :class:`~olmo_core.data.numpy_dataset.NumpyDatasetConfigBase` builders (for example 11 | ``NumpyFSLDatasetConfig``), build your data loader with the 12 | :class:`~olmo_core.data.data_loader.NumpyDataLoaderConfig` 13 | builder, and pass it to :meth:`TrainerConfig.build() `. 14 | """ 15 | 16 | from .collator import DataCollator, PaddingDirection, ByteDataCollator 17 | from .data_loader import ( 18 | DataLoaderBase, 19 | NumpyDataLoaderBase, 20 | NumpyDataLoaderConfig, 21 | NumpyFSLDataLoader, 22 | NumpyVSLDataLoader, 23 | TextDataLoaderBase, 24 | ) 25 | from .mixes import DataMix, DataMixBase 26 | from .numpy_dataset import ( 27 | InstanceFilterConfig, 28 | NumpyDatasetBase, 29 | NumpyDatasetConfig, 30 | NumpyByteFSLDataset, 31 | NumpyByteFSLDatasetConfig, 32 | NumpyBytePaddedFSLDataset, 33 | NumpyBytePaddedFSLDatasetConfig, 34 | NumpyFSLDataset, 35 | NumpyFSLDatasetBase, 36 | NumpyFSLDatasetConfig, 37 | NumpyInterleavedFSLDatasetConfig, 38 | NumpyPackedFSLDataset, 39 | NumpyPackedFSLDatasetConfig, 40 | NumpyPaddedFSLDataset, 41 | NumpyPaddedFSLDatasetConfig, 42 | NumpyVSLDataset, 43 | NumpyVSLDatasetConfig, 44 | VSLCurriculum, 45 | VSLCurriculumConfig, 46 | VSLCurriculumType, 47 | VSLGrowLinearCurriculum, 48 | VSLGrowP2Curriculum, 49 | VSLGrowthCurriculum, 50 | VSLNaturalCurriculum, 51 | ) 52 | from .tokenizer import ByteTokenizerConfig, ByteTokenizer, TokenizerConfig, TokenizerName 53 | from .types import LongDocStrategy, NumpyDatasetDType 54 | 55 | __all__ = [ 56 | "NumpyDatasetBase", 57 | "NumpyFSLDatasetBase", 58 | "NumpyFSLDataset", 59 | "NumpyByteFSLDataset", 60 | "NumpyPaddedFSLDataset", 61 | "NumpyBytePaddedFSLDataset", 62 | "NumpyPackedFSLDataset", 63 | "NumpyVSLDataset", 64 | "VSLCurriculum", 65 | "VSLNaturalCurriculum", 66 | "VSLGrowthCurriculum", 67 | "VSLGrowP2Curriculum", 68 | "VSLGrowLinearCurriculum", 69 | "NumpyDatasetConfig", 70 | "NumpyFSLDatasetConfig", 71 | "NumpyByteFSLDatasetConfig", 72 | "NumpyPaddedFSLDatasetConfig", 73 | "NumpyBytePaddedFSLDatasetConfig", 74 | "NumpyPackedFSLDatasetConfig", 75 | "NumpyInterleavedFSLDatasetConfig", 76 | "NumpyVSLDatasetConfig", 77 | "InstanceFilterConfig", 78 | "VSLCurriculumType", 79 | "VSLCurriculumConfig", 80 | "NumpyDatasetDType", 81 | "TokenizerConfig", 82 | "ByteDataCollator", 83 | "ByteTokenizerConfig", 84 | "ByteTokenizer", 85 | "TokenizerName", 86 | "DataMixBase", 87 | "DataMix", 88 | "DataCollator", 89 | "PaddingDirection", 90 | "DataLoaderBase", 91 | "TextDataLoaderBase", 92 | "NumpyDataLoaderBase", 93 | "NumpyFSLDataLoader", 94 | "NumpyVSLDataLoader", 95 | "NumpyDataLoaderConfig", 96 | "LongDocStrategy", 97 | ] 98 | -------------------------------------------------------------------------------- /src/olmo_core/nn/bolmo/embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | primes = [ 5 | 1000000007, 6 | 5915587277, 7 | 1500450271, 8 | 3267000013, 9 | 5754853343, 10 | 4093082899, 11 | 9576890767, 12 | 3628273133, 13 | 2860486313, 14 | 5463458053, 15 | 3367900313, 16 | ] 17 | 18 | 19 | def rolling_polynomial_hash(t, hash_func_nb: int = 0): 20 | # DIVERGENCE FROM BLT: avoid sync 21 | prime_powers = primes[hash_func_nb] ** torch.arange(t.shape[-1], dtype=torch.int64, device=t.device) 22 | return torch.sum(t * prime_powers, dim=-1) 23 | 24 | 25 | def byte_group_hash_function( 26 | x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000 27 | ): 28 | """ 29 | Returns a hash of the input x and maps it to a value in the range [0, max_hash]. 30 | 31 | expects: x of shape (batch_size, seq_len) with values as ids in the token vocab. 32 | returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash]. 33 | 34 | Note: max hash can make a big difference on the number of collisions. 35 | """ 36 | with torch.no_grad(): 37 | bs, seq_len = x.shape 38 | 39 | prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device) 40 | x = torch.cat([prefix, x], dim=1) 41 | windows = x.unfold(1, group_size, 1) 42 | hashes = rolling_polynomial_hash(windows, hash_func_nb) 43 | hash_values_range = hashes % max_hash 44 | hash_values_range.requires_grad = False 45 | return hash_values_range 46 | 47 | 48 | def add_hash_embeddings( 49 | embeddings: torch.Tensor, 50 | tokens: torch.Tensor, 51 | encoder_hash_tok_embeddings: nn.ModuleList, 52 | encoder_hash_byte_group_nb_functions: int, 53 | encoder_hash_byte_group_size: list, 54 | encoder_hash_byte_group_vocab: list, 55 | ) -> torch.Tensor: 56 | """ 57 | Compute embeddings using hash token embeddings. 58 | 59 | Args: 60 | embeddings: Input embeddings tensor of shape (batch_size, seq_len, d_model) 61 | tokens: Input tokens tensor 62 | encoder_hash_tok_embedding: ModuleList of hash token embeddings 63 | encoder_hash_byte_group_nb_functions: Number of hash functions 64 | encoder_hash_byte_group_size: List of byte group sizes 65 | encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings 66 | 67 | Returns: 68 | torch.Tensor: Embeddings tensor augmented with hash token embeddings, shape (batch_size, seq_len, d_model) 69 | """ 70 | out_embeddings = embeddings 71 | 72 | hash_embed_idx = 0 73 | for byte_group_size in encoder_hash_byte_group_size: 74 | for func_nb in range(encoder_hash_byte_group_nb_functions): 75 | hash_ids = byte_group_hash_function( 76 | tokens, 77 | byte_group_size, 78 | hash_func_nb=func_nb, 79 | max_hash=encoder_hash_byte_group_vocab[hash_embed_idx], 80 | ) 81 | hash_tok_embedding = encoder_hash_tok_embeddings[hash_embed_idx] 82 | out_embeddings = out_embeddings + hash_tok_embedding(hash_ids) 83 | hash_embed_idx += 1 84 | 85 | assert hash_embed_idx == len(encoder_hash_tok_embeddings) 86 | return out_embeddings -------------------------------------------------------------------------------- /src/olmo_core/train/callbacks/console_logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass, field 3 | from fnmatch import fnmatch 4 | from typing import Dict, List, Optional 5 | 6 | from olmo_core.utils import format_float, format_timedelta 7 | 8 | from .callback import Callback 9 | 10 | log = logging.getLogger(__name__) 11 | 12 | 13 | @dataclass 14 | class ConsoleLoggerCallback(Callback): 15 | """ 16 | Logs progress and a subset of metrics to the console. 17 | 18 | .. important:: 19 | This callback gets added automatically if you don't explicitly configure it. 20 | If you want to override this callback you should subclass it. 21 | """ 22 | 23 | log_interval: int = 1 24 | """ 25 | How often, in steps, to log progress to the console. 26 | """ 27 | 28 | metrics_log_interval: Optional[int] = None 29 | """ 30 | How often, in steps, to log metrics to the console. If not set, defaults to :data:`log_interval`. 31 | """ 32 | 33 | metrics: List[str] = field( 34 | default_factory=lambda: [ 35 | "train/CE loss", 36 | "train/PPL", 37 | "train/Z loss", 38 | "train/load balancing loss", 39 | "train/router Z loss", 40 | "train/block */load imbalance", 41 | "system/*", 42 | "optim/total grad norm", 43 | "optim/step skipped", 44 | "optim/LR*", 45 | "throughput/*", 46 | ] 47 | ) 48 | """ 49 | Metrics to log to the console. Wildcards are supported. 50 | """ 51 | 52 | def post_step(self): 53 | if self._should_log_metrics(self.step): 54 | # Will log to console from `self.log_metrics()`. 55 | return 56 | 57 | if self.step % self.log_interval != 0: 58 | return 59 | 60 | log.info(self._get_progress_marker(self.step)) 61 | 62 | def log_metrics(self, step: int, metrics: Dict[str, float]): 63 | if not self._should_log_metrics(step): 64 | return 65 | 66 | prefix = self._get_progress_marker(step, include_eta=True) 67 | log.info( 68 | f"{prefix}\n" 69 | + "\n".join( 70 | [ 71 | f" {name}={format_float(value)}" 72 | for name, value in metrics.items() 73 | if any(fnmatch(name, pat) for pat in self.metrics) 74 | ] 75 | ) 76 | ) 77 | 78 | def _get_progress_marker(self, step: int, include_eta: bool = False) -> str: 79 | if include_eta and (eta := self.trainer.training_progress.time_remaining) is not None: 80 | eta_str = format_timedelta(eta).replace(", ", "") 81 | if self.trainer.hard_stop: 82 | eta_str = f"{eta_str}(hard stop)" 83 | return ( 84 | f"[step={step}/{self.trainer.max_steps},epoch={self.trainer.epoch},eta={eta_str}]" 85 | ) 86 | else: 87 | return f"[step={step}/{self.trainer.max_steps},epoch={self.trainer.epoch}]" 88 | 89 | def _should_log_metrics(self, step: int) -> bool: 90 | metrics_log_interval = self.metrics_log_interval or self.log_interval 91 | if step == 1 or (step > 1 and step % metrics_log_interval == 0): 92 | return True 93 | else: 94 | return False 95 | -------------------------------------------------------------------------------- /src/scripts/train/nGPT-1B.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a 1B nGPT model. Run this script without any arguments to see usage info. 3 | """ 4 | 5 | from functools import partial 6 | 7 | from olmo_core.config import DType 8 | from olmo_core.distributed.parallel import DataParallelType 9 | from olmo_core.float8 import Float8Config 10 | from olmo_core.internal.experiment import CommonComponents, build_config, main 11 | from olmo_core.nn.transformer import TransformerConfig 12 | from olmo_core.optim import AdamConfig, CosWithWarmup 13 | from olmo_core.train import TrainerConfig 14 | from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback 15 | from olmo_core.train.train_module import ( 16 | TransformerDataParallelConfig, 17 | TransformerTrainModuleConfig, 18 | ) 19 | 20 | SEQUENCE_LENGTH = 4096 21 | GLOBAL_BATCH_SIZE = 1024 * 4096 22 | 23 | 24 | def build_model_config(common: CommonComponents) -> TransformerConfig: 25 | return TransformerConfig.ngpt_1B( 26 | vocab_size=common.tokenizer.padded_vocab_size(), 27 | ) 28 | 29 | 30 | def build_train_module_config(common: CommonComponents) -> TransformerTrainModuleConfig: 31 | return TransformerTrainModuleConfig( 32 | rank_microbatch_size=4 * 4096, # TODO: can we increase this? 33 | max_sequence_length=common.max_sequence_length, 34 | optim=AdamConfig( 35 | lr=4e-4, 36 | betas=(0.9, 0.95), 37 | fused=True, 38 | ), 39 | compile_model=True, 40 | dp_config=TransformerDataParallelConfig( 41 | name=DataParallelType.hsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32 42 | ), 43 | float8_config=Float8Config(enabled=False), 44 | z_loss_multiplier=1e-5, 45 | max_grad_norm=1.0, 46 | scheduler=CosWithWarmup(warmup_steps=0), 47 | ) 48 | 49 | 50 | def build_trainer_config(common: CommonComponents) -> TrainerConfig: 51 | return ( 52 | TrainerConfig( 53 | save_folder=common.save_folder, 54 | save_overwrite=True, 55 | metrics_collect_interval=10, 56 | cancel_check_interval=1, 57 | ) 58 | .with_callback( 59 | "checkpointer", 60 | CheckpointerCallback( 61 | save_interval=10_000, 62 | ephemeral_save_interval=1000, 63 | save_async=True, 64 | ), 65 | ) 66 | .with_callback( 67 | "comet", 68 | CometCallback( 69 | name=common.run_name, 70 | workspace="ai2", 71 | project="OLMo-core-1B", 72 | enabled=True, 73 | cancel_check_interval=10, 74 | ), 75 | ) 76 | .with_callback( 77 | "wandb", 78 | WandBCallback( 79 | name=common.run_name, 80 | entity="ai2-llm", 81 | project="OLMo-core-1B", 82 | enabled=False, 83 | cancel_check_interval=10, 84 | ), 85 | ) 86 | ) 87 | 88 | 89 | if __name__ == "__main__": 90 | config_builder = partial( 91 | build_config, 92 | global_batch_size=GLOBAL_BATCH_SIZE, 93 | max_sequence_length=SEQUENCE_LENGTH, 94 | model_config_builder=build_model_config, 95 | train_module_config_builder=build_train_module_config, 96 | trainer_config_builder=build_trainer_config, 97 | ) 98 | main(config_builder=config_builder) 99 | -------------------------------------------------------------------------------- /src/scripts/train/Llama3-8B.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a Llama 8B OLMo model. Run this script without any arguments to see usage info. 3 | """ 4 | 5 | import logging 6 | from functools import partial 7 | 8 | from olmo_core.config import DType 9 | from olmo_core.distributed.parallel import DataParallelType 10 | from olmo_core.float8 import Float8Config 11 | from olmo_core.internal.experiment import CommonComponents, build_config, main 12 | from olmo_core.nn.transformer import TransformerConfig 13 | from olmo_core.optim import AdamWConfig, CosWithWarmup 14 | from olmo_core.train import TrainerConfig 15 | from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback 16 | from olmo_core.train.train_module import ( 17 | TransformerDataParallelConfig, 18 | TransformerTrainModuleConfig, 19 | ) 20 | 21 | SEQUENCE_LENGTH = 4096 22 | GLOBAL_BATCH_SIZE = 1024 * 4096 23 | 24 | log = logging.getLogger(__name__) 25 | 26 | 27 | def build_model_config(common: CommonComponents) -> TransformerConfig: 28 | return TransformerConfig.llama3_8B(vocab_size=common.tokenizer.padded_vocab_size()) 29 | 30 | 31 | def build_train_module_config(common: CommonComponents) -> TransformerTrainModuleConfig: 32 | return TransformerTrainModuleConfig( 33 | rank_microbatch_size=2 * 4096, 34 | max_sequence_length=common.max_sequence_length, 35 | optim=AdamWConfig( 36 | lr=3e-4, 37 | weight_decay=0.1, 38 | betas=(0.9, 0.95), 39 | fused=True, 40 | ), 41 | compile_model=True, 42 | dp_config=TransformerDataParallelConfig( 43 | name=DataParallelType.hsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32 44 | ), 45 | float8_config=Float8Config(enabled=False), 46 | z_loss_multiplier=1e-5, 47 | max_grad_norm=1.0, 48 | scheduler=CosWithWarmup(warmup_steps=2000), 49 | ) 50 | 51 | 52 | def build_trainer_config(common: CommonComponents) -> TrainerConfig: 53 | return ( 54 | TrainerConfig( 55 | save_folder=common.save_folder, 56 | save_overwrite=True, 57 | metrics_collect_interval=10, 58 | cancel_check_interval=1, 59 | ) 60 | .with_callback( 61 | "checkpointer", 62 | CheckpointerCallback( 63 | save_interval=10_000, 64 | ephemeral_save_interval=250, 65 | save_async=True, 66 | ), 67 | ) 68 | .with_callback( 69 | "comet", 70 | CometCallback( 71 | name=common.run_name, 72 | workspace="ai2", 73 | project="Llama-8B", 74 | enabled=True, 75 | cancel_check_interval=10, 76 | ), 77 | ) 78 | .with_callback( 79 | "wandb", 80 | WandBCallback( 81 | name=common.run_name, 82 | entity="ai2-llm", 83 | project="Llama-8B", 84 | enabled=False, 85 | cancel_check_interval=10, 86 | ), 87 | ) 88 | ) 89 | 90 | 91 | if __name__ == "__main__": 92 | config_builder = partial( 93 | build_config, 94 | global_batch_size=GLOBAL_BATCH_SIZE, 95 | max_sequence_length=SEQUENCE_LENGTH, 96 | model_config_builder=build_model_config, 97 | train_module_config_builder=build_train_module_config, 98 | trainer_config_builder=build_trainer_config, 99 | ) 100 | main(config_builder=config_builder) 101 | -------------------------------------------------------------------------------- /src/test/distributed/checkpoint/filesystem_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.distributed as dist 4 | import torch.distributed.checkpoint as distcp 5 | from torch.distributed.tensor import Shard, distribute_tensor, init_device_mesh 6 | 7 | from olmo_core.distributed.checkpoint.filesystem import ( 8 | RemoteFileSystemReader, 9 | RemoteFileSystemWriter, 10 | ) 11 | from olmo_core.io import dir_is_empty 12 | from olmo_core.testing import BACKENDS, run_distributed_test 13 | from olmo_core.utils import get_default_device 14 | 15 | 16 | def run_save_and_load_with_dtensors(dir, throttle: bool = False): 17 | mesh = init_device_mesh(get_default_device().type, (dist.get_world_size(),)) 18 | 19 | x_full = torch.randn(4, 4, device=get_default_device()) 20 | y_full = torch.randn(4, 8, device=get_default_device()) 21 | # Make sure these tensors are the same across all ranks. We could scatt 22 | dist.broadcast(x_full, 0) 23 | dist.broadcast(y_full, 0) 24 | 25 | # Shard the tensors. 26 | x = distribute_tensor(x_full, mesh, [Shard(dim=0)]) 27 | y = distribute_tensor(y_full, mesh, [Shard(dim=0)]) 28 | 29 | # Save the sharded tensors. 30 | distcp.state_dict_saver.save( 31 | {"x": x, "y": y}, 32 | checkpoint_id=dir, 33 | storage_writer=RemoteFileSystemWriter(dir, thread_count=2, throttle_uploads=throttle), 34 | ) 35 | 36 | # Now create new sharded copies with a different sharding strategy and load the checkpoint. 37 | x_loaded = distribute_tensor(torch.zeros_like(x_full), mesh, [Shard(dim=1)]) 38 | y_loaded = distribute_tensor(torch.zeros_like(y_full), mesh, [Shard(dim=1)]) 39 | distcp.state_dict_loader.load( 40 | {"x": x_loaded, "y": y_loaded}, 41 | checkpoint_id=dir, 42 | storage_reader=RemoteFileSystemReader(dir, thread_count=2), 43 | ) 44 | 45 | # Make sure the loaded tensors match the original tensors. 46 | x_full_loaded = x_loaded.full_tensor() 47 | y_full_loaded = y_loaded.full_tensor() 48 | torch.testing.assert_close(x_full, x_full_loaded) 49 | torch.testing.assert_close(y_full, y_full_loaded) 50 | 51 | 52 | @pytest.mark.parametrize("backend", BACKENDS) 53 | def test_save_and_load_locally_with_dtensors(backend, tmp_path): 54 | run_distributed_test( 55 | run_save_and_load_with_dtensors, 56 | backend=backend, 57 | func_args=(tmp_path,), 58 | start_method="spawn", 59 | ) 60 | 61 | 62 | @pytest.mark.parametrize("backend", BACKENDS) 63 | @pytest.mark.parametrize("throttle", [True, False]) 64 | def test_save_and_load_remotely_to_s3_with_dtensors(backend, s3_checkpoint_dir, throttle): 65 | from botocore.exceptions import NoCredentialsError 66 | 67 | try: 68 | dir_is_empty(s3_checkpoint_dir) 69 | except NoCredentialsError: 70 | pytest.skip("Requires AWS credentials") 71 | 72 | run_distributed_test( 73 | run_save_and_load_with_dtensors, 74 | backend=backend, 75 | func_args=(s3_checkpoint_dir, throttle), 76 | start_method="spawn", # NOTE: forking causes a crash with boto3 77 | ) 78 | 79 | 80 | @pytest.mark.parametrize("backend", BACKENDS) 81 | @pytest.mark.parametrize("throttle", [True, False]) 82 | def test_save_and_load_remotely_to_gcs_with_dtensors(backend, gcs_checkpoint_dir, throttle): 83 | from google.auth.exceptions import DefaultCredentialsError 84 | 85 | try: 86 | dir_is_empty(gcs_checkpoint_dir) 87 | except DefaultCredentialsError: 88 | pytest.skip("Requires authentication with Google Cloud") 89 | 90 | run_distributed_test( 91 | run_save_and_load_with_dtensors, 92 | backend=backend, 93 | func_args=(gcs_checkpoint_dir, throttle), 94 | start_method="spawn", # NOTE: forking causes a crash with boto3 95 | ) 96 | -------------------------------------------------------------------------------- /src/scripts/train/OLMo2/OLMo2-13B.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a 13B OLMo model. Run this script without any arguments to see usage info. 3 | """ 4 | 5 | import logging 6 | from functools import partial 7 | 8 | from olmo_core.config import DType 9 | from olmo_core.distributed.parallel import DataParallelType 10 | from olmo_core.float8 import Float8Config 11 | from olmo_core.internal.experiment import CommonComponents, build_config, main 12 | from olmo_core.nn.transformer import TransformerConfig 13 | from olmo_core.optim import AdamWConfig, CosWithWarmup, OptimGroupOverride 14 | from olmo_core.train import TrainerConfig 15 | from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback 16 | from olmo_core.train.train_module import ( 17 | TransformerDataParallelConfig, 18 | TransformerTrainModuleConfig, 19 | ) 20 | 21 | SEQUENCE_LENGTH = 4096 22 | GLOBAL_BATCH_SIZE = 2048 * SEQUENCE_LENGTH 23 | 24 | log = logging.getLogger(__name__) 25 | 26 | 27 | def build_model_config(common: CommonComponents) -> TransformerConfig: 28 | return TransformerConfig.olmo2_13B(vocab_size=common.tokenizer.padded_vocab_size()) 29 | 30 | 31 | def build_train_module_config(common: CommonComponents) -> TransformerTrainModuleConfig: 32 | return TransformerTrainModuleConfig( 33 | rank_microbatch_size=1 * 4096, 34 | max_sequence_length=common.max_sequence_length, 35 | optim=AdamWConfig( 36 | lr=3e-4, 37 | weight_decay=0.1, 38 | betas=(0.9, 0.95), 39 | group_overrides=[ 40 | OptimGroupOverride(params=["embeddings.weight"], opts=dict(weight_decay=0.0)) 41 | ], 42 | fused=True, 43 | ), 44 | compile_model=True, 45 | dp_config=TransformerDataParallelConfig( 46 | name=DataParallelType.hsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32 47 | ), 48 | float8_config=Float8Config(enabled=False), 49 | z_loss_multiplier=1e-5, 50 | max_grad_norm=1.0, 51 | scheduler=CosWithWarmup(warmup_steps=2000), 52 | ) 53 | 54 | 55 | def build_trainer_config(common: CommonComponents) -> TrainerConfig: 56 | return ( 57 | TrainerConfig( 58 | save_folder=common.save_folder, 59 | save_overwrite=True, 60 | metrics_collect_interval=10, 61 | cancel_check_interval=1, 62 | ) 63 | .with_callback( 64 | "checkpointer", 65 | CheckpointerCallback( 66 | save_interval=10_000, 67 | ephemeral_save_interval=250, 68 | save_async=True, 69 | ), 70 | ) 71 | .with_callback( 72 | "comet", 73 | CometCallback( 74 | name=common.run_name, 75 | workspace="ai2", 76 | project="OLMo-core-13B", 77 | enabled=True, 78 | cancel_check_interval=10, 79 | ), 80 | ) 81 | .with_callback( 82 | "wandb", 83 | WandBCallback( 84 | name=common.run_name, 85 | entity="ai2-llm", 86 | project="OLMo-core-13B", 87 | enabled=False, 88 | cancel_check_interval=10, 89 | ), 90 | ) 91 | ) 92 | 93 | 94 | if __name__ == "__main__": 95 | config_builder = partial( 96 | build_config, 97 | global_batch_size=GLOBAL_BATCH_SIZE, 98 | max_sequence_length=SEQUENCE_LENGTH, 99 | model_config_builder=build_model_config, 100 | train_module_config_builder=build_train_module_config, 101 | trainer_config_builder=build_trainer_config, 102 | ) 103 | main(config_builder=config_builder) 104 | -------------------------------------------------------------------------------- /src/test/data/custom_data_loader.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Any, Dict, Iterable, List, Optional 3 | 4 | import torch 5 | 6 | from olmo_core.aliases import PathOrStr 7 | from olmo_core.data import DataCollator, TextDataLoaderBase 8 | 9 | 10 | class CustomDataLoader(TextDataLoaderBase): 11 | """ 12 | An example custom data loader that generates random token IDs. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | *, 18 | sequence_length: int, 19 | vocab_size: int, 20 | work_dir: PathOrStr, 21 | global_batch_size: int, 22 | dp_world_size: int = 1, 23 | dp_rank: int = 0, 24 | fs_local_rank: int = 0, 25 | seed: int = 0, 26 | total_batches: int = 2048, 27 | ): 28 | super().__init__( 29 | collator=DataCollator(pad_token_id=vocab_size - 1), 30 | work_dir=work_dir, 31 | global_batch_size=global_batch_size, 32 | dp_world_size=dp_world_size, 33 | dp_rank=dp_rank, 34 | fs_local_rank=fs_local_rank, 35 | ) 36 | assert self.rank_batch_size % sequence_length == 0 37 | self.sequence_length = sequence_length 38 | self.vocab_size = vocab_size 39 | self.seed = seed 40 | self._total_batches = total_batches 41 | self._dataset: Optional[List[torch.Tensor]] 42 | 43 | @property 44 | def total_batches(self) -> int: 45 | return self._total_batches 46 | 47 | def state_dict(self) -> Dict[str, Any]: 48 | return { 49 | "batches_processed": self.batches_processed, 50 | "seed": self.seed, 51 | "epoch": self._epoch, 52 | } 53 | 54 | def load_state_dict(self, state_dict: Dict[str, Any]): 55 | self.batches_processed = state_dict["batches_processed"] 56 | self.seed = state_dict["seed"] 57 | self._epoch = state_dict["epoch"] 58 | 59 | def reshuffle(self, epoch: Optional[int] = None, **kwargs): 60 | del kwargs # unused 61 | 62 | # Set current epoch. 63 | if epoch is None: 64 | epoch = 1 if self._epoch is None else self._epoch + 1 65 | self._epoch = epoch 66 | 67 | # Generate data. 68 | rng = random.Random(self.seed + self.epoch) 69 | instances_per_batch = self.global_batch_size // self.sequence_length 70 | total_instances = instances_per_batch * self.total_batches 71 | self._dataset = [ 72 | torch.arange(start=start_idx, end=start_idx + self.sequence_length) 73 | for start_idx in ( 74 | rng.randint(0, self.vocab_size - self.sequence_length - 2) 75 | for _ in range(total_instances) 76 | ) 77 | ] 78 | 79 | def get_mock_batch(self) -> Dict[str, Any]: 80 | num_instances = self.rank_batch_size // self.sequence_length 81 | input_ids = torch.randint(0, self.vocab_size, (num_instances, self.sequence_length)) 82 | return {"input_ids": input_ids} 83 | 84 | def _iter_batches(self) -> Iterable[Dict[str, Any]]: 85 | assert self._dataset is not None, "did you forget to call 'reshuffle()'?" 86 | 87 | # Get global batch instance indices. Shape: (total batches, instances per batch) 88 | instances_per_batch = self.global_batch_size // self.sequence_length 89 | indices = torch.arange(len(self._dataset)).view(self.total_batches, instances_per_batch) 90 | 91 | # Offset by batches processed so far. 92 | indices = indices[self.batches_processed :] 93 | 94 | for batch_indices in indices: 95 | # Slice batch indices up by rank to create data parallel micro-batches. 96 | local_batch_indices = batch_indices[self.dp_rank :: self.dp_world_size] 97 | yield self.collator([self._dataset[idx] for idx in local_batch_indices]) 98 | -------------------------------------------------------------------------------- /src/scripts/train/OLMoE-1B-7B.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a 1B-7B OLMoE model (mixture of experts). 3 | Run this script without any arguments to see usage info. 4 | """ 5 | 6 | from functools import partial 7 | 8 | from olmo_core.config import DType 9 | from olmo_core.distributed.parallel import DataParallelType 10 | from olmo_core.internal.experiment import CommonComponents, build_config, main 11 | from olmo_core.nn.transformer import TransformerConfig 12 | from olmo_core.optim import AdamWConfig, CosWithWarmup, OptimGroupOverride 13 | from olmo_core.train import TrainerConfig 14 | from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback 15 | from olmo_core.train.train_module import ( 16 | TransformerDataParallelConfig, 17 | TransformerDataParallelWrappingStrategy, 18 | TransformerTrainModuleConfig, 19 | ) 20 | 21 | SEQUENCE_LENGTH = 4096 22 | GLOBAL_BATCH_SIZE = 1024 * SEQUENCE_LENGTH 23 | 24 | 25 | def build_model_config(common: CommonComponents) -> TransformerConfig: 26 | return TransformerConfig.olmoe_1B_7B(vocab_size=common.tokenizer.padded_vocab_size()) 27 | 28 | 29 | def build_train_module_config(common: CommonComponents) -> TransformerTrainModuleConfig: 30 | return TransformerTrainModuleConfig( 31 | rank_microbatch_size=2 * common.max_sequence_length, 32 | max_sequence_length=common.max_sequence_length, 33 | optim=AdamWConfig( 34 | lr=4e-4, 35 | weight_decay=0.1, 36 | betas=(0.9, 0.95), 37 | group_overrides=[ 38 | OptimGroupOverride(params=["embeddings.weight"], opts=dict(weight_decay=0.0)) 39 | ], 40 | fused=True, 41 | ), 42 | compile_model=True, 43 | dp_config=TransformerDataParallelConfig( 44 | name=DataParallelType.fsdp, 45 | param_dtype=DType.bfloat16, 46 | reduce_dtype=DType.float32, 47 | wrapping_strategy=TransformerDataParallelWrappingStrategy.full, 48 | ), 49 | z_loss_multiplier=1e-5, 50 | max_grad_norm=1.0, 51 | scheduler=CosWithWarmup(warmup_steps=2000), 52 | ) 53 | 54 | 55 | def build_trainer_config(common: CommonComponents) -> TrainerConfig: 56 | return ( 57 | TrainerConfig( 58 | save_folder=common.save_folder, 59 | save_overwrite=True, 60 | metrics_collect_interval=10, 61 | cancel_check_interval=1, 62 | ) 63 | .with_callback( 64 | "checkpointer", 65 | CheckpointerCallback( 66 | save_interval=10_000, 67 | ephemeral_save_interval=1000, 68 | save_async=True, 69 | ), 70 | ) 71 | .with_callback( 72 | "comet", 73 | CometCallback( 74 | name=common.run_name, 75 | workspace="ai2", 76 | project="OLMo-core-1B", 77 | enabled=True, 78 | cancel_check_interval=10, 79 | ), 80 | ) 81 | .with_callback( 82 | "wandb", 83 | WandBCallback( 84 | name=common.run_name, 85 | entity="ai2-llm", 86 | project="OLMo-core-1B", 87 | enabled=False, 88 | cancel_check_interval=10, 89 | ), 90 | ) 91 | ) 92 | 93 | 94 | if __name__ == "__main__": 95 | config_builder = partial( 96 | build_config, 97 | global_batch_size=GLOBAL_BATCH_SIZE, 98 | max_sequence_length=SEQUENCE_LENGTH, 99 | model_config_builder=build_model_config, 100 | train_module_config_builder=build_train_module_config, 101 | trainer_config_builder=build_trainer_config, 102 | ) 103 | main(config_builder=config_builder) 104 | -------------------------------------------------------------------------------- /src/scripts/train/small-moe.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a small MoE model (mixture of experts). 3 | Run this script without any arguments to see usage info. 4 | """ 5 | 6 | from functools import partial 7 | 8 | from olmo_core.config import DType 9 | from olmo_core.distributed.parallel import DataParallelType 10 | from olmo_core.internal.experiment import CommonComponents, build_config, main 11 | from olmo_core.nn.transformer import TransformerConfig 12 | from olmo_core.optim import AdamWConfig, CosWithWarmup, OptimGroupOverride 13 | from olmo_core.train import TrainerConfig 14 | from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback 15 | from olmo_core.train.train_module import ( 16 | TransformerDataParallelConfig, 17 | TransformerDataParallelWrappingStrategy, 18 | TransformerTrainModuleConfig, 19 | ) 20 | 21 | SEQUENCE_LENGTH = 4096 22 | GLOBAL_BATCH_SIZE = 512 * 4096 23 | 24 | 25 | def build_model_config(common: CommonComponents) -> TransformerConfig: 26 | return TransformerConfig.small_hybrid_moe(vocab_size=common.tokenizer.padded_vocab_size()) 27 | 28 | 29 | def build_train_module_config(common: CommonComponents) -> TransformerTrainModuleConfig: 30 | return TransformerTrainModuleConfig( 31 | rank_microbatch_size=8 * 4096, 32 | max_sequence_length=common.max_sequence_length, 33 | optim=AdamWConfig( 34 | lr=4e-4, 35 | weight_decay=0.1, 36 | betas=(0.9, 0.95), 37 | group_overrides=[ 38 | OptimGroupOverride(params=["embeddings.weight"], opts=dict(weight_decay=0.0)) 39 | ], 40 | fused=True, 41 | ), 42 | compile_model=True, 43 | dp_config=TransformerDataParallelConfig( 44 | name=DataParallelType.fsdp, 45 | param_dtype=DType.bfloat16, 46 | reduce_dtype=DType.float32, 47 | wrapping_strategy=TransformerDataParallelWrappingStrategy.full, 48 | ), 49 | z_loss_multiplier=1e-5, 50 | max_grad_norm=1.0, 51 | scheduler=CosWithWarmup(warmup_steps=2000), 52 | ) 53 | 54 | 55 | def build_trainer_config(common: CommonComponents) -> TrainerConfig: 56 | return ( 57 | TrainerConfig( 58 | save_folder=common.save_folder, 59 | save_overwrite=True, 60 | metrics_collect_interval=10, 61 | cancel_check_interval=1, 62 | ) 63 | .with_callback( 64 | "checkpointer", 65 | CheckpointerCallback( 66 | save_interval=10_000, 67 | ephemeral_save_interval=1000, 68 | save_async=True, 69 | ), 70 | ) 71 | .with_callback( 72 | "comet", 73 | CometCallback( 74 | name=common.run_name, 75 | workspace="ai2", 76 | project="OLMo-core-1B", 77 | enabled=True, 78 | cancel_check_interval=10, 79 | ), 80 | ) 81 | .with_callback( 82 | "wandb", 83 | WandBCallback( 84 | name=common.run_name, 85 | entity="ai2-llm", 86 | project="OLMo-core-1B", 87 | enabled=False, 88 | cancel_check_interval=10, 89 | ), 90 | ) 91 | ) 92 | 93 | 94 | if __name__ == "__main__": 95 | config_builder = partial( 96 | build_config, 97 | global_batch_size=GLOBAL_BATCH_SIZE, 98 | max_sequence_length=SEQUENCE_LENGTH, 99 | model_config_builder=build_model_config, 100 | train_module_config_builder=build_train_module_config, 101 | trainer_config_builder=build_trainer_config, 102 | include_default_evals=False, 103 | ) 104 | main(config_builder=config_builder) 105 | -------------------------------------------------------------------------------- /src/test/optim/noop_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | 5 | from olmo_core.optim import NoOpConfig, SkipStepAdamWConfig 6 | from olmo_core.testing import DEVICES 7 | 8 | 9 | class TinyModel(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | self.wte = nn.Embedding(128, 8) 13 | self.fc = nn.Linear(8, 8) 14 | 15 | def forward(self, x: torch.Tensor) -> torch.Tensor: 16 | x = self.wte(x) 17 | return self.fc(x) 18 | 19 | 20 | @pytest.mark.parametrize("device", DEVICES) 21 | def test_noop_vs_zero_lr_adamw(device: torch.device): 22 | """Test that NoOpOptimizer produces the same output as SkipStepAdamW with lr=0.""" 23 | torch.manual_seed(42) 24 | 25 | # Create two identical models 26 | model1 = TinyModel().to(device) 27 | model2 = TinyModel().to(device) 28 | model2.load_state_dict(model1.state_dict()) 29 | 30 | # Create optimizers: SkipStepAdamW with lr=0 and NoOpOptimizer 31 | optim1 = SkipStepAdamWConfig( 32 | lr=0.0, 33 | rolling_interval_length=128, 34 | sigma_factor=6, 35 | ).build(model1) 36 | 37 | optim2 = NoOpConfig( 38 | lr=1e-3, # lr doesn't matter for NoOp 39 | rolling_interval_length=128, 40 | sigma_factor=6, 41 | ).build(model2) 42 | 43 | # Run both models for 10 steps 44 | for step in range(10): 45 | # Set the same seed for both models to ensure same input 46 | torch.manual_seed(100 + step) 47 | x = torch.randint(0, 128, (4, 8), device=device) 48 | 49 | # Model 1 with SkipStepAdamW (lr=0) 50 | optim1.zero_grad(set_to_none=True) 51 | out1 = model1(x) 52 | loss1 = out1.sum() 53 | optim1.latest_loss = loss1.detach() 54 | loss1.backward() 55 | optim1.step() 56 | 57 | # Model 2 with NoOpOptimizer 58 | optim2.zero_grad(set_to_none=True) 59 | out2 = model2(x) 60 | loss2 = out2.sum() 61 | optim2.latest_loss = loss2.detach() 62 | loss2.backward() 63 | optim2.step() 64 | 65 | # Verify that the models produce the same output 66 | torch.manual_seed(200 + step) 67 | test_input = torch.randint(0, 128, (2, 4), device=device) 68 | 69 | with torch.no_grad(): 70 | test_out1 = model1(test_input) 71 | test_out2 = model2(test_input) 72 | 73 | assert torch.allclose( 74 | test_out1, test_out2, atol=1e-6 75 | ), f"Step {step}: Outputs differ between SkipStepAdamW(lr=0) and NoOpOptimizer" 76 | 77 | # Verify final model parameters are identical 78 | for (name1, param1), (name2, param2) in zip( 79 | model1.named_parameters(), model2.named_parameters() 80 | ): 81 | assert name1 == name2 82 | assert torch.equal(param1, param2), f"Parameter {name1} differs between models" 83 | 84 | 85 | @pytest.mark.parametrize("device", DEVICES) 86 | def test_noop_no_parameter_updates(device: torch.device): 87 | """Test that NoOpOptimizer doesn't update any parameters.""" 88 | torch.manual_seed(42) 89 | 90 | model = TinyModel().to(device) 91 | optim = NoOpConfig(rolling_interval_length=2, sigma_factor=6).build(model) 92 | 93 | # Store initial parameters 94 | initial_params = {name: param.clone() for name, param in model.named_parameters()} 95 | 96 | # Run for 10 steps 97 | for step in range(10): 98 | optim.zero_grad(set_to_none=True) 99 | x = torch.randint(0, 128, (4, 8), device=device) 100 | out = model(x) 101 | loss = out.sum() 102 | optim.latest_loss = loss.detach() 103 | loss.backward() 104 | optim.step() 105 | 106 | # Verify parameters haven't changed 107 | for name, param in model.named_parameters(): 108 | assert torch.equal( 109 | param, initial_params[name] 110 | ), f"Parameter {name} was modified by NoOpOptimizer" 111 | -------------------------------------------------------------------------------- /src/test/nn/transformer/block_test.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Type 2 | 3 | import pytest 4 | import torch 5 | from torch.distributed.tensor import Shard, init_device_mesh 6 | 7 | from olmo_core.distributed.checkpoint import ( 8 | load_model_and_optim_state, 9 | save_model_and_optim_state, 10 | ) 11 | from olmo_core.distributed.utils import get_rank, get_world_size 12 | from olmo_core.nn.attention import AttentionConfig, AttentionType 13 | from olmo_core.nn.feed_forward import FeedForwardConfig 14 | from olmo_core.nn.layer_norm import LayerNormConfig 15 | from olmo_core.nn.transformer.block import ( 16 | ReorderedNormTransformerBlock, 17 | TransformerBlock, 18 | ) 19 | from olmo_core.testing import BACKENDS, run_distributed_test 20 | from olmo_core.utils import get_default_device, seed_all 21 | 22 | 23 | def _build_block( 24 | block_cls: Type[TransformerBlock], 25 | *, 26 | d_model: int, 27 | init_device: str, 28 | attn_kwargs: Dict[str, Any], 29 | ) -> TransformerBlock: 30 | attn_cfg = AttentionConfig(**attn_kwargs) 31 | ff_cfg = FeedForwardConfig(hidden_size=4 * d_model) 32 | ln_cfg = LayerNormConfig() 33 | return block_cls( 34 | d_model=d_model, 35 | block_idx=0, 36 | n_layers=1, 37 | attention=attn_cfg, 38 | feed_forward=ff_cfg, 39 | layer_norm=ln_cfg, 40 | init_device=init_device, 41 | ) 42 | 43 | 44 | def _run_tensor_parallel_block( 45 | checkpoint_dir: str, 46 | inputs_path: str, 47 | outputs_path: str, 48 | block_cls: Type[TransformerBlock], 49 | d_model: int, 50 | attn_kwargs: Dict[str, Any], 51 | ): 52 | device = get_default_device() 53 | mesh = init_device_mesh(device.type, (get_world_size(),), mesh_dim_names=("tp",)) 54 | 55 | block = _build_block( 56 | block_cls, d_model=d_model, init_device=device.type, attn_kwargs=attn_kwargs 57 | ) 58 | 59 | # Shard sequence dim in/out like the transformer model does. 60 | block.apply_tp(mesh["tp"], input_layout=Shard(1)) 61 | load_model_and_optim_state(checkpoint_dir, block) 62 | 63 | x = torch.load(inputs_path, map_location=device) 64 | rank, world_size = get_rank(), get_world_size() 65 | chunk = x.size(1) // world_size 66 | x_local = x[:, rank * chunk : (rank + 1) * chunk, :] 67 | y_local = block(x_local) 68 | 69 | # Backward to exercise graph in TP mode. 70 | y_local.sum().backward() 71 | 72 | y_ref = torch.load(outputs_path, map_location=device) 73 | y_ref_local = y_ref[:, rank * chunk : (rank + 1) * chunk, :] 74 | torch.testing.assert_close(y_ref_local, y_local.to_local()) 75 | 76 | 77 | @pytest.mark.parametrize("backend", BACKENDS) 78 | @pytest.mark.parametrize( 79 | "attn_kwargs", 80 | [ 81 | pytest.param(dict(n_heads=8), id="default"), 82 | pytest.param(dict(n_heads=8, rope=None, bias=False), id="no-bias"), 83 | ], 84 | ) 85 | @pytest.mark.parametrize("block_cls", [TransformerBlock, ReorderedNormTransformerBlock]) 86 | def test_tensor_parallel_transformer_block( 87 | backend: str, block_cls: Type[TransformerBlock], attn_kwargs: Dict[str, Any], tmp_path 88 | ): 89 | device = torch.device("cuda") if "nccl" in backend else torch.device("cpu") 90 | 91 | seed_all(0) 92 | d_model = 128 93 | attn_kwargs = {**attn_kwargs, "name": AttentionType.default, "use_flash": False} 94 | 95 | block = _build_block( 96 | block_cls, d_model=d_model, init_device=device.type, attn_kwargs=attn_kwargs 97 | ) 98 | 99 | bs, seq_len = 2, 64 100 | x = torch.randn(bs, seq_len, d_model, device=device) 101 | y = block(x) 102 | 103 | outputs_path = tmp_path / "block_y.pt" 104 | torch.save(y, outputs_path) 105 | inputs_path = tmp_path / "block_x.pt" 106 | torch.save(x, inputs_path) 107 | checkpoint_dir = tmp_path / "checkpoint" 108 | save_model_and_optim_state(checkpoint_dir, block) 109 | 110 | run_distributed_test( 111 | _run_tensor_parallel_block, 112 | backend=backend, 113 | start_method="spawn", 114 | func_args=(checkpoint_dir, inputs_path, outputs_path, block_cls, d_model, attn_kwargs), 115 | ) 116 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY : checks 2 | checks : style-check lint-check type-check 3 | 4 | .PHONY : style-check 5 | style-check : 6 | @echo "======== running isort... ========" 7 | @isort --check . 8 | @echo "======== running black... ========" 9 | @black --check . 10 | 11 | .PHONY : lint-check 12 | lint-check : 13 | @echo "======== running ruff... =========" 14 | @ruff check . 15 | 16 | .PHONY : type-check 17 | type-check : 18 | @echo "======== running mypy... =========" 19 | @mypy src/ 20 | 21 | .PHONY : style 22 | style: 23 | @echo "======== formatting with isort... ========" 24 | @isort . 25 | @echo "======== formatting with black... ========" 26 | @black . 27 | 28 | .PHONY : docs 29 | docs : 30 | rm -rf docs/build/ 31 | sphinx-autobuild -b html --watch src/olmo_core/ --watch README.md docs/source/ docs/build/ 32 | 33 | .PHONY : build 34 | build : 35 | rm -rf *.egg-info/ 36 | python -m build 37 | 38 | #################################################################################################### 39 | # Docker build 40 | #################################################################################################### 41 | 42 | #-----------------# 43 | # Build variables # 44 | #-----------------# 45 | 46 | # NOTE: When upgrading dependency versions (like for torch) make sure: 47 | # * The corresponding versions specified in 'pyproject.toml' include the new version. 48 | # * The versions installed in '.github/actions/setup-venv/action.yml' match if necessary. 49 | # NOTE: See https://hub.docker.com/r/nvidia/cuda/tags?name=devel-ubuntu22.04 for available CUDA versions. 50 | CUDA_VERSION = 12.8.1 51 | CUDA_VERSION_PATH=cu$(shell echo $(CUDA_VERSION) | cut -d"." -f1-2 | tr -d .) 52 | PYTHON_VERSION = 3.11 53 | TORCH_VERSION = 2.8.0 54 | TORCH_VERSION_SHORT = $(shell echo $(TORCH_VERSION) | tr -d .) 55 | INSTALL_CHANNEL = whl 56 | GROUPED_GEMM_VERSION = "grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@main" 57 | FLASH_ATTN_VERSION = 2.8.2 58 | FLASH_ATTN_3_SHA = "1ceaa984b2f348caea18b39a98458d33b4ea7a09" 59 | TE_VERSION = 2.6.0.post1 60 | RING_FLASH_ATTN_VERSION = 0.1.8 61 | LIGER_KERNEL_VERSION = 0.6.2 62 | 63 | #--------------# 64 | # Build naming # 65 | #--------------# 66 | 67 | VERSION = $(shell python src/olmo_core/version.py) 68 | VERSION_SHORT = $(shell python src/olmo_core/version.py short) 69 | IMAGE_SUFFIX = $(shell date "+%Y-%m-%d") 70 | IMAGE_TAG = tch$(TORCH_VERSION_SHORT)$(CUDA_VERSION_PATH)-$(IMAGE_SUFFIX) 71 | 72 | .PHONY : docker-image 73 | docker-image : 74 | docker build -f src/Dockerfile \ 75 | --build-arg BUILDKIT_INLINE_CACHE=1 \ 76 | --build-arg CUDA_VERSION=$(CUDA_VERSION) \ 77 | --build-arg CUDA_VERSION_PATH=$(CUDA_VERSION_PATH) \ 78 | --build-arg PYTHON_VERSION=$(PYTHON_VERSION) \ 79 | --build-arg TORCH_VERSION=$(TORCH_VERSION) \ 80 | --build-arg INSTALL_CHANNEL=$(INSTALL_CHANNEL) \ 81 | --build-arg GROUPED_GEMM_VERSION=$(GROUPED_GEMM_VERSION) \ 82 | --build-arg FLASH_ATTN_VERSION=$(FLASH_ATTN_VERSION) \ 83 | --build-arg FLASH_ATTN_3_SHA=$(FLASH_ATTN_3_SHA) \ 84 | --build-arg TE_VERSION=$(TE_VERSION) \ 85 | --build-arg RING_FLASH_ATTN_VERSION=$(RING_FLASH_ATTN_VERSION) \ 86 | --build-arg LIGER_KERNEL_VERSION=$(LIGER_KERNEL_VERSION) \ 87 | --target release \ 88 | -t olmo-core:$(IMAGE_TAG) . 89 | echo "Built image 'olmo-core:$(IMAGE_TAG)', size: $$(docker inspect -f '{{ .Size }}' olmo-core:$(IMAGE_TAG) | numfmt --to=si)" 90 | 91 | .PHONY : ghcr-image 92 | ghcr-image : docker-image 93 | docker tag olmo-core:$(IMAGE_TAG) ghcr.io/allenai/olmo-core:$(IMAGE_TAG) 94 | docker push ghcr.io/allenai/olmo-core:$(IMAGE_TAG) 95 | docker tag olmo-core:$(IMAGE_TAG) ghcr.io/allenai/olmo-core:latest 96 | docker push ghcr.io/allenai/olmo-core:latest 97 | 98 | BEAKER_WORKSPACE = ai2/OLMo-core 99 | BEAKER_USER = $(shell beaker account whoami --format=json | jq -r '.[0].name') 100 | 101 | .PHONY : beaker-image 102 | beaker-image : docker-image 103 | ./src/scripts/beaker/create_beaker_image.sh olmo-core:$(IMAGE_TAG) olmo-core-$(IMAGE_TAG) $(BEAKER_WORKSPACE) 104 | 105 | .PHONY : get-beaker-workspace 106 | get-beaker-workspace : 107 | @echo $(BEAKER_WORKSPACE) 108 | 109 | .PHONY : get-full-beaker-image-name 110 | get-full-beaker-image-name : 111 | @./src/scripts/beaker/get_full_image_name.sh $(IMAGE_TAG) $(BEAKER_WORKSPACE) 112 | --------------------------------------------------------------------------------