├── .devcontainer ├── README.md └── nvidia │ ├── devcontainer.json │ └── post-create.sh ├── .dockerignore ├── .flake8 ├── .github ├── dependabot.yml └── workflows │ ├── docker-base-image.yaml │ ├── gpt2_small_itest.yaml │ ├── launch_small_fast.yaml │ ├── publish_dev.yaml │ ├── run_entry_tests.yaml │ ├── run_pre_commit.yaml │ ├── run_ray_tests.yaml │ ├── run_tests.yaml │ └── tpu_unit_tests.yaml ├── .gitignore ├── .idea └── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── config ├── backpack.yaml ├── backpack_nano.yaml ├── data │ ├── dclm_gpt_neo.yaml │ ├── dolma_llama.yaml │ ├── dolma_llama_euwest.yaml │ ├── dolma_olmo_paloma.yaml │ ├── fineweb_llama_txt.yaml │ ├── marin_dolma.yaml │ ├── openwebtext_source.yaml │ ├── pile_mixture.yaml │ ├── pile_source_old.yaml │ ├── pubmed_source.yaml │ ├── redpajama_1b_source.yaml │ ├── redpajama_1t_source.yaml │ ├── rpv1_llama.yaml │ └── wikitext_source.yaml ├── gemma_2b.yaml ├── gpt2_1536.yaml ├── gpt2_1536_sophiah.yaml ├── gpt2_20b.yaml ├── gpt2_7b.yaml ├── gpt2_hyena_nano.yaml ├── gpt2_large.yaml ├── gpt2_large_sophia_h.yaml ├── gpt2_medium.yaml ├── gpt2_micro.yaml ├── gpt2_nano.yaml ├── gpt2_nano_fp8.yaml ├── gpt2_nano_harness.yaml ├── gpt2_nano_mixture.yaml ├── gpt2_nano_skip.yaml ├── gpt2_nano_tb.yaml ├── gpt2_small.yaml ├── gpt2_small_fast.yaml ├── gpt2_small_fast_batch_schedule.yaml ├── gpt2_small_fast_ema.yaml ├── gpt2_small_fast_eval.yaml ├── gpt2_small_fast_fp8.yaml ├── gpt2_small_fast_mix.yaml ├── gpt2_small_fast_mix_chat.yaml ├── gpt2_small_fast_pile.yaml ├── gpt2_small_fast_public.yaml ├── gpt2_small_fast_skip.yaml ├── gpt2_small_fast_sophia_h.yaml ├── gpt2_small_fast_sophiah.yaml ├── gpt2_small_fast_wiki.yaml ├── gpt2_small_itest.yaml ├── gpt2_small_pile.yaml ├── gpt2_small_pile_mixture.yaml ├── gpt2_small_sophiah.yaml ├── gpt2_xl.yaml ├── harness │ ├── eval_llama3.yaml │ └── harness_nano.yaml ├── llama2_3b_pretrain.yaml ├── llama2_7b.yaml ├── llama2_7b_continued.yaml ├── llama2_nano.yaml ├── llama2_small_fast_mix.yaml ├── llama3_small_fast.yaml ├── llama_1b_with_fineweb_txt.yaml ├── llama_1b_with_olmo_config.yaml ├── llama_7b_tulu.yaml ├── llama_7b_with_dclm.yaml ├── llama_7b_with_olmo_config.yaml ├── llama_7b_with_olmo_config_euwest4.yaml ├── llama_7b_with_olmo_config_uswest4.yaml ├── llama_small_fast.yaml ├── llama_small_fast_remat.yaml ├── lora │ └── mpt_biomed.yaml ├── lora_llama2.yaml ├── mistral_7b.yaml ├── mixtral_8x7b.yaml ├── olmo │ └── olmo_7b_repro.yaml ├── olmo2_sft.yaml ├── optim │ ├── sophia-h_large.yaml │ ├── sophia-h_medium.yaml │ ├── sophia-h_small.yaml │ └── sophia-h_xl.yaml ├── sft_hf_llama3_ckpt.yaml ├── sft_llama3.1_tulu3.yaml ├── sft_llama3.1_tulu3_mixture.yaml ├── sft_llama3_mixture.yaml ├── sft_llama3_mixture_generic.yaml ├── sft_llama3_mixture_reason.yaml ├── sft_llama3_openthoughts.yaml ├── sft_tootsie_mixture.yaml ├── train_lm_llama3_tulu_sft.yaml └── whisper_tiny_librispeech.yaml ├── docker ├── nvidia │ └── Dockerfile └── tpu │ ├── Dockerfile.base │ ├── Dockerfile.cluster │ └── Dockerfile.incremental ├── docs ├── Fine-Tuning.md ├── Getting-Started-GPU.md ├── Getting-Started-TPU-VM.md ├── Getting-Started-Training.md ├── Hardware-Agnostic-Training.md ├── Installation.md ├── Levanter-1.0-Release.md ├── LoRA.md ├── Performance-Guide.md ├── Training-On-Your-Data.md ├── css │ ├── custom.css │ ├── friendly.css │ └── mkdocstrings.css ├── design │ ├── Data-Loader-Design.md │ ├── Multiple-Data-Mixture.md │ └── Ray-Job-Manager.md ├── dev │ ├── GPU-Docker-Dev.md │ └── Port-Models.md ├── faq.md ├── figures │ ├── bitwise_repro_curve.png │ ├── data_parallel_mesh.png │ ├── data_parallel_mesh_replicated.png │ ├── device_mesh_1d.png │ ├── device_mesh_1d_zero.png │ ├── device_mesh_2d.png │ ├── device_mesh_2d_batch_partitioned.png │ ├── device_mesh_2d_data_replicated.png │ ├── device_mesh_2d_data_replicated_mlp_partitioned.png │ ├── device_mesh_2d_intermediate_fully_partitioned.png │ ├── device_mesh_2d_zero.png │ ├── finetune_func_cm_full_weight.png │ ├── finetune_func_cm_lora.png │ ├── helm-gsm8k-results.png │ ├── helm-instance-example.png │ ├── lora-diagram.png │ ├── palm_mfu_table.png │ ├── resumed_curve.png │ ├── stopped_curve.png │ └── token_probabilities.mov ├── guides │ ├── Direct-Cache-Construction.md │ ├── Training-Data-Guide.md │ └── Training-On-Audio-Data.md ├── index.md ├── javascripts │ └── mathjax.js ├── reference │ ├── Configuration.md │ ├── Data-Formats.md │ └── Trackers.md ├── requirements.txt └── tutorials │ └── Fine-Tuning-Semantic-Parsing.md ├── examples ├── alpaca-lora │ ├── alpaca-lora-llama2.yaml │ ├── alpaca-lora.yaml │ ├── alpaca_lora.py │ ├── code-alpaca-lora.yaml │ ├── hf_lora_inference.py │ └── peft_inference.ipynb ├── alpaca │ ├── alpaca-llama2.yaml │ ├── alpaca.py │ └── alpaca.yaml ├── gsm8k-lora │ ├── gsm8k-llama2.yaml │ └── gsm8k_lora.py └── sft │ ├── alpaca-llama-sft.yaml │ ├── alpaca-llama.yaml │ ├── dolly-llama.yaml │ ├── oasst-llama.yaml │ └── tulu-llama-sft.yaml ├── infra ├── babysit-tpu-vm ├── babysit-tpu-vm.sh ├── cluster │ ├── job-cluster.yaml │ └── push_cluster_docker.sh ├── helpers │ ├── gen-id.sh │ ├── parse-tpu-creation-args.sh │ ├── setup-tpu-vm-tests.sh │ └── setup-tpu-vm.sh ├── launch.py ├── launch.sh ├── launch_on_ray.py ├── push_docker.py ├── run-slurm.sh ├── run.sh └── spin-up-vm.sh ├── mkdocs.yml ├── pyproject.toml ├── scripts ├── clean_old_checkpoints.py ├── gcs_bulk_delete.py ├── launch_gpt2_small_fast_gpu.sh ├── launch_gpt2_small_fast_supervised_tpu.sh ├── launch_gpt2_small_fast_tpu.sh ├── launch_gpt2_small_itest_tpu.sh ├── loss_history.py └── preproc │ └── split-pile-shards.py ├── src └── levanter │ ├── __init__.py │ ├── analysis │ ├── __init__.py │ ├── entropy.py │ ├── tree_stats.py │ └── visualization.py │ ├── callbacks │ ├── __init__.py │ ├── _core.py │ ├── _metrics.py │ └── watch.py │ ├── checkpoint.py │ ├── compat │ ├── __init__.py │ └── hf_checkpoints.py │ ├── config.py │ ├── data │ ├── __init__.py │ ├── _preprocessor.py │ ├── _prp.py │ ├── audio.py │ ├── dataset.py │ ├── loader.py │ ├── metrics_monitor.py │ ├── mixture.py │ ├── packing.py │ ├── passthrough_tokenizer.py │ ├── permutation.py │ ├── sharded_datasource.py │ ├── text.py │ └── utils.py │ ├── distributed.py │ ├── eval.py │ ├── eval_harness.py │ ├── grad_accum.py │ ├── infra │ ├── __init__.py │ ├── cli_helpers.py │ ├── docker.py │ ├── ray_tpu.py │ └── tpus.py │ ├── lora.py │ ├── main │ ├── cache_dataset.py │ ├── eval_lm.py │ ├── export_lm_to_hf.py │ ├── lora_lm.py │ ├── sft.py │ ├── sft_mixture.py │ ├── train_asr.py │ ├── train_lm.py │ └── viz_logprobs.py │ ├── models │ ├── __init__.py │ ├── asr_model.py │ ├── attention.py │ ├── backpack.py │ ├── flash_attention.py │ ├── gemma.py │ ├── gpt2.py │ ├── gpt2_hyena.py │ ├── hyena.py │ ├── llama.py │ ├── lm_model.py │ ├── loss.py │ ├── mistral.py │ ├── mixtral.py │ ├── olmo.py │ ├── qwen.py │ ├── rotary.py │ └── whisper.py │ ├── optim │ ├── __init__.py │ ├── config.py │ ├── model_averaging.py │ ├── skipstep.py │ ├── sophia.py │ └── util.py │ ├── schedule.py │ ├── shapes.py │ ├── store │ ├── __init__.py │ ├── cache.py │ ├── jagged_array.py │ └── tree_store.py │ ├── tensorstore_serialization.py │ ├── tracker │ ├── __init__.py │ ├── helpers.py │ ├── histogram.py │ ├── tensorboard.py │ ├── tracker.py │ ├── tracker_fns.py │ └── wandb.py │ ├── trainer.py │ ├── trainer_state.py │ ├── utils │ ├── __init__.py │ ├── activation.py │ ├── background_iterable.py │ ├── cloud_utils.py │ ├── datetime_utils.py │ ├── flop_utils.py │ ├── fsspec_utils.py │ ├── hf_utils.py │ ├── index.py │ ├── jax_utils.py │ ├── json_utils.py │ ├── logging.py │ ├── py_utils.py │ ├── ray_utils.py │ ├── stat_utils.py │ ├── thread_utils.py │ ├── tree_utils.py │ └── types.py │ └── visualization.py └── tests ├── __init__.py ├── data └── hero_data.npy ├── gpt2_test.py ├── gpt2_tokenizer_config.json ├── requirements.txt ├── test_attention.py ├── test_audio.py ├── test_background_iterable.py ├── test_backpack.py ├── test_checkpoint.py ├── test_config.py ├── test_datetime_utils.py ├── test_distributed.py ├── test_eval_harness.py ├── test_eval_lm.py ├── test_export_to_hf.py ├── test_flash_attention.py ├── test_gemma.py ├── test_grad_accum.py ├── test_hf_checkpoints.py ├── test_hf_gpt2_serialize.py ├── test_hf_utils.py ├── test_histogram.py ├── test_hyena.py ├── test_jagged_array.py ├── test_jax_utils.py ├── test_llama.py ├── test_llama3.py ├── test_logging.py ├── test_lora.py ├── test_loss.py ├── test_mistral.py ├── test_mixtral.py ├── test_mixture.py ├── test_new_cache.py ├── test_new_loader.py ├── test_newdataset.py ├── test_olmo.py ├── test_optimizer_config.py ├── test_packing.py ├── test_prp.py ├── test_py_utils.py ├── test_qwen2.py ├── test_scheduler.py ├── test_sft.py ├── test_sharded_dataset.py ├── test_skip_step.py ├── test_sophia.py ├── test_supervised.py ├── test_tensorboard.py ├── test_tensorstore_serialization.py ├── test_text.py ├── test_torch_serialization.py ├── test_tracker.py ├── test_train_asr.py ├── test_train_lm.py ├── test_tree_store.py ├── test_utils.py ├── test_varying_mixture.py ├── test_viz_lm.py ├── test_weight_decay_mask.py ├── tiny_test_corpus.py └── whisper_test.py /.devcontainer/README.md: -------------------------------------------------------------------------------- 1 | See https://containers.dev/ 2 | -------------------------------------------------------------------------------- /.devcontainer/nvidia/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Levanter on NVIDIA", 3 | "image": "nvcr.io/nvidia/jax:25.01-py3", 4 | "runArgs": [ 5 | "--gpus", 6 | "all", 7 | "--network", 8 | "host", 9 | "--shm-size", 10 | "16G" 11 | ], 12 | "postCreateCommand": ".devcontainer/nvidia/post-create.sh" 13 | } 14 | -------------------------------------------------------------------------------- /.devcontainer/nvidia/post-create.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -o errexit 3 | set -o nounset 4 | 5 | # Install uv 6 | curl -LsSf https://astral.sh/uv/install.sh | sh 7 | source $HOME/.local/bin/env 8 | 9 | # Install Levanter in editable mode. 10 | # the system flags are only OK because we are using a dedicated container. 11 | uv pip install --system --break-system-packages --editable ".[test]" 12 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .git 2 | 3 | scratch 4 | cache 5 | new-cache 6 | wandb 7 | checkpoints 8 | 9 | # Byte-compiled / optimized / DLL files 10 | **/__pycache__/ 11 | **/*$py.class 12 | 13 | # build stuff 14 | dist/ 15 | build/ 16 | 17 | # Unit test / coverage reports 18 | htmlcov/ 19 | .tox/ 20 | .nox/ 21 | .coverage 22 | .coverage.* 23 | .cache 24 | nosetests.xml 25 | coverage.xml 26 | *.cover 27 | *.py,cover 28 | .hypothesis/ 29 | .pytest_cache/ 30 | 31 | # Translations 32 | *.mo 33 | *.pot 34 | 35 | # Django stuff: 36 | *.log 37 | local_settings.py 38 | db.sqlite3 39 | db.sqlite3-journal 40 | 41 | # Flask stuff: 42 | instance/ 43 | .webassets-cache 44 | 45 | # Scrapy stuff: 46 | .scrapy 47 | 48 | # Sphinx documentation 49 | docs/_build/ 50 | docs/figures/ 51 | 52 | # PyBuilder 53 | target/ 54 | 55 | # Jupyter Notebook 56 | .ipynb_checkpoints 57 | 58 | # IPython 59 | profile_default/ 60 | ipython_config.py 61 | 62 | # pyenv 63 | .python-version 64 | 65 | # pipenv 66 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 67 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 68 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 69 | # install all needed dependencies. 70 | #Pipfile.lock 71 | 72 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 73 | __pypackages__/ 74 | 75 | # Celery stuff 76 | celerybeat-schedule 77 | celerybeat.pid 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # Environments 83 | .env 84 | .venv 85 | env/ 86 | venv/ 87 | ENV/ 88 | env.bak/ 89 | venv.bak/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | .dmypy.json 104 | dmypy.json 105 | 106 | # Pyre type checker 107 | .pyre/ 108 | 109 | # JetBrains 110 | .idea/ 111 | 112 | # dataset cache files 113 | **/*.parquet 114 | **/ledger.json 115 | 116 | *.jaxpr 117 | 118 | # local execution commands 119 | local_*.sh 120 | .aider* 121 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = .git 3 | max-line-length = 120 4 | ignore = E203, E501, W503, W605, F821, E266, E731 5 | per-file-ignores = 6 | */__init__.py: F401 7 | examples/*.py: E402 8 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "weekly" 12 | -------------------------------------------------------------------------------- /.github/workflows/docker-base-image.yaml: -------------------------------------------------------------------------------- 1 | name: Build and Push Docker TPU Images 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Run Tests"] 6 | types: 7 | - completed 8 | branches: [main] 9 | workflow_dispatch: 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - name: Checkout code 17 | uses: actions/checkout@v3 18 | 19 | - name: Set up Docker Buildx 20 | uses: docker/setup-buildx-action@v2 21 | 22 | - name: Cache Docker layers 23 | uses: actions/cache@v3 24 | with: 25 | path: /tmp/.buildx-cache 26 | key: ${{ runner.os }}-buildx-${{ github.sha }} 27 | restore-keys: | 28 | ${{ runner.os }}-buildx- 29 | 30 | - name: Get current date 31 | id: date 32 | run: echo "DATE=$(date +'%Y%m%d')" >> $GITHUB_ENV 33 | 34 | - name: Login to GitHub Container Registry 35 | uses: docker/login-action@v2 36 | with: 37 | registry: ghcr.io 38 | username: ${{ github.actor }} 39 | password: ${{ secrets.DOCKER_PUSH_TOKEN }} 40 | 41 | - name: Build and Push Docker image 42 | run: | 43 | docker buildx build --file docker/tpu/Dockerfile.base --tag ghcr.io/${{ github.repository_owner }}/levanter-base:latest --tag ghcr.io/${{ github.repository_owner }}/levanter-base:${{ env.DATE }} --push . 44 | 45 | - name: Build and Push Incremental Docker image 46 | run: | 47 | docker buildx build --file docker/tpu/Dockerfile.incremental --tag ghcr.io/${{ github.repository_owner }}/levanter-tpu:latest --tag ghcr.io/${{ github.repository_owner }}/levanter-tpu:${{ env.DATE }} --push . 48 | -------------------------------------------------------------------------------- /.github/workflows/gpt2_small_itest.yaml: -------------------------------------------------------------------------------- 1 | name: GPT-2 Small Integration Test 2 | 3 | on: 4 | push: 5 | branches: 6 | - main # Trigger on merges/pushes to main 7 | workflow_dispatch: # Allow manual triggering 8 | 9 | jobs: 10 | integration_test: 11 | runs-on: ubuntu-latest 12 | env: 13 | TPU_ZONE: "us-central2-b" # Matching the launch script 14 | WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} 15 | 16 | steps: 17 | - name: Checkout code 18 | uses: actions/checkout@v3 # Use a more recent version 19 | 20 | - name: Set up Google Cloud SDK 21 | uses: google-github-actions/setup-gcloud@v1 22 | with: 23 | project_id: ${{ secrets.GCP_PROJECT_ID }} 24 | 25 | - name: Authenticate to Google Cloud 26 | uses: google-github-actions/auth@v1 27 | with: 28 | credentials_json: ${{ secrets.GCP_SA_KEY }} 29 | 30 | - name: Configure Google Cloud 31 | run: | 32 | gcloud config set project ${{ secrets.GCP_PROJECT_ID }} 33 | 34 | - name: Run GPT-2 Small Integration Test 35 | run: | 36 | # The launch script handles TPU creation and deletion 37 | bash scripts/launch_gpt2_small_itest_tpu.sh 38 | env: 39 | USER: ci-runner-${{ github.run_id }} # Set a unique user for TPU naming in the script 40 | # Ensure other necessary env vars for the script are set if any 41 | 42 | # infra/launch.py with --foreground should handle cleanup. 43 | # If not, a manual cleanup step would be needed here similar to tpu_unit_tests.yaml, 44 | # but it would need to know the TPU_NAME used by launch.py. 45 | # For now, relying on launch.py for cleanup. 46 | # - name: Cleanup TPU 47 | # if: ${{ always() }} 48 | # run: | 49 | # # This would require launch.py to output the TPU name or use a predictable one 50 | # # For example, if TPU_NAME was consistently $(whoami)-levanter-itest-32 as in the script 51 | # TPU_NAME_IN_SCRIPT="ci-runner-${{ github.run_id }}-levanter-itest-32" 52 | # echo "Attempting to delete TPU: $TPU_NAME_IN_SCRIPT in zone ${TPU_ZONE}" 53 | # gcloud compute tpus tpu-vm delete $TPU_NAME_IN_SCRIPT --zone ${TPU_ZONE} --quiet || echo "TPU deletion failed or TPU did not exist." 54 | -------------------------------------------------------------------------------- /.github/workflows/launch_small_fast.yaml: -------------------------------------------------------------------------------- 1 | name: Launch Llama 2 Small Fast 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Build and Push Docker TPU Images"] 6 | types: 7 | - completed 8 | branches: [main, "experiment/*"] 9 | # pull_request: 10 | workflow_dispatch: 11 | 12 | jobs: 13 | test: 14 | if: (github.event.pull_request.head.repo.full_name == github.repository) 15 | runs-on: ubuntu-latest 16 | env: 17 | TPU_ZONE: "us-central2-b" 18 | TPU_TYPE: "v4-32" 19 | 20 | steps: 21 | - name: Checkout code 22 | uses: actions/checkout@v2 23 | 24 | - name: Set up Google Cloud SDK 25 | uses: google-github-actions/setup-gcloud@v1 26 | with: 27 | project_id: ${{ secrets.GCP_PROJECT_ID }} 28 | 29 | - name: Authenticate to Google Cloud 30 | uses: google-github-actions/auth@v1 31 | with: 32 | credentials_json: ${{ secrets.GCP_SA_KEY }} 33 | 34 | - name: Configure Google Cloud 35 | run: | 36 | gcloud config set project ${{ secrets.GCP_PROJECT_ID }} 37 | REGION=${TPU_ZONE%-*} 38 | echo "$REGION" 39 | gcloud auth configure-docker $REGION-docker.pkg.dev 40 | 41 | - name: Install locally 42 | run: | 43 | python -m pip install --upgrade pip 44 | pip install -e .[test] "jax[cpu]==0.4.38" 45 | 46 | - name: Launch Small Fast TPU Train LM job 47 | run: | 48 | export TPU_NAME=small-fast-${{ github.run_id }} 49 | export WANDB_API_KEY=${{ secrets.WANDB_API_KEY }} 50 | export RUN_ID=small_fast_${{ github.run_id }} 51 | export HF_TOKEN=${{ secrets.HF_TOKEN }} 52 | 53 | cat > .config <> $GITHUB_ENV 31 | echo "Calculated version with build number: $FULL_VERSION" 32 | - name: Update pyproject.toml version 33 | run: | 34 | # replace the version in pyproject.toml 35 | sed -i "s/version = \".*\"/version = \"$FULL_VERSION\"/g" pyproject.toml 36 | 37 | - name: Build package 38 | run: | 39 | python -m pip install --upgrade pip 40 | pip install build 41 | python -m build 42 | 43 | - name: Upload package 44 | uses: actions/upload-artifact@v4 45 | with: 46 | name: package 47 | path: dist/ 48 | 49 | 50 | # cf https://test.pypi.org/manage/project/levanter/settings/publishing/ 51 | publish-dev: 52 | runs-on: ubuntu-latest 53 | needs: 54 | - build-package 55 | permissions: 56 | id-token: write 57 | steps: 58 | - name: Retrieve release distributions 59 | uses: actions/download-artifact@v4 60 | with: 61 | name: package 62 | path: dist/ 63 | 64 | - name: Publish release distributions to PyPI 65 | uses: pypa/gh-action-pypi-publish@release/v1 66 | 67 | 68 | -------------------------------------------------------------------------------- /.github/workflows/run_entry_tests.yaml: -------------------------------------------------------------------------------- 1 | name: Run entry tests 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | if: github.event_name == 'push' || (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name != github.repository) 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: ["3.10"] 12 | jax-version: ["0.5.0"] 13 | 14 | steps: 15 | - uses: actions/checkout@v3 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install flake8 pytest 24 | pip install .[test] "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" 25 | pip install soundfile librosa 26 | - name: Run entry tests with pytest 27 | run: | 28 | JAX_PLATFORMS="cpu" PYTHONPATH=$(pwd)/tests:$(pwd)/src:$(pwd):. pytest -s tests -m entry 29 | -------------------------------------------------------------------------------- /.github/workflows/run_pre_commit.yaml: -------------------------------------------------------------------------------- 1 | name: Pre-Commit 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | if: github.event_name == 'push' || (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name != github.repository) 8 | 9 | runs-on: ubuntu-latest 10 | strategy: 11 | matrix: 12 | python-version: ["3.10"] 13 | jax-version: ["0.4.38"] 14 | 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install flake8 pytest pre-commit 25 | pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" 26 | - name: "Run Pre-commit" 27 | run: | 28 | pre-commit run --all-files --show-diff-on-failure 29 | 30 | -------------------------------------------------------------------------------- /.github/workflows/run_ray_tests.yaml: -------------------------------------------------------------------------------- 1 | name: Run tests that use ray 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: ["3.10"] 12 | jax-version: ["0.4.38"] 13 | 14 | steps: 15 | - uses: actions/checkout@v3 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install flake8 pytest 24 | pip install .[test] "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" 25 | pip install soundfile librosa 26 | - name: Run ray tests with pytest 27 | run: | 28 | PYTHONPATH=$(pwd)/tests:$(pwd)/src:$(pwd):. pytest tests -m ray 29 | -------------------------------------------------------------------------------- /.github/workflows/run_tests.yaml: -------------------------------------------------------------------------------- 1 | name: Run tests 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | if: github.event_name == 'push' || (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name != github.repository) 8 | 9 | runs-on: ubuntu-latest 10 | strategy: 11 | matrix: 12 | python-version: ["3.10"] 13 | jax-version: ["0.5"] 14 | 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install .[test] "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" 25 | pip install -r ./tests/requirements.txt 26 | - name: Test with pytest 27 | run: | 28 | PYTHONPATH=tests:src:. pytest tests -m "not entry and not slow and not ray" 29 | -------------------------------------------------------------------------------- /.github/workflows/tpu_unit_tests.yaml: -------------------------------------------------------------------------------- 1 | name: CI with GCP TPU 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | test: 7 | if: (github.event.pull_request.head.repo.full_name == github.repository) 8 | runs-on: ubuntu-latest 9 | env: 10 | TPU_ZONE: "us-central2-b" 11 | 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v2 15 | 16 | - name: Set up Google Cloud SDK 17 | uses: google-github-actions/setup-gcloud@v1 18 | with: 19 | project_id: ${{ secrets.GCP_PROJECT_ID }} 20 | 21 | - name: Authenticate to Google Cloud 22 | uses: google-github-actions/auth@v1 23 | with: 24 | credentials_json: ${{ secrets.GCP_SA_KEY }} 25 | 26 | - name: Configure Google Cloud 27 | run: | 28 | gcloud config set project ${{ secrets.GCP_PROJECT_ID }} 29 | 30 | - name: Create VM 31 | run: | 32 | export TPU_NAME=ci-run-${{ github.run_id }} 33 | eval "$(ssh-agent -s)" 34 | TRUE_SHA=${{ github.event.pull_request.head.sha }} 35 | bash infra/spin-up-vm.sh $TPU_NAME -z ${TPU_ZONE} -t v4-8 --preemptible -s infra/helpers/setup-tpu-vm-tests.sh -b ${TRUE_SHA} --retries 1 36 | # infra/babysit-tpu-vm.sh $TPU_NAME -z ${{ TPU_ZONE }} -t v4-8 --preemptible -s infra/helpers/setup-tpu-vm-tests.sh -b ${{ github.sha }} --retries 1 -- \ 37 | # PYTHONPATH=$PYTHONPATH:levanter/tests bash levanter/infra/run.sh pytest levanter/tests -m "not entry" 38 | 39 | - name: Run most tests 40 | run: | 41 | export TPU_NAME=ci-run-${{ github.run_id }} 42 | gcloud compute tpus tpu-vm ssh $TPU_NAME --zone ${TPU_ZONE} --command "JAX_TRACEBACK_FILTERING=off PYTHONPATH=$PYTHONPATH:levanter/tests CI=1 bash levanter/infra/run.sh pytest levanter/tests -m 'not entry and not ray'" 43 | # Something's wrong with these 44 | # 45 | # - name: Run forked tests 46 | # run: | 47 | # export TPU_NAME=ci-run-${{ github.run_id }} 48 | # gcloud compute tpus tpu-vm ssh $TPU_NAME --zone ${TPU_ZONE} --command "PYTHONPATH=$PYTHONPATH:levanter/tests bash levanter/infra/run.sh pytest --forked levanter/tests -m 'entry'" 49 | # 50 | - name: Cleanup 51 | if: ${{ always() }} 52 | run: | 53 | export TPU_NAME=ci-run-${{ github.run_id }} 54 | echo gcloud compute tpus tpu-vm delete $TPU_NAME --zone ${TPU_ZONE} --quiet 55 | gcloud compute tpus tpu-vm delete $TPU_NAME --zone ${TPU_ZONE} --quiet 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /scratch 2 | 3 | # Configuration for TPU launches/secrets 4 | .levanter.yaml 5 | .levanter.yaml 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | # JetBrains 138 | .idea/ 139 | 140 | # vscode 141 | .vscode 142 | 143 | # Wandb stuff 144 | /wandb 145 | 146 | # dataset cache files 147 | /cache 148 | *.parquet 149 | ledger.json 150 | 151 | /checkpoints 152 | *.jaxpr 153 | 154 | local_*.sh 155 | 156 | # aider 157 | .aider* 158 | 159 | .benchmarks 160 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | 10 | 11 | # GitHub Copilot persisted chat sessions 12 | /copilot/chatSessions 13 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | exclude: ".git" 4 | default_stages: 5 | - pre-commit 6 | fail_fast: true 7 | 8 | repos: 9 | - repo: https://github.com/pre-commit/pre-commit-hooks 10 | rev: v5.0.0 11 | hooks: 12 | - id: trailing-whitespace 13 | - id: end-of-file-fixer 14 | - id: check-yaml 15 | args: ['--unsafe'] 16 | - id: check-toml 17 | - id: check-merge-conflict 18 | - id: check-added-large-files 19 | 20 | - repo: https://github.com/psf/black 21 | rev: 22.3.0 22 | hooks: 23 | - id: black 24 | 25 | - repo: https://github.com/timothycrosley/isort 26 | rev: 5.11.5 27 | hooks: 28 | - id: isort 29 | 30 | - repo: https://github.com/PyCQA/flake8 31 | rev: 6.1.0 32 | hooks: 33 | - id: flake8 34 | additional_dependencies: [flake8-isort] 35 | 36 | - repo: https://github.com/pre-commit/mirrors-mypy 37 | rev: 'v1.4.1' 38 | hooks: 39 | - id: mypy 40 | args: [--ignore-missing-imports] 41 | additional_dependencies: [wandb==0.17.8, types-PyYAML] 42 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file for MkDocs projects 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Set the version of Python and other tools you might need 8 | build: 9 | os: ubuntu-22.04 10 | tools: 11 | python: "3.11" 12 | mkdocs: 13 | configuration: mkdocs.yml 14 | # Optionally declare the Python requirements required to build your docs 15 | python: 16 | install: 17 | - requirements: docs/requirements.txt 18 | -------------------------------------------------------------------------------- /config/backpack.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/openwebtext_source.yaml 2 | model: 3 | type: backpack 4 | hidden_dim: 768 5 | num_heads: 12 6 | num_layers: 12 7 | seq_len: 512 8 | gradient_checkpointing: true 9 | scale_attn_by_inverse_layer_idx: true 10 | num_senses: 16 11 | sense_intermediate_scale: 4 12 | trainer: 13 | tracker: 14 | project: "levanter" 15 | tags: [ "openwebtext", "backpack" ] 16 | 17 | mp: p=f32,c=bfloat16 18 | 19 | num_train_steps: 50000 20 | train_batch_size: 1024 21 | model_axis_size: 1 22 | 23 | optimizer: 24 | learning_rate: 6E-4 25 | weight_decay: 0.1 26 | -------------------------------------------------------------------------------- /config/backpack_nano.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | id: dlwh/wikitext_103_detokenized 3 | model: 4 | type: backpack 5 | hidden_dim: 32 6 | num_heads: 4 7 | num_layers: 2 8 | seq_len: 512 9 | gradient_checkpointing: true 10 | scale_attn_by_inverse_layer_idx: true 11 | num_senses: 16 12 | sense_intermediate_scale: 4 13 | trainer: 14 | mp: f32 15 | 16 | num_train_steps: 100 17 | train_batch_size: 32 18 | model_axis_size: 1 19 | 20 | optimizer: 21 | learning_rate: 6E-4 22 | weight_decay: 0.1 23 | -------------------------------------------------------------------------------- /config/data/dclm_gpt_neo.yaml: -------------------------------------------------------------------------------- 1 | cache_dir: "gs://marin-us-central2/tokenized/gpt_neox/" 2 | tokenizer: "EleutherAI/gpt-neox-20b" 3 | cache_options: 4 | batch_size: 256 5 | num_shard_groups: 1024 6 | stop_strategy: restart 7 | shuffle: 100000 8 | configs: 9 | "dclm": 10 | train_urls: 11 | - gs://marin-us-central2/raw/dclm/v2024-07-09-baseline-dedup/**/*.zstd 12 | # these are just for eval 13 | "paloma/4chan": 14 | validation_urls: 15 | - gs://levanter-data/paloma/4chan_meta_sep/val/val*.jsonl.gz 16 | "paloma/c4_100_domains": 17 | validation_urls: 18 | - gs://levanter-data/paloma/c4_100_domains/val/val*.jsonl.gz 19 | "paloma/c4_en": 20 | validation_urls: 21 | - gs://levanter-data/paloma/c4_en/val/val*.jsonl.gz 22 | "paloma/dolma-v1_5": 23 | validation_urls: 24 | - gs://levanter-data/paloma/dolma-v1_5/val/val*.jsonl.gz 25 | "paloma/dolma_100_programing_languages": 26 | validation_urls: 27 | - gs://levanter-data/paloma/dolma_100_programing_languages/val/val*.jsonl.gz 28 | "paloma/dolma_100_subreddits": 29 | validation_urls: 30 | - gs://levanter-data/paloma/dolma_100_subreddits/val/val*.jsonl.gz 31 | "paloma/falcon-refinedweb": 32 | validation_urls: 33 | - gs://levanter-data/paloma/falcon-refinedweb/val/val*.jsonl.gz 34 | "paloma/gab": 35 | validation_urls: 36 | - gs://levanter-data/paloma/gab/val/val*.jsonl.gz 37 | "paloma/m2d2_s2orc_unsplit": 38 | validation_urls: 39 | - gs://levanter-data/paloma/m2d2_s2orc_unsplit/val/val*.jsonl.gz 40 | "paloma/m2d2_wikipedia_unsplit": 41 | validation_urls: 42 | - gs://levanter-data/paloma/m2d2_wikipedia_unsplit/val/val*.jsonl.gz 43 | "paloma/manosphere_meta_sep": 44 | validation_urls: 45 | - gs://levanter-data/paloma/manosphere_meta_sep/val/val*.jsonl.gz 46 | "paloma/mc4": 47 | validation_urls: 48 | - gs://levanter-data/paloma/mc4/val/val*.jsonl.gz 49 | "paloma/ptb": 50 | validation_urls: 51 | - gs://levanter-data/paloma/ptb/val/val*.jsonl.gz 52 | "paloma/redpajama": 53 | validation_urls: 54 | - gs://levanter-data/paloma/redpajama/val/val*.jsonl.gz 55 | "paloma/twitterAAE_HELM_fixed": 56 | validation_urls: 57 | - gs://levanter-data/paloma/twitterAAE_HELM_fixed/val/val*.jsonl.gz 58 | "paloma/wikitext_103": 59 | validation_urls: 60 | - gs://levanter-data/paloma/wikitext_103/val/val*.jsonl.gz 61 | train_weights: 62 | dclm: 1.0 63 | paloma/4chan: 0.0 64 | paloma/c4_100_domains: 0.0 65 | paloma/c4_en: 0.0 66 | paloma/dolma-v1_5: 0.0 67 | paloma/dolma_100_programing_languages: 0.0 68 | paloma/dolma_100_subreddits: 0.0 69 | paloma/falcon-refinedweb: 0.0 70 | paloma/gab: 0.0 71 | paloma/m2d2_s2orc_unsplit: 0.0 72 | paloma/m2d2_wikipedia_unsplit: 0.0 73 | paloma/manosphere_meta_sep: 0.0 74 | paloma/mc4: 0.0 75 | paloma/ptb: 0.0 76 | paloma/redpajama: 0.0 77 | paloma/twitterAAE_HELM_fixed: 0.0 78 | paloma/wikitext_103: 0.0 79 | -------------------------------------------------------------------------------- /config/data/openwebtext_source.yaml: -------------------------------------------------------------------------------- 1 | train_urls: 2 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" 3 | validation_urls: 4 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" 5 | cache_dir: "gs://levanter-data/tokenized/openwebtext/" 6 | tokenizer: "gpt2" 7 | cache_options: 8 | batch_size: 1024 9 | num_shard_groups: 64 10 | -------------------------------------------------------------------------------- /config/data/pile_source_old.yaml: -------------------------------------------------------------------------------- 1 | train_urls: 2 | - gs://levanter-data/pile/train/{00..29}.jsonl.zst 3 | validation_urls: 4 | - gs://levanter-data/pile/val.jsonl.zst 5 | cache_dir: "gs://levanter-data/tokenized/pile-old/" 6 | tokenizer: "EleutherAI/gpt-neox-20b" 7 | -------------------------------------------------------------------------------- /config/data/pubmed_source.yaml: -------------------------------------------------------------------------------- 1 | train_urls: 2 | - "gs://pubmed-mosaic/pubmed-sharded/pubmedRandomized_train.{1..128}-of-128.jsonl.gz" 3 | validation_urls: 4 | - "gs://pubmed-mosaic/pubmed-sharded/pubmedRandomized_val.{1..8}-of-8.jsonl.gz" 5 | cache_dir: "gs://pubmed-mosaic/tokenized/pubmed-sharded/" 6 | -------------------------------------------------------------------------------- /config/data/redpajama_1b_source.yaml: -------------------------------------------------------------------------------- 1 | id: togethercomputer/RedPajama-Data-1T-Sample 2 | cache_dir: gs://levanter-data/tokenized/redpajama-sample/ 3 | tokenizer: EleutherAI/gpt-neox-20b 4 | splits: 5 | - train 6 | -------------------------------------------------------------------------------- /config/data/redpajama_1t_source.yaml: -------------------------------------------------------------------------------- 1 | id: togethercomputer/RedPajama-Data-1T 2 | cache_dir: gs://levanter-data/tokenized/redpajama/ 3 | tokenizer: EleutherAI/gpt-neox-20b 4 | splits: 5 | - train 6 | -------------------------------------------------------------------------------- /config/data/wikitext_source.yaml: -------------------------------------------------------------------------------- 1 | id: dlwh/wikitext_103_detokenized 2 | cache_dir: "gs://levanter-data/tokenized/wikitext" 3 | -------------------------------------------------------------------------------- /config/gemma_2b.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_urls: 3 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" 4 | validation_urls: 5 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..1}-of-8.jsonl.gz" 6 | cache_dir: "gs://wasabi-tpu-training/openwebtext-mini" 7 | tokenizer: "google/gemma-2b" 8 | model: 9 | type: gemma 10 | initialize_from_hf: "google/gemma-2b" 11 | use_hf_model_config: true 12 | trainer: 13 | checkpointer: 14 | base_path: "gs://wasabi-tpu-training/gemma-2b/" 15 | tracker: 16 | type: wandb 17 | project: "levanter" 18 | tags: ["openwebtext", "gemma"] 19 | 20 | mp: p=bfloat16,c=bfloat16 21 | train_batch_size: 16 # set for v5e-16 TPU 22 | num_train_steps: 100000 23 | steps_per_eval: 50 24 | tensor_parallel_axes: ["mlp", "heads"] 25 | fsdp_axis: "embed" 26 | batch_axis: "batch" 27 | optimizer: 28 | learning_rate: 1.2E-5 # set low for fine-tuning 29 | weight_decay: 0.1 30 | min_lr_ratio: 0.1 31 | -------------------------------------------------------------------------------- /config/gpt2_1536.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/openwebtext_source.yaml 2 | model: 3 | type: gpt2 4 | hidden_dim: 1536 5 | num_heads: 24 6 | num_layers: 48 7 | seq_len: 1024 8 | gradient_checkpointing: true 9 | scale_attn_by_inverse_layer_idx: true 10 | trainer: 11 | tracker: 12 | project: "levanter" 13 | tags: [ "openwebtext", "gpt2"] 14 | 15 | mp: p=f32,c=bfloat16 16 | model_axis_size: 1 17 | per_device_eval_parallelism: 8 18 | optimizer: 19 | learning_rate: 1E-4 20 | weight_decay: 0.1 21 | min_lr_ratio: 0.1 22 | -------------------------------------------------------------------------------- /config/gpt2_1536_sophiah.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_urls: 3 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" 4 | validation_urls: 5 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" 6 | cache_dir: "gs://levanter-data/tokenized/openwebtext/" 7 | tokenizer: "gpt2" 8 | model: 9 | type: gpt2 10 | hidden_dim: 1536 11 | num_heads: 24 12 | num_layers: 48 13 | seq_len: 1024 14 | gradient_checkpointing: true 15 | scale_attn_by_inverse_layer_idx: true 16 | trainer: 17 | tracker: 18 | project: "levanter" 19 | tags: [ "openwebtext", "gpt2"] 20 | 21 | mp: p=f32,c=bfloat16 22 | model_axis_size: 1 23 | optimizer: 24 | type: sophia-h 25 | learning_rate: 2E-4 26 | weight_decay: 0.2 27 | min_lr_ratio: 0.1 28 | gamma: 0.01 29 | warmup: 2000 30 | -------------------------------------------------------------------------------- /config/gpt2_20b.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/pile_source_old.yaml 2 | model: 3 | type: gpt2 4 | hidden_dim: 6144 5 | num_heads: 64 6 | num_layers: 44 7 | seq_len: 2048 8 | gradient_checkpointing: true 9 | scale_attn_by_inverse_layer_idx: true 10 | attn_pdrop: 0.0 11 | resid_pdrop: 0.0 12 | use_bias: false 13 | trainer: 14 | tracker: 15 | project: "levanter" 16 | tags: ["pile", "gpt2"] 17 | 18 | mp: p=f32,c=bfloat16 19 | 20 | 21 | per_device_eval_parallelism: 4 22 | 23 | train_batch_size: 1024 24 | num_train_steps: 100000 25 | steps_per_eval: 500 26 | 27 | axis_resources: 28 | batch: "data" 29 | vocab: "model" 30 | mlp: "model" 31 | heads: "model" 32 | # ZERO-3 33 | parameter_axis_resources: 34 | embed: "data" 35 | 36 | optimizer: 37 | learning_rate: 1.2E-4 38 | weight_decay: 0.1 39 | min_lr_ratio: 0.1 40 | -------------------------------------------------------------------------------- /config/gpt2_7b.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/pile_source_old.yaml 2 | model: 3 | type: gpt2 4 | hidden_dim: 4096 5 | num_heads: 32 6 | num_layers: 32 7 | seq_len: 2048 8 | gradient_checkpointing: true 9 | scale_attn_by_inverse_layer_idx: true 10 | attn_pdrop: 0.0 11 | resid_pdrop: 0.0 12 | trainer: 13 | tracker: 14 | project: "levanter" 15 | tags: ["pile", "gpt2"] 16 | 17 | mp: p=f32,c=bfloat16 18 | 19 | model_axis_size: 1 20 | per_device_parallelism: -1 21 | per_device_eval_parallelism: -1 22 | 23 | train_batch_size: 1024 24 | num_train_steps: 100000 25 | steps_per_eval: 500 26 | optimizer: 27 | learning_rate: 1.2E-4 28 | weight_decay: 0.1 29 | min_lr_ratio: 0.1 30 | -------------------------------------------------------------------------------- /config/gpt2_hyena_nano.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | id: dlwh/wikitext_103_detokenized 3 | 4 | model: 5 | type: gpt2_hyena 6 | hyena: 7 | hidden_dim: 32 # Default: 768 8 | filter_order: 16 # Default: 64 9 | num_layers: 2 # Default: 12 10 | 11 | trainer: 12 | num_train_steps: 100 13 | require_accelerator: false 14 | 15 | checkpointer: 16 | keep: 17 | - every: 50 18 | save_interval: 5m 19 | 20 | per_device_parallelism: -1 21 | train_batch_size: 32 22 | 23 | tensor_parallel_axes: ["hyena_filter_order", "mlp"] 24 | fsdp_axis: "embed" 25 | batch_axis: "batch" 26 | -------------------------------------------------------------------------------- /config/gpt2_large.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/openwebtext_source.yaml 2 | model: 3 | type: gpt2 4 | hidden_dim: 1280 5 | num_heads: 20 6 | num_layers: 36 7 | seq_len: 1024 8 | gradient_checkpointing: true 9 | scale_attn_by_inverse_layer_idx: true 10 | trainer: 11 | tracker: 12 | project: "levanter" 13 | tags: [ "openwebtext", "gpt2"] 14 | 15 | mp: p=f32,c=bfloat16 16 | model_axis_size: 1 17 | per_device_parallelism: -1 18 | optimizer: 19 | learning_rate: 2E-4 20 | weight_decay: 0.1 21 | -------------------------------------------------------------------------------- /config/gpt2_large_sophia_h.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/openwebtext_source.yaml 2 | model: 3 | type: gpt2 4 | hidden_dim: 1280 5 | num_heads: 20 6 | num_layers: 36 7 | seq_len: 1024 8 | gradient_checkpointing: true 9 | scale_attn_by_inverse_layer_idx: true 10 | trainer: 11 | wandb: 12 | project: "levanter" 13 | tags: [ "openwebtext", "gpt2", "sophia-h"] 14 | 15 | num_train_steps: 200000 16 | mp: p=f32,c=bfloat16 17 | 18 | optimizer: 19 | type: sophia-h 20 | learning_rate: 1.7E-4 21 | weight_decay: 0.2 22 | -------------------------------------------------------------------------------- /config/gpt2_medium.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/openwebtext_source.yaml 2 | model: 3 | type: gpt2 4 | hidden_dim: 1024 5 | num_heads: 16 6 | num_layers: 24 7 | seq_len: 1024 8 | gradient_checkpointing: true 9 | scale_attn_by_inverse_layer_idx: true 10 | trainer: 11 | tracker: 12 | project: "levanter" 13 | tags: [ "openwebtext", "gpt2"] 14 | 15 | mp: p=f32,c=bfloat16 16 | model_axis_size: 1 17 | optimizer: 18 | learning_rate: 3E-4 19 | weight_decay: 0.1 20 | min_lr_ratio: 0.1 21 | -------------------------------------------------------------------------------- /config/gpt2_micro.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | id: dlwh/wikitext_103_detokenized 3 | model: 4 | type: gpt2 5 | hidden_dim: 128 6 | num_heads: 8 7 | num_layers: 4 8 | trainer: 9 | tracker: 10 | project: "levanter" 11 | tags: [ "openwebtext", "gpt2"] 12 | 13 | mp: p=f32,c=bfloat16 14 | num_train_steps: 100 15 | per_device_eval_parallelism: 1 16 | train_batch_size: 32 17 | -------------------------------------------------------------------------------- /config/gpt2_nano.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | id: dlwh/wikitext_103_detokenized 3 | model: 4 | type: gpt2 5 | hidden_dim: 32 6 | num_heads: 4 7 | num_layers: 2 8 | trainer: 9 | mp: f32 10 | num_train_steps: 100 11 | 12 | checkpointer: 13 | keep: 14 | - every: 50 15 | save_interval: 5m 16 | 17 | per_device_parallelism: -1 18 | train_batch_size: 32 19 | 20 | tensor_parallel_axes: ["mlp", "heads"] 21 | fsdp_axis: "embed" 22 | batch_axis: "batch" 23 | -------------------------------------------------------------------------------- /config/gpt2_nano_fp8.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | id: dlwh/wikitext_103_detokenized 3 | model: 4 | type: gpt2 5 | hidden_dim: 32 6 | num_heads: 4 7 | num_layers: 2 8 | trainer: 9 | mp: f32 10 | quantization: 11 | fp8: true 12 | num_train_steps: 100 13 | 14 | checkpointer: 15 | keep: 16 | - every: 50 17 | save_interval: 5m 18 | 19 | per_device_parallelism: -1 20 | train_batch_size: 32 21 | 22 | tensor_parallel_axes: ["mlp", "heads"] 23 | fsdp_axis: "embed" 24 | batch_axis: "batch" 25 | -------------------------------------------------------------------------------- /config/gpt2_nano_harness.yaml: -------------------------------------------------------------------------------- 1 | eval_harness: 2 | task_spec: 3 | - mmlu 4 | # - task: mmlu 5 | # task_alias: mmlu_0shot 6 | # num_fewshot: 0 7 | max_examples: 32 8 | eval_harness_steps: 50 9 | data: 10 | id: dlwh/wikitext_103_detokenized 11 | model: 12 | type: gpt2 13 | hidden_dim: 32 14 | num_heads: 4 15 | num_layers: 2 16 | trainer: 17 | mp: f32 18 | num_train_steps: 100 19 | 20 | checkpointer: 21 | keep: 22 | - every: 50 23 | save_interval: 5m 24 | 25 | per_device_parallelism: -1 26 | train_batch_size: 4 27 | 28 | tensor_parallel_axes: ["mlp", "heads"] 29 | fsdp_axis: "embed" 30 | batch_axis: "batch" 31 | -------------------------------------------------------------------------------- /config/gpt2_nano_mixture.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | cache_dir: data_mix_cache 3 | configs: 4 | wikitext: 5 | id: dlwh/wikitext_103_detokenized 6 | w2: 7 | id: dlwh/wikitext_103_detokenized 8 | cache_dir: wikitext2_cache 9 | train_weights: 10 | - [0, {"wikitext": 0.8, "w2": 0.2}] 11 | - [100, {"w2": 0.5, "wikitext": 0.5}] 12 | model: 13 | type: gpt2 14 | hidden_dim: 32 15 | num_heads: 4 16 | num_layers: 2 17 | trainer: 18 | mp: f32 19 | num_train_steps: 100 20 | 21 | checkpointer: 22 | keep: 23 | - every: 50 24 | save_interval: 5m 25 | 26 | per_device_eval_parallelism: 16 27 | train_batch_size: 32 28 | 29 | tensor_parallel_axes: ["mlp", "heads"] 30 | fsdp_axis: "embed" 31 | batch_axis: "batch" 32 | -------------------------------------------------------------------------------- /config/gpt2_nano_skip.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | id: dlwh/wikitext_103_detokenized 3 | model: 4 | type: gpt2 5 | hidden_dim: 32 6 | num_heads: 4 7 | num_layers: 2 8 | trainer: 9 | mp: f32 10 | num_train_steps: 100 11 | 12 | checkpointer: 13 | keep: 14 | - every: 50 15 | save_interval: 5m 16 | 17 | per_device_parallelism: -1 18 | train_batch_size: 32 19 | 20 | tensor_parallel_axes: ["mlp", "heads"] 21 | fsdp_axis: "embed" 22 | batch_axis: "batch" 23 | 24 | 25 | optimizer: 26 | skip_bad_steps: true 27 | -------------------------------------------------------------------------------- /config/gpt2_nano_tb.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | id: dlwh/wikitext_103_detokenized 3 | model: 4 | type: gpt2 5 | hidden_dim: 32 6 | num_heads: 4 7 | num_layers: 2 8 | trainer: 9 | mp: f32 10 | num_train_steps: 100 11 | 12 | checkpointer: 13 | keep: 14 | - every: 50 15 | save_interval: 5m 16 | 17 | per_device_parallelism: -1 18 | train_batch_size: 32 19 | 20 | tensor_parallel_axes: ["mlp", "heads"] 21 | fsdp_axis: "embed" 22 | batch_axis: "batch" 23 | tracker: 24 | type: tensorboard 25 | logdir: tb_logs/ 26 | -------------------------------------------------------------------------------- /config/gpt2_small.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/openwebtext_source.yaml 2 | model: 3 | type: gpt2 4 | hidden_dim: 768 5 | num_heads: 12 6 | num_layers: 12 7 | seq_len: 1024 8 | gradient_checkpointing: true 9 | scale_attn_by_inverse_layer_idx: true 10 | attn_backend: jax_flash 11 | trainer: 12 | tracker: 13 | project: "levanter" 14 | tags: [ "openwebtext", "gpt2"] 15 | 16 | mp: p=f32,c=bfloat16 17 | model_axis_size: 1 18 | per_device_parallelism: -1 19 | 20 | train_batch_size: 512 21 | optimizer: 22 | learning_rate: 6E-4 23 | weight_decay: 0.1 24 | min_lr_ratio: 0.1 25 | -------------------------------------------------------------------------------- /config/gpt2_small_fast.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/openwebtext_source.yaml 2 | model: 3 | type: gpt2 4 | hidden_dim: 768 5 | num_heads: 12 6 | num_layers: 12 7 | seq_len: 1024 8 | gradient_checkpointing: true 9 | scale_attn_by_inverse_layer_idx: true 10 | trainer: 11 | tracker: 12 | - type: wandb 13 | project: "levanter" 14 | tags: [ "openwebtext", "gpt2", "itest"] 15 | 16 | mp: p=f32,c=bfloat16 17 | model_axis_size: 1 18 | per_device_parallelism: -1 19 | 20 | train_batch_size: 256 21 | num_train_steps: 20000 22 | 23 | # tensor_parallel_axes: ["position", "key_position"] 24 | # tensor_parallel_axes: ["heads", "mlp"] 25 | optimizer: 26 | learning_rate: 1E-3 27 | weight_decay: 0.1 28 | warmup: 0.01 29 | -------------------------------------------------------------------------------- /config/gpt2_small_fast_batch_schedule.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | configs: 3 | owt: 4 | train_urls: 5 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" 6 | validation_urls: 7 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" 8 | wikitext: 9 | id: dlwh/wikitext_103_detokenized 10 | train_weights: 11 | - [0, {"owt": 0.8, "wikitext": 0.2}] 12 | - [1000, {"owt": 0.5, "wikitext": 0.5}] 13 | tokenizer: gpt2 14 | cache_dir: "gs://levanter-data/tokenized/data_mix" 15 | mixture_block_size: 128 16 | model: 17 | type: gpt2 18 | hidden_dim: 768 19 | num_heads: 12 20 | num_layers: 12 21 | seq_len: 1024 22 | gradient_checkpointing: true 23 | scale_attn_by_inverse_layer_idx: true 24 | trainer: 25 | tracker: 26 | - type: wandb 27 | project: "levanter" 28 | tags: [ "openwebtext", "gpt2", "itest"] 29 | 30 | mp: p=f32,c=bfloat16 31 | model_axis_size: 1 32 | per_device_parallelism: -1 33 | per_device_eval_parallelism: 128 34 | 35 | train_batch_size: 36 | - start: 0 37 | value: 64 38 | - start: 64 39 | value: 128 40 | - start: 256 41 | value: 64 42 | - start: 1024 43 | value: 256 44 | num_train_steps: 20000 45 | allow_nondivisible_batch_size: true 46 | 47 | # tensor_parallel_axes: ["position", "key_position"] 48 | # tensor_parallel_axes: ["heads", "mlp"] 49 | optimizer: 50 | learning_rate: 1E-3 51 | weight_decay: 0.1 52 | warmup: 0.01 53 | -------------------------------------------------------------------------------- /config/gpt2_small_fast_ema.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/openwebtext_source.yaml 2 | model: 3 | type: gpt2 4 | hidden_dim: 768 5 | num_heads: 12 6 | num_layers: 12 7 | seq_len: 1024 8 | gradient_checkpointing: true 9 | scale_attn_by_inverse_layer_idx: true 10 | trainer: 11 | tracker: 12 | - type: wandb 13 | project: "levanter" 14 | tags: [ "openwebtext", "gpt2", "itest"] 15 | 16 | mp: p=f32,c=bfloat16 17 | model_averaging: 18 | type: ema 19 | beta: 0.995 20 | 21 | model_axis_size: 1 22 | per_device_parallelism: -1 23 | 24 | train_batch_size: 256 25 | num_train_steps: 20000 26 | 27 | # tensor_parallel_axes: ["position", "key_position"] 28 | # tensor_parallel_axes: ["heads", "mlp"] 29 | optimizer: 30 | learning_rate: 1E-3 31 | weight_decay: 0.1 32 | warmup: 0.01 33 | decay: 200 # no decay b/c EMA 34 | lr_schedule: inv 35 | -------------------------------------------------------------------------------- /config/gpt2_small_fast_eval.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | configs: 3 | owt: 4 | train_urls: 5 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" 6 | validation_urls: 7 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" 8 | wikitext: 9 | id: dlwh/wikitext_103_detokenized 10 | train_weights: 11 | owt: 0.6 12 | wikitext: 0.4 13 | tokenizer: gpt2 14 | cache_dir: "gs://levanter-data/tokenized/data_mix" 15 | 16 | eval_harness: 17 | task_spec: ["piqa", "hellaswag"] 18 | max_examples: 2048 19 | eval_harness_steps: 1000 20 | 21 | model: 22 | type: gpt2 23 | hidden_dim: 768 24 | num_heads: 12 25 | num_layers: 12 26 | seq_len: 1024 27 | gradient_checkpointing: true 28 | scale_attn_by_inverse_layer_idx: true 29 | trainer: 30 | tracker: 31 | project: "levanter" 32 | tags: [ "openwebtext+wiki", "gpt2", "itest"] 33 | 34 | mp: p=f32,c=bfloat16 35 | model_axis_size: 1 36 | 37 | train_batch_size: 256 38 | num_train_steps: 20000 39 | optimizer: 40 | learning_rate: 1E-3 41 | weight_decay: 0.1 42 | warmup: 0.01 43 | -------------------------------------------------------------------------------- /config/gpt2_small_fast_fp8.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/openwebtext_source.yaml 2 | model: 3 | type: gpt2 4 | hidden_dim: 768 5 | num_heads: 12 6 | num_layers: 12 7 | seq_len: 1024 8 | gradient_checkpointing: true 9 | scale_attn_by_inverse_layer_idx: true 10 | trainer: 11 | tracker: 12 | - type: wandb 13 | project: "levanter" 14 | tags: [ "openwebtext", "gpt2", "itest"] 15 | 16 | mp: p=f32,c=bfloat16 17 | quantization: 18 | fp8: true 19 | model_axis_size: 1 20 | per_device_parallelism: -1 21 | 22 | train_batch_size: 256 23 | num_train_steps: 20000 24 | optimizer: 25 | learning_rate: 1E-3 26 | weight_decay: 0.1 27 | warmup: 0.01 28 | -------------------------------------------------------------------------------- /config/gpt2_small_fast_mix.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | configs: 3 | owt: 4 | train_urls: 5 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" 6 | validation_urls: 7 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" 8 | wikitext: 9 | id: dlwh/wikitext_103_detokenized 10 | train_weights: 11 | owt: 0.6 12 | wikitext: 0.4 13 | tokenizer: gpt2 14 | cache_dir: "gs://marin-us-central2/scratch/dlwh/gpt2_small_fast_mix" 15 | model: 16 | type: gpt2 17 | hidden_dim: 768 18 | num_heads: 12 19 | num_layers: 12 20 | seq_len: 1024 21 | gradient_checkpointing: true 22 | scale_attn_by_inverse_layer_idx: true 23 | trainer: 24 | tracker: 25 | project: "levanter" 26 | tags: [ "openwebtext+wiki", "gpt2", "itest"] 27 | 28 | mp: p=f32,c=bfloat16 29 | model_axis_size: 1 30 | 31 | train_batch_size: 256 32 | num_train_steps: 20000 33 | optimizer: 34 | learning_rate: 1E-3 35 | weight_decay: 0.1 36 | warmup: 0.01 37 | -------------------------------------------------------------------------------- /config/gpt2_small_fast_mix_chat.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | configs: 3 | owt: 4 | train_urls: 5 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" 6 | validation_urls: 7 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" 8 | wikitext: 9 | id: dlwh/wikitext_103_detokenized 10 | tulu: 11 | id: allenai/tulu-3-sft-mixture 12 | format: 13 | type: "chat" 14 | train_weights: 15 | owt: 0.6 16 | wikitext: 0.3 17 | tulu: 0.1 18 | tokenizer: stanford-crfm/marin-tokenizer 19 | cache_dir: "gs://marin-us-central2/scratch/dlwh/marin_small_fast_mix" 20 | model: 21 | type: gpt2 22 | hidden_dim: 768 23 | num_heads: 12 24 | num_layers: 12 25 | seq_len: 1024 26 | gradient_checkpointing: true 27 | scale_attn_by_inverse_layer_idx: true 28 | trainer: 29 | tracker: 30 | project: "levanter" 31 | tags: [ "openwebtext+wiki", "gpt2", "itest"] 32 | 33 | mp: p=f32,c=bfloat16 34 | model_axis_size: 1 35 | 36 | train_batch_size: 256 37 | num_train_steps: 20000 38 | optimizer: 39 | learning_rate: 1E-3 40 | weight_decay: 0.1 41 | warmup: 0.01 42 | -------------------------------------------------------------------------------- /config/gpt2_small_fast_pile.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/pile_mixture.yaml 2 | model: 3 | type: gpt2 4 | hidden_dim: 768 5 | num_heads: 12 6 | num_layers: 12 7 | seq_len: 1024 8 | gradient_checkpointing: true 9 | scale_attn_by_inverse_layer_idx: true 10 | trainer: 11 | tracker: 12 | project: "levanter" 13 | tags: [ "pile", "gpt2", "itest"] 14 | 15 | mp: p=f32,c=bfloat16 16 | model_axis_size: 1 17 | 18 | train_batch_size: 256 19 | num_train_steps: 20000 20 | optimizer: 21 | learning_rate: 1E-3 22 | weight_decay: 0.1 23 | -------------------------------------------------------------------------------- /config/gpt2_small_fast_public.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_urls: 3 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" 4 | validation_urls: 5 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" 6 | cache_dir: "gs://levanter-data/tokenized/openwebtext/" 7 | tokenizer: "gpt2" 8 | model: 9 | type: gpt2 10 | hidden_dim: 768 11 | num_heads: 12 12 | num_layers: 12 13 | seq_len: 1024 14 | gradient_checkpointing: true 15 | scale_attn_by_inverse_layer_idx: true 16 | trainer: 17 | tracker: 18 | - type: wandb 19 | project: "levanter" 20 | tags: [ "openwebtext", "gpt2", "itest"] 21 | 22 | mp: p=f32,c=bfloat16 23 | model_axis_size: 1 24 | per_device_parallelism: -1 25 | 26 | train_batch_size: 256 27 | num_train_steps: 10000 28 | ray: 29 | auto_start_cluster: false 30 | optimizer: 31 | learning_rate: 1E-3 32 | weight_decay: 0.1 33 | warmup: 0.01 34 | -------------------------------------------------------------------------------- /config/gpt2_small_fast_skip.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/openwebtext_source.yaml 2 | model: 3 | type: gpt2 4 | hidden_dim: 768 5 | num_heads: 12 6 | num_layers: 12 7 | seq_len: 1024 8 | gradient_checkpointing: true 9 | scale_attn_by_inverse_layer_idx: true 10 | trainer: 11 | tracker: 12 | - type: wandb 13 | project: "levanter" 14 | tags: [ "openwebtext", "gpt2", "itest"] 15 | 16 | mp: p=f32,c=bfloat16 17 | model_axis_size: 1 18 | per_device_parallelism: -1 19 | 20 | train_batch_size: 256 21 | num_train_steps: 20000 22 | 23 | # tensor_parallel_axes: ["position", "key_position"] 24 | # tensor_parallel_axes: ["heads", "mlp"] 25 | optimizer: 26 | learning_rate: 1E-3 27 | weight_decay: 0.1 28 | warmup: 0.01 29 | skip_bad_steps: true 30 | -------------------------------------------------------------------------------- /config/gpt2_small_fast_sophia_h.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/openwebtext_source.yaml 2 | model: 3 | type: gpt2 4 | hidden_dim: 768 5 | num_heads: 12 6 | num_layers: 12 7 | seq_len: 1024 8 | gradient_checkpointing: true 9 | scale_attn_by_inverse_layer_idx: true 10 | trainer: 11 | wandb: 12 | project: "levanter" 13 | tags: [ "openwebtext", "gpt2", "itest", "sophia-h"] 14 | 15 | mp: p=f32,c=bfloat16 16 | model_axis_size: 1 17 | 18 | train_batch_size: 256 19 | num_train_steps: 20000 20 | optimizer: 21 | type: sophia-h 22 | learning_rate: .85E-3 23 | weight_decay: 0.2 24 | -------------------------------------------------------------------------------- /config/gpt2_small_fast_sophiah.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/openwebtext_source.yaml 2 | model: 3 | type: gpt2 4 | hidden_dim: 768 5 | num_heads: 12 6 | num_layers: 12 7 | seq_len: 1024 8 | gradient_checkpointing: true 9 | scale_attn_by_inverse_layer_idx: true 10 | trainer: 11 | wandb: 12 | project: "levanter" 13 | tags: [ "openwebtext", "gpt2", "itest"] 14 | 15 | mp: p=f32,c=bfloat16 16 | model_axis_size: 1 17 | per_device_parallelism: -1 18 | 19 | train_batch_size: 256 20 | num_train_steps: 20000 21 | optimizer: 22 | type: sophia-h 23 | learning_rate: 0.8E-3 24 | weight_decay: 0.1 25 | warmup: 0.01 26 | gamma: 0.005 27 | -------------------------------------------------------------------------------- /config/gpt2_small_fast_wiki.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | id: dlwh/wikitext_103_detokenized 3 | model: 4 | type: gpt2 5 | hidden_dim: 768 6 | num_heads: 12 7 | num_layers: 12 8 | seq_len: 1024 9 | gradient_checkpointing: true 10 | scale_attn_by_inverse_layer_idx: true 11 | trainer: 12 | tracker: 13 | project: "levanter" 14 | tags: [ "openwebtext", "gpt2", "itest"] 15 | 16 | mp: p=f32,c=bfloat16 17 | model_axis_size: 1 18 | per_device_parallelism: -1 19 | 20 | train_batch_size: 256 21 | # this is deliberately very small for testing 22 | num_train_steps: 200 23 | optimizer: 24 | learning_rate: 1E-3 25 | weight_decay: 0.1 26 | warmup: 0.01 27 | -------------------------------------------------------------------------------- /config/gpt2_small_itest.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/openwebtext_source.yaml 2 | model: 3 | type: gpt2 4 | hidden_dim: 768 5 | num_heads: 12 6 | num_layers: 12 7 | seq_len: 1024 8 | gradient_checkpointing: true 9 | scale_attn_by_inverse_layer_idx: true 10 | trainer: 11 | tracker: 12 | - type: wandb 13 | project: "levanter-itest" 14 | tags: [ "openwebtext", "gpt2", "itest"] 15 | 16 | mp: p=f32,c=bfloat16 17 | model_axis_size: 1 18 | per_device_parallelism: -1 19 | 20 | train_batch_size: 256 21 | num_train_steps: 5000 22 | 23 | # tensor_parallel_axes: ["position", "key_position"] 24 | # tensor_parallel_axes: ["heads", "mlp"] 25 | optimizer: 26 | learning_rate: 1E-3 27 | weight_decay: 0.1 28 | warmup: 0.01 29 | -------------------------------------------------------------------------------- /config/gpt2_small_pile.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/pile_source_old.yaml 2 | model: 3 | type: gpt2 4 | hidden_dim: 768 5 | num_heads: 12 6 | num_layers: 12 7 | seq_len: 2048 8 | gradient_checkpointing: true 9 | scale_attn_by_inverse_layer_idx: true 10 | trainer: 11 | tracker: 12 | project: "levanter" 13 | tags: [ "pile", "gpt2"] 14 | 15 | mp: p=f32,c=bfloat16 16 | model_axis_size: 1 17 | 18 | train_batch_size: 256 19 | num_train_steps: 50000 20 | optimizer: 21 | learning_rate: 6e-4 22 | weight_decay: 0.1 23 | -------------------------------------------------------------------------------- /config/gpt2_small_pile_mixture.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/pile_mixture.yaml 2 | model: 3 | type: gpt2 4 | hidden_dim: 768 5 | num_heads: 12 6 | num_layers: 12 7 | seq_len: 2048 8 | gradient_checkpointing: true 9 | scale_attn_by_inverse_layer_idx: true 10 | trainer: 11 | tracker: 12 | project: "levanter" 13 | tags: [ "pile", "gpt2"] 14 | 15 | mp: p=f32,c=bfloat16 16 | model_axis_size: 1 17 | 18 | train_batch_size: 256 19 | num_train_steps: 50000 20 | optimizer: 21 | learning_rate: 6e-4 22 | weight_decay: 0.1 23 | -------------------------------------------------------------------------------- /config/gpt2_small_sophiah.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/openwebtext_source.yaml 2 | model: 3 | type: gpt2 4 | hidden_dim: 768 5 | num_heads: 12 6 | num_layers: 12 7 | seq_len: 1024 8 | gradient_checkpointing: true 9 | scale_attn_by_inverse_layer_idx: true 10 | trainer: 11 | tracker: 12 | project: "levanter" 13 | tags: [ "openwebtext", "gpt2", "sophia-h"] 14 | 15 | mp: p=f32,c=bfloat16 16 | model_axis_size: 1 17 | 18 | train_batch_size: 512 19 | optimizer: !include optim/sophia-h_small.yaml 20 | -------------------------------------------------------------------------------- /config/gpt2_xl.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/openwebtext_source.yaml 2 | model: 3 | type: gpt2 4 | hidden_dim: 1600 5 | num_heads: 25 6 | num_layers: 48 7 | seq_len: 1024 8 | gradient_checkpointing: true 9 | scale_attn_by_inverse_layer_idx: true 10 | trainer: 11 | tracker: 12 | project: "levanter" 13 | tags: [ "openwebtext", "gpt2"] 14 | mp: p=f32,c=bfloat16 15 | optimizer: 16 | learning_rate: 1E-4 17 | weight_decay: 0.1 18 | min_lr_ratio: 0.1 19 | -------------------------------------------------------------------------------- /config/harness/eval_llama3.yaml: -------------------------------------------------------------------------------- 1 | eval_harness: 2 | task_spec: 3 | - task: commonsense_qa # 5-way multiple-choice questions based on common-sense, everyday scenarios 4 | num_fewshot: 10 5 | # - task: agieval_lsat_ar # 3-shot tests in legal domain 6 | # num_fewshot: 3 7 | # - task: arc_easy # 10-shot, four-way MCQ questions involving grade 3-9 basic science 8 | # num_fewshot: 10 9 | # - task: arc_challenge # a (harder) version of arc_easy 10 | # num_fewshot: 10 11 | # - task: boolq # answer yes/no questions based on a passage 12 | # num_fewshot: 10 13 | # - task: copa # use causal reasoning to predict the correct outcome of a given scenario 14 | # num_fewshot: 0 15 | # - task: hellaswag # 4-way multiple choice commonsense reasoning dataset 16 | # num_fewshot: 0 17 | # task_alias: hellaswag_0shot 18 | # - task: hellaswag # 4-way multiple choice commonsense reasoning dataset 19 | # num_fewshot: 10 20 | # task_alias: hellaswag_10shot 21 | # - task: lambada # predict the endings of text passages 22 | # num_fewshot: 0 23 | # - task: openbookqa # 4-way multiple choice question answering task that requires multi-step reasoning 24 | # num_fewshot: 0 25 | # - task: piqa # answer questions based on a passage 26 | # num_fewshot: 10 27 | # - task: wsc273 # Winograd Schema Challenge 28 | # num_fewshot: 0 29 | # - task: winogrande # Winograd challenge, extended to more domains 30 | # num_fewshot: 0 31 | # requires generation 32 | ## - task: squadv2 # reading comprehension benchmark 33 | # num_fewshot: 10 34 | max_eval_length: 4096 35 | tokenizer: meta-llama/Meta-Llama-3-8B 36 | model: 37 | type: llama 38 | #checkpoint_path: gs://marin-us-central2/checkpoints/dclm_baseline_1b_1x_replication_nov12_3404462497seed-b68241/hf/step-54930 39 | checkpoint_path: meta-llama/Meta-Llama-3-8B 40 | checkpoint_is_hf: true 41 | trainer: 42 | mp: f32 43 | profiler: true 44 | 45 | per_device_parallelism: -1 46 | train_batch_size: 512 47 | 48 | tensor_parallel_axes: ["mlp", "heads"] 49 | fsdp_axis: "embed" 50 | batch_axis: "batch" 51 | ray: 52 | auto_start_cluster: false 53 | -------------------------------------------------------------------------------- /config/harness/harness_nano.yaml: -------------------------------------------------------------------------------- 1 | eval_harness: 2 | # task_spec: ["hellaswag"] 3 | task_spec: 4 | # - mmlu 5 | - task: mmlu 6 | num_fewshot: 1 7 | task_alias: mmlu_1shot 8 | tokenizer: "gpt2" 9 | model: 10 | type: gpt2 11 | hidden_dim: 32 12 | num_heads: 4 13 | num_layers: 2 14 | trainer: 15 | mp: f32 16 | num_train_steps: 100 17 | profiler: true 18 | 19 | checkpointer: 20 | keep: 21 | - every: 50 22 | save_interval: 5m 23 | 24 | per_device_parallelism: -1 25 | train_batch_size: 32 26 | 27 | tensor_parallel_axes: ["mlp", "heads"] 28 | fsdp_axis: "embed" 29 | batch_axis: "batch" 30 | -------------------------------------------------------------------------------- /config/llama2_3b_pretrain.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/rpv1_llama.yaml 2 | model: 3 | type: llama # Llama2-3.4B 4 | seq_len: 4096 5 | hidden_dim: 4096 6 | intermediate_dim: 8640 7 | num_layers: 26 8 | num_heads: 32 9 | attn_backend: jax_flash 10 | flash_attention_block_size: 2048 11 | trainer: 12 | tracker: 13 | type: wandb 14 | project: "levanter" 15 | tags: ["redpajama", "llama"] 16 | 17 | mp: p=f32,c=bfloat16 18 | train_batch_size: 1024 19 | per_device_parallelism: 8 20 | per_device_eval_parallelism: 16 # set a larger batch size for eval 21 | num_train_steps: 250000 # 3,000,000,000,000 / 4096 / 1024 22 | steps_per_eval: 1000 23 | max_eval_batches: 50 24 | tensor_parallel_axes: ["mlp", "heads"] 25 | fsdp_axis: "embed" 26 | batch_axis: "batch" 27 | optimizer: 28 | learning_rate: 3E-4 # same as Llama2-7B 29 | weight_decay: 0.1 30 | beta1: 0.9 31 | beta2: 0.95 32 | epsilon: 1E-5 33 | warmup: 2000 34 | min_lr_ratio: 0.1 35 | -------------------------------------------------------------------------------- /config/llama2_7b.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_urls: 3 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" 4 | validation_urls: 5 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" 6 | cache_dir: "gs://levanter-data/tokenized/openwebtext_llama/" 7 | tokenizer: "meta-llama/Llama-2-70b-hf" 8 | model: 9 | activation_function: silu 10 | attn_backend: null 11 | cross_entropy_block_size: null 12 | flash_attention_block_size: null 13 | gradient_checkpointing: true 14 | hidden_dim: 4096 15 | initializer_range: 0.02 16 | intermediate_dim: 14336 17 | layer_norm_epsilon: 1.0e-05 18 | num_heads: 32 19 | num_kv_heads: 8 20 | num_layers: 32 21 | reference_checkpoint: meta-llama/Llama-2-7b-hf 22 | rope: 23 | factor: 1.0 24 | theta: 10000 25 | type: default 26 | scan_layers: true 27 | seq_len: 4096 28 | tie_word_embeddings: false 29 | type: llama 30 | upcast_attn: false 31 | use_bias: false 32 | use_flash_attention: true 33 | use_layer_norm_weight: true 34 | optimizer: 35 | beta1: 0.9 36 | beta2: 0.95 37 | cooldown: null 38 | cycle_length: 10000 39 | cycles: null 40 | decay: 0.1 41 | default_weight_decay_mask: null 42 | epsilon: 1.0e-08 43 | haps: null 44 | learning_rate: 0.001 45 | lr_schedule: inv 46 | max_grad_norm: 1.0 47 | min_lr_ratio: 0.1 48 | rewarmup: 0.0 49 | type: adam 50 | warmup: 1000 51 | weight_decay: 0.05 52 | weight_decay_modules: null 53 | trainer: 54 | axis_resources: {} 55 | batch_axis: batch 56 | checkpointer: 57 | append_run_id_to_base_path: false 58 | base_path: gs://levanter-checkpoints/checkpoints/llama-8b-tootsie-0.001-19ad63/checkpoints 59 | keep: 60 | - every: 20000 61 | save_interval: 10m 62 | quantization: null 63 | fsdp_axis: embed 64 | id: llama-8b-tootsie-0.001-19ad63 65 | initialize_from: null 66 | jax_config: 67 | jax_softmax_custom_jvp: true 68 | jax_threefry_partitionable: true 69 | load_checkpoint: null 70 | load_checkpoint_path: null 71 | log_dir: logs 72 | max_eval_batches: null 73 | model_axis_size: 1 74 | mp: compute=bfloat16,params=float32,output=bfloat16 75 | num_train_steps: 10000 76 | parameter_axis_resources: {} 77 | per_device_eval_parallelism: 2 78 | per_device_parallelism: 2 79 | profiler: false 80 | profiler_num_steps: 100 81 | profiler_perfetto_link: false 82 | profiler_start_step: 5 83 | ray: 84 | address: null 85 | auto_start_cluster: false 86 | start_workers: false 87 | # replica_dcn_axis_size: 2 88 | # replica_ici_axis_size: 1 89 | require_accelerator: true 90 | seed: 0 91 | shutdown_at_exit: false 92 | steps_per_eval: 10000 93 | tensor_parallel_axes: null 94 | tracker: 95 | entity: null 96 | group: null 97 | id: null 98 | mode: null 99 | name: null 100 | project: levanter 101 | resume: allow 102 | save_code: true 103 | save_xla_dumps: false 104 | tags: 105 | - llama-8b-test 106 | - llama 107 | - 8b 108 | - wsd-s 109 | type: wandb 110 | train_batch_size: 1024 111 | wandb: null 112 | use_hf_model_config: false 113 | -------------------------------------------------------------------------------- /config/llama2_7b_continued.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | id: EleutherAI/pile 3 | tokenizer: meta-llama/Llama-2-7b-hf 4 | model: 5 | type: llama 6 | initialize_from_hf: true 7 | use_hf_model_config: true 8 | trainer: 9 | tracker: 10 | type: wandb 11 | project: "levanter" 12 | tags: ["pile", "llama2"] 13 | 14 | mp: p=f32,c=bfloat16 15 | 16 | model_axis_size: 1 17 | per_device_eval_parallelism: 4 18 | 19 | train_batch_size: 1024 20 | num_train_steps: 10000 21 | steps_per_eval: 500 22 | optimizer: 23 | learning_rate: 1.2e-4 24 | weight_decay: 0.0 25 | -------------------------------------------------------------------------------- /config/llama2_nano.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_urls: 3 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" 4 | validation_urls: 5 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" 6 | cache_dir: "gs://levanter-data/tokenized/openwebtext_llama/" 7 | tokenizer: "meta-llama/Llama-2-70b-hf" 8 | model: 9 | type: llama 10 | hidden_dim: 32 11 | num_heads: 4 12 | num_kv_heads: 4 13 | num_layers: 2 14 | trainer: 15 | tracker: 16 | project: "levanter" 17 | tags: ["openwebtext", "llama"] 18 | mp: p=f32 19 | train_batch_size: 32 20 | num_train_steps: 100 21 | steps_per_eval: 50 22 | tensor_parallel_axes: ["mlp", "heads"] 23 | fsdp_axis: "embed" 24 | batch_axis: "batch" 25 | -------------------------------------------------------------------------------- /config/llama3_small_fast.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_urls: 3 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" 4 | validation_urls: 5 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" 6 | cache_dir: "gs://levanter-data/tokenized/openwebtext_llama3/" 7 | tokenizer: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF" 8 | model: 9 | type: llama 10 | hidden_dim: 768 11 | intermediate_dim: 2048 12 | num_heads: 12 13 | num_kv_heads: 12 14 | num_layers: 12 15 | seq_len: 1024 16 | gradient_checkpointing: true 17 | trainer: 18 | tracker: 19 | - type: wandb 20 | project: "levanter" 21 | tags: [ "openwebtext", "llama", "itest"] 22 | 23 | mp: p=f32,c=bfloat16 24 | model_axis_size: 1 25 | per_device_parallelism: -1 26 | 27 | train_batch_size: 256 28 | num_train_steps: 20000 29 | optimizer: 30 | learning_rate: 1E-3 31 | weight_decay: 0.1 32 | warmup: 0.01 33 | -------------------------------------------------------------------------------- /config/llama_1b_with_fineweb_txt.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/fineweb_llama_txt.yaml 2 | model: # 1B class model 3 | type: llama 4 | seq_len: 2048 5 | hidden_dim: 2048 6 | intermediate_dim: 8192 7 | num_layers: 16 8 | num_heads: 16 9 | num_kv_heads: 16 10 | use_flash_attention: True 11 | flash_attention_block_size: 1024 12 | trainer: 13 | tracker: 14 | type: wandb 15 | project: "marin" 16 | tags: ["fineweb", "llama"] 17 | 18 | mp: p=f32,c=bfloat16 19 | train_batch_size: 1024 20 | num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 21 | steps_per_eval: 1000 22 | tensor_parallel_axes: ["mlp", "heads"] 23 | fsdp_axis: "embed" 24 | batch_axis: "batch" 25 | optimizer: 26 | learning_rate: 4E-4 27 | weight_decay: 0.1 28 | min_lr_ratio: 0.1 29 | warmup: 5000 30 | -------------------------------------------------------------------------------- /config/llama_1b_with_olmo_config.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/dolma_olmo_paloma.yaml 2 | model: # 1B class model 3 | type: llama 4 | seq_len: 2048 5 | hidden_dim: 2048 6 | intermediate_dim: 8192 7 | num_layers: 16 8 | num_heads: 16 9 | num_kv_heads: 16 10 | use_flash_attention: True 11 | flash_attention_block_size: 1024 12 | trainer: 13 | tracker: 14 | type: wandb 15 | project: "marin" 16 | tags: ["dolma", "olmo", "llama"] 17 | 18 | mp: p=f32,c=bfloat16 19 | train_batch_size: 1024 20 | num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 21 | steps_per_eval: 1000 22 | tensor_parallel_axes: ["mlp", "heads"] 23 | fsdp_axis: "embed" 24 | batch_axis: "batch" 25 | optimizer: 26 | learning_rate: 4E-4 27 | weight_decay: 0.1 28 | min_lr_ratio: 0.1 29 | warmup: 5000 30 | -------------------------------------------------------------------------------- /config/llama_7b_tulu.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_urls: 3 | - "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-000.jsonl.gz" 4 | - "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-001.jsonl.gz" 5 | - "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-002.jsonl.gz" 6 | cache_dir: "gs://marin-us-central2/tokenized/OLMo-1B/tuluv2_sft/" 7 | tokenizer: "allenai/OLMo-1B" 8 | model: # 7B class model 9 | type: llama 10 | seq_len: 2048 11 | hidden_dim: 4096 12 | intermediate_dim: 11008 13 | num_layers: 32 14 | num_heads: 32 15 | num_kv_heads: 32 16 | use_flash_attention: True 17 | flash_attention_block_size: 512 18 | use_bias: false 19 | use_layer_norm_weight: false 20 | trainer: 21 | tracker: 22 | type: wandb 23 | project: "marin" 24 | tags: ["dolma", "olmo", "llama"] 25 | 26 | mp: p=f32,c=bfloat16 27 | train_batch_size: 256 28 | num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 29 | steps_per_eval: 1000 30 | tensor_parallel_axes: ["mlp", "heads"] 31 | fsdp_axis: "embed" 32 | batch_axis: "batch" 33 | optimizer: 34 | learning_rate: 4E-4 35 | weight_decay: 0.1 36 | min_lr_ratio: 0.1 37 | warmup: 5000 38 | 39 | epoch: 3 40 | -------------------------------------------------------------------------------- /config/llama_7b_with_dclm.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/dclm_gpt_neo.yaml 2 | model: # 7B class model 3 | type: llama 4 | seq_len: 2048 5 | hidden_dim: 4096 6 | intermediate_dim: 11008 7 | num_layers: 32 8 | num_heads: 32 9 | num_kv_heads: 32 10 | use_flash_attention: True 11 | trainer: 12 | tracker: 13 | type: wandb 14 | entity: "stanford-mercury" 15 | project: "marin" 16 | tags: ["dclm", "7B", "llama"] 17 | 18 | mp: p=f32,c=bfloat16 19 | train_batch_size: 2048 20 | num_train_steps: 480000 # 2T / 4M 21 | steps_per_eval: 1000 22 | tensor_parallel_axes: ["mlp", "heads"] 23 | fsdp_axis: "embed" 24 | batch_axis: "batch" 25 | optimizer: 26 | learning_rate: 4e-4 27 | weight_decay: 0.1 28 | min_lr_ratio: 0.1 29 | beta1: 0.9 30 | beta2: 0.95 31 | warmup: 5000 32 | 33 | z_loss_weight: 5e-6 34 | -------------------------------------------------------------------------------- /config/llama_7b_with_olmo_config.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/dolma_olmo_paloma.yaml 2 | model: # 7B class model 3 | type: llama 4 | seq_len: 2048 5 | hidden_dim: 4096 6 | intermediate_dim: 11008 7 | num_layers: 32 8 | num_heads: 32 9 | num_kv_heads: 32 10 | use_flash_attention: True 11 | flash_attention_block_size: 1024 12 | trainer: 13 | tracker: 14 | type: wandb 15 | project: "marin" 16 | tags: ["dolma", "olmo", "llama"] 17 | 18 | mp: p=f32,c=bfloat16 19 | train_batch_size: 2048 20 | num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 21 | steps_per_eval: 1000 22 | tensor_parallel_axes: ["mlp", "heads"] 23 | fsdp_axis: "embed" 24 | batch_axis: "batch" 25 | optimizer: 26 | learning_rate: 4E-4 27 | weight_decay: 0.1 28 | min_lr_ratio: 0.1 29 | warmup: 0.01 30 | -------------------------------------------------------------------------------- /config/llama_7b_with_olmo_config_euwest4.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/dolma_llama_euwest.yaml 2 | model: # 7B class model 3 | type: llama 4 | seq_len: 2048 5 | hidden_dim: 4096 6 | intermediate_dim: 11008 7 | num_layers: 32 8 | num_heads: 32 9 | num_kv_heads: 32 10 | use_flash_attention: True 11 | flash_attention_block_size: 1024 12 | trainer: 13 | tracker: 14 | type: wandb 15 | project: "marin" 16 | tags: ["dolma", "olmo", "llama"] 17 | checkpointer: 18 | keep: 19 | - every: 1 20 | until: 2 21 | - every: 5 22 | until: 30 23 | - every: 50 24 | until: 1000 25 | - every: 1000 26 | until: 40000 27 | mp: p=f32,c=bfloat16 28 | train_batch_size: 2048 29 | num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 30 | steps_per_eval: 1000 31 | tensor_parallel_axes: ["mlp", "heads"] 32 | fsdp_axis: "embed" 33 | batch_axis: "batch" 34 | optimizer: 35 | learning_rate: 4E-4 36 | weight_decay: 0.1 37 | min_lr_ratio: 0.1 38 | warmup: 0.01 39 | -------------------------------------------------------------------------------- /config/llama_7b_with_olmo_config_uswest4.yaml: -------------------------------------------------------------------------------- 1 | data: !include data/dolma_llama.yaml 2 | model: # 7B class model 3 | type: llama 4 | seq_len: 2048 5 | hidden_dim: 4096 6 | intermediate_dim: 11008 7 | num_layers: 32 8 | num_heads: 32 9 | num_kv_heads: 32 10 | use_flash_attention: True 11 | flash_attention_block_size: 1024 12 | trainer: 13 | tracker: 14 | type: wandb 15 | project: "marin" 16 | tags: ["dolma", "olmo", "llama"] 17 | checkpointer: 18 | keep: 19 | - every: 1 20 | until: 2 21 | - every: 5 22 | until: 30 23 | - every: 50 24 | until: 1000 25 | - every: 1000 26 | until: 40000 27 | 28 | mp: p=f32,c=bfloat16 29 | train_batch_size: 2048 30 | num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 31 | steps_per_eval: 1000 32 | tensor_parallel_axes: ["mlp", "heads"] 33 | fsdp_axis: "embed" 34 | batch_axis: "batch" 35 | optimizer: 36 | learning_rate: 4E-4 37 | weight_decay: 0.1 38 | min_lr_ratio: 0.1 39 | warmup: 0.01 40 | -------------------------------------------------------------------------------- /config/llama_small_fast.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_urls: 3 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" 4 | validation_urls: 5 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" 6 | cache_dir: "gs://levanter-data/tokenized/openwebtext_llama/" 7 | tokenizer: "meta-llama/Llama-2-70b-hf" 8 | model: 9 | type: llama 10 | hidden_dim: 768 11 | intermediate_dim: 2048 12 | num_heads: 12 13 | num_kv_heads: 12 14 | num_layers: 12 15 | seq_len: 1024 16 | gradient_checkpointing: true 17 | trainer: 18 | tracker: 19 | - type: wandb 20 | project: "levanter" 21 | tags: [ "openwebtext", "llama", "itest"] 22 | 23 | mp: p=f32,c=bfloat16 24 | model_axis_size: 1 25 | per_device_parallelism: -1 26 | 27 | train_batch_size: 256 28 | num_train_steps: 20000 29 | optimizer: 30 | learning_rate: 1E-3 31 | weight_decay: 0.1 32 | warmup: 0.01 33 | -------------------------------------------------------------------------------- /config/llama_small_fast_remat.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_urls: 3 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" 4 | validation_urls: 5 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" 6 | cache_dir: "gs://levanter-data/tokenized/openwebtext_llama/" 7 | tokenizer: "meta-llama/Llama-2-70b-hf" 8 | model: 9 | type: llama 10 | hidden_dim: 768 11 | intermediate_dim: 2048 12 | num_heads: 12 13 | num_kv_heads: 12 14 | num_layers: 12 15 | seq_len: 1024 16 | gradient_checkpointing: "nested" 17 | trainer: 18 | tracker: 19 | - type: wandb 20 | project: "levanter" 21 | tags: [ "openwebtext", "llama", "itest"] 22 | 23 | mp: p=f32,c=bfloat16 24 | model_axis_size: 1 25 | per_device_parallelism: -1 26 | 27 | train_batch_size: 256 28 | num_train_steps: 20000 29 | optimizer: 30 | learning_rate: 1E-3 31 | weight_decay: 0.1 32 | warmup: 0.01 33 | -------------------------------------------------------------------------------- /config/lora/mpt_biomed.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_urls: 3 | - "gs://pubmed-mosaic/pubmed-sharded/pubmedRandomized_train.{1..128}-of-128.jsonl.gz" 4 | validation_urls: 5 | - "gs://pubmed-mosaic/pubmed-sharded/pubmedRandomized_val.{1..8}-of-8.jsonl.gz" 6 | cache_dir: "gs://pubmed-mosaic/tokenized/pubmed-sharded-neox/" 7 | tokenizer: "EleutherAI/gpt-neox-20b" 8 | initialize_from_hf: "mosaicml/mpt-7b@68e1a8e0ebb9b30f3c45c1ef6195980f29063ae2" 9 | lora: 10 | r: 32 11 | alpha: 32.0 12 | target_modules: ["Wqkv"] 13 | trainer: 14 | tracker: 15 | type: wandb 16 | project: "levanter" 17 | tags: ["mpt", "lora", "pubmed"] 18 | 19 | mp: p=f32,c=bfloat16 20 | 21 | model_axis_size: 1 22 | per_device_parallelism: 4 23 | per_device_eval_parallelism: 4 24 | 25 | train_batch_size: 1024 26 | num_train_steps: 1000 27 | steps_per_eval: 50 28 | optimizer: 29 | learning_rate: 1.2e-3 30 | weight_decay: 0.1 31 | -------------------------------------------------------------------------------- /config/lora_llama2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | # you should set a data.id or train_urls and validation_urls 3 | # id: math-ai/AutoMathText 4 | # You may also want to set cache_dir if using more than one machine 5 | # cache_dir: 6 | tokenizer: "meta-llama/Llama-2-70b-hf" 7 | initialize_from_hf: "meta-llama/Llama-2-7b-hf" 8 | peft_save_path: "lora_llama2" 9 | max_train_length: 2048 # train on sequences of this length to reduce memory usage 10 | trainer: 11 | mp: p=f32,c=bfloat16 12 | wandb: 13 | project: "levanter-lora" 14 | tags: ["lora", "llama2"] 15 | num_train_steps: 5000 # tune to suit your needs 16 | train_batch_size: -1 # set to -1 so effective bs is per_device * num_devices (no grad accum) 17 | per_device_parallelism: 4 # set for a 40GB device, but can go up a lot if you have multiple devices or a bigger device 18 | 19 | # if using model parallelism, this is useful: 20 | tensor_parallel_axes: ["mlp", "heads"] 21 | optimizer: 22 | learning_rate: 3e-4 23 | weight_decay: 0.1 24 | -------------------------------------------------------------------------------- /config/mistral_7b.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_urls: 3 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" 4 | validation_urls: 5 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" 6 | cache_dir: "gs://levanter-data/tokenized/openwebtext_llama/" 7 | tokenizer: "mistralai/Mistral-7B-v0.1" 8 | model: 9 | type: mistral 10 | # TODO: uncomment this once we resolve the resource exhaustion issue 11 | # initialize_from_hf: "mistralai/Mistral-7B-v0.1" 12 | # use_hf_model_config: true 13 | trainer: 14 | wandb: 15 | project: "levanter" 16 | tags: ["openwebtext", "mistral"] 17 | 18 | mp: p=f32,c=bfloat16 19 | train_batch_size: 256 # set for v4-64 TPU 20 | num_train_steps: 1000 21 | steps_per_eval: 50 22 | tensor_parallel_axes: ["mlp", "heads"] 23 | fsdp_axis: "embed" 24 | batch_axis: "batch" 25 | optimizer: 26 | learning_rate: 1.2E-5 # set low for fine-tuning 27 | weight_decay: 0.1 28 | min_lr_ratio: 0.1 29 | -------------------------------------------------------------------------------- /config/mixtral_8x7b.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_urls: 3 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" 4 | validation_urls: 5 | - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" 6 | cache_dir: "gs://levanter-data/tokenized/openwebtext_llama/" 7 | tokenizer: "mistralai/Mixtral-8x7B-v0.1" 8 | model: 9 | type: mixtral 10 | # TODO: uncomment this once we resolve the resource exhaustion issue 11 | # initialize_from_hf: "mistralai/Mistral-7B-v0.1" 12 | # use_hf_model_config: true 13 | trainer: 14 | tracker: 15 | - type: wandb 16 | project: "levanter" 17 | tags: [ "openwebtext", "llama", "itest"] 18 | 19 | mp: p=f32,c=bfloat16 20 | train_batch_size: 256 # set for v4-64 TPU 21 | num_train_steps: 1000 22 | steps_per_eval: 50 23 | tensor_parallel_axes: ["mlp", "heads"] 24 | fsdp_axis: "embed" 25 | batch_axis: "batch" 26 | axis_resources: 27 | token: "data" 28 | optimizer: 29 | learning_rate: 1.2E-5 # set low for fine-tuning 30 | weight_decay: 0.1 31 | min_lr_ratio: 0.1 32 | -------------------------------------------------------------------------------- /config/olmo2_sft.yaml: -------------------------------------------------------------------------------- 1 | # Olmo2 SFT Configuration 2 | 3 | dataset_type: chat_jsonl 4 | 5 | # Config for supervised datasets 6 | supervised_data: 7 | tulu: 8 | cache_dir: "gs://marin-us-central2/tokenized/tulu_sft_v3_olmo2tokenizer-8cb4bb" 9 | train_urls: 10 | - "gs://marin-us-central2/documents/allenai--tulu-3-sft-mixture-55e9fd6-27c6a7/**/*.jsonl.gz" 11 | 12 | # Set all weight to the SFT dataset 13 | mixture_weights: 14 | tulu: 1 15 | 16 | mixture_block_size: 2048 17 | stop_strategy: restart 18 | 19 | max_seq_len: 4096 20 | tokenizer: "allenai/OLMo-2-1124-7B-SFT" 21 | model: # Olmo2 7B model config 22 | type: olmo2 23 | seq_len: 4096 24 | hidden_dim: 4096 25 | intermediate_dim: 11008 26 | num_layers: 32 27 | num_heads: 32 28 | num_kv_heads: 32 29 | use_flash_attention: True 30 | flash_attention_block_size: 512 31 | use_bias: false 32 | use_layer_norm_weight: true 33 | initializer_range: 0.02 34 | layer_norm_epsilon: 1e-6 35 | activation_function: "silu" 36 | attention_bias: false 37 | upcast_attn: True 38 | rope: 39 | type: "default" 40 | theta: 500000 41 | 42 | trainer: 43 | seed: 0 44 | tracker: 45 | type: wandb 46 | project: "marin" 47 | tags: ["dolma", "olmo", "mixture"] 48 | wandb: 49 | project: "marin" 50 | name: "olmo2_sft_1e-5" 51 | 52 | mp: p=f32,c=bfloat16 53 | train_batch_size: 128 54 | num_train_steps: 3000 55 | steps_per_eval: 500 56 | tensor_parallel_axes: ["mlp", "heads"] 57 | fsdp_axis: "embed" 58 | batch_axis: "batch" 59 | checkpointer: 60 | base_path: "gs://marin-us-central2/checkpoints/olmo2_sft/seed_0/" 61 | 62 | optimizer: 63 | learning_rate: 1e-5 64 | weight_decay: 0.0 # Keep at 0.0, OLMo 2 doesn't use weight decay on embeddings 65 | min_lr_ratio: 0.0 66 | lr_schedule: "linear" 67 | warmup: 0.03 68 | 69 | 70 | hf_save_steps: 1000 71 | hf_save_path: "gs://marin-us-central2/checkpoints/olmo2_sft/hf/seed_0/" 72 | 73 | initialize_from_hf: True 74 | model_name_or_path: "allenai/OLMo-2-1124-7B" 75 | 76 | messages_field: "messages" 77 | input_role: "user" 78 | output_role: "assistant" 79 | -------------------------------------------------------------------------------- /config/optim/sophia-h_large.yaml: -------------------------------------------------------------------------------- 1 | type: sophia-h 2 | learning_rate: 3E-4 3 | weight_decay: 0.2 4 | min_lr_ratio: 0.1 5 | gamma: 0.01 6 | # sophia needs a minimum amount of warmup or it doesn't do well 7 | warmup: 2000 8 | -------------------------------------------------------------------------------- /config/optim/sophia-h_medium.yaml: -------------------------------------------------------------------------------- 1 | type: sophia-h 2 | learning_rate: 4E-4 3 | weight_decay: 0.2 4 | min_lr_ratio: 0.1 5 | gamma: 0.01 6 | # sophia needs a minimum amount of warmup or it doesn't do well 7 | warmup: 2000 8 | -------------------------------------------------------------------------------- /config/optim/sophia-h_small.yaml: -------------------------------------------------------------------------------- 1 | type: sophia-h 2 | learning_rate: 6E-4 3 | weight_decay: 0.2 4 | min_lr_ratio: 0.1 5 | gamma: 0.01 6 | # sophia needs a minimum amount of warmup or it doesn't do well 7 | warmup: 2000 8 | -------------------------------------------------------------------------------- /config/optim/sophia-h_xl.yaml: -------------------------------------------------------------------------------- 1 | type: sophia-h 2 | learning_rate: 1.2E-4 3 | weight_decay: 0.2 4 | min_lr_ratio: 0.1 5 | gamma: 0.01 6 | # sophia needs a minimum amount of warmup or it doesn't do well 7 | warmup: 2000 8 | -------------------------------------------------------------------------------- /config/sft_hf_llama3_ckpt.yaml: -------------------------------------------------------------------------------- 1 | # Model configuration 2 | model: 3 | type: llama 4 | seq_len: 4096 5 | hidden_dim: 4096 6 | intermediate_dim: 14336 7 | num_layers: 32 8 | num_heads: 32 9 | num_kv_heads: 8 10 | use_flash_attention: true 11 | flash_attention_block_size: 512 12 | use_bias: false 13 | use_layer_norm_weight: true 14 | initializer_range: 0.02 15 | rope: 16 | type: "llama3" 17 | 18 | # need to set this! 19 | tokenizer: "meta-llama/Meta-Llama-3.1-8B" 20 | -------------------------------------------------------------------------------- /config/sft_llama3.1_tulu3.yaml: -------------------------------------------------------------------------------- 1 | dataset_type: chat_jsonl 2 | chat_train_urls: 3 | - "gs://marin-us-central2/documents/allenai--tulu-v2-sft-mixture-0ba27c/data/**/*.jsonl.gz" 4 | supervised_data: 5 | # cache_dir before trying sequence packing 6 | cache_dir: "gs://marin-us-central2/tokenized/tulu_3_sft_mixture_llama3_instruct_tokenizer-8f8ba3" 7 | 8 | max_seq_len: 4096 9 | tokenizer: "meta-llama/Meta-Llama-3.1-8B" 10 | model: # 8B llama3 class model 11 | type: llama 12 | seq_len: 4096 13 | hidden_dim: 4096 14 | intermediate_dim: 14336 15 | num_layers: 32 16 | num_heads: 32 17 | num_kv_heads: 8 18 | use_flash_attention: True 19 | flash_attention_block_size: 512 20 | use_bias: false 21 | use_layer_norm_weight: true 22 | initializer_range: 0.02 23 | rope: 24 | type: "llama3" 25 | 26 | trainer: 27 | seed: 0 28 | tracker: 29 | type: wandb 30 | project: "marin" 31 | tags: ["dolma", "olmo", "llama"] 32 | wandb: 33 | project: "marin" 34 | name: "llama3.1_tulu3_sft" 35 | 36 | mp: p=f32,c=bfloat16 37 | # same as 606 sft in marin 38 | train_batch_size: 128 39 | # number of steps until we hit stop iteration 40 | num_train_steps: 1845 41 | steps_per_eval: 1000 42 | tensor_parallel_axes: ["mlp", "heads"] 43 | fsdp_axis: "embed" 44 | batch_axis: "batch" 45 | checkpointer: 46 | base_path: "gs://marin-us-central2/checkpoints/llama_3.1_tulusft/seed_0/" 47 | optimizer: 48 | learning_rate: 5e-6 49 | weight_decay: 0.0 50 | min_lr_ratio: 0.0 51 | lr_schedule: "linear" 52 | warmup: 0.03 53 | 54 | hf_save_steps: 500 55 | hf_save_path: "gs://marin-us-central2/checkpoints/llama_3.1_tulusft/hf/seed_0/" 56 | 57 | initialize_from_hf: True 58 | model_name_or_path: "meta-llama/Llama-3.1-8B" 59 | epoch: 0 60 | -------------------------------------------------------------------------------- /config/sft_llama3_openthoughts.yaml: -------------------------------------------------------------------------------- 1 | dataset_type: chat_jsonl 2 | chat_train_urls: 3 | - "gs://marin-us-central2/documents/open-thoughts--OpenThoughts-114k-216e29/data/**/*.jsonl.gz" 4 | supervised_data: 5 | # cache_dir before trying sequence packing 6 | cache_dir: "gs://marin-us-central2/tokenized/openthoughts_llama3_tokenizer-9edd80" 7 | #cache_dir: "gs://marin-us-central2/tokenized/tulu_sft_v3_llama3_tokenizer_retrypack-bca8bd/" 8 | 9 | max_seq_len: 4096 10 | tokenizer: "meta-llama/Meta-Llama-3.1-8B" 11 | model: # 8B llama3 class model 12 | type: llama 13 | seq_len: 4096 14 | hidden_dim: 4096 15 | intermediate_dim: 14336 16 | num_layers: 32 17 | num_heads: 32 18 | num_kv_heads: 8 19 | use_flash_attention: True 20 | flash_attention_block_size: 512 21 | use_bias: false 22 | use_layer_norm_weight: true 23 | initializer_range: 0.02 24 | rope: 25 | type: "llama3" 26 | trainer: 27 | seed: 1 28 | tracker: 29 | type: wandb 30 | project: "marin" 31 | tags: ["dolma", "olmo", "llama"] 32 | 33 | mp: p=f32,c=bfloat16 34 | # same as 606 sft in marin 35 | train_batch_size: 128 36 | # number of steps until we hit stop iteration 37 | num_train_steps: 802 38 | steps_per_eval: 1000 39 | tensor_parallel_axes: ["mlp", "heads"] 40 | fsdp_axis: "embed" 41 | batch_axis: "batch" 42 | optimizer: 43 | learning_rate: 5e-6 44 | weight_decay: 0.0 45 | min_lr_ratio: 0.0 46 | lr_schedule: "linear" 47 | warmup: 0.03 48 | 49 | 50 | hf_save_steps: 801 51 | hf_save_path: "gs://levanter-checkpoints/marin/tulusft_openthoughtsft/" 52 | 53 | epoch: 0 54 | -------------------------------------------------------------------------------- /config/train_lm_llama3_tulu_sft.yaml: -------------------------------------------------------------------------------- 1 | # Config adapted from sft_tootsie_mixture.yaml for use with train_lm.py 2 | # Trains only on the tulu dataset using the chat format. 3 | 4 | data: 5 | # Using LMMixtureDatasetConfig structure like gpt2_small_fast_mix_chat.yaml 6 | configs: 7 | tulu: 8 | id: allenai/tulu-3-sft-mixture 9 | format: 10 | type: "chat" 11 | train_weights: 12 | tulu: 1.0 # Weight for the single dataset 13 | tokenizer: stanford-crfm/marin-tokenizer 14 | cache_dir: "gs://marin-us-central2/tokenized/marin-tokenizer/tulu-3-sft-mixture" 15 | shuffle: true 16 | # permutation_type: "feistel" # Removed due to draccus parsing error for Literal type 17 | # cache_dir: # Can optionally specify a top-level cache dir for the mixture if needed 18 | 19 | model: # 8B llama3 class model from sft_tootsie_mixture.yaml 20 | type: llama 21 | seq_len: 4096 22 | hidden_dim: 4096 23 | intermediate_dim: 14336 24 | num_layers: 32 25 | num_heads: 32 26 | num_kv_heads: 8 27 | use_flash_attention: True 28 | flash_attention_block_size: 512 29 | use_bias: false 30 | use_layer_norm_weight: true 31 | initializer_range: 0.02 32 | rope: 33 | type: "llama3" 34 | 35 | trainer: 36 | seed: 0 37 | tracker: 38 | type: wandb 39 | project: "marin" 40 | tags: ["dolma", "olmo", "llama", "tulu", "train_lm"] # Adjusted tags 41 | wandb: 42 | project: "marin" 43 | name: "llama3_tulu_sft_seed0_shuffle_fixed_tokenizer" # Adjusted name 44 | 45 | mp: p=f32,c=bfloat16 46 | train_batch_size: 128 47 | num_train_steps: 3834 48 | steps_per_eval: 1000 # Note: No eval dataset specified, so this might not do much unless one is added 49 | tensor_parallel_axes: ["mlp", "heads"] 50 | fsdp_axis: "embed" 51 | batch_axis: "batch" 52 | checkpointer: 53 | base_path: "gs://marin-us-central2/checkpoints/llama3_tulu_sft_fixed_tokenizer/seed_0/" # Adjusted path 54 | 55 | optimizer: 56 | learning_rate: 5e-6 57 | weight_decay: 0.0 58 | min_lr_ratio: 0.0 59 | lr_schedule: "linear" 60 | warmup: 0.03 61 | 62 | # Initialization from the specific HF checkpoint used in sft_tootsie_mixture.yaml 63 | initialize_from_hf: "meta-llama/Llama-3.1-8B" #"gs://marin-us-central2/checkpoints/tootsie-8b-hypnotic-spoonbill-2/hf/step-829999/" 64 | use_hf_model_config: False # Use the model config defined above 65 | 66 | # HF Saving config from sft_tootsie_mixture.yaml 67 | hf_save_steps: 1000 68 | hf_save_path: "gs://marin-us-central2/checkpoints/llama3_tulu_sft_fixed_tokenizer/hf/seed_0/" # Adjusted path 69 | 70 | # Defaults or settings not applicable/present in sft_tootsie_mixture.yaml for train_lm: 71 | # z_loss_weight: 0.0 72 | # epoch: 0 73 | # data_seed: None 74 | # eval_harness: None 75 | # eval_harness_steps: 10000 76 | # log_entropy: False 77 | # reinit_tokens: Not supported by train_lm.py 78 | -------------------------------------------------------------------------------- /config/whisper_tiny_librispeech.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | cache_dir: "gs://diva-flash/processed/mixture" 3 | # The Whisper Tokenizer is way too large for Librispeech 4 | tokenizer: "facebook/wav2vec2-base-960h" 5 | configs: 6 | librispeech: 7 | id: WillHeld/librispeech_parquet 8 | cache_dir: "gs://diva-flash/processed/librispeech" 9 | train_split: "train.360" 10 | validation_split: "validation" 11 | train_weights: 12 | librispeech: 1.0 13 | model: 14 | type: whisper 15 | vocab_size: 32 16 | trainer: 17 | tracker: 18 | - type: wandb 19 | project: "levanter" 20 | tags: [ "librispeech", "whisper"] 21 | 22 | mp: p=f32,c=bf16 23 | model_axis_size: 1 24 | per_device_parallelism: -1 25 | 26 | train_batch_size: 128 27 | num_train_steps: 16000 28 | optimizer: 29 | learning_rate: 3E-3 30 | weight_decay: 0.1 31 | warmup: 0.01 32 | hf_save_steps: 16000 33 | -------------------------------------------------------------------------------- /docker/nvidia/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ghcr.io/nvidia/jax:nightly-2023-10-25-linux-amd64 as base 2 | LABEL authors="dlwh" 3 | 4 | # copy everything from the repo into the container 5 | COPY . /opt/levanter 6 | 7 | ARG GIT_USER_EMAIL 8 | ARG GIT_USER_NAME 9 | RUN <=1.20 \ 38 | psutil \ 39 | # Required a recent version of setuptools to be compatible with python 3.12+. 40 | setuptools==71.1.0 \ 41 | "google-api-python-client==1.7.8" \ 42 | "google-oauth" 43 | 44 | 45 | # Install gcloud so we can get secrets (maybe we should just curl?) 46 | RUN curl https://dl.google.com/dl/cloudsdk/release/google-cloud-sdk.tar.gz > /tmp/google-cloud-sdk.tar.gz 47 | 48 | RUN mkdir -p /usr/local/gcloud \ 49 | && tar -C /usr/local/gcloud -xvf /tmp/google-cloud-sdk.tar.gz \ 50 | && /usr/local/gcloud/google-cloud-sdk/install.sh \ 51 | && rm -f /tmp/google-cloud-sdk.tar.gz 52 | 53 | # Adding the package path to local 54 | ENV PATH=$PATH:/usr/local/gcloud/google-cloud-sdk/bin 55 | 56 | # GCP doesn't like it when root ssh's into a machine 57 | RUN useradd -m -s /bin/bash levanter 58 | RUN echo "levanter ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers 59 | RUN usermod -aG docker levanter 60 | RUN mkdir -p $HOME && touch $HOME/.bashrc && chown -R levanter $HOME 61 | RUN echo "export PATH=$PATH" >> $HOME/.bashrc 62 | RUN adduser levanter docker 63 | 64 | RUN chown -R levanter /opt/levanter 65 | 66 | USER levanter 67 | 68 | # HACK until https://github.com/ray-project/ray/issues/47769 is resolved 69 | RUN pip install 'ray[default,gcp]==2.34.0' 70 | RUN git clone https://github.com/dlwh/ray.git ~/ray --branch tpu_docker_2.34 --depth 1 71 | RUN cp ~/ray/python/ray/autoscaler/_private/gcp/tpu_command_runner.py /opt/levanter/.venv/lib/python3.10/site-packages/ray/autoscaler/_private/gcp/tpu_command_runner.py 72 | 73 | 74 | WORKDIR /opt/levanter 75 | -------------------------------------------------------------------------------- /docker/tpu/Dockerfile.incremental: -------------------------------------------------------------------------------- 1 | ARG IMAGE=ghcr.io/stanford-crfm/levanter-base 2 | ARG TAG=latest 3 | 4 | FROM ${IMAGE}:${TAG} 5 | 6 | # This usually is a config directory so users can have their own config directory outside the repo. 7 | ARG EXTRA_CTX=/config 8 | 9 | ENV TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS=60\ 10 | TENSORSTORE_CURL_LOW_SPEED_LIMIT_BYTES=1024\ 11 | RAY_USAGE_STATS_ENABLED=0\ 12 | PATH=/opt/levanter/.venv/bin:$PATH\ 13 | PYTHONPATH=/opt/levanter:/opt/levanter/src:/opt/levanter/examples:/opt/levanter/tests\ 14 | HOME=/home/levanter 15 | 16 | WORKDIR /opt/levanter 17 | 18 | ADD pyproject.toml README.md /opt/levanter/ 19 | RUN mkdir -p /opt/levanter/src/levanter 20 | RUN pip install -e '.[test]' 21 | RUN pip install "lm-eval@git+https://github.com/dlwh/lm-evaluation-harness.git@no_torch" 22 | RUN pip install "draccus@git+https://github.com/dlwh/draccus.git" 23 | ADD . /opt/levanter 24 | 25 | # Add $EXTRA_CTX to the same location as in local machine. 26 | # it's already in the image, so we don't need to copy it. just move it if we set EXTRA_CTX 27 | RUN if [ -f ".mnt" ] || [ -d ".mnt" ]; then mkdir -p $(dirname $EXTRA_CTX) && mv .mnt $EXTRA_CTX; fi 28 | -------------------------------------------------------------------------------- /docs/css/mkdocstrings.css: -------------------------------------------------------------------------------- 1 | /* Indentation. */ 2 | div.doc-contents:not(.first) { 3 | padding-left: 25px; 4 | border-left: .05rem solid var(--md-typeset-table-color); 5 | } 6 | 7 | /* Mark external links as such. */ 8 | a.external::after, 9 | a.autorefs-external::after { 10 | /* https://primer.style/octicons/arrow-up-right-24 */ 11 | mask-image: url('data:image/svg+xml,'); 12 | content: ' '; 13 | 14 | display: inline-block; 15 | vertical-align: middle; 16 | position: relative; 17 | 18 | height: 1em; 19 | width: 1em; 20 | background-color: var(--md-typeset-a-color); 21 | } 22 | 23 | a.external:hover::after, 24 | a.autorefs-external:hover::after { 25 | background-color: var(--md-accent-fg-color); 26 | } 27 | -------------------------------------------------------------------------------- /docs/faq.md: -------------------------------------------------------------------------------- 1 | # Frequently Asked Questions 2 | 3 | ## Project 4 | 5 | ### Why is it called Levanter? 6 | 7 | Levanter is a wind that blows from the east in the Mediterranean. Stanford CRFM's first training project was 8 | called [Mistral](https://github.com/stanford-crfm/mistral), which is another Mediterranean wind. (That Mistral 9 | has no relation to the now more famous [Mistral AI](https://www.mistral.ai/). They took our name!) 10 | 11 | 12 | ## Installation Issues 13 | 14 | ### CUDA: `XLA requires ptxas version 11.8 or higher` 15 | 16 | `jaxlib.xla_extension.XlaRuntimeError: INTERNAL: XLA requires ptxas version 11.8 or higher` 17 | 18 | This error occurs when your local CUDA installation is too old. When you follow the 19 | [GPU installation instructions](Getting-Started-GPU.md), you install a version of CUDA in a pip environment. 20 | If you have another version of CUDA installed on your machine, it may be interfering with the pip environment. 21 | The usual solution for this is to either upgrade your local CUDA installation or hide it from your PATH. Usually this works: 22 | 23 | ```bash 24 | export PATH=$(echo $PATH | sed 's|:/usr/local/cuda/bin||') 25 | ``` 26 | 27 | You should add that to your `.bashrc` or `.zshrc` or whatever shell you use and restart your shell. 28 | 29 | 30 | ## Nuisance Warnings 31 | 32 | ### Transformers: `None of PyTorch, TensorFlow >= 2.0, or Flax have been found.` 33 | 34 | `None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.` 35 | 36 | Hugging Face's Transformers library is a dependency of Levanter, but we only use it for the tokenizers and a few utilities. 37 | If you don't plan on using the Hugging Face models, you can safely ignore this warning. If the warning bothers you, 38 | you can set the environment variable `export TRANSFORMERS_VERBOSITY=error` to silence it. 39 | 40 | ## Cloud Issues 41 | 42 | ### Permission 'storage.objects.get' denied on resource 43 | 44 | ``` 45 | gcsfs.retry.HttpError: Anonymous caller does not have storage.objects.get access to the Google Cloud Storage object. 46 | Permission 'storage.objects.get' denied on resource (or it may not exist)., 401 47 | ``` 48 | 49 | If you're using Google Cloud Storage on a non-TPU machine, you might get errors like this. The solution is to log 50 | into your Google Cloud account on the machine: 51 | 52 | ```bash 53 | gcloud auth login 54 | gcloud auth application-default login 55 | ``` 56 | 57 | ## Ray Issues 58 | 59 | ### RuntimeError: Failed to start ray head with exit code 256 60 | 61 | Probably ray is still running and Levanter didn't clean up the ray cluster (or another user is using the same port). 62 | If the former, you can kill the ray cluster with `ray stop`. If the latter, there's not much you can do about it. 63 | [Ray doesn't work super well when multiple users are running Ray on the same machine.](https://github.com/ray-project/ray/issues/20634) 64 | Try docker? 65 | 66 | Another reason could be the ports are not open in your VM. If using GCP, check the firewall settings of your VPC and expose port `61964` (used by ray). 67 | -------------------------------------------------------------------------------- /docs/figures/bitwise_repro_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/docs/figures/bitwise_repro_curve.png -------------------------------------------------------------------------------- /docs/figures/data_parallel_mesh.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/docs/figures/data_parallel_mesh.png -------------------------------------------------------------------------------- /docs/figures/data_parallel_mesh_replicated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/docs/figures/data_parallel_mesh_replicated.png -------------------------------------------------------------------------------- /docs/figures/device_mesh_1d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/docs/figures/device_mesh_1d.png -------------------------------------------------------------------------------- /docs/figures/device_mesh_1d_zero.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/docs/figures/device_mesh_1d_zero.png -------------------------------------------------------------------------------- /docs/figures/device_mesh_2d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/docs/figures/device_mesh_2d.png -------------------------------------------------------------------------------- /docs/figures/device_mesh_2d_batch_partitioned.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/docs/figures/device_mesh_2d_batch_partitioned.png -------------------------------------------------------------------------------- /docs/figures/device_mesh_2d_data_replicated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/docs/figures/device_mesh_2d_data_replicated.png -------------------------------------------------------------------------------- /docs/figures/device_mesh_2d_data_replicated_mlp_partitioned.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/docs/figures/device_mesh_2d_data_replicated_mlp_partitioned.png -------------------------------------------------------------------------------- /docs/figures/device_mesh_2d_intermediate_fully_partitioned.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/docs/figures/device_mesh_2d_intermediate_fully_partitioned.png -------------------------------------------------------------------------------- /docs/figures/device_mesh_2d_zero.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/docs/figures/device_mesh_2d_zero.png -------------------------------------------------------------------------------- /docs/figures/finetune_func_cm_full_weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/docs/figures/finetune_func_cm_full_weight.png -------------------------------------------------------------------------------- /docs/figures/finetune_func_cm_lora.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/docs/figures/finetune_func_cm_lora.png -------------------------------------------------------------------------------- /docs/figures/helm-gsm8k-results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/docs/figures/helm-gsm8k-results.png -------------------------------------------------------------------------------- /docs/figures/helm-instance-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/docs/figures/helm-instance-example.png -------------------------------------------------------------------------------- /docs/figures/lora-diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/docs/figures/lora-diagram.png -------------------------------------------------------------------------------- /docs/figures/palm_mfu_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/docs/figures/palm_mfu_table.png -------------------------------------------------------------------------------- /docs/figures/resumed_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/docs/figures/resumed_curve.png -------------------------------------------------------------------------------- /docs/figures/stopped_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/docs/figures/stopped_curve.png -------------------------------------------------------------------------------- /docs/figures/token_probabilities.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/docs/figures/token_probabilities.mov -------------------------------------------------------------------------------- /docs/guides/Direct-Cache-Construction.md: -------------------------------------------------------------------------------- 1 | # Direct Cache Construction 2 | 3 | (See also [Training on Your Own Data](../Training-On-Your-Data.md) for more details on training on your own data.) 4 | 5 | Levanter typically handles cache construction automatically, but if you have custom preprocessing logic or Ray isn't 6 | working for you for some reason, you can directly construct a cache of preprocessed data. 7 | 8 | You can directly construct a cache of preprocessed data without using Ray. To do so, you can use [levanter.store.SerialCacheWriter](https://github.com/stanford-crfm/levanter/blob/main/src/levanter/store/cache.py) 9 | to write batches directly. Here's an example: 10 | 11 | ```python 12 | import numpy as np 13 | 14 | from levanter.store import SerialCacheWriter 15 | 16 | exemplar = { 17 | "input_ids": np.zeros((0), dtype=np.int32), 18 | } 19 | 20 | def process_batches(): 21 | for i in range(0, 1000): 22 | yield [{"input_ids": np.array([i]) for _ in range(1000)}] 23 | 24 | cache_dir = "gs://path/to/cache" 25 | 26 | with SerialCacheWriter(cache_dir, exemplar) as writer: 27 | for batch in process_batches(): 28 | # batch should be a list of dicts, each with keys "input_ids", "attention_mask", and "labels" 29 | writer.write_batch(batch) 30 | ``` 31 | 32 | In this case, `batch` should be a list of dicts, each with keys `"input_ids"`, `"attention_mask"`, and `"labels"`. 33 | To work with `train_lm`'s `text` format, it should have an `input_ids` key that is a list of `int`s. 34 | See the [Data Formats Reference](../reference/Data-Formats.md) for more details of other formats. 35 | 36 | ## Passthrough Tokenizers 37 | 38 | Oftentimes, if you're using direct cache construction, you'll want to use a passthrough tokenizer. For instance, 39 | in our music work, tokens were actually parts of a custom formatting of MIDI files and there was no actual tokenizer. 40 | 41 | To use a cache like this, you can use the `passthrough` tokenizer: 42 | 43 | ```yaml 44 | data: 45 | cache_dir: "gs://path/to/cache" 46 | tokenizer: "passthrough" 47 | vocab_size: 5567 48 | ``` 49 | 50 | The passthrough tokenizer is a special tokenizer that just passes through the input ids without any processing. 51 | Basically, you just need to tell Levanter what the vocab size is. 52 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | {% 2 | include-markdown "../README.md" 3 | start="" 4 | end="" 5 | %} 6 | 7 | The code is released on GitHub: [Levanter repository](https://github.com/stanford-crfm/levanter/). 8 | 9 | To get started, please refer to the User Guide's chapters: 10 | 11 | - [Getting Started on GPU](Getting-Started-GPU.md) 12 | - [Getting Started on TPU](Getting-Started-TPU-VM.md) 13 | - [Getting Started Training](Getting-Started-Training.md) 14 | 15 | Please also see the guides in the menu on the left. 16 | 17 | To contribute, please refer to the [Contributing Guide](https://github.com/stanford-crfm/levanter/blob/main/CONTRIBUTING.md). 18 | -------------------------------------------------------------------------------- /docs/javascripts/mathjax.js: -------------------------------------------------------------------------------- 1 | window.MathJax = { 2 | tex: { 3 | inlineMath: [["\\(", "\\)"]], 4 | displayMath: [["\\[", "\\]"]], 5 | processEscapes: true, 6 | processEnvironments: true 7 | }, 8 | options: { 9 | ignoreHtmlClass: ".*|", 10 | processHtmlClass: "arithmatex" 11 | } 12 | }; 13 | 14 | document$.subscribe(() => { 15 | MathJax.typesetPromise() 16 | }) 17 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs 2 | mkdocstrings 3 | mkdocstrings-python 4 | mkdocs-material 5 | mkdocs-material-extensions 6 | mkdocs-autorefs 7 | mkdocs-include-markdown-plugin 8 | mkdocs-literate-nav 9 | mkdocs-macros-plugin 10 | haliax 11 | -------------------------------------------------------------------------------- /examples/alpaca-lora/alpaca-lora-llama2.yaml: -------------------------------------------------------------------------------- 1 | # cf https://github.com/tatsu-lab/stanford_alpaca#fine-tuning 2 | model_name_or_path: meta-llama/Llama-2-7b-hf 3 | trainer: 4 | mp: p=f32,c=bfloat16 5 | wandb: 6 | project: "levanter-alpaca" 7 | tags: ["lora", "llama2"] 8 | num_train_steps: 1218 # 128 * 1218 = 155904, which is almost but not quite 3 epochs, which is what alpaca did 9 | train_batch_size: 128 10 | 11 | # if using model parallelism, this is useful: 12 | tensor_parallel_axes: ["mlp", "heads"] 13 | optimizer: 14 | learning_rate: 3e-4 15 | weight_decay: 0.0 16 | -------------------------------------------------------------------------------- /examples/alpaca-lora/alpaca-lora.yaml: -------------------------------------------------------------------------------- 1 | # cf https://github.com/tatsu-lab/stanford_alpaca#fine-tuning 2 | model_name_or_path: huggyllama/llama-7b 3 | trainer: 4 | mp: p=f32,c=bfloat16 5 | wandb: 6 | project: "levanter-alpaca" 7 | tags: ["lora", "llama1"] 8 | num_train_steps: 1218 # 128 * 1218 = 155904, which is almost but not quite 3 epochs, which is what alpaca did 9 | train_batch_size: 128 10 | 11 | # if using model parallelism, this is useful: 12 | tensor_parallel_axes: ["mlp", "heads"] 13 | optimizer: 14 | learning_rate: 3e-4 15 | weight_decay: 0.0 16 | -------------------------------------------------------------------------------- /examples/alpaca-lora/code-alpaca-lora.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: huggyllama/llama-7b 2 | data: lucasmccabe-lmi/CodeAlpaca-20k 3 | data_cache_dir: code_alpaca_cache 4 | prompts: 5 | prompt_input: |- 6 | ### Instruction: {instruction} 7 | ### Input: {input} 8 | ### Output: 9 | prompt_no_input: |- 10 | ### Instruction: {instruction} 11 | ### Output: 12 | trainer: 13 | mp: p=f32,c=bfloat16 14 | wandb: 15 | project: "levanter-alpaca" 16 | tags: ["code", "lora", "llama1"] 17 | num_train_steps: 500 # 128 * 500 = 64000, which is a bit more than 3 epochs 18 | train_batch_size: 128 19 | 20 | # if using model parallelism, this is useful: 21 | tensor_parallel_axes: ["mlp", "heads"] 22 | optimizer: 23 | learning_rate: 3e-4 24 | weight_decay: 0.0 25 | -------------------------------------------------------------------------------- /examples/alpaca-lora/hf_lora_inference.py: -------------------------------------------------------------------------------- 1 | # Simple example script to use peft to do inference with an HF model. 2 | import os 3 | import sys 4 | 5 | import torch 6 | from peft import PeftConfig, PeftModel 7 | from transformers import AutoModelForCausalLM, AutoTokenizer 8 | 9 | 10 | # from ..alpaca.alpaca import DEFAULT_PROMPT_DICT 11 | SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) 12 | sys.path.append(os.path.join(SCRIPT_DIR, "..", "alpaca")) 13 | import alpaca # noqa: E402 14 | 15 | 16 | PROMPT = alpaca.DEFAULT_PROMPT_DICT["prompt_no_input"] 17 | 18 | 19 | def format_prompt(**kwargs): 20 | return PROMPT.format(**kwargs) 21 | 22 | 23 | def main(peft_model_id): 24 | config = PeftConfig.from_pretrained(peft_model_id) 25 | model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, device_map="auto") 26 | model = PeftModel.from_pretrained(model, peft_model_id, device_map="auto") 27 | 28 | model.eval() 29 | 30 | tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) 31 | 32 | if tokenizer.pad_token is None: 33 | tokenizer.pad_token = tokenizer.unk_token 34 | 35 | # if on macos, set cpu b/c mps doesn't work yet w/ int64.cumsum 36 | if model.device.type == "mps": 37 | model.cpu() 38 | 39 | while True: 40 | msg = input("> ") 41 | msg = format_prompt(instruction=msg) 42 | inputs = tokenizer(msg, return_tensors="pt") 43 | print("... ", end="") 44 | with torch.no_grad(): 45 | outputs = model.generate(**inputs, max_new_tokens=200, do_sample=True, min_new_tokens=1) 46 | print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)) 47 | 48 | 49 | if __name__ == "__main__": 50 | if len(sys.argv) > 2: 51 | print("Usage: python hf_lora_inference.py [model_name_or_path]") 52 | sys.exit(1) 53 | 54 | model_name_or_path = sys.argv[1] if len(sys.argv) == 2 else "dlwh/levanter-lora-test" 55 | 56 | main(model_name_or_path) 57 | -------------------------------------------------------------------------------- /examples/alpaca/alpaca-llama2.yaml: -------------------------------------------------------------------------------- 1 | # cf https://github.com/tatsu-lab/stanford_alpaca#fine-tuning 2 | model_name_or_path: meta-llama/Llama-2-7b-hf 3 | trainer: 4 | mp: p=f32,c=bfloat16 5 | wandb: 6 | project: "levanter-alpaca" 7 | tags: ["llama2"] 8 | num_train_steps: 1218 # 128 * 1218 = 155904, which is almost but not quite 3 epochs, which is what alpaca did 9 | train_batch_size: 128 10 | 11 | # if using model parallelism, this is useful: 12 | tensor_parallel_axes: ["mlp", "heads"] 13 | optimizer: 14 | learning_rate: 2e-5 15 | weight_decay: 0.0 16 | prompts: 17 | # |- means multiline string, keeping all but the final newline 18 | prompt_input: |- 19 | Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 20 | 21 | ### Instruction: 22 | {instruction} 23 | 24 | ### Input: 25 | {input} 26 | 27 | ### Response: 28 | prompt_no_input: |- 29 | Below is an instruction that describes a task. Write a response that appropriately completes the request. 30 | 31 | ### Instruction: 32 | {instruction} 33 | 34 | ### Response: 35 | -------------------------------------------------------------------------------- /examples/alpaca/alpaca.yaml: -------------------------------------------------------------------------------- 1 | # cf https://github.com/tatsu-lab/stanford_alpaca#fine-tuning 2 | model_name_or_path: huggyllama/llama-7b 3 | trainer: 4 | mp: p=f32,c=bfloat16 5 | wandb: 6 | project: "levanter-alpaca" 7 | tags: ["llama1"] 8 | num_train_steps: 1218 # 128 * 1218 = 155904, which is almost but not quite 3 epochs, which is what alpaca did 9 | train_batch_size: 128 10 | 11 | # if using model parallelism, this is useful: 12 | tensor_parallel_axes: ["mlp", "heads"] 13 | optimizer: 14 | learning_rate: 2e-5 15 | weight_decay: 0.0 16 | prompts: 17 | # |- means multiline string, keeping all but the final newline 18 | prompt_input: |- 19 | Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 20 | 21 | ### Instruction: 22 | {instruction} 23 | 24 | ### Input: 25 | {input} 26 | 27 | ### Response: 28 | prompt_no_input: |- 29 | Below is an instruction that describes a task. Write a response that appropriately completes the request. 30 | 31 | ### Instruction: 32 | {instruction} 33 | 34 | ### Response: 35 | -------------------------------------------------------------------------------- /examples/gsm8k-lora/gsm8k-llama2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: llama 3 | data: gsm8k 4 | data_cache_dir: gsm8k_cache 5 | trainer: 6 | mp: p=f32,c=bfloat16 7 | wandb: 8 | project: "levanter-gsm8k" 9 | tags: ["gsm8k", "lora", "llama2"] 10 | num_train_steps: 550 # 64 * 550 = 35200, which is a bit more than 4 epochs 11 | train_batch_size: 64 12 | 13 | # if using model parallelism, this is useful: 14 | tensor_parallel_axes: ["mlp", "heads"] 15 | optimizer: 16 | # values in qlora paper 17 | learning_rate: 2e-4 18 | weight_decay: 0.0 19 | lr_schedule: "constant" 20 | lora: 21 | # These are the defaults, but just so you can see them 22 | r: 8 # rank of LoRA transform 23 | alpha: 8.0 # scaling factor for LoRA transform 24 | dropout: 0.0 # dropout probability for LoRA layers 25 | -------------------------------------------------------------------------------- /examples/sft/alpaca-llama-sft.yaml: -------------------------------------------------------------------------------- 1 | # Model configuration 2 | model: 3 | type: llama 4 | seq_len: 2048 5 | hidden_dim: 4096 6 | intermediate_dim: 11008 7 | num_layers: 32 8 | num_heads: 32 9 | num_kv_heads: 32 10 | use_flash_attention: true 11 | flash_attention_block_size: 512 12 | use_bias: false 13 | use_layer_norm_weight: false 14 | 15 | # Training configuration 16 | trainer: 17 | mp: p=f32,c=bfloat16 18 | tracker: 19 | type: wandb 20 | project: "levanter-sft" 21 | tags: ["llama", "sft"] 22 | num_train_steps: 750000 23 | train_batch_size: 64 24 | tensor_parallel_axes: ["mlp", "heads"] 25 | fsdp_axis: "embed" 26 | batch_axis: "batch" 27 | steps_per_eval: 1000 28 | 29 | # Optimizer settings 30 | optimizer: 31 | learning_rate: 2e-5 32 | weight_decay: 0.0 33 | min_lr_ratio: 0.1 34 | warmup: 100 35 | 36 | # Supervised data configuration 37 | supervised_data: 38 | cache_dir: "gs://levanter-checkpoints/marin/sft_cache/alpaca-olmo" 39 | input_field: "instruction" 40 | output_field: "output" 41 | hf_dataset_name: "tatsu-lab/alpaca" # Changed from id 42 | hf_dataset_split: "train" 43 | name: "alpaca" # Optional metadata 44 | tags: ["instruction-tuning"] # Optional metadata 45 | validation_urls: [] # Empty list for no validation files 46 | 47 | # Additional settings 48 | tokenizer: "allenai/OLMo-1B" 49 | max_tune_length: 2048 50 | epoch: 0 51 | 52 | initialize_from_hf: false 53 | -------------------------------------------------------------------------------- /examples/sft/alpaca-llama.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: meta-llama/Llama-2-7b-hf 2 | 3 | # Training configuration 4 | trainer: 5 | mp: p=f32,c=bfloat16 6 | wandb: 7 | project: "levanter-sft" 8 | tags: ["llama2", "alpaca"] 9 | num_train_steps: 1218 10 | train_batch_size: 64 11 | # If using model parallelism 12 | tensor_parallel_axes: ["mlp", "heads"] 13 | 14 | # Optimizer settings 15 | optimizer: 16 | learning_rate: 2e-5 17 | weight_decay: 0.0 18 | 19 | supervised_data: 20 | hf_dataset_name: "tatsu-lab/alpaca" 21 | hf_dataset_split: "train" 22 | input_field: "instruction" # change from prompt 23 | output_field: "output" # this is correct 24 | cache_dir: "gs://levanter-checkpoints/marin/sft_cache/alpaca-new" 25 | 26 | max_tune_length: 2048 27 | trust_remote_code: false 28 | model_cache_dir: null 29 | 30 | hf_save_path: "sft_hf_ckpts" 31 | hf_upload: false 32 | hf_save_steps: 1000 33 | -------------------------------------------------------------------------------- /examples/sft/dolly-llama.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: meta-llama/Llama-2-7b-hf 2 | 3 | # Training configuration 4 | trainer: 5 | mp: p=f32,c=bfloat16 6 | wandb: 7 | project: "levanter-sft" 8 | tags: ["llama2", "oasst"] 9 | num_train_steps: 1218 10 | train_batch_size: 128 11 | # If using model parallelism 12 | tensor_parallel_axes: ["mlp", "heads"] 13 | 14 | # Optimizer settings 15 | optimizer: 16 | learning_rate: 2e-5 17 | weight_decay: 0.0 18 | 19 | supervised_data: 20 | hf_dataset_name: "databricks/databricks-dolly-15k" 21 | hf_dataset_split: "train" 22 | input_field: "instruction" # change from prompt 23 | output_field: "response" # this is correct 24 | cache_dir: "cache/dolly" 25 | 26 | max_tune_length: 2048 27 | trust_remote_code: false 28 | model_cache_dir: null 29 | 30 | hf_save_path: "sft_hf_ckpts" 31 | hf_upload: false 32 | hf_save_steps: 1000 33 | -------------------------------------------------------------------------------- /examples/sft/oasst-llama.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: meta-llama/Llama-2-7b-hf 2 | 3 | # Training configuration 4 | trainer: 5 | mp: p=f32,c=bfloat16 6 | wandb: 7 | project: "levanter-sft" 8 | tags: ["llama2", "oasst"] 9 | num_train_steps: 1218 10 | train_batch_size: 128 11 | 12 | # If using model parallelism 13 | tensor_parallel_axes: ["mlp", "heads"] 14 | 15 | # Optimizer settings 16 | optimizer: 17 | learning_rate: 2e-5 18 | weight_decay: 0.0 19 | 20 | # Supervised data configuration 21 | supervised_data: 22 | # For HF dataset 23 | id: "databricks/databricks-dolly-15k" 24 | input_field: "instruction" # adjust based on dataset 25 | output_field: "response" # adjust based on dataset 26 | cache_dir: "cache/dolly" 27 | 28 | # Model configuration 29 | max_tune_length: 2048 30 | trust_remote_code: false 31 | model_cache_dir: null 32 | 33 | # Checkpoint saving configuration 34 | hf_save_path: "sft_hf_ckpts" 35 | hf_upload: false 36 | hf_save_steps: 1000 37 | 38 | # python examples/sft/sft.py --config_path examples/sft/oasst-llama2.yaml 39 | -------------------------------------------------------------------------------- /examples/sft/tulu-llama-sft.yaml: -------------------------------------------------------------------------------- 1 | # Model configuration 2 | model: 3 | type: llama 4 | seq_len: 2048 5 | hidden_dim: 4096 6 | intermediate_dim: 11008 7 | num_layers: 32 8 | num_heads: 32 9 | num_kv_heads: 32 10 | use_flash_attention: true 11 | flash_attention_block_size: 512 12 | use_bias: false 13 | use_layer_norm_weight: false 14 | 15 | # Training configuration 16 | trainer: 17 | mp: p=f32,c=bfloat16 18 | tracker: 19 | type: wandb 20 | project: "levanter-sft" 21 | tags: ["llama", "sft"] 22 | num_train_steps: 750000 23 | train_batch_size: 64 24 | tensor_parallel_axes: ["mlp", "heads"] 25 | fsdp_axis: "embed" 26 | batch_axis: "batch" 27 | steps_per_eval: 1000 28 | 29 | # Optimizer settings 30 | optimizer: 31 | learning_rate: 2e-5 32 | weight_decay: 0.0 33 | min_lr_ratio: 0.1 34 | warmup: 100 35 | 36 | # Supervised data configuration 37 | dataset_type: chat_jsonl 38 | chat_train_urls: 39 | - "gs://marin-us-central2/documents/allenai--tulu-v2-sft-mixture-0ba27c/data/**/*.jsonl.gz" 40 | supervised_data: 41 | cache_dir: "gs://levanter-checkpoints/marin/sft_cache/chat-data" 42 | messages_field: "messages" 43 | input_role: "user" 44 | output_role: "assistant" 45 | 46 | # Additional settings 47 | tokenizer: "EleutherAI/gpt-neox-20b" 48 | max_tune_length: 2048 49 | epoch: 0 50 | 51 | initialize_from_hf: false 52 | -------------------------------------------------------------------------------- /infra/babysit-tpu-vm: -------------------------------------------------------------------------------- 1 | babysit-tpu-vm.sh -------------------------------------------------------------------------------- /infra/cluster/push_cluster_docker.sh: -------------------------------------------------------------------------------- 1 | python infra/push_docker.py --docker_file docker/tpu/Dockerfile.cluster --image levanter-cluster --tag latest $* 2 | -------------------------------------------------------------------------------- /infra/helpers/gen-id.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # on osx tr has trouble reading /dev/urandom 4 | TR=$(which gtr) 5 | 6 | if [[ $? -ne 0 ]]; then 7 | TR=tr 8 | fi 9 | 10 | cat /dev/urandom | $TR -C -d a-z0-9 | head -c 8 11 | -------------------------------------------------------------------------------- /infra/launch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is used for launching on TPU pods (or other direct run environments) via remote ssh with a virtual env 3 | set -e 4 | umask 000 5 | LEV_ROOT=$(dirname "$(readlink -f $0)")/.. 6 | 7 | # figure out venv, first check if we wrote a path in infra/venv_path 8 | if [ ! -d "$VENV" ] && [ -f "$LEV_ROOT/infra/venv_path.txt" ]; then 9 | VENV=$(cat "$LEV_ROOT"/infra/venv_path.txt) 10 | fi 11 | 12 | # if we still don't have a venv, we'll look in our default 13 | if [ ! -d "$VENV" ]; then 14 | VENV=/files/venv32 15 | fi 16 | 17 | if [ ! -d "$VENV" ]; then 18 | VENV=~/files/venv310 19 | fi 20 | 21 | source $VENV/bin/activate 22 | 23 | PYTHONPATH=${LEV_ROOT}:${LEV_ROOT}/src:${LEV_ROOT}/examples:$PYTHONPATH nohup "$@" >& "~/log-$(hostname).log" & 24 | -------------------------------------------------------------------------------- /infra/push_docker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | """ 4 | Build and deploy the Levanter base image to Artifact Registry or Docker Hub. 5 | 6 | It is not necessary to run this yourself unless you are deploying a new base image: the launch 7 | script will automatically build and deploy an image based on your current code. 8 | """ 9 | import argparse 10 | 11 | from levanter.infra import cli_helpers as cli 12 | from levanter.infra import docker 13 | from levanter.infra.docker import build_docker, push_to_gcp, push_to_github 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser(description="Build and push Docker image to Artifact Registry.") 18 | config = cli.load_config() 19 | cli.add_arg(parser, config, ["--project"], help="GCP project ID") 20 | cli.add_arg(parser, config, ["--region"], help="Artifact Registry region (e.g., us-west4)") 21 | cli.add_arg(parser, config, ["--repository"], default="levanter", help="Artifact Registry repository name") 22 | cli.add_arg(parser, config, ["--image"], default="levanter", help="Docker image name.") 23 | cli.add_arg(parser, config, ["--tag"], default="latest", help="Docker image tag.") 24 | cli.add_arg(parser, config, ["--github_user"], default=None, help="Github user name.") 25 | cli.add_arg(parser, config, ["--github_token"], default=None, help="Github token.") 26 | cli.add_arg(parser, config, ["--docker_file"], default="docker/tpu/Dockerfile.base", help="Dockerfile to use.") 27 | cli.add_arg(parser, config, ["--extra_context"], required=False, default=None) 28 | 29 | # push to either github or GCP 30 | cli.add_arg(parser, config, ["--docker_target"], choices=["github", "gcp", "ghcr"], required=True) 31 | 32 | args = parser.parse_args() 33 | 34 | with docker.copy_extra_ctx(args.extra_context) as extra_ctx: 35 | build_args = {"EXTRA_CTX": extra_ctx} if extra_ctx else None 36 | local_id = build_docker(docker_file=args.docker_file, image_name=args.image, tag=args.tag) 37 | 38 | if args.docker_target in ["github", "ghcr"]: 39 | assert args.github_user, "Must specify --github_user when pushing to Github" 40 | assert args.github_token, "Must specify --github_token when pushing to Github" 41 | push_to_github(local_id=local_id, github_user=args.github_user, github_token=args.github_token) 42 | else: 43 | assert args.region, "Must specify --region when pushing to GCP" 44 | assert args.project, "Must specify --project when pushing to GCP" 45 | assert args.repository, "Must specify --repository when pushing to GCP" 46 | 47 | push_to_gcp(local_id, args.project, args.region, args.repository) 48 | -------------------------------------------------------------------------------- /infra/run-slurm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is used for launching on a slurm cluster or other queueing systems 3 | # This script assumes your virtual env or conda is already activated 4 | set -e 5 | LEV_ROOT=$(dirname "$(readlink -f $0)")/.. 6 | 7 | PYTHONPATH=${LEV_ROOT}:${LEV_ROOT}/src:${LEV_ROOT}/examples:$PYTHONPATH "$@" > >(tee -a stdout-$(hostname).log) 2> >(tee -a stderr-$(hostname).log >&2) 8 | -------------------------------------------------------------------------------- /infra/run.sh: -------------------------------------------------------------------------------- 1 | umask 000 2 | LEV_ROOT=$(dirname "$(readlink -f $0)")/.. 3 | ulimit -s 65536 4 | 5 | # figure out venv, first check if we wrote a path in infra/venv_path 6 | if [ ! -d "$VENV" ] && [ -f "$LEV_ROOT/infra/venv_path.txt" ]; then 7 | VENV=$(cat "$LEV_ROOT"/infra/venv_path.txt) 8 | fi 9 | 10 | # if we still don't have a venv, we'll look in our default 11 | if [ ! -d "$VENV" ]; then 12 | VENV=/files/venv32 13 | fi 14 | 15 | source $VENV/bin/activate 16 | 17 | 18 | PYTHONPATH=${LEV_ROOT}:${LEV_ROOT}/src:${LEV_ROOT}/examples:$PYTHONPATH "$@" 19 | -------------------------------------------------------------------------------- /infra/spin-up-vm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" 4 | 5 | . "${SCRIPT_DIR}"/helpers/parse-tpu-creation-args.sh "$@" 6 | 7 | # error out if we didn't set a name 8 | if [ -z "$VM_NAME" ]; then 9 | echo "Error: VM name not set" 10 | exit 1 11 | fi 12 | 13 | # Determine the correct gcloud command based on USE_ALPHA flag 14 | GCLOUD_CMD="gcloud compute" 15 | if [ "$USE_ALPHA" = true ]; then 16 | GCLOUD_CMD="gcloud alpha compute" 17 | fi 18 | 19 | # first delete if we're supposed to 20 | if [ "$AUTODELETE" = "true" ]; then 21 | # check if it's there 22 | $GCLOUD_CMD tpus tpu-vm describe --zone $ZONE $VM_NAME &> /dev/null 23 | if [ $? -eq 0 ]; then 24 | echo "Deleting existing VM $VM_NAME" 25 | $GCLOUD_CMD tpus tpu-vm delete --zone $ZONE $VM_NAME 26 | fi 27 | fi 28 | 29 | # if ssh-agent isn't running, complain 30 | if [ -z "$SSH_AUTH_SOCK" ]; then 31 | echo "Error: ssh-agent not running" 32 | exit 1 33 | fi 34 | 35 | # create the vm 36 | # spin loop until we get a good error code 37 | echo "Creating VM $VM_NAME" 38 | # create the command. note that --preemptible doesn't accept a value, so just append it if we want it 39 | CMD="$GCLOUD_CMD tpus tpu-vm create $VM_NAME \ 40 | --zone=$ZONE \ 41 | --accelerator-type=$TYPE \ 42 | --version=$VM_IMAGE" 43 | 44 | # if network isn't 'default', set it 45 | if [ "$NETWORK" != "default" ]; then 46 | CMD="$CMD --network=$NETWORK" 47 | fi 48 | 49 | if [ "$PREEMPTIBLE" = true ]; then 50 | CMD="$CMD --preemptible" 51 | fi 52 | echo "Running command: $CMD" 53 | 54 | while ! $CMD; do 55 | echo "Error creating VM, retrying in 5 seconds" 56 | sleep 5 57 | done 58 | 59 | echo "Giving the VM a few seconds to start up" 60 | sleep 15 61 | 62 | echo "Adding ssh keys just in case..." 63 | echo ssh-add ~/.ssh/google_compute_engine 64 | ssh-add ~/.ssh/google_compute_engine 65 | 66 | # upload the setup script 67 | SETUP_SCRIPT_NAME=$(basename $SETUP_SCRIPT) 68 | # note that gcloud scp doesn't always work... so we do it a few times to just be sure 69 | for i in {1..5}; do 70 | echo "Uploading $SETUP_SCRIPT to VM $VM_NAME" 71 | $GCLOUD_CMD tpus tpu-vm scp --zone=$ZONE $SETUP_SCRIPT $VM_NAME:~/ --worker=all 72 | # check to see if the file exists on all nodes 73 | if $GCLOUD_CMD tpus tpu-vm ssh --zone=$ZONE $VM_NAME --command="ls ~/$SETUP_SCRIPT_NAME" --worker=all; then 74 | break 75 | fi 76 | if [ 5 -eq $i ]; then 77 | echo "Error uploading ${SETUP_SCRIPT_NAME}, giving up. Note that the machine is still (probably) running" 78 | exit 1 79 | fi 80 | echo "Error uploading ${SETUP_SCRIPT_NAME}, retrying in 5 seconds" 81 | sleep 5 82 | done 83 | 84 | # run the setup script 85 | for i in {1..5}; do 86 | $GCLOUD_CMD tpus tpu-vm ssh --zone=$ZONE $VM_NAME --command="bash ~/$SETUP_SCRIPT_NAME --branch ${GIT_BRANCH} --repo ${GIT_REPO} > setup.out" --worker=all 87 | if [ $? -eq 0 ]; then 88 | break 89 | fi 90 | if [ 5 -eq $i ]; then 91 | echo "Error running ${SETUP_SCRIPT_NAME}, giving up. Note that the machine is still (probably) running" 92 | exit 1 93 | fi 94 | echo "Error running ${SETUP_SCRIPT_NAME}, retrying in 5 seconds" 95 | sleep 5 96 | done 97 | 98 | # print out the IP addresses 99 | echo "VM $VM_NAME IP addresses:" 100 | $GCLOUD_CMD tpus tpu-vm describe --zone $ZONE $VM_NAME | awk '/externalIp: (.*)/ {print $2}' 101 | -------------------------------------------------------------------------------- /scripts/launch_gpt2_small_fast_gpu.sh: -------------------------------------------------------------------------------- 1 | # 4 gpus, 3090s 2 | # TODO: maybe move to the a100s or a6000s? 3 | srun --account=nlp --cpus-per-task=2 --gres=gpu:3090:4 --job-name=dlwh-job-1681253 --mem=16G --open-mode=append --partition=jag-standard --time=14-0 \ 4 | bash infra/run-slurm.sh python src/levanter/main/train_lm.py \ 5 | --config_path config/gpt2_small_fast.yaml \ 6 | --trainer.checkpointer.save_interval 30m \ 7 | --trainer.per_device_parallelism -1 $* 8 | -------------------------------------------------------------------------------- /scripts/launch_gpt2_small_fast_supervised_tpu.sh: -------------------------------------------------------------------------------- 1 | # Launches the "gpt_small_fast" model on a TPU node 2 | 3 | python infra/launch.py --foreground --tpu_name $(whoami)-levanter-itest-32 --zone us-central2-b --tpu_type v4-32 --preemptible -- \ 4 | python -m levanter.main.train_lm \ 5 | --config_path config/gpt2_small_fast_supervised.yaml \ 6 | --trainer.checkpointer.base_path gs://levanter-checkpoints/gpt-itest/ --trainer.checkpointer.save_interval 30m $* 7 | -------------------------------------------------------------------------------- /scripts/launch_gpt2_small_fast_tpu.sh: -------------------------------------------------------------------------------- 1 | # Launches the "gpt_small_fast" model on a TPU node 2 | 3 | python infra/launch.py --foreground --tpu_name $(whoami)-levanter-itest-32 --zone us-central2-b --tpu_type v4-32 --preemptible -- \ 4 | python -m levanter.main.train_lm \ 5 | --config_path config/gpt2_small_fast.yaml \ 6 | --trainer.checkpointer.base_path gs://levanter-checkpoints/gpt-itest/ --trainer.checkpointer.save_interval 30m $* 7 | -------------------------------------------------------------------------------- /scripts/launch_gpt2_small_itest_tpu.sh: -------------------------------------------------------------------------------- 1 | # Launches the "gpt_small_fast" model on a TPU node 2 | 3 | python infra/launch.py --foreground --tpu_name $(whoami)-levanter-itest-32 --zone us-central2-b --tpu_type v4-32 --preemptible -- \ 4 | python -m levanter.main.train_lm \ 5 | --config_path config/gpt2_small_itest.yaml \ 6 | --trainer.checkpointer.base_path gs://levanter-checkpoints/gpt-itest/ --trainer.checkpointer.save_interval 30m $* 7 | -------------------------------------------------------------------------------- /scripts/loss_history.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | import wandb 4 | 5 | 6 | DEFAULT_METRIC = "train/loss" 7 | 8 | 9 | def fetch_metric_for_sha(sha, metric_name=DEFAULT_METRIC, steps=range(0, 101, 10)): 10 | """ 11 | Fetches the metrics for a given git sha using the wandb API. 12 | """ 13 | # Setting up the wandb API 14 | api = wandb.Api() 15 | 16 | # Search for the run tagged with this git sha 17 | runs = api.runs("levanter", {"config.git_commit": sha}) 18 | 19 | metric_values = {} 20 | 21 | if runs: 22 | # Get the history of metrics for this run 23 | metrics = runs[0].scan_history(keys=["_step", metric_name], min_step=min(steps), max_step=max(steps)) 24 | 25 | for row in metrics: 26 | step = row["_step"] 27 | loss = row[metric_name] 28 | if step not in steps: 29 | continue 30 | metric_values[int(step)] = float(loss) 31 | 32 | return metric_values 33 | 34 | 35 | def visualize_git_tree_with_metric(metric_name=DEFAULT_METRIC): 36 | """ 37 | Fetches the git log, and associates each commit with the metric. 38 | """ 39 | # Getting the git log with shas 40 | result = subprocess.run(["git", "log", "--pretty=format:%H %s"], stdout=subprocess.PIPE) 41 | log_lines = result.stdout.decode("utf-8").strip().split("\n") 42 | 43 | for line in log_lines: 44 | sha, message = line.split(" ", 1) 45 | metric_values = fetch_metric_for_sha(sha, metric_name) 46 | 47 | metrics_str = ", ".join(f"{step}: {value}" for step, value in metric_values.items()) 48 | print(f"{sha} - {message} -> {metrics_str}") 49 | 50 | 51 | if __name__ == "__main__": 52 | visualize_git_tree_with_metric() 53 | -------------------------------------------------------------------------------- /scripts/preproc/split-pile-shards.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | from pathlib import Path 5 | 6 | import fsspec 7 | import tqdm 8 | 9 | 10 | OUT_PATH = "gs://levanter-data/pile-domains" 11 | 12 | categories_to_out_names = { 13 | "ArXiv": "arxiv", 14 | "BookCorpus2": "books2", 15 | "Books3": "books3", 16 | "DM Mathematics": "dm_math", 17 | "Enron Emails": "enron", 18 | "EuroParl": "europarl", 19 | "FreeLaw": "freelaw", 20 | "Github": "github", 21 | "Gutenberg (PG-19)": "pg_19", 22 | "HackerNews": "hackernews", 23 | "NIH ExPorter": "nih", 24 | "OpenSubtitles": "opensubtitles", 25 | "OpenWebText2": "owt2", 26 | "PhilPapers": "philpapers", 27 | "Pile-CC": "pile_cc", 28 | "PubMed Abstracts": "pubmed_abs", 29 | "PubMed Central": "pubmed_central", 30 | "StackExchange": "stack_exchange", 31 | "USPTO Backgrounds": "uspto", 32 | "Ubuntu IRC": "ubuntu_irc", 33 | "Wikipedia (en)": "wiki_en", 34 | "YoutubeSubtitles": "youtube_subtitles", 35 | } 36 | 37 | 38 | def format_category(category): 39 | return categories_to_out_names[category] 40 | 41 | 42 | def process_file(input_file_path): 43 | base_file = Path(input_file_path).stem 44 | compressors = {} 45 | 46 | with fsspec.open(input_file_path, "r", compression="infer") as text_stream: 47 | for line in tqdm.tqdm(text_stream): 48 | if not line.strip(): 49 | continue # Skip empty lines 50 | 51 | # Decode line to string and load as JSON 52 | data = json.loads(line) 53 | category = data["meta"]["pile_set_name"] 54 | category = format_category(category) 55 | output_file_path = os.path.join(OUT_PATH, category, f"{base_file}.zst") 56 | 57 | # Check if compressor exists for this category, if not create it 58 | if category not in compressors: 59 | # output_file = open(output_file_path, 'wb') 60 | output_file = fsspec.open(str(output_file_path), "wb", compression="infer").open() 61 | print("opened", output_file_path) 62 | compressors[category] = output_file 63 | 64 | # Write to the compressor 65 | compressors[category].write(line.encode("utf-8")) 66 | compressors[category].flush() 67 | 68 | # Close all open compressors 69 | for compressor in compressors.values(): 70 | compressor.close() 71 | 72 | 73 | if __name__ == "__main__": 74 | for path in sys.argv[1:]: 75 | process_file(path) 76 | -------------------------------------------------------------------------------- /src/levanter/__init__.py: -------------------------------------------------------------------------------- 1 | import levanter.analysis as analysis 2 | import levanter.callbacks as callbacks 3 | import levanter.checkpoint as checkpoint 4 | import levanter.config as config 5 | import levanter.data as data 6 | import levanter.distributed as distributed 7 | import levanter.eval as eval 8 | import levanter.eval_harness as eval_harness 9 | import levanter.models as models 10 | import levanter.optim as optim 11 | import levanter.tracker as tracker 12 | import levanter.trainer as trainer 13 | import levanter.visualization as visualization 14 | from levanter.tracker import current_tracker 15 | from levanter.trainer import initialize 16 | 17 | 18 | __version__ = "1.2" 19 | -------------------------------------------------------------------------------- /src/levanter/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | from .entropy import cb_compute_entropies, cb_compute_top2_gap, compute_entropy_histogram, compute_top2_gap_histogram 2 | from .tree_stats import summary_statistics_for_tree 3 | from .visualization import cb_compute_and_visualize_log_probs, visualize_log_prob_diff, visualize_log_probs 4 | -------------------------------------------------------------------------------- /src/levanter/compat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/src/levanter/compat/__init__.py -------------------------------------------------------------------------------- /src/levanter/data/__init__.py: -------------------------------------------------------------------------------- 1 | from ._preprocessor import BatchProcessor 2 | from .dataset import AsyncDataset, ListAsyncDataset, MappedAsyncDataset, SyncDataset 3 | from .loader import DataLoader 4 | from .mixture import MixtureDataset, StopStrategy 5 | from .permutation import EraShufflingDataset, PermutationDataset 6 | from .sharded_datasource import ShardedDataSource, datasource_from_hf, datasource_from_json, datasource_from_jsonl 7 | from .utils import batched 8 | 9 | 10 | __all__ = [ 11 | "AsyncDataset", 12 | "BatchProcessor", 13 | "DataLoader", 14 | "ListAsyncDataset", 15 | "MappedAsyncDataset", 16 | "MixtureDataset", 17 | "ShardedDataSource", 18 | "StopStrategy", 19 | "SyncDataset", 20 | "batched", 21 | "datasource_from_hf", 22 | "datasource_from_json", 23 | "datasource_from_jsonl", 24 | ] 25 | -------------------------------------------------------------------------------- /src/levanter/data/metrics_monitor.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import logging as pylogging 3 | import time 4 | from dataclasses import dataclass 5 | from typing import Any, Dict, Optional, Protocol, Union 6 | 7 | import jax 8 | from dataclasses_json import dataclass_json 9 | 10 | import levanter.tracker 11 | 12 | 13 | # TODO: should we just make the ledger have all this? 14 | @dataclass_json 15 | @dataclass 16 | class InProgressCacheMetrics: 17 | rows_finished: int = 0 18 | shards_finished: int = 0 19 | field_counts: Dict[str, int] = dataclasses.field(default_factory=dict) 20 | is_finished: bool = False 21 | 22 | 23 | class MetricsMonitor(Protocol): 24 | def __call__(self, metrics: InProgressCacheMetrics): 25 | ... 26 | 27 | 28 | class LoggingMetricsMonitor(MetricsMonitor): 29 | last_metrics: Optional[InProgressCacheMetrics] 30 | last_time: Optional[float] 31 | 32 | def __init__(self, prefix: str = "preproc", commit=False): 33 | """ 34 | :param prefix: 35 | :param commit: Forwarded to wandb.log. Use False (default) if it's part of a simultaneous training run, 36 | and True if you're running standalone. 37 | """ 38 | self.prefix = prefix 39 | self.commit = commit 40 | self.last_metrics = None 41 | self.last_time = None 42 | 43 | def __call__(self, metrics: InProgressCacheMetrics): 44 | to_log: Dict[str, Any] = {} 45 | 46 | to_log[f"{self.prefix}/shards"] = metrics.shards_finished 47 | to_log[f"{self.prefix}/rows"] = metrics.rows_finished 48 | 49 | for field, count in metrics.field_counts.items(): 50 | to_log[f"{self.prefix}/{field}"] = count 51 | 52 | if metrics.is_finished: 53 | to_log[f"{self.prefix}/finished"] = 1 54 | 55 | self.last_metrics = metrics 56 | self.last_time = time.time() 57 | 58 | levanter.tracker.log(to_log, step=None, commit=self.commit) 59 | 60 | 61 | class LoggerMetricsMonitor(MetricsMonitor): 62 | # TODO: I'd like to get the trainer pbar migrated to rich and just use rich everywhere, but until then, 63 | # we have separate logging 64 | def __init__( 65 | self, 66 | logger: Optional[Union[pylogging.Logger, str]] = None, 67 | level=pylogging.INFO, 68 | log_interval: float | int = 30.0, 69 | ): 70 | if isinstance(logger, str): 71 | logger = pylogging.getLogger(logger) 72 | self.logger = logger or pylogging.getLogger(__name__) 73 | self.level = level 74 | self.log_interval = log_interval 75 | self._last_log_time = time.time() 76 | 77 | def __call__(self, metrics: InProgressCacheMetrics): 78 | if jax.process_index() == 0: 79 | if time.time() - self._last_log_time > self.log_interval: 80 | self._last_log_time = time.time() 81 | 82 | self.logger.log( 83 | self.level, 84 | f" done: Shards: {metrics.shards_finished} | Docs: {metrics.rows_finished}", 85 | ) 86 | 87 | if metrics.is_finished: 88 | self.logger.info("Cache creation finished") 89 | -------------------------------------------------------------------------------- /src/levanter/data/passthrough_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import numpy as np 4 | from transformers import PreTrainedTokenizer 5 | 6 | 7 | class PassthroughTokenizer(PreTrainedTokenizer): 8 | """ 9 | A tokenizer that takes plain-text integers, parses them as integers, and returns them as-is. 10 | """ 11 | 12 | def __init__(self, vocab_size, **kwargs): 13 | self._vocab = {i: i for i in range(vocab_size)} 14 | self._vocab_size = vocab_size 15 | super().__init__(**kwargs) 16 | 17 | @property 18 | def vocab_size(self) -> int: 19 | return self._vocab_size 20 | 21 | def get_vocab(self): 22 | return self._vocab 23 | 24 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, ...]: 25 | return () 26 | 27 | def _tokenize(self, text, **kwargs): 28 | tokens = np.fromstring(text, dtype=int, sep=" ") 29 | return tokens 30 | 31 | def _convert_token_to_id(self, token: str) -> int: 32 | return int(token) 33 | 34 | def _convert_id_to_token(self, index: int) -> str: 35 | return str(index) 36 | -------------------------------------------------------------------------------- /src/levanter/data/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Iterator, List, TypeVar 2 | 3 | 4 | T = TypeVar("T") 5 | 6 | 7 | def batched(iterable: Iterable[T], batch_size: int) -> Iterator[List[T]]: 8 | """Yields batches of the given size from the given iterable.""" 9 | batch = [] 10 | for item in iterable: 11 | batch.append(item) 12 | if len(batch) == batch_size: 13 | yield batch 14 | batch = [] 15 | 16 | if len(batch) > 0: 17 | yield batch 18 | -------------------------------------------------------------------------------- /src/levanter/infra/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/src/levanter/infra/__init__.py -------------------------------------------------------------------------------- /src/levanter/main/cache_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from dataclasses import dataclass, field 4 | 5 | import levanter 6 | from levanter.data.metrics_monitor import LoggingMetricsMonitor 7 | from levanter.data.text import BatchTokenizer, SingleDatasetLMConfigBase 8 | from levanter.distributed import RayConfig 9 | from levanter.store.cache import build_or_load_cache 10 | from levanter.tracker import NoopConfig, TrackerConfig 11 | from levanter.utils.logging import init_logging 12 | 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | @dataclass 18 | class RayCachedLMDatasetConfig(SingleDatasetLMConfigBase, RayConfig): 19 | tracker: TrackerConfig = field(default_factory=NoopConfig) 20 | 21 | 22 | @levanter.config.main() 23 | def main(args: RayCachedLMDatasetConfig): 24 | """Caches two different kinds of datasets. It can cache a dataset from a list of urls, or a dataset from a hf dataset""" 25 | init_logging(".", "cache_dataset.log") 26 | args.initialize() 27 | 28 | tokenizer = args.the_tokenizer 29 | 30 | for split in ["train", "validation"]: 31 | print(f"Caching {split} to {args.cache_dir}.") 32 | # connect or start the actor 33 | batch_tokenizer = BatchTokenizer(tokenizer, enforce_eos=args.enforce_eos) 34 | split_cache_dir = os.path.join(args.cache_dir, split) # type: ignore 35 | source = args.get_shard_source(split) 36 | 37 | if source is None: 38 | logger.warning(f"Skipping {split} because it is empty.") 39 | continue 40 | 41 | monitors: list = [] 42 | if not isinstance(args.tracker, NoopConfig): 43 | monitors.append(LoggingMetricsMonitor("preprocess/" + split, commit=True)) 44 | 45 | cache = build_or_load_cache( 46 | cache_dir=split_cache_dir, 47 | source=source, 48 | processor=batch_tokenizer, 49 | await_finished=False, 50 | monitors=monitors, 51 | ) 52 | 53 | cache.await_finished() 54 | print(f"Finished caching {split} to {split_cache_dir}.") 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /src/levanter/main/export_lm_to_hf.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from functools import cached_property 4 | from typing import Optional 5 | 6 | import equinox as eqx 7 | import jax 8 | from jax.sharding import Mesh 9 | 10 | from haliax import Axis 11 | 12 | import levanter 13 | from levanter.checkpoint import load_checkpoint 14 | from levanter.compat.hf_checkpoints import RepoRef, load_tokenizer 15 | from levanter.models.gpt2 import Gpt2Config 16 | from levanter.models.lm_model import LmConfig, LmHeadModel 17 | from levanter.utils.jax_utils import is_inexact_arrayish, use_cpu_device 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | @dataclass 24 | class ConvertLmConfig: 25 | checkpoint_path: str 26 | output_dir: str 27 | upload_to_hf: Optional[RepoRef] = None # if specified, attempt to upload this checkpoint to the hf hub 28 | 29 | model: LmConfig = Gpt2Config() 30 | save_tokenizer: bool = True # if True, save the tokenizer to the output directory 31 | tokenizer: str = "gpt2" 32 | override_vocab_size: Optional[int] = None # if specified, override the vocab size in the config 33 | 34 | config_overrides: Optional[dict] = None # if specified, override the config with these values 35 | 36 | @cached_property 37 | def the_tokenizer(self): 38 | return load_tokenizer(self.tokenizer) 39 | 40 | 41 | def main(config: ConvertLmConfig): 42 | logger.setLevel(logging.INFO) 43 | tokenizer = config.the_tokenizer 44 | 45 | vocab_size = config.override_vocab_size or len(tokenizer) 46 | Vocab = Axis("vocab", vocab_size) 47 | 48 | key = jax.random.PRNGKey(0) 49 | 50 | with use_cpu_device(), Mesh([jax.local_devices(backend="cpu")[0]], "dev"): 51 | model: LmHeadModel = eqx.filter_eval_shape(config.model.build, Vocab, key=key) 52 | trainable, non_trainable = eqx.partition(model, is_inexact_arrayish) 53 | # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model 54 | trainable = load_checkpoint(trainable, config.checkpoint_path, subpath="model") 55 | 56 | assert trainable is not None 57 | model = eqx.combine(trainable, non_trainable) 58 | 59 | if config.override_vocab_size: 60 | model = model.resize_vocab(config.override_vocab_size) 61 | 62 | converter = model.config.hf_checkpoint_converter().replaced(tokenizer=tokenizer) 63 | 64 | converter.save_pretrained( 65 | model, 66 | config.output_dir, 67 | upload_to_hf=config.upload_to_hf or False, 68 | save_tokenizer=config.save_tokenizer, 69 | ) 70 | 71 | 72 | if __name__ == "__main__": 73 | levanter.config.main(main)() 74 | -------------------------------------------------------------------------------- /src/levanter/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/src/levanter/models/__init__.py -------------------------------------------------------------------------------- /src/levanter/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import AdamConfig, OptimizerConfig 2 | from .sophia import ( # SophiaGConfig,; SophiaGObjective, 3 | ScaleBySophiaState, 4 | SophiaHConfig, 5 | scale_by_sophia_g, 6 | scale_by_sophia_h, 7 | ) 8 | -------------------------------------------------------------------------------- /src/levanter/optim/model_averaging.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import dataclasses 3 | from typing import Generic, TypeVar 4 | 5 | import draccus 6 | import equinox as eqx 7 | import optax 8 | 9 | 10 | S = TypeVar("S") 11 | M = TypeVar("M") 12 | 13 | 14 | class ModelAveraging(eqx.Module, Generic[M]): 15 | """ 16 | This is the interface for model averaging algorithms. Model averaging algorithms are used to average 17 | the parameters of a model over multiple training steps. This is useful for improving generalization 18 | """ 19 | 20 | @abc.abstractmethod 21 | def update(self: S, model: M, step: int) -> S: 22 | pass 23 | 24 | @property 25 | @abc.abstractmethod 26 | def model_params(self) -> M: 27 | pass 28 | 29 | 30 | class EmaModelAveraging(ModelAveraging[M]): 31 | """ 32 | Exponential moving average model averaging 33 | """ 34 | 35 | model: M 36 | beta: float = eqx.field(static=True) 37 | 38 | def update(self: S, new_model: M, step: int) -> S: 39 | del step 40 | # 1 - beta because increment_update expects the weight of the new model 41 | return dataclasses.replace(self, model=optax.incremental_update(new_model, self.model, 1 - self.beta)) # type: ignore 42 | 43 | @property 44 | def model_params(self) -> M: 45 | return self.model 46 | 47 | 48 | class ModelAveragingConfig(abc.ABC, draccus.ChoiceRegistry, Generic[M]): 49 | @abc.abstractmethod 50 | def create(self, model: M) -> ModelAveraging[M]: 51 | pass 52 | 53 | 54 | @ModelAveragingConfig.register_subclass("ema") 55 | @dataclasses.dataclass 56 | class EmaModelAveragingConfig(ModelAveragingConfig[M]): 57 | beta: float = 0.999 58 | 59 | def create(self, model: M) -> EmaModelAveraging[M]: 60 | return EmaModelAveraging(model=model, beta=self.beta) 61 | -------------------------------------------------------------------------------- /src/levanter/optim/util.py: -------------------------------------------------------------------------------- 1 | import equinox as eqx 2 | import jax 3 | 4 | from levanter.utils.jax_utils import is_inexact_arrayish 5 | 6 | 7 | def hvp(f, x, v): 8 | """Compute the Hessian-vector product of a function.""" 9 | return eqx.filter_jvp(eqx.filter_grad(f), (x,), (v,))[1] 10 | 11 | 12 | def tree_gaussian_like(key, tree): 13 | """ 14 | Samples a tree of gaussian noise with the same structure as `tree`, except for leaves which are not inexact arrays, 15 | for which it returns None 16 | """ 17 | leaves, structure = jax.tree_util.tree_flatten(tree) 18 | keys = jax.random.split(key, len(leaves)) 19 | rand_n = lambda x, key: jax.random.normal(key, x.shape) if is_inexact_arrayish(x) else None 20 | g = jax.tree_util.tree_map(rand_n, leaves, list(keys)) 21 | g = jax.tree_util.tree_unflatten(structure, g) 22 | 23 | return g 24 | -------------------------------------------------------------------------------- /src/levanter/shapes.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from math import prod 3 | from typing import Optional, Tuple, Type, TypeAlias, Union 4 | 5 | import jax 6 | import numpy as np 7 | from jax import ShapeDtypeStruct 8 | from jaxtyping import PyTree 9 | 10 | from haliax import Axis 11 | from haliax.util import is_named_array 12 | 13 | 14 | DType = Union[np.dtype, Type[int], Type[float], Type[bool]] 15 | 16 | ShapeSpec: TypeAlias = ShapeDtypeStruct 17 | 18 | 19 | @dataclass(frozen=True) 20 | class NamedShapeSpec: 21 | """A shape specification with named axes.""" 22 | 23 | shape: Optional[Tuple[Axis, ...]] 24 | dtype: Optional[DType] 25 | 26 | size = property(lambda self: prod(ax.size for ax in self.shape)) 27 | ndim = property(lambda self: len(self.shape)) 28 | 29 | 30 | def to_raw_shape(shape: Union[ShapeSpec, NamedShapeSpec]) -> Optional[Tuple[int, ...]]: 31 | if isinstance(shape, ShapeDtypeStruct): 32 | return shape.shape 33 | else: 34 | raw = shape.shape 35 | if raw is None: 36 | return None 37 | return tuple(ax.size for ax in raw) 38 | 39 | 40 | def shape_spec_of(tree: PyTree) -> PyTree[Union[ShapeSpec, NamedShapeSpec]]: 41 | """Get the shape specification of a tree.""" 42 | 43 | def _leaf_spec(leaf): 44 | if is_named_array(leaf): 45 | return NamedShapeSpec(leaf.axes, leaf.dtype) 46 | else: 47 | return ShapeDtypeStruct(leaf.shape, leaf.dtype) 48 | 49 | return jax.tree_util.tree_map(_leaf_spec, tree, is_leaf=is_named_array) 50 | 51 | 52 | def conforms(shape: PyTree[Union[ShapeSpec, NamedShapeSpec]], tree: PyTree) -> bool: 53 | """Check if a tree conforms to a shape specification.""" 54 | 55 | def _leaf_conforms(shape_spec: Union[ShapeSpec, NamedShapeSpec], leaf): 56 | if isinstance(shape_spec, ShapeSpec): # type: ignore 57 | return shape_spec.shape == leaf.shape and shape_spec.dtype == leaf.dtype 58 | else: 59 | return (shape_spec.shape is None or shape_spec.shape == leaf.axes) and ( 60 | shape_spec.dtype is None or shape_spec.dtype == leaf.dtype 61 | ) 62 | 63 | return jax.tree_util.tree_all(jax.tree_util.tree_map(_leaf_conforms, shape, tree, is_leaf=is_named_array)) 64 | -------------------------------------------------------------------------------- /src/levanter/store/__init__.py: -------------------------------------------------------------------------------- 1 | from .cache import SerialCacheWriter, TreeCache, build_or_load_cache 2 | from .jagged_array import JaggedArrayStore 3 | from .tree_store import TreeStore 4 | 5 | 6 | __all__ = ["TreeCache", "build_or_load_cache", "SerialCacheWriter", "JaggedArrayStore", "TreeStore"] 7 | -------------------------------------------------------------------------------- /src/levanter/tracker/__init__.py: -------------------------------------------------------------------------------- 1 | from levanter.tracker.helpers import capture_time, log_optimizer_hyperparams 2 | from levanter.tracker.tracker import CompositeTracker, NoopConfig, NoopTracker, Tracker, TrackerConfig 3 | from levanter.tracker.tracker_fns import ( 4 | LoggableValue, 5 | current_tracker, 6 | defer_tracker_for_jit, 7 | get_tracker, 8 | jit_log, 9 | log, 10 | log_configuration, 11 | log_hyperparameters, 12 | log_metrics, 13 | log_summary, 14 | set_global_tracker, 15 | ) 16 | 17 | 18 | __all__ = [ 19 | "Tracker", 20 | "TrackerConfig", 21 | "CompositeTracker", 22 | "log_optimizer_hyperparams", 23 | "NoopTracker", 24 | "current_tracker", 25 | "get_tracker", 26 | "jit_log", 27 | "log_configuration", 28 | "log", 29 | "log_summary", 30 | "log_hyperparameters", 31 | "set_global_tracker", 32 | "capture_time", 33 | "log_metrics", 34 | "LoggableValue", 35 | "defer_tracker_for_jit", 36 | "NoopConfig", 37 | ] 38 | -------------------------------------------------------------------------------- /src/levanter/tracker/helpers.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import dataclasses 3 | import logging 4 | import os 5 | import time 6 | from typing import Optional 7 | 8 | from git import InvalidGitRepositoryError, NoSuchPathError, Repo 9 | 10 | import levanter.tracker 11 | from levanter.utils.jax_utils import jnp_to_python 12 | 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def log_optimizer_hyperparams(opt_state, prefix: Optional[str] = None, *, step=None): 18 | try: 19 | from optax._src.wrappers import MultiStepsState 20 | 21 | if isinstance(opt_state, MultiStepsState): 22 | opt_state = opt_state.inner_opt_state 23 | except ImportError: 24 | pass 25 | 26 | def wrap_key(key): 27 | if prefix: 28 | return f"{prefix}/{key}" 29 | return key 30 | 31 | if hasattr(opt_state, "hyperparams"): 32 | params = {wrap_key(k): jnp_to_python(v) for k, v in opt_state.hyperparams.items()} 33 | levanter.tracker.log(params, step=step) 34 | 35 | 36 | def hparams_to_dict(hparams, **extra_hparams): 37 | if hparams is None: 38 | hparams_to_save = {} 39 | elif dataclasses.is_dataclass(hparams): 40 | hparams_to_save = dataclasses.asdict(hparams) 41 | else: 42 | hparams_to_save = dict(hparams) 43 | if extra_hparams: 44 | hparams_to_save.update(extra_hparams) 45 | return hparams_to_save 46 | 47 | 48 | def infer_experiment_git_root() -> Optional[str | os.PathLike[str]]: 49 | # sniff out the main directory (since we typically don't run from the root of the repo) 50 | # we'll walk the stack and directories for the files in the stack the until we're at a git root 51 | import os 52 | import traceback 53 | 54 | stack = traceback.extract_stack() 55 | # start from the top of the stack and work our way down since we want to hit the main file first 56 | top_git_root = None 57 | for frame in stack: 58 | dirname = os.path.dirname(frame.filename) 59 | # bit hacky but we want to skip anything that's in the python env 60 | if any(x in dirname for x in ["site-packages", "dist-packages", "venv", "opt/homebrew", "conda", "pyenv"]): 61 | continue 62 | # see if it's under a git root 63 | try: 64 | repo = Repo(dirname, search_parent_directories=True) 65 | top_git_root = repo.working_dir 66 | break 67 | except (NoSuchPathError, InvalidGitRepositoryError): 68 | logger.debug(f"Skipping {dirname} since it's not a git root") 69 | pass 70 | return top_git_root 71 | 72 | 73 | def generate_pip_freeze(): 74 | from importlib.metadata import distributions 75 | 76 | dists = distributions() 77 | return "\n".join(f"{dist.name}=={dist.version}" for dist in dists) 78 | 79 | 80 | @contextlib.contextmanager 81 | def capture_time(): 82 | start = time.perf_counter() 83 | done = False 84 | 85 | def fn(): 86 | if done: 87 | return end - start 88 | else: 89 | return time.perf_counter() - start 90 | 91 | yield fn 92 | end = time.perf_counter() 93 | done = True 94 | -------------------------------------------------------------------------------- /src/levanter/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/src/levanter/utils/__init__.py -------------------------------------------------------------------------------- /src/levanter/utils/activation.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import typing 3 | from functools import partial 4 | 5 | import jax 6 | 7 | import haliax as hax 8 | import haliax.nn as hnn 9 | 10 | 11 | _A = typing.TypeVar("_A", hax.Scalar, hax.NamedArray, jax.Array) 12 | ActivationFunction = typing.Callable[[_A], _A] 13 | 14 | 15 | class ActivationFunctionEnum(str, enum.Enum): 16 | relu = "relu" 17 | silu = "silu" 18 | swish = "swish" 19 | gelu = "gelu" 20 | gelu_new = "gelu_new" 21 | quick_gelu = "quick_gelu" 22 | tanh = "tanh" 23 | 24 | def to_fn(self) -> ActivationFunction: 25 | return TO_FN[self] 26 | 27 | 28 | # type: ignore 29 | TO_FN: dict[ActivationFunctionEnum, ActivationFunction] = { 30 | ActivationFunctionEnum.relu: hnn.relu, 31 | ActivationFunctionEnum.silu: hnn.silu, 32 | ActivationFunctionEnum.swish: hnn.swish, 33 | ActivationFunctionEnum.gelu: partial(hnn.gelu, approximate=False), 34 | ActivationFunctionEnum.gelu_new: partial(hnn.gelu, approximate=True), 35 | ActivationFunctionEnum.quick_gelu: hnn.quick_gelu, 36 | ActivationFunctionEnum.tanh: hax.tanh, 37 | } 38 | -------------------------------------------------------------------------------- /src/levanter/utils/datetime_utils.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | 3 | import pytimeparse 4 | 5 | 6 | def parse_timedelta(td_str) -> timedelta: 7 | td = timedelta(seconds=pytimeparse.parse(td_str)) 8 | if td.total_seconds() < 0: 9 | raise ValueError("Cannot encode negative timedelta") # not worth the trouble 10 | 11 | return td 12 | 13 | 14 | def encode_timedelta(td: timedelta) -> str: 15 | """Encodes a timedelta as a string that can be parsed by parse_timedelta/pytimeparse.""" 16 | out = "" 17 | if td.total_seconds() < 0: 18 | raise ValueError("Cannot encode negative timedelta") # not worth the trouble 19 | 20 | if td.days: 21 | out += f"{td.days}d" 22 | 23 | seconds: float = td.seconds 24 | 25 | if seconds > 3600: 26 | hours = seconds // 3600 27 | seconds -= hours * 3600 28 | out += f"{hours}h" 29 | if seconds > 60: 30 | minutes = seconds // 60 31 | seconds -= minutes * 60 32 | out += f"{minutes}m" 33 | 34 | if td.microseconds: 35 | seconds += td.microseconds / 1e6 36 | 37 | if seconds: 38 | out += f"{seconds}s" 39 | 40 | assert parse_timedelta(out) == td, f"Failed to encode {td} as {out}" 41 | return out 42 | -------------------------------------------------------------------------------- /src/levanter/utils/fsspec_utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import braceexpand 5 | import fsspec 6 | from fsspec.asyn import AsyncFileSystem 7 | 8 | 9 | def exists(url, **kwargs) -> bool: 10 | """Check if a file exists on a remote filesystem.""" 11 | fs, path = fsspec.core.url_to_fs(url, **kwargs) 12 | return fs.exists(path) 13 | 14 | 15 | def mkdirs(path): 16 | """Create a directory and any necessary parent directories.""" 17 | fs, path = fsspec.core.url_to_fs(path) 18 | fs.makedirs(path, exist_ok=True) 19 | 20 | 21 | def expand_glob(url): 22 | """ 23 | Yield every URL produced by brace and glob expansion. 24 | 25 | Examples 26 | -------- 27 | >>> list(expand_glob("s3://bucket/{2023,2024}/*.json")) 28 | ['s3://bucket/2023/a.json', 's3://bucket/2024/b.json', ...] 29 | """ 30 | for candidate in braceexpand.braceexpand(url): 31 | fs, path = fsspec.core.url_to_fs(candidate) 32 | 33 | if glob.has_magic(path): 34 | proto = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[0] 35 | for p in fs.glob(path): 36 | yield f"{proto}://{p}" if proto else p 37 | else: 38 | yield candidate 39 | 40 | 41 | def remove(url, *, recursive=False, **kwargs): 42 | """Remove a file from a remote filesystem.""" 43 | # TODO: better to use a STS deletion policy or job for this one. 44 | fs, path = fsspec.core.url_to_fs(url, **kwargs) 45 | 46 | fs.rm(path, recursive=recursive) 47 | 48 | 49 | async def async_remove(url, *, recursive=False, **kwargs): 50 | """Remove a file from a remote filesystem.""" 51 | fs, path = fsspec.core.url_to_fs(url, **kwargs) 52 | 53 | if isinstance(fs, AsyncFileSystem): 54 | return await fs._rm(path, recursive=recursive) 55 | else: 56 | fs.rm(path, recursive=recursive) 57 | 58 | 59 | def join_path(lhs, rhs): 60 | """ 61 | Join parts of a path together. Similar to plain old os.path.join except when there is a protocol in the rhs, it 62 | is treated as an absolute path. However, the lhs protocol and rhs protocol must match if the rhs has one. 63 | 64 | """ 65 | 66 | lhs_protocol, lhs_rest = fsspec.core.split_protocol(lhs) 67 | rhs_protocol, rhs_rest = fsspec.core.split_protocol(rhs) 68 | 69 | if rhs_protocol is not None and lhs_protocol is not None and lhs_protocol != rhs_protocol: 70 | raise ValueError(f"Cannot join paths with different protocols: {lhs} and {rhs}") 71 | 72 | if rhs_protocol is not None: 73 | return rhs 74 | else: 75 | return os.path.join(lhs, rhs) 76 | -------------------------------------------------------------------------------- /src/levanter/utils/hf_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from typing import TypeAlias 4 | 5 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 6 | 7 | from levanter.utils.logging import silence_transformer_nag 8 | from levanter.utils.py_utils import logical_cpu_core_count 9 | 10 | 11 | silence_transformer_nag() 12 | 13 | _HF_TOKENIZER_OFF_VALUES = {"off", "false", "f", "no", "n", "0"} 14 | 15 | HfTokenizer: TypeAlias = PreTrainedTokenizerFast | PreTrainedTokenizer 16 | """ 17 | Type alias for a Hugging Face tokenizer. This is a union of the two tokenizer types. 18 | While there is PreTrainedTokenizerBase, it doesn't have all methods that are implemented in both 19 | PreTrainedTokenizer and PreTrainedTokenizerFast. grumble grumble. 20 | """ 21 | 22 | 23 | def num_cpus_used_by_tokenizer(tokenizer: HfTokenizer) -> int: 24 | if getattr(tokenizer, "is_fast", False): 25 | if os.getenv("TOKENIZERS_PARALLELISM", "true").lower() in _HF_TOKENIZER_OFF_VALUES: 26 | return 1 27 | else: 28 | # This is a bit hacky, but HF's fast tokenizers are parallelized under the hood. 29 | # we reserve a couple of cores just so Ray has somewhere to run the coordinator. 30 | # Empirically it doesn't usually exceed 16-20, and it's useful to have some slack 31 | return min(max(1, logical_cpu_core_count() - 4), 12) 32 | else: 33 | return 1 34 | 35 | 36 | def byte_length_of_token(tokenizer, idx: int) -> int: 37 | # this is a pain because we want the prefix spaces, but we don't want extra noise for bytes 38 | # e.g. in llama 39 | # >>> t.convert_ids_to_tokens(q[2]) 40 | # '▁this' 41 | # >>> t.convert_ids_to_tokens(25) 42 | # '<0x16>' 43 | # We want the _ (as a single byte, not the 3 it's encoded as) but not the <0x16>, which should instead be a single byte \x16 44 | # decode strips the prefix spaces, but does correctly handle the <0x16> case 45 | # we can avoid prefix space issues by prepending another token before decoding, then stripping 46 | repr = tokenizer.convert_ids_to_tokens(idx) 47 | if idx in tokenizer.all_special_ids: 48 | # NB: special tokens don't have bytes, but they contribute to perplexity/bits 49 | return 0 50 | # handle bytes specially. This is a bit of a hack, but there's no other way 51 | elif m := re.match(r"<0x([0-9A-Fa-f]+)>", repr): 52 | return len(bytes.fromhex(m.group(1))) 53 | else: 54 | extra_token = tokenizer(".", add_special_tokens=False)["input_ids"][0] 55 | excess_bytes = len(".".encode("utf-8")) 56 | decoded = tokenizer.decode([extra_token, idx]).encode("utf-8") 57 | return len(decoded) - excess_bytes 58 | -------------------------------------------------------------------------------- /src/levanter/utils/index.py: -------------------------------------------------------------------------------- 1 | from typing import Generic, Iterable, Iterator, TypeVar 2 | 3 | 4 | T = TypeVar("T") 5 | 6 | 7 | class Index(Generic[T]): 8 | """ 9 | Index is a bidirectional mapping from (incremental) integers to objects. 10 | 11 | Needs to be fast, so it exposes the underlying data structures. 12 | """ 13 | 14 | def __init__(self, objs: Iterable[T] = ()): 15 | self._index_to_obj: list[T] = [] 16 | self._obj_to_index: dict[T, int] = {} 17 | for obj in objs: 18 | self.append(obj) 19 | 20 | def __len__(self): 21 | return len(self._index_to_obj) 22 | 23 | def __getitem__(self, index: int) -> T: 24 | return self._index_to_obj[index] 25 | 26 | def __setitem__(self, index: int, obj: T): 27 | self._index_to_obj[index] = obj 28 | self._obj_to_index[obj] = index 29 | 30 | def append(self, obj: T) -> int: 31 | index = len(self) 32 | self._index_to_obj.append(obj) 33 | self._obj_to_index[obj] = index 34 | return index 35 | 36 | def get_index(self, obj: T) -> int: 37 | return self._obj_to_index[obj] 38 | 39 | def get_obj(self, index: int) -> T: 40 | return self._index_to_obj[index] 41 | 42 | def __contains__(self, obj: T) -> bool: 43 | return obj in self._obj_to_index 44 | 45 | def __iter__(self) -> Iterator[T]: 46 | return iter(self._index_to_obj) 47 | -------------------------------------------------------------------------------- /src/levanter/utils/json_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from levanter.utils.activation import ActivationFunctionEnum 4 | 5 | 6 | class ConfigJSONEncoder(json.JSONEncoder): 7 | """Supports all the custom types we put into configs.""" 8 | 9 | def default(self, o): 10 | # We can probably get rid of this if we require python 3.11 11 | # and change ActivationFunctionEnum to a StrEnum 12 | if isinstance(o, ActivationFunctionEnum): 13 | return o.name 14 | return super().default(o) 15 | -------------------------------------------------------------------------------- /src/levanter/utils/stat_utils.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import equinox as eqx 4 | import jax.numpy as jnp 5 | import numpy as np 6 | 7 | import haliax as hax 8 | 9 | 10 | Arrayish: typing.TypeAlias = hax.NamedArray | np.ndarray | jnp.ndarray 11 | 12 | 13 | class RunningMean(eqx.Module): 14 | mean: Arrayish 15 | total: Arrayish 16 | 17 | @staticmethod 18 | def zeros_like(x: Arrayish) -> "RunningMean": 19 | return RunningMean(x * 0.0, x * 0.0) 20 | 21 | def add(self, x: Arrayish, total: Arrayish) -> "RunningMean": 22 | delta = x - self.mean 23 | # careful: total and self.total can be 0 24 | new_total = self.total + total 25 | ratio = hax.where(new_total, total / new_total, 0.0) 26 | new_mean = self.mean + delta * ratio 27 | new_total = self.total + total 28 | return RunningMean(new_mean, new_total) 29 | 30 | def __add__(self, other: "RunningMean"): 31 | return self.add(other.mean, other.total) 32 | 33 | def __str__(self): 34 | return f"RunningMean(mean={self.mean}, total={self.total})" 35 | -------------------------------------------------------------------------------- /src/levanter/utils/thread_utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import threading 3 | from concurrent.futures import ThreadPoolExecutor 4 | from typing import Iterator 5 | 6 | 7 | # Create a ThreadPoolExecutor 8 | _executor = ThreadPoolExecutor(max_workers=10) 9 | 10 | 11 | def blocking_wait(coro): 12 | """ 13 | This will only work if there are fewer than 10 levels of nested coroutines... 14 | """ 15 | try: 16 | loop = asyncio.get_running_loop() 17 | except RuntimeError: 18 | loop = None 19 | 20 | if loop is not None and loop.is_running(): 21 | future = _executor.submit(lambda: asyncio.run(coro)) 22 | return future.result() 23 | else: 24 | return asyncio.run(coro) 25 | 26 | 27 | def future_from_value(value): 28 | future = asyncio.Future() 29 | future.set_result(value) 30 | return future 31 | 32 | 33 | class AsyncIteratorWrapper(Iterator): 34 | def __init__(self, async_iter): 35 | self.async_iter = async_iter 36 | self.loop = asyncio.new_event_loop() 37 | self.thread = threading.Thread(target=self._run_loop, daemon=True) 38 | self.thread.start() 39 | self._exhausted = False # Flag to indicate if the iterator is exhausted 40 | 41 | def _run_loop(self): 42 | asyncio.set_event_loop(self.loop) 43 | self.loop.run_forever() 44 | 45 | def _run_async_task(self, coro): 46 | if not self.loop.is_running() or not self.thread.is_alive(): 47 | raise StopIteration 48 | try: 49 | future = asyncio.run_coroutine_threadsafe(coro, self.loop) 50 | return future.result() 51 | except (RuntimeError, asyncio.CancelledError): 52 | raise StopIteration 53 | 54 | def __iter__(self): 55 | return self 56 | 57 | def __next__(self): 58 | if self._exhausted: 59 | raise StopIteration 60 | try: 61 | return self._run_async_task(self.async_iter.__anext__()) 62 | except StopAsyncIteration: 63 | self._exhausted = True # Mark the iterator as exhausted 64 | if self.loop.is_running(): 65 | self.loop.call_soon_threadsafe(self.loop.stop) 66 | self.thread.join() 67 | raise StopIteration 68 | 69 | def close(self): 70 | """Close the event loop and thread gracefully.""" 71 | if self.loop.is_running(): 72 | self.loop.call_soon_threadsafe(self.loop.stop) 73 | self.thread.join() 74 | self.loop.close() 75 | 76 | 77 | class ExceptionTrackingThread(threading.Thread): 78 | """A thread that will store exceptions that occur in the target function and 79 | re-raise them in the main thread.""" 80 | 81 | def __init__(self, *args, **kwargs): 82 | super().__init__(*args, **kwargs) 83 | self._exception = None 84 | 85 | def run(self): 86 | try: 87 | super().run() 88 | except Exception as e: 89 | self._exception = e 90 | 91 | def join(self, *args, **kwargs): 92 | super().join(*args, **kwargs) 93 | if self._exception: 94 | raise self._exception 95 | 96 | def check_raise(self): 97 | if self._exception: 98 | raise self._exception 99 | -------------------------------------------------------------------------------- /src/levanter/utils/tree_utils.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import TypeVar 3 | 4 | import equinox as eqx 5 | import jax 6 | from jaxtyping import PyTree 7 | 8 | from haliax.util import StringHolderEnum 9 | 10 | 11 | T = TypeVar("T", bound=PyTree) 12 | 13 | 14 | class NonePolicy(StringHolderEnum): 15 | PRESERVE = "preserve" 16 | REPLACE = "replace" 17 | ERROR = "error" 18 | 19 | 20 | def inference_mode(tree: T, value: bool, none_policy: str = NonePolicy.REPLACE) -> T: 21 | """ 22 | Analogous to [equinox.nn.inference_mode][] (neé [equinox.tree_inference][]), but 23 | it works in the presence of nones for the `inference` argument. 24 | """ 25 | 26 | has_inference = lambda leaf: hasattr(leaf, "inference") # noqa: E731 27 | 28 | def replace_fn(node): 29 | if not has_inference(node): 30 | return node 31 | 32 | if node.inference is None: 33 | if none_policy == NonePolicy.PRESERVE: 34 | return node 35 | elif none_policy == NonePolicy.ERROR: 36 | raise ValueError(f"None found in {tree}.inference with none_policy={none_policy}") 37 | else: 38 | assert none_policy == NonePolicy.REPLACE, f"Unknown none_policy {none_policy}" 39 | 40 | if dataclasses.is_dataclass(node): 41 | return dataclasses.replace(node, inference=value) 42 | else: 43 | return eqx.tree_at(lambda x: x.inference, node, value, is_leaf=lambda x: x is node) 44 | 45 | def rec_set(tree): 46 | if has_inference(tree): 47 | tree = replace_fn(tree) 48 | 49 | if jax.tree_util.tree_leaves(tree) == [tree]: 50 | return tree 51 | 52 | return jax.tree_util.tree_map(rec_set, tree, is_leaf=lambda x: has_inference(x) and tree is not x) 53 | 54 | return rec_set(tree) 55 | -------------------------------------------------------------------------------- /src/levanter/utils/types.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional, Protocol, Tuple, TypeVar, Union 2 | 3 | from jaxtyping import PyTree 4 | 5 | import haliax as hax 6 | from haliax.types import Scalar 7 | 8 | 9 | M = TypeVar("M") # Model 10 | M_con = TypeVar("M_con", contravariant=True) # Model 11 | X = TypeVar("X", contravariant=True) # Input 12 | 13 | try: 14 | from haliax.nn.scan import BlockFoldable 15 | except ImportError: 16 | 17 | class BlockFoldable(Protocol[M]): # type: ignore 18 | def fold(self, *args, **kwargs): 19 | ... 20 | 21 | def scan(self, *args, **kwargs): 22 | ... 23 | 24 | 25 | class ValAndGradFn(Protocol[M, X]): 26 | def __call__(self, model: M, *inputs: X, **input_kwargs) -> Tuple[Scalar, M]: 27 | ... 28 | 29 | 30 | class ValFn(Protocol[M_con, X]): 31 | def __call__(self, model: M_con, *inputs: X, **input_kwargs) -> Scalar: 32 | ... 33 | 34 | 35 | FilterSpec = Union[bool, Callable[[Any], bool]] 36 | """ 37 | A filter specification. Typically used on a pytree to filter out certain subtrees. Boolean values are 38 | treated as-is, while callables are called on each element of the pytree. If the callable returns True, the element 39 | is kept, otherwise it is filtered out. 40 | """ 41 | 42 | FilterTree = FilterSpec | PyTree[FilterSpec] 43 | 44 | 45 | class ComputeLossFunction(Protocol[M_con, X]): 46 | """ 47 | Function signature for "compute_loss" functions in Levanter: these 48 | couple the computation of the logits and the evaluation of the loss 49 | """ 50 | 51 | def __call__( 52 | self, 53 | model: M_con, 54 | *inputs: X, 55 | reduction: Optional[hax.ReductionFunction] = hax.mean, 56 | reduction_axis: Optional[hax.AxisSelection] = None, 57 | **kwargs, 58 | ) -> Scalar | hax.NamedArray: 59 | ... 60 | -------------------------------------------------------------------------------- /src/levanter/visualization.py: -------------------------------------------------------------------------------- 1 | # This module has been moved to levanter.analysis.visualization 2 | # This file is kept for backward compatibility 3 | 4 | import numpy as np 5 | 6 | from levanter.analysis.visualization import compute_and_diff_log_probs # noqa 7 | from levanter.analysis.visualization import compute_and_visualize_log_probs # noqa 8 | from levanter.analysis.visualization import visualize_log_prob_diff, visualize_log_probs 9 | 10 | 11 | # dumb main to test it out 12 | if __name__ == "__main__": 13 | np.random.seed(1) 14 | tokens = [["Hello", "world", "!"], ["This", "is", "a", "test", "."]] 15 | log_probs = np.log(np.random.uniform(size=(2, 5))) 16 | visualize_log_probs(tokens, log_probs, "test.html") 17 | 18 | # test diff 19 | log_probs_a = np.log(np.random.uniform(size=(2, 5))) 20 | log_probs_b = np.log(np.random.uniform(size=(2, 5))) 21 | visualize_log_prob_diff(tokens, log_probs_a, log_probs_b, "test_diff.html") 22 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/tests/__init__.py -------------------------------------------------------------------------------- /tests/data/hero_data.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/tests/data/hero_data.npy -------------------------------------------------------------------------------- /tests/gpt2_test.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import jax.numpy as jnp 4 | import pytest 5 | from jax.random import PRNGKey 6 | 7 | import haliax as hax 8 | from haliax import Axis 9 | 10 | from levanter.models.attention import AttentionBackend, AttentionMask 11 | from levanter.models.gpt2 import Gpt2Config, Gpt2LMHeadModel 12 | from test_utils import check_load_config, check_model_works_with_seqlen, parameterize_with_configs 13 | 14 | 15 | @pytest.mark.parametrize("num_blocks", [1, 4, 12]) 16 | @pytest.mark.parametrize("attn_backend", [AttentionBackend.JAX_FLASH, AttentionBackend.VANILLA]) 17 | def test_gradient_checkpointing(num_blocks, attn_backend): 18 | # ensure that gradient checkpointing doesn't change the output 19 | # (this is a regression test for a bug that caused the output to change) 20 | config = Gpt2Config( 21 | seq_len=64, 22 | hidden_dim=64, 23 | num_layers=num_blocks, 24 | num_heads=8, 25 | gradient_checkpointing=False, 26 | # use_flash_attention=True, 27 | attn_backend=attn_backend, 28 | ) 29 | config_checkpoint = dataclasses.replace(config, gradient_checkpointing=True) 30 | key = PRNGKey(0) 31 | 32 | Vocab = Axis("vocab", 128) 33 | 34 | model = Gpt2LMHeadModel.init(Vocab, config, key=key) 35 | model_checkpoint = Gpt2LMHeadModel.init(Vocab, config_checkpoint, key=key) 36 | 37 | input_ids = hax.arange(config.Pos, dtype=jnp.int32) 38 | 39 | causal_mask = AttentionMask.causal() 40 | 41 | a1 = model(input_ids, key=key, attn_mask=causal_mask) 42 | a2 = model_checkpoint(input_ids, key=key, attn_mask=causal_mask) 43 | 44 | assert hax.all(hax.isclose(a1, a2, rtol=1e-4, atol=1e-5)), f"failed with num_blocks={num_blocks}" 45 | 46 | 47 | @parameterize_with_configs("gpt2*.yaml") 48 | def test_gpt2_configs(config_file): 49 | from levanter.main.train_lm import TrainLmConfig 50 | 51 | check_load_config(TrainLmConfig, config_file) 52 | 53 | 54 | def test_pass_different_length_seq_to_gpt2(): 55 | config = Gpt2Config( 56 | seq_len=64, 57 | hidden_dim=16, 58 | num_layers=4, 59 | num_heads=2, 60 | gradient_checkpointing=False, 61 | use_flash_attention=True, 62 | ) 63 | check_model_works_with_seqlen(Gpt2LMHeadModel, config, 16) 64 | -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | flake8 2 | pytest 3 | soundfile 4 | librosa 5 | pytest-forked 6 | pytest-asyncio 7 | tensorboard 8 | -------------------------------------------------------------------------------- /tests/test_datetime_utils.py: -------------------------------------------------------------------------------- 1 | from levanter.utils.datetime_utils import encode_timedelta, parse_timedelta 2 | 3 | 4 | def test_encode_timedelta(): 5 | # various time strings from the unit tests for pytimeparse 6 | # https://github.com/wroberts/pytimeparse/blob/master/pytimeparse/tests/testtimeparse.py 7 | # i skipped negative ones because they're not supported, and fractional things because really 8 | 9 | def ensure_roundtrip(td_str, expected_seconds): 10 | # we don't enforce that the output is the same as the input, 11 | # but we do enforce that it can be parsed to the same timedelta 12 | td = parse_timedelta(td_str) 13 | assert td.total_seconds() == expected_seconds 14 | assert parse_timedelta(encode_timedelta(td)) == td, f"Failed to roundtrip {td_str}: {encode_timedelta(td)}" 15 | 16 | ensure_roundtrip("1d", 86400) 17 | ensure_roundtrip("+32 m 1 s", 1921) 18 | ensure_roundtrip("+ 32 m 1 s", 1921) 19 | ensure_roundtrip("32m", 1920) 20 | ensure_roundtrip("+32m", 1920) 21 | ensure_roundtrip("2h32m", 9120) 22 | ensure_roundtrip("+2h32m", 9120) 23 | ensure_roundtrip("3d2h32m", 268320) 24 | ensure_roundtrip("+3d2h32m", 268320) 25 | ensure_roundtrip("1w3d2h32m", 873120) 26 | ensure_roundtrip("1w 3d 2h 32m", 873120) 27 | ensure_roundtrip("1 w 3 d 2 h 32 m", 873120) 28 | ensure_roundtrip("4:13", 253) 29 | ensure_roundtrip(":13", 13) 30 | ensure_roundtrip("4:13:02", 15182) 31 | ensure_roundtrip("4:13:02.266", 15182.266) 32 | ensure_roundtrip("2:04:13:02.266", 187982.266) 33 | ensure_roundtrip("2 days, 4:13:02", 187982) 34 | ensure_roundtrip("5hr34m56s", 20096) 35 | ensure_roundtrip("5 hours, 34 minutes, 56 seconds", 20096) 36 | ensure_roundtrip("5 hrs, 34 mins, 56 secs", 20096) 37 | ensure_roundtrip("2 days, 5 hours, 34 minutes, 56 seconds", 192896) 38 | ensure_roundtrip("172 hr", 619200) 39 | -------------------------------------------------------------------------------- /tests/test_distributed.py: -------------------------------------------------------------------------------- 1 | from levanter.distributed import _square_brace_expand 2 | 3 | 4 | def test_square_brace_expand(): 5 | custom_sequence = "node[001-004,007]suffix" 6 | expanded_nodes = _square_brace_expand(custom_sequence) 7 | assert expanded_nodes == ["node001suffix", "node002suffix", "node003suffix", "node004suffix", "node007suffix"] 8 | 9 | custom_sequence_2 = "prefix[001-002]node[005-006]suffix" 10 | expanded_nodes_2 = _square_brace_expand(custom_sequence_2) 11 | assert expanded_nodes_2 == [ 12 | "prefix001node005suffix", 13 | "prefix001node006suffix", 14 | "prefix002node005suffix", 15 | "prefix002node006suffix", 16 | ] 17 | 18 | custom_sequence_3 = "node[1-11]suffix" 19 | expanded_nodes_3 = _square_brace_expand(custom_sequence_3) 20 | assert expanded_nodes_3 == [f"node{i}suffix" for i in range(1, 12)] 21 | 22 | custom_sequence_3 = "node[1-11,21]suffix" 23 | expanded_nodes_3 = _square_brace_expand(custom_sequence_3) 24 | assert expanded_nodes_3 == [f"node{i}suffix" for i in range(1, 12)] + ["node21suffix"] 25 | -------------------------------------------------------------------------------- /tests/test_export_to_hf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import equinox as eqx 5 | import jax 6 | import pytest 7 | from transformers import AutoModelForCausalLM 8 | 9 | import haliax 10 | 11 | import levanter.main.export_lm_to_hf as export_lm_to_hf 12 | import tiny_test_corpus 13 | from levanter.checkpoint import save_checkpoint 14 | from levanter.models.gpt2 import Gpt2Config, Gpt2LMHeadModel 15 | from levanter.utils.jax_utils import is_inexact_arrayish 16 | from test_utils import has_torch 17 | 18 | 19 | @pytest.mark.entry 20 | def test_export_lm_to_hf(): 21 | # just testing if train_lm has a pulse 22 | model_config = Gpt2Config( 23 | num_layers=2, 24 | num_heads=2, 25 | seq_len=32, 26 | use_flash_attention=True, 27 | hidden_dim=32, 28 | ) 29 | 30 | with tempfile.TemporaryDirectory() as tmpdir: 31 | data_config = tiny_test_corpus.tiny_corpus_config(tmpdir) 32 | tok = data_config.the_tokenizer 33 | Vocab = haliax.Axis("vocab", len(tok)) 34 | model = Gpt2LMHeadModel.init(Vocab, model_config, key=jax.random.PRNGKey(0)) 35 | # in our trainer, we only export the trainable params 36 | trainable, non_trainable = eqx.partition(model, is_inexact_arrayish) 37 | 38 | save_checkpoint({"model": trainable}, 0, f"{tmpdir}/ckpt") 39 | 40 | try: 41 | config = export_lm_to_hf.ConvertLmConfig( 42 | checkpoint_path=f"{tmpdir}/ckpt", 43 | output_dir=f"{tmpdir}/output", 44 | model=model_config, 45 | ) 46 | export_lm_to_hf.main(config) 47 | 48 | if has_torch(): 49 | AutoModelForCausalLM.from_pretrained(f"{tmpdir}/output") 50 | 51 | finally: 52 | try: 53 | os.unlink("wandb") 54 | except Exception: 55 | pass 56 | -------------------------------------------------------------------------------- /tests/test_grad_accum.py: -------------------------------------------------------------------------------- 1 | import equinox as eqx 2 | import jax 3 | import pytest 4 | from chex import assert_trees_all_close 5 | from jax.sharding import Mesh 6 | 7 | import haliax 8 | import haliax as hax 9 | import haliax.nn as hnn 10 | 11 | from levanter.grad_accum import microbatched 12 | 13 | 14 | class Mlp(eqx.Module): 15 | """ 16 | Simple 1 hidden layer MLP implementation 17 | """ 18 | 19 | w_in: hax.NamedArray 20 | w_out: hax.NamedArray 21 | In: hax.Axis = eqx.field(static=True) 22 | Out: hax.Axis = eqx.field(static=True) 23 | Mid: hax.Axis = eqx.field(static=True) 24 | 25 | @staticmethod 26 | def init(In: hax.Axis, Out: hax.Axis, Mid: hax.Axis, *, key): 27 | w_in = hax.random.normal(key, hax.concat_axis_specs(In, Mid)) * 0.02 28 | w_out = hax.random.normal(key, hax.concat_axis_specs(Mid, Out)) * 0.02 29 | return Mlp(w_in, w_out, In, Out, Mid) 30 | 31 | def __call__(self, x): 32 | x = hax.dot(self.w_in, x, axis=self.In) 33 | x = hnn.relu(x) 34 | x = hax.dot(self.w_out, x, axis=self.Mid) 35 | return x 36 | 37 | 38 | @pytest.mark.parametrize("parallelism", [1, 2, 4]) 39 | @pytest.mark.parametrize("accum_steps", [1, 3]) 40 | def test_accumulate_gradients_sharded(parallelism, accum_steps): 41 | In = hax.Axis("In", 32) 42 | Out = hax.Axis("Out", 32) 43 | Mid = hax.Axis("Mid", 32) 44 | Batch = hax.Axis("Batch", len(jax.devices()) * parallelism * accum_steps) 45 | mlp = Mlp.init(In, Out, Mid, key=jax.random.PRNGKey(0)) 46 | 47 | def loss_fn(mlp, x): 48 | return mlp(x).mean().scalar() 49 | 50 | x = hax.random.normal(jax.random.PRNGKey(0), (Batch, In)) 51 | 52 | x = jax.device_put(x, jax.sharding.PositionalSharding(jax.devices()).reshape((-1, 1))) 53 | 54 | axis_mapping = {"Batch": "data"} 55 | 56 | mesh = Mesh(jax.devices(), ("data",)) 57 | 58 | @hax.partitioning.named_jit(axis_resources=axis_mapping) 59 | def jit_grad_accum(mlp, x): 60 | grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False) 61 | grad_fn = microbatched(grad_fn, Batch, parallelism, axis_mapping, axis_mapping) 62 | acc_v, acc_g = grad_fn( 63 | mlp, 64 | x, 65 | ) 66 | return acc_v, acc_g 67 | 68 | with mesh: 69 | mlp = haliax.shard(mlp, axis_mapping) 70 | x = haliax.shard(x, axis_mapping) 71 | grad_fn = eqx.filter_value_and_grad(loss_fn) 72 | acc_v, acc_g = jit_grad_accum(mlp, x) 73 | v, g = grad_fn(mlp, x) 74 | 75 | assert_trees_all_close(acc_v, v, atol=1e-3, rtol=1e-3) 76 | 77 | for l1, l2 in zip(jax.tree_util.tree_leaves(acc_g), jax.tree_util.tree_leaves(g)): 78 | assert_trees_all_close(l1, l2, atol=1e-3, rtol=1e-3) 79 | -------------------------------------------------------------------------------- /tests/test_histogram.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import equinox 4 | import jax 5 | import numpy as np 6 | from jax.random import PRNGKey 7 | from jax.sharding import Mesh 8 | 9 | import haliax as hax 10 | from haliax.partitioning import ResourceAxis 11 | 12 | import levanter.tracker.histogram 13 | from test_utils import skip_if_not_enough_devices 14 | 15 | 16 | def test_sharded_histogram_simple(): 17 | mesh = Mesh((jax.devices()), (ResourceAxis.DATA,)) 18 | 19 | Batch = hax.Axis("batch", 64) 20 | Feature = hax.Axis("feature", 128) 21 | 22 | with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): 23 | a = hax.random.normal(PRNGKey(1), (Batch, Feature)) 24 | a = hax.shard(a) 25 | hist, bins = levanter.tracker.histogram.sharded_histogram(a, bins=32) 26 | 27 | hist_normal, bins_normal = jax.numpy.histogram(a.array, bins=32) 28 | 29 | assert jax.numpy.allclose(hist, hist_normal) 30 | assert jax.numpy.allclose(bins, bins_normal) 31 | 32 | 33 | @skip_if_not_enough_devices(2) 34 | def test_sharded_histogram_tp(): 35 | mesh = Mesh(np.array(jax.devices()).reshape(-1, 2), (ResourceAxis.DATA, ResourceAxis.MODEL)) 36 | 37 | Batch = hax.Axis("batch", 64) 38 | Feature = hax.Axis("feature", 128) 39 | 40 | with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA, "feature": ResourceAxis.MODEL}): 41 | a = hax.random.normal(PRNGKey(0), (Batch, Feature)) * 100 42 | a = hax.shard(a) 43 | hist, bins = levanter.tracker.histogram.sharded_histogram(a, bins=64) 44 | 45 | jnp_hist, jnp_bins = jax.numpy.histogram(a.array, bins=64) 46 | 47 | assert jax.numpy.allclose(hist, jnp_hist) 48 | assert jax.numpy.allclose(bins, jnp_bins) 49 | 50 | 51 | def test_sharded_histogram_with_vmap(): 52 | mesh = Mesh((jax.devices()), (ResourceAxis.DATA,)) 53 | 54 | Layer = hax.Axis("layer", 4) 55 | Batch = hax.Axis("batch", 16) 56 | Feature = hax.Axis("feature", 128) 57 | 58 | @equinox.filter_jit 59 | def jit_vmap_hist(a): 60 | """ 61 | This function will be JIT compiled and VMapped. 62 | """ 63 | # Call the sharded histogram function 64 | hist, bins = hax.vmap(levanter.tracker.histogram.sharded_histogram, Layer)(a, bins=32) 65 | return hist, bins 66 | 67 | with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): 68 | a = hax.random.normal(PRNGKey(1), (Layer, Batch, Feature)) 69 | a = hax.shard(a) 70 | hist, bins = jit_vmap_hist(a) 71 | 72 | hist_normal, bins_normal = jax.vmap(functools.partial(jax.numpy.histogram, bins=32), in_axes=0)(a.array) 73 | 74 | assert jax.numpy.allclose(hist, hist_normal) 75 | assert jax.numpy.allclose(bins, bins_normal) 76 | -------------------------------------------------------------------------------- /tests/test_hyena.py: -------------------------------------------------------------------------------- 1 | import chex 2 | import jax 3 | 4 | import haliax as hax 5 | 6 | from levanter.models.hyena import HyenaConfig, HyenaOperator 7 | from levanter.utils.activation import ActivationFunctionEnum 8 | 9 | 10 | def test_causality(): 11 | """ 12 | Test that the Hyena operator is causal - future tokens 13 | should not affect predictions for past tokens. 14 | """ 15 | # Create a test config that matches the PyTorch example 16 | config = HyenaConfig( 17 | seq_len=1024, 18 | hidden_dim=512, 19 | order=2, 20 | filter_order=64, 21 | activation=ActivationFunctionEnum.gelu_new, 22 | ) 23 | 24 | # Initialize the model with a fixed key for reproducibility 25 | key = jax.random.PRNGKey(0) 26 | model_key, input_key = jax.random.split(key) 27 | model = HyenaOperator.init(config, key=model_key) 28 | 29 | # Create a random input tensor with shape matching the PyTorch example 30 | Pos = config.Pos 31 | Embed = config.Embed 32 | x = hax.random.normal(input_key, (Pos, Embed)) 33 | 34 | # Define a function to compute the sum of a specific position's output 35 | loss_pos = 10 36 | 37 | def loss_fn(x): 38 | y = model(x) 39 | return hax.sum(y.slice(Pos, start=loss_pos, length=1)).array 40 | 41 | # Compute gradients using JAX's grad 42 | grad_fn = jax.grad(loss_fn) 43 | grads = grad_fn(x) 44 | 45 | # Check that gradients flow from past to present but not from future to past 46 | # Position 10 should affect itself 47 | pos_10_grad_sum = hax.sum(hax.abs(grads.slice(Pos, start=loss_pos, length=1))) 48 | assert pos_10_grad_sum > 0, "Position should affect itself" 49 | 50 | # Position 9 should affect position 10 (past affects future) 51 | pos_9_grad_sum = hax.sum(hax.abs(grads.slice(Pos, start=loss_pos - 1, length=1))) 52 | assert pos_9_grad_sum > 0, "Past should affect future" 53 | 54 | # Position 11 should NOT affect position 10 (future should not affect past) 55 | pos_11_grad_sum = hax.sum(hax.abs(grads.slice(Pos, start=loss_pos + 1, length=1))) 56 | assert pos_11_grad_sum == 0.0, "Future should not affect past (causality violation detected)" 57 | 58 | # Additional test: all positions greater than 10 should have zero gradient 59 | future_positions_grads = grads.slice(Pos, start=loss_pos + 1, length=Pos.size - loss_pos - 1) 60 | chex.assert_trees_all_close(future_positions_grads, hax.zeros_like(future_positions_grads)) 61 | -------------------------------------------------------------------------------- /tests/test_logging.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import pytest 4 | from git import InvalidGitRepositoryError, NoSuchPathError, Repo 5 | 6 | from levanter.tracker.helpers import infer_experiment_git_root 7 | 8 | 9 | def test_infer_experiment_git_root(): 10 | # make sure this test is running in a git repo 11 | try: 12 | repo = Repo(pathlib.Path(__file__), search_parent_directories=True) 13 | except (InvalidGitRepositoryError, NoSuchPathError): 14 | pytest.skip("test not running in a git repo") 15 | 16 | root = infer_experiment_git_root() 17 | 18 | # ensure that 1) this is a git root and 2) this source file is underneath 19 | assert root is not None 20 | assert pathlib.Path(root).exists() 21 | repo = Repo(root) 22 | assert repo.working_dir == root 23 | assert pathlib.Path(__file__).is_relative_to(root), f"{__file__} is not relative to {root}" 24 | -------------------------------------------------------------------------------- /tests/test_py_utils.py: -------------------------------------------------------------------------------- 1 | from levanter.utils.py_utils import actual_sizeof 2 | 3 | 4 | def test_actual_sizeof(): 5 | d1 = {"a": 1, "b": 2} 6 | d2 = {"a": "this is a string", "b": "this is another string"} 7 | 8 | assert actual_sizeof(d1) < actual_sizeof(d2) 9 | -------------------------------------------------------------------------------- /tests/test_scheduler.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from levanter.schedule import BatchSchedule, ScheduleStep 4 | 5 | 6 | @pytest.fixture 7 | def scheduler(): 8 | """ 9 | A pytest fixture that sets up a BatchScheduler with the following schedule: 10 | - Use batch size 32 until step 1000 11 | - Then batch size 64 until step 100000 12 | - Then batch size 128 forever 13 | """ 14 | schedule = [ 15 | ScheduleStep(start=0, value=32), 16 | ScheduleStep(start=1000, value=64), 17 | ScheduleStep(start=100000, value=128), 18 | ] 19 | return BatchSchedule(schedule) 20 | 21 | 22 | @pytest.mark.parametrize( 23 | "step, expected_bs, expected_offset, expected_indices", 24 | [ 25 | (0, 32, 0, (0, 32)), 26 | (500, 32, 500 * 32, (500 * 32, 500 * 32 + 32)), 27 | (999, 32, 999 * 32, (999 * 32, 999 * 32 + 32)), 28 | (1000, 64, 1000 * 32, (1000 * 32, 1000 * 32 + 64)), 29 | (50000, 64, 32000 + (50000 - 1000) * 64, (32000 + (50000 - 1000) * 64, 32000 + (50000 - 1000) * 64 + 64)), 30 | ( 31 | 100000, 32 | 128, 33 | 32000 + (100000 - 1000) * 64, 34 | (32000 + (100000 - 1000) * 64, 32000 + (100000 - 1000) * 64 + 128), 35 | ), 36 | ( 37 | 150000, 38 | 128, 39 | 32000 + (100000 - 1000) * 64 + (150000 - 100000) * 128, 40 | ( 41 | 32000 + (100000 - 1000) * 64 + (150000 - 100000) * 128, 42 | 32000 + (100000 - 1000) * 64 + (150000 - 100000) * 128 + 128, 43 | ), 44 | ), 45 | ], 46 | ) 47 | def test_batch_scheduler(scheduler, step, expected_bs, expected_offset, expected_indices): 48 | """ 49 | Parametric test to ensure the batch scheduler returns the correct 50 | batch size, data offset, and batch indices for given training steps. 51 | """ 52 | bs = scheduler.batch_size_at_step(step) 53 | offset = scheduler.global_data_offset_by_step(step) 54 | # indices = scheduler.batch_indices_at_step(step) 55 | 56 | assert bs == expected_bs, f"Unexpected batch size at step {step}" 57 | assert offset == expected_offset, f"Unexpected data offset at step {step}" 58 | # assert indices == expected_indices, f"Unexpected batch indices at step {step}" 59 | -------------------------------------------------------------------------------- /tests/test_sophia.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | 4 | import equinox as eqx 5 | import equinox.nn as nn 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | from chex import assert_trees_all_close 10 | 11 | import levanter 12 | import levanter.optim.sophia 13 | 14 | 15 | def test_sophia_h(): 16 | key = jax.random.PRNGKey(0) 17 | model = nn.Linear(4, 4, use_bias=False, key=key) 18 | data = np.load(f"{os.path.dirname(__file__)}/data/hero_data.npy").astype("float32") 19 | optimizer = levanter.optim.sophia.sophia_h( 20 | lr=1, 21 | b1=0, 22 | b2=0.99, 23 | gamma=2, 24 | weight_decay=0.0, 25 | clip_threshold=1, 26 | key=key, 27 | update_interval=1, 28 | ) 29 | model = jax.tree_util.tree_map(lambda x: jnp.ones_like(x), model) 30 | zero_grad = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), model) 31 | 32 | opt_state = optimizer.init(model) 33 | 34 | def loss_fn(model, data): 35 | out = eqx.filter_vmap(model)(data) 36 | return jnp.mean(out**2) * 4 37 | 38 | jit_update = eqx.filter_jit(optimizer.update) 39 | 40 | obj_fn = functools.partial(loss_fn, data=data) 41 | for i in range(1000): 42 | _, opt_state = jit_update(zero_grad, opt_state, params=model, obj_fn=obj_fn) 43 | 44 | # print('Test-estimated hessian: most coordinates should be approximately 2') 45 | # print('Estimated hessian:', opt_state[0].h.weight) 46 | assert_trees_all_close(opt_state[0].h.weight, 2, rtol=0.2, atol=0.3) # this is very approximate 47 | 48 | grad_loss_fn = eqx.filter_jit(eqx.filter_value_and_grad(loss_fn)) 49 | 50 | loss, grad = grad_loss_fn(model, data) 51 | model_updates, opt_state = optimizer.update(grad, opt_state, params=model, obj_fn=obj_fn) 52 | model = eqx.apply_updates(model, model_updates) 53 | 54 | assert_trees_all_close(loss, 15.74834156036377, rtol=1e-3, atol=1e-3) 55 | 56 | # print("Test-model param after 1 step: most coordinates should be very loosely 0.5") 57 | assert_trees_all_close(model.weight, 0.5, rtol=0.2, atol=0.1) # this is very approximate 58 | 59 | # print("Test-loss: loss should shrink by approximately 75% after each iteration") 60 | for i in range(10): 61 | loss, grad = grad_loss_fn(model, data) 62 | model_updates, opt_state = optimizer.update(grad, opt_state, params=model, obj_fn=obj_fn) 63 | model = eqx.apply_updates(model, model_updates) 64 | 65 | # print('Step:', i , "Loss:", loss.item()) 66 | assert loss < 15.74834156036377 * 0.75 ** (i + 1) 67 | -------------------------------------------------------------------------------- /tests/test_supervised.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from transformers import AutoTokenizer 3 | 4 | import haliax 5 | from haliax import Axis 6 | 7 | from levanter.data.text import SupervisedProcessor, _prepare_supervised_examples 8 | 9 | 10 | def test_supervised_eval(): 11 | examples = [ 12 | { 13 | "input": "Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.\nA. 0\nB. 1\nC. 2\nD. 3\nAnswer:", 14 | "output": "B", 15 | } 16 | ] 17 | tokenizer = AutoTokenizer.from_pretrained("gpt2") 18 | 19 | if tokenizer.pad_token is None: 20 | tokenizer.pad_token = tokenizer.eos_token 21 | 22 | # output = _preprocess_supervised_example(examples, tokenizer, "input", "output") 23 | processor = SupervisedProcessor(tokenizer, "input", "output") 24 | output = processor(examples) 25 | assert len(output["input_ids"][0]) == output["sources_len"][0] + 1 26 | 27 | ex = { 28 | "input_ids": np.array( 29 | [ 30 | 16742, 31 | 477, 32 | 269, 33 | 287, 34 | 1168, 35 | 62, 36 | 18, 37 | 884, 38 | 326, 39 | 1168, 40 | 62, 41 | 18, 42 | 58, 43 | 87, 44 | 60, 45 | 29006, 46 | 87, 47 | 61, 48 | 17, 49 | 1343, 50 | 269, 51 | 8, 52 | 318, 53 | 257, 54 | 2214, 55 | 13, 56 | 198, 57 | 32, 58 | 13, 59 | 657, 60 | 198, 61 | 33, 62 | 13, 63 | 352, 64 | 198, 65 | 34, 66 | 13, 67 | 362, 68 | 198, 69 | 35, 70 | 13, 71 | 513, 72 | 198, 73 | 33706, 74 | 25, 75 | 33, 76 | ], 77 | dtype=np.int32, 78 | ), 79 | "sources_len": np.array(45, dtype=np.int32), 80 | } 81 | 82 | lm_ex = _prepare_supervised_examples([ex], tokenizer, Axis("position", 128))[0] 83 | 84 | assert lm_ex.loss_mask["position", 44] 85 | assert haliax.sum(lm_ex.loss_mask) == 1 86 | -------------------------------------------------------------------------------- /tests/test_tensorboard.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import jax.numpy as jnp 4 | import numpy as np 5 | from tensorboardX import SummaryWriter 6 | 7 | from levanter.tracker.histogram import Histogram 8 | from levanter.tracker.tensorboard import TensorboardTracker 9 | 10 | 11 | def test_log_summary(): 12 | with tempfile.TemporaryDirectory() as tmpdir: 13 | with SummaryWriter(logdir=tmpdir) as writer: 14 | tracker = TensorboardTracker(writer) 15 | tracker.log_summary({"float": 2.0}) 16 | tracker.log_summary({"str": "test"}) 17 | tracker.log_summary({"scalar_jax_array": jnp.array(3.0)}) 18 | tracker.log_summary({"scalar_np_array": np.array(3.0)}) 19 | 20 | 21 | def test_log(): 22 | with tempfile.TemporaryDirectory() as tmpdir: 23 | with SummaryWriter(logdir=tmpdir) as writer: 24 | tracker = TensorboardTracker(writer) 25 | tracker.log({"float": 2.0}, step=0) 26 | tracker.log({"str": "test"}, step=0) 27 | tracker.log({"scalar_jax_array": jnp.array(3.0)}, step=0) 28 | tracker.log({"scalar_np_array": np.array(3.0)}, step=0) 29 | tracker.log({"histogram": Histogram.from_array(jnp.array([1.0, 2.0, 3.0]))}, step=0) 30 | -------------------------------------------------------------------------------- /tests/test_torch_serialization.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-crfm/levanter/6e2d938a369f248c48c19b06971a844b8153926f/tests/test_torch_serialization.py -------------------------------------------------------------------------------- /tests/test_tracker.py: -------------------------------------------------------------------------------- 1 | # NOTE: Do not explicitly import wandb/other trackers here, as this will cause the tests to trivially pass. 2 | import dataclasses 3 | from typing import Tuple 4 | 5 | import pytest 6 | import yaml 7 | 8 | import levanter.tracker 9 | from levanter.tracker import CompositeTracker, TrackerConfig 10 | 11 | 12 | def test_tracker_plugin_stuff_works(): 13 | assert TrackerConfig.get_choice_class("wandb") is not None 14 | with pytest.raises(KeyError): 15 | TrackerConfig.get_choice_class("foo") 16 | 17 | 18 | def test_tracker_plugin_default_works(): 19 | config = """ 20 | tracker: 21 | entity: foo 22 | """ 23 | parsed = yaml.safe_load(config) 24 | 25 | @dataclasses.dataclass 26 | class ConfigHolder: 27 | tracker: TrackerConfig 28 | 29 | import draccus 30 | 31 | tconfig = draccus.decode(ConfigHolder, parsed).tracker 32 | 33 | assert isinstance(tconfig, TrackerConfig.get_choice_class("wandb")) 34 | 35 | assert tconfig.entity == "foo" # type: ignore 36 | 37 | 38 | def test_tracker_plugin_multi_parsing_work(): 39 | config = """ 40 | tracker: 41 | type: noop 42 | """ 43 | parsed = yaml.safe_load(config) 44 | 45 | @dataclasses.dataclass 46 | class ConfigHolder: 47 | tracker: TrackerConfig | Tuple[TrackerConfig, ...] 48 | 49 | import draccus 50 | 51 | from levanter.tracker.tracker import NoopConfig 52 | 53 | assert isinstance(draccus.decode(ConfigHolder, parsed).tracker, NoopConfig) 54 | 55 | config = """ 56 | tracker: 57 | - type: noop 58 | - type: wandb 59 | """ 60 | parsed = yaml.safe_load(config) 61 | decoded = draccus.decode(ConfigHolder, parsed).tracker 62 | assert decoded == (NoopConfig(), TrackerConfig.get_choice_class("wandb")()) 63 | 64 | 65 | def test_get_tracker_by_name(): 66 | wandb_config = TrackerConfig.get_choice_class("wandb") 67 | if wandb_config is None: 68 | pytest.skip("wandb not installed") 69 | 70 | from levanter.tracker import NoopTracker 71 | 72 | wandb1 = wandb_config(mode="disabled").init(None) 73 | tracker = CompositeTracker([wandb1, NoopTracker()]) 74 | 75 | with tracker: 76 | assert levanter.tracker.get_tracker("wandb") is wandb1 77 | assert levanter.tracker.get_tracker("noop") is not None 78 | 79 | with pytest.raises(KeyError): 80 | levanter.tracker.get_tracker("foo") 81 | -------------------------------------------------------------------------------- /tests/test_train_asr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import jax 5 | import pytest 6 | 7 | import levanter.main.train_asr as train_asr 8 | import tiny_test_corpus 9 | from levanter.distributed import RayConfig 10 | from levanter.tracker.wandb import WandbConfig 11 | from test_utils import skip_if_no_soundlibs 12 | 13 | 14 | @pytest.mark.skip 15 | @pytest.mark.entry 16 | @skip_if_no_soundlibs 17 | def test_train_asr(): 18 | # just testing if train_asr has a pulse 19 | with tempfile.TemporaryDirectory() as tmpdir: 20 | data_config = tiny_test_corpus.tiny_asr_corpus_config(tmpdir) 21 | try: 22 | config = train_asr.TrainASRConfig( 23 | data=data_config, 24 | model=train_asr.WhisperConfig( 25 | d_model=32, 26 | ), 27 | trainer=train_asr.TrainerConfig( 28 | num_train_steps=2, 29 | train_batch_size=len(jax.devices()), 30 | max_eval_batches=1, 31 | wandb=WandbConfig(mode="disabled"), 32 | require_accelerator=False, 33 | ray=RayConfig(auto_start_cluster=False), 34 | ), 35 | hf_save_path=f"{tmpdir}/hf_asr_output", 36 | ) 37 | train_asr.main(config) 38 | finally: 39 | try: 40 | os.unlink("wandb") 41 | except Exception: 42 | pass 43 | -------------------------------------------------------------------------------- /tests/test_train_lm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import jax 5 | import pytest 6 | 7 | from haliax.quantization import QuantizationConfig 8 | 9 | import levanter.main.train_lm as train_lm 10 | import tiny_test_corpus 11 | from levanter.distributed import RayConfig 12 | from levanter.tracker.wandb import WandbConfig 13 | 14 | 15 | @pytest.mark.entry 16 | def test_train_lm(): 17 | # just testing if train_lm has a pulse 18 | with tempfile.TemporaryDirectory() as tmpdir: 19 | data_config, _ = tiny_test_corpus.construct_small_data_cache(tmpdir) 20 | try: 21 | config = train_lm.TrainLmConfig( 22 | data=data_config, 23 | model=train_lm.Gpt2Config( 24 | num_layers=2, 25 | num_heads=2, 26 | seq_len=64, 27 | hidden_dim=32, 28 | attn_backend=None, # use default for platform 29 | ), 30 | trainer=train_lm.TrainerConfig( 31 | num_train_steps=2, 32 | train_batch_size=len(jax.devices()), 33 | max_eval_batches=1, 34 | wandb=WandbConfig(mode="disabled"), 35 | require_accelerator=False, 36 | ray=RayConfig(auto_start_cluster=False), 37 | ), 38 | ) 39 | train_lm.main(config) 40 | finally: 41 | try: 42 | os.unlink("wandb") 43 | except Exception: 44 | pass 45 | 46 | 47 | @pytest.mark.entry 48 | def test_train_lm_fp8(): 49 | # just testing if train_lm has a pulse 50 | with tempfile.TemporaryDirectory() as tmpdir: 51 | data_config, _ = tiny_test_corpus.construct_small_data_cache(tmpdir) 52 | try: 53 | config = train_lm.TrainLmConfig( 54 | data=data_config, 55 | model=train_lm.Gpt2Config( 56 | num_layers=2, 57 | num_heads=2, 58 | seq_len=64, 59 | hidden_dim=32, 60 | attn_backend=None, # use default for platform 61 | ), 62 | trainer=train_lm.TrainerConfig( 63 | quantization=QuantizationConfig(fp8=True), 64 | num_train_steps=2, 65 | train_batch_size=len(jax.devices()), 66 | max_eval_batches=1, 67 | wandb=WandbConfig(mode="disabled"), 68 | require_accelerator=False, 69 | ray=RayConfig(auto_start_cluster=False), 70 | ), 71 | ) 72 | train_lm.main(config) 73 | finally: 74 | try: 75 | os.unlink("wandb") 76 | except Exception: 77 | pass 78 | -------------------------------------------------------------------------------- /tests/test_viz_lm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import jax 5 | import pytest 6 | 7 | import haliax 8 | 9 | import levanter.main.viz_logprobs as viz_logprobs 10 | import tiny_test_corpus 11 | from levanter.checkpoint import save_checkpoint 12 | from levanter.distributed import RayConfig 13 | from levanter.models.gpt2 import Gpt2Config, Gpt2LMHeadModel 14 | from levanter.tracker.wandb import WandbConfig 15 | 16 | 17 | @pytest.mark.entry 18 | def test_viz_lm(): 19 | # just testing if eval_lm has a pulse 20 | # save a checkpoint 21 | model_config = Gpt2Config( 22 | num_layers=2, 23 | num_heads=2, 24 | hidden_dim=32, 25 | seq_len=64, 26 | use_flash_attention=True, 27 | ) 28 | 29 | with tempfile.TemporaryDirectory() as f: 30 | try: 31 | data_config, _ = tiny_test_corpus.construct_small_data_cache(f) 32 | tok = data_config.the_tokenizer 33 | Vocab = haliax.Axis("vocab", len(tok)) 34 | model = Gpt2LMHeadModel.init(Vocab, model_config, key=jax.random.PRNGKey(0)) 35 | 36 | save_checkpoint({"model": model}, 0, f"{f}/ckpt") 37 | 38 | config = viz_logprobs.VizLmConfig( 39 | data=data_config, 40 | model=model_config, 41 | trainer=viz_logprobs.TrainerConfig( 42 | per_device_eval_parallelism=len(jax.devices()), 43 | max_eval_batches=1, 44 | wandb=WandbConfig(mode="disabled"), 45 | require_accelerator=False, 46 | ray=RayConfig(auto_start_cluster=False), 47 | ), 48 | checkpoint_path=f"{f}/ckpt", 49 | num_docs=len(jax.devices()), 50 | path=f"{f}/viz", 51 | ) 52 | viz_logprobs.main(config) 53 | finally: 54 | try: 55 | os.unlink("wandb") 56 | except Exception: 57 | pass 58 | -------------------------------------------------------------------------------- /tests/tiny_test_corpus.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy 5 | import numpy as np 6 | 7 | from levanter.data.audio import AudioIODatasetConfig 8 | from levanter.data.text import UrlSingleDatasetLMConfig 9 | from levanter.store.cache import TreeCache 10 | 11 | 12 | def _write_tiny_corpus(path): 13 | os.makedirs(f"{path}/train", exist_ok=True) 14 | with open(f"{path}/train/docs.jsonl", "w") as f: 15 | for i in range(10): 16 | f.write(json.dumps({"text": f"hello world {i} " * 100})) 17 | f.write("\n") 18 | 19 | os.makedirs(f"{path}/validation", exist_ok=True) 20 | with open(f"{path}/validation/docs.jsonl", "w") as f: 21 | for i in range(10): 22 | f.write(json.dumps({"text": f"bye world {i} " * 100})) 23 | f.write("\n") 24 | 25 | 26 | def tiny_corpus_config(path): 27 | _write_tiny_corpus(path) 28 | return UrlSingleDatasetLMConfig( 29 | train_urls=[f"file://{path}/train/docs.jsonl"], 30 | validation_urls=[f"file://{path}/validation/docs.jsonl"], 31 | cache_dir=f"{path}/cache", 32 | ) 33 | 34 | 35 | def tiny_asr_corpus_config(path): 36 | return AudioIODatasetConfig( 37 | id="WillHeld/test_librispeech_parquet", 38 | text_key="text", 39 | train_split="validation", 40 | validation_split="validation", 41 | cache_dir=f"{path}/cache_asr", 42 | ) 43 | 44 | 45 | def construct_small_data_cache( 46 | path, num_shards=8, chunk_size=512, doc_len=128, vocab_size=1024 47 | ) -> tuple[UrlSingleDatasetLMConfig, dict[str, TreeCache]]: 48 | from levanter.store.cache import SerialCacheWriter 49 | 50 | rng = numpy.random.default_rng(0) 51 | 52 | caches: dict[str, TreeCache] = {} 53 | 54 | exemplar = {"input_ids": numpy.zeros((doc_len,), dtype=numpy.int32)} 55 | 56 | for split in ["train", "validation"]: 57 | with SerialCacheWriter(f"{path}/cache/{split}", exemplar) as writer: 58 | for shard in range(num_shards): 59 | writer.write_batch( 60 | [ 61 | {"input_ids": rng.integers(0, vocab_size, size=(doc_len,), dtype=np.int32)} 62 | for _ in range(chunk_size) 63 | ] 64 | ) 65 | caches[split] = writer.result() 66 | 67 | config = UrlSingleDatasetLMConfig( 68 | train_urls=[], 69 | validation_urls=[], 70 | cache_dir=f"{path}/cache", 71 | vocab_size=vocab_size, 72 | tokenizer="gpt2", 73 | ) 74 | 75 | return config, caches 76 | --------------------------------------------------------------------------------